aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-sql-driver/mysql/packets.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/packets.go242
1 files changed, 117 insertions, 125 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go
index f63d250..f99934e 100644
--- a/vendor/github.com/go-sql-driver/mysql/packets.go
+++ b/vendor/github.com/go-sql-driver/mysql/packets.go
@@ -149,24 +149,29 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}
/******************************************************************************
-* Initialisation Process *
+* Initialization Process *
******************************************************************************/
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() ([]byte, error) {
+func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
data, err := mc.readPacket()
if err != nil {
- return nil, err
+ // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
+ // in connection initialization we don't risk retrying non-idempotent actions.
+ if err == ErrInvalidConn {
+ return nil, "", driver.ErrBadConn
+ }
+ return nil, "", err
}
if data[0] == iERR {
- return nil, mc.handleErrorPacket(data)
+ return nil, "", mc.handleErrorPacket(data)
}
// protocol version [1 byte]
if data[0] < minProtocolVersion {
- return nil, fmt.Errorf(
+ return nil, "", fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
data[0],
minProtocolVersion,
@@ -178,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
// first part of the password cipher [8 bytes]
- cipher := data[pos : pos+8]
+ authData := data[pos : pos+8]
// (filler) always 0x00 [1 byte]
pos += 8 + 1
@@ -186,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
- return nil, ErrOldProtocol
+ return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
- return nil, ErrNoTLS
+ return nil, "", ErrNoTLS
}
pos += 2
+ plugin := ""
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
@@ -213,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
//
// The official Python library uses the fixed length 12
// which seems to work but technically could have a hidden bug.
- cipher = append(cipher, data[pos:pos+12]...)
+ authData = append(authData, data[pos:pos+12]...)
+ pos += 13
- // TODO: Verify string termination
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
// \NUL otherwise
- //
- //if data[len(data)-1] == 0 {
- // return
- //}
- //return ErrMalformPkt
+ if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
+ plugin = string(data[pos : pos+end])
+ } else {
+ plugin = string(data[pos:])
+ }
// make a memory safe copy of the cipher slice
var b [20]byte
- copy(b[:], cipher)
- return b[:], nil
+ copy(b[:], authData)
+ return b[:], plugin, nil
}
+ plugin = defaultAuthPlugin
+
// make a memory safe copy of the cipher slice
var b [8]byte
- copy(b[:], cipher)
- return b[:], nil
+ copy(b[:], authData)
+ return b[:], plugin, nil
}
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
@@ -262,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientFlags |= clientMultiStatements
}
- // User Password
- scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
+ // encode length of the auth plugin data
+ var authRespLEIBuf [9]byte
+ authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
+ if len(authRespLEI) > 1 {
+ // if the length can not be written in 1 byte, it must be written as a
+ // length encoded integer
+ clientFlags |= clientPluginAuthLenEncClientData
+ }
- pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
+ pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
+ if addNUL {
+ pktLen++
+ }
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
@@ -276,7 +293,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@@ -333,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
data[pos] = 0x00
pos++
- // ScrambleBuffer [length encoded integer]
- data[pos] = byte(len(scrambleBuff))
- pos += 1 + copy(data[pos+1:], scrambleBuff)
+ // Auth Data [length encoded integer]
+ pos += copy(data[pos:], authRespLEI)
+ pos += copy(data[pos:], authResp)
+ if addNUL {
+ data[pos] = 0x00
+ pos++
+ }
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
@@ -344,76 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
pos++
}
- // Assume native client during response
- pos += copy(data[pos:], "mysql_native_password")
+ pos += copy(data[pos:], plugin)
data[pos] = 0x00
// Send Auth packet
return mc.writePacket(data)
}
-// Client old authentication packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
- // User password
- // https://dev.mysql.com/doc/internals/en/old-password-authentication.html
- // Old password authentication only need and will need 8-byte challenge.
- scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd))
-
- // Calculate the packet length and add a tailing 0
- pktLen := len(scrambleBuff) + 1
- data := mc.buf.takeSmallBuffer(4 + pktLen)
- if data == nil {
- // can not take the buffer. Something must be wrong with the connection
- errLog.Print(ErrBusyBuffer)
- return errBadConnNoWrite
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
+ pktLen := 4 + len(authData)
+ if addNUL {
+ pktLen++
}
-
- // Add the scrambled password [null terminated string]
- copy(data[4:], scrambleBuff)
- data[4+pktLen-1] = 0x00
-
- return mc.writePacket(data)
-}
-
-// Client clear text authentication packet
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeClearAuthPacket() error {
- // Calculate the packet length and add a tailing 0
- pktLen := len(mc.cfg.Passwd) + 1
- data := mc.buf.takeSmallBuffer(4 + pktLen)
+ data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
- // Add the clear password [null terminated string]
- copy(data[4:], mc.cfg.Passwd)
- data[4+pktLen-1] = 0x00
-
- return mc.writePacket(data)
-}
-
-// Native password authentication method
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
- // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
- // Native password authentication only need and will need 20-byte challenge.
- scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd))
-
- // Calculate the packet length and add a tailing 0
- pktLen := len(scrambleBuff)
- data := mc.buf.takeSmallBuffer(4 + pktLen)
- if data == nil {
- // can not take the buffer. Something must be wrong with the connection
- errLog.Print(ErrBusyBuffer)
- return errBadConnNoWrite
+ // Add the auth data [EOF]
+ copy(data[4:], authData)
+ if addNUL {
+ data[pktLen-1] = 0x00
}
- // Add the scramble
- copy(data[4:], scrambleBuff)
-
return mc.writePacket(data)
}
@@ -427,7 +404,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@@ -446,7 +423,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@@ -467,7 +444,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@@ -489,45 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
* Result Packets *
******************************************************************************/
-// Returns error if Packet is not an 'Result OK'-Packet
-func (mc *mysqlConn) readResultOK() ([]byte, error) {
+func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
data, err := mc.readPacket()
- if err == nil {
- // packet indicator
- switch data[0] {
+ if err != nil {
+ return nil, "", err
+ }
- case iOK:
- return nil, mc.handleOkPacket(data)
+ // packet indicator
+ switch data[0] {
- case iEOF:
- if len(data) > 1 {
- pluginEndIndex := bytes.IndexByte(data, 0x00)
- plugin := string(data[1:pluginEndIndex])
- cipher := data[pluginEndIndex+1:]
-
- switch plugin {
- case "mysql_old_password":
- // using old_passwords
- return cipher, ErrOldPassword
- case "mysql_clear_password":
- // using clear text password
- return cipher, ErrCleartextPassword
- case "mysql_native_password":
- // using mysql default authentication method
- return cipher, ErrNativePassword
- default:
- return cipher, ErrUnknownPlugin
- }
- }
+ case iOK:
+ return nil, "", mc.handleOkPacket(data)
- // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
- return nil, ErrOldPassword
+ case iAuthMoreData:
+ return data[1:], "", err
- default: // Error otherwise
- return nil, mc.handleErrorPacket(data)
+ case iEOF:
+ if len(data) < 1 {
+ // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+ return nil, "mysql_old_password", nil
}
+ pluginEndIndex := bytes.IndexByte(data, 0x00)
+ if pluginEndIndex < 0 {
+ return nil, "", ErrMalformPkt
+ }
+ plugin := string(data[1:pluginEndIndex])
+ authData := data[pluginEndIndex+1:]
+ return authData, plugin, nil
+
+ default: // Error otherwise
+ return nil, "", mc.handleErrorPacket(data)
}
- return nil, err
+}
+
+// Returns error if Packet is not an 'Result OK'-Packet
+func (mc *mysqlConn) readResultOK() error {
+ data, err := mc.readPacket()
+ if err != nil {
+ return err
+ }
+
+ if data[0] == iOK {
+ return mc.handleOkPacket(data)
+ }
+ return mc.handleErrorPacket(data)
}
// Result Set Header Packet
@@ -697,10 +679,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
if err != nil {
return nil, err
}
+ pos += n
// Filler [uint8]
+ pos++
+
// Charset [charset, collation uint8]
- pos += n + 1 + 2
+ columns[i].charSet = data[pos]
+ pos += 2
// Length [uint32]
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
@@ -857,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
// 2 bytes paramID
const dataOffset = 1 + 4 + 2
- // Can not use the write buffer since
+ // Cannot use the write buffer since
// a) the buffer is too small
// b) it is in use
data := make([]byte, 4+1+4+2+len(arg))
@@ -912,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
+ // Determine threshould dynamically to avoid packet size shortage.
+ longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
+ if longDataSize < 64 {
+ longDataSize = 64
+ }
+
// Reset packet-sequence
mc.sequence = 0
@@ -923,7 +915,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
data = mc.buf.takeCompleteBuffer()
}
if data == nil {
- // can not take the buffer. Something must be wrong with the connection
+ // cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@@ -1039,7 +1031,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
- if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
+ if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
@@ -1061,7 +1053,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
paramTypes[i+i] = byte(fieldTypeString)
paramTypes[i+i+1] = 0x00
- if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
+ if len(v) < longDataSize {
paramValues = appendLengthEncodedInteger(paramValues,
uint64(len(v)),
)
@@ -1091,7 +1083,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
paramValues = append(paramValues, b...)
default:
- return fmt.Errorf("can not convert type: %T", arg)
+ return fmt.Errorf("cannot convert type: %T", arg)
}
}
@@ -1269,7 +1261,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
rows.rs.columns[i].decimals,
)
}
- dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
+ dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
@@ -1289,7 +1281,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
)
}
}
- dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
+ dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
}
if err == nil {