diff options
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
| -rw-r--r-- | vendor/github.com/go-sql-driver/mysql/packets.go | 242 | 
1 files changed, 117 insertions, 125 deletions
| diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index f63d250..f99934e 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -149,24 +149,29 @@ func (mc *mysqlConn) writePacket(data []byte) error {  }  /****************************************************************************** -*                           Initialisation Process                            * +*                           Initialization Process                            *  ******************************************************************************/  // Handshake Initialization Packet  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { +func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {  	data, err := mc.readPacket()  	if err != nil { -		return nil, err +		// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since +		// in connection initialization we don't risk retrying non-idempotent actions. +		if err == ErrInvalidConn { +			return nil, "", driver.ErrBadConn +		} +		return nil, "", err  	}  	if data[0] == iERR { -		return nil, mc.handleErrorPacket(data) +		return nil, "", mc.handleErrorPacket(data)  	}  	// protocol version [1 byte]  	if data[0] < minProtocolVersion { -		return nil, fmt.Errorf( +		return nil, "", fmt.Errorf(  			"unsupported protocol version %d. Version %d or higher is required",  			data[0],  			minProtocolVersion, @@ -178,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {  	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4  	// first part of the password cipher [8 bytes] -	cipher := data[pos : pos+8] +	authData := data[pos : pos+8]  	// (filler) always 0x00 [1 byte]  	pos += 8 + 1 @@ -186,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {  	// capability flags (lower 2 bytes) [2 bytes]  	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))  	if mc.flags&clientProtocol41 == 0 { -		return nil, ErrOldProtocol +		return nil, "", ErrOldProtocol  	}  	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { -		return nil, ErrNoTLS +		return nil, "", ErrNoTLS  	}  	pos += 2 +	plugin := ""  	if len(data) > pos {  		// character set [1 byte]  		// status flags [2 bytes] @@ -213,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {  		//  		// The official Python library uses the fixed length 12  		// which seems to work but technically could have a hidden bug. -		cipher = append(cipher, data[pos:pos+12]...) +		authData = append(authData, data[pos:pos+12]...) +		pos += 13 -		// TODO: Verify string termination  		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)  		// \NUL otherwise -		// -		//if data[len(data)-1] == 0 { -		//	return -		//} -		//return ErrMalformPkt +		if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { +			plugin = string(data[pos : pos+end]) +		} else { +			plugin = string(data[pos:]) +		}  		// make a memory safe copy of the cipher slice  		var b [20]byte -		copy(b[:], cipher) -		return b[:], nil +		copy(b[:], authData) +		return b[:], plugin, nil  	} +	plugin = defaultAuthPlugin +  	// make a memory safe copy of the cipher slice  	var b [8]byte -	copy(b[:], cipher) -	return b[:], nil +	copy(b[:], authData) +	return b[:], plugin, nil  }  // Client Authentication Packet  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {  	// Adjust client flags based on server support  	clientFlags := clientProtocol41 |  		clientSecureConn | @@ -262,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {  		clientFlags |= clientMultiStatements  	} -	// User Password -	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) +	// encode length of the auth plugin data +	var authRespLEIBuf [9]byte +	authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) +	if len(authRespLEI) > 1 { +		// if the length can not be written in 1 byte, it must be written as a +		// length encoded integer +		clientFlags |= clientPluginAuthLenEncClientData +	} -	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 +	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 +	if addNUL { +		pktLen++ +	}  	// To specify a db name  	if n := len(mc.cfg.DBName); n > 0 { @@ -276,7 +293,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {  	// Calculate packet length and get buffer with that size  	data := mc.buf.takeSmallBuffer(pktLen + 4)  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} @@ -333,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {  	data[pos] = 0x00  	pos++ -	// ScrambleBuffer [length encoded integer] -	data[pos] = byte(len(scrambleBuff)) -	pos += 1 + copy(data[pos+1:], scrambleBuff) +	// Auth Data [length encoded integer] +	pos += copy(data[pos:], authRespLEI) +	pos += copy(data[pos:], authResp) +	if addNUL { +		data[pos] = 0x00 +		pos++ +	}  	// Databasename [null terminated string]  	if len(mc.cfg.DBName) > 0 { @@ -344,76 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {  		pos++  	} -	// Assume native client during response -	pos += copy(data[pos:], "mysql_native_password") +	pos += copy(data[pos:], plugin)  	data[pos] = 0x00  	// Send Auth packet  	return mc.writePacket(data)  } -//  Client old authentication packet  // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { -	// User password -	// https://dev.mysql.com/doc/internals/en/old-password-authentication.html -	// Old password authentication only need and will need 8-byte challenge. -	scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd)) - -	// Calculate the packet length and add a tailing 0 -	pktLen := len(scrambleBuff) + 1 -	data := mc.buf.takeSmallBuffer(4 + pktLen) -	if data == nil { -		// can not take the buffer. Something must be wrong with the connection -		errLog.Print(ErrBusyBuffer) -		return errBadConnNoWrite +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +	pktLen := 4 + len(authData) +	if addNUL { +		pktLen++  	} - -	// Add the scrambled password [null terminated string] -	copy(data[4:], scrambleBuff) -	data[4+pktLen-1] = 0x00 - -	return mc.writePacket(data) -} - -//  Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { -	// Calculate the packet length and add a tailing 0 -	pktLen := len(mc.cfg.Passwd) + 1 -	data := mc.buf.takeSmallBuffer(4 + pktLen) +	data := mc.buf.takeSmallBuffer(pktLen)  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} -	// Add the clear password [null terminated string] -	copy(data[4:], mc.cfg.Passwd) -	data[4+pktLen-1] = 0x00 - -	return mc.writePacket(data) -} - -//  Native password authentication method -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { -	// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html -	// Native password authentication only need and will need 20-byte challenge. -	scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd)) - -	// Calculate the packet length and add a tailing 0 -	pktLen := len(scrambleBuff) -	data := mc.buf.takeSmallBuffer(4 + pktLen) -	if data == nil { -		// can not take the buffer. Something must be wrong with the connection -		errLog.Print(ErrBusyBuffer) -		return errBadConnNoWrite +	// Add the auth data [EOF] +	copy(data[4:], authData) +	if addNUL { +		data[pktLen-1] = 0x00  	} -	// Add the scramble -	copy(data[4:], scrambleBuff) -  	return mc.writePacket(data)  } @@ -427,7 +404,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {  	data := mc.buf.takeSmallBuffer(4 + 1)  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} @@ -446,7 +423,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {  	pktLen := 1 + len(arg)  	data := mc.buf.takeBuffer(pktLen + 4)  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} @@ -467,7 +444,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {  	data := mc.buf.takeSmallBuffer(4 + 1 + 4)  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} @@ -489,45 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {  *                              Result Packets                                 *  ******************************************************************************/ -// Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() ([]byte, error) { +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {  	data, err := mc.readPacket() -	if err == nil { -		// packet indicator -		switch data[0] { +	if err != nil { +		return nil, "", err +	} -		case iOK: -			return nil, mc.handleOkPacket(data) +	// packet indicator +	switch data[0] { -		case iEOF: -			if len(data) > 1 { -				pluginEndIndex := bytes.IndexByte(data, 0x00) -				plugin := string(data[1:pluginEndIndex]) -				cipher := data[pluginEndIndex+1:] - -				switch plugin { -				case "mysql_old_password": -					// using old_passwords -					return cipher, ErrOldPassword -				case "mysql_clear_password": -					// using clear text password -					return cipher, ErrCleartextPassword -				case "mysql_native_password": -					// using mysql default authentication method -					return cipher, ErrNativePassword -				default: -					return cipher, ErrUnknownPlugin -				} -			} +	case iOK: +		return nil, "", mc.handleOkPacket(data) -			// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest -			return nil, ErrOldPassword +	case iAuthMoreData: +		return data[1:], "", err -		default: // Error otherwise -			return nil, mc.handleErrorPacket(data) +	case iEOF: +		if len(data) < 1 { +			// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest +			return nil, "mysql_old_password", nil  		} +		pluginEndIndex := bytes.IndexByte(data, 0x00) +		if pluginEndIndex < 0 { +			return nil, "", ErrMalformPkt +		} +		plugin := string(data[1:pluginEndIndex]) +		authData := data[pluginEndIndex+1:] +		return authData, plugin, nil + +	default: // Error otherwise +		return nil, "", mc.handleErrorPacket(data)  	} -	return nil, err +} + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { +	data, err := mc.readPacket() +	if err != nil { +		return err +	} + +	if data[0] == iOK { +		return mc.handleOkPacket(data) +	} +	return mc.handleErrorPacket(data)  }  // Result Set Header Packet @@ -697,10 +679,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {  		if err != nil {  			return nil, err  		} +		pos += n  		// Filler [uint8] +		pos++ +  		// Charset [charset, collation uint8] -		pos += n + 1 + 2 +		columns[i].charSet = data[pos] +		pos += 2  		// Length [uint32]  		columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) @@ -857,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {  	// 2 bytes paramID  	const dataOffset = 1 + 4 + 2 -	// Can not use the write buffer since +	// Cannot use the write buffer since  	// a) the buffer is too small  	// b) it is in use  	data := make([]byte, 4+1+4+2+len(arg)) @@ -912,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {  	const minPktLen = 4 + 1 + 4 + 1 + 4  	mc := stmt.mc +	// Determine threshould dynamically to avoid packet size shortage. +	longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) +	if longDataSize < 64 { +		longDataSize = 64 +	} +  	// Reset packet-sequence  	mc.sequence = 0 @@ -923,7 +915,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {  		data = mc.buf.takeCompleteBuffer()  	}  	if data == nil { -		// can not take the buffer. Something must be wrong with the connection +		// cannot take the buffer. Something must be wrong with the connection  		errLog.Print(ErrBusyBuffer)  		return errBadConnNoWrite  	} @@ -1039,7 +1031,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {  					paramTypes[i+i] = byte(fieldTypeString)  					paramTypes[i+i+1] = 0x00 -					if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { +					if len(v) < longDataSize {  						paramValues = appendLengthEncodedInteger(paramValues,  							uint64(len(v)),  						) @@ -1061,7 +1053,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {  				paramTypes[i+i] = byte(fieldTypeString)  				paramTypes[i+i+1] = 0x00 -				if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { +				if len(v) < longDataSize {  					paramValues = appendLengthEncodedInteger(paramValues,  						uint64(len(v)),  					) @@ -1091,7 +1083,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {  				paramValues = append(paramValues, b...)  			default: -				return fmt.Errorf("can not convert type: %T", arg) +				return fmt.Errorf("cannot convert type: %T", arg)  			}  		} @@ -1269,7 +1261,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {  						rows.rs.columns[i].decimals,  					)  				} -				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) +				dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)  			case rows.mc.parseTime:  				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)  			default: @@ -1289,7 +1281,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {  						)  					}  				} -				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) +				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)  			}  			if err == nil { | 
