aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-sql-driver/mysql/packets.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/packets.go111
1 files changed, 76 insertions, 35 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go
index f06752b..aafe979 100644
--- a/vendor/github.com/go-sql-driver/mysql/packets.go
+++ b/vendor/github.com/go-sql-driver/mysql/packets.go
@@ -25,9 +25,9 @@ import (
// Read packet to buffer 'data'
func (mc *mysqlConn) readPacket() ([]byte, error) {
- var payload []byte
+ var prevData []byte
for {
- // Read packet header
+ // read packet header
data, err := mc.buf.readNext(4)
if err != nil {
errLog.Print(err)
@@ -35,16 +35,10 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
return nil, driver.ErrBadConn
}
- // Packet Length [24 bit]
+ // packet length [24 bit]
pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
- if pktLen < 1 {
- errLog.Print(ErrMalformPkt)
- mc.Close()
- return nil, driver.ErrBadConn
- }
-
- // Check Packet Sync [8 bit]
+ // check packet sync [8 bit]
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
@@ -53,7 +47,20 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
mc.sequence++
- // Read packet body [pktLen bytes]
+ // packets with length 0 terminate a previous packet which is a
+ // multiple of (2^24)−1 bytes long
+ if pktLen == 0 {
+ // there was no previous packet
+ if prevData == nil {
+ errLog.Print(ErrMalformPkt)
+ mc.Close()
+ return nil, driver.ErrBadConn
+ }
+
+ return prevData, nil
+ }
+
+ // read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
errLog.Print(err)
@@ -61,18 +68,17 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
return nil, driver.ErrBadConn
}
- isLastPacket := (pktLen < maxPacketSize)
+ // return data if this was the last packet
+ if pktLen < maxPacketSize {
+ // zero allocations for non-split packets
+ if prevData == nil {
+ return data, nil
+ }
- // Zero allocations for non-splitting packets
- if isLastPacket && payload == nil {
- return data, nil
+ return append(prevData, data...), nil
}
- payload = append(payload, data...)
-
- if isLastPacket {
- return payload, nil
- }
+ prevData = append(prevData, data...)
}
}
@@ -372,6 +378,26 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
return mc.writePacket(data)
}
+// Native password authentication method
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
+func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
+ scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
+
+ // Calculate the packet length and add a tailing 0
+ pktLen := len(scrambleBuff)
+ data := mc.buf.takeSmallBuffer(4 + pktLen)
+ if data == nil {
+ // can not take the buffer. Something must be wrong with the connection
+ errLog.Print(ErrBusyBuffer)
+ return driver.ErrBadConn
+ }
+
+ // Add the scramble
+ copy(data[4:], scrambleBuff)
+
+ return mc.writePacket(data)
+}
+
/******************************************************************************
* Command Packets *
******************************************************************************/
@@ -445,36 +471,43 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
******************************************************************************/
// Returns error if Packet is not an 'Result OK'-Packet
-func (mc *mysqlConn) readResultOK() error {
+func (mc *mysqlConn) readResultOK() ([]byte, error) {
data, err := mc.readPacket()
if err == nil {
// packet indicator
switch data[0] {
case iOK:
- return mc.handleOkPacket(data)
+ return nil, mc.handleOkPacket(data)
case iEOF:
if len(data) > 1 {
- plugin := string(data[1:bytes.IndexByte(data, 0x00)])
+ pluginEndIndex := bytes.IndexByte(data, 0x00)
+ plugin := string(data[1:pluginEndIndex])
+ cipher := data[pluginEndIndex+1 : len(data)-1]
+
if plugin == "mysql_old_password" {
// using old_passwords
- return ErrOldPassword
+ return cipher, ErrOldPassword
} else if plugin == "mysql_clear_password" {
// using clear text password
- return ErrCleartextPassword
+ return cipher, ErrCleartextPassword
+ } else if plugin == "mysql_native_password" {
+ // using mysql default authentication method
+ return cipher, ErrNativePassword
} else {
- return ErrUnknownPlugin
+ return cipher, ErrUnknownPlugin
}
} else {
- return ErrOldPassword
+ // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+ return nil, ErrOldPassword
}
default: // Error otherwise
- return mc.handleErrorPacket(data)
+ return nil, mc.handleErrorPacket(data)
}
}
- return err
+ return nil, err
}
// Result Set Header Packet
@@ -674,11 +707,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
- if err := rows.mc.discardResults(); err != nil {
- return err
+ err = rows.mc.discardResults()
+ if err == nil {
+ err = io.EOF
+ } else {
+ // connection unusable
+ rows.mc.Close()
}
rows.mc = nil
- return io.EOF
+ return err
}
if data[0] == iERR {
rows.mc = nil
@@ -1079,11 +1116,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
- if err := rows.mc.discardResults(); err != nil {
- return err
+ err = rows.mc.discardResults()
+ if err == nil {
+ err = io.EOF
+ } else {
+ // connection unusable
+ rows.mc.Close()
}
rows.mc = nil
- return io.EOF
+ return err
}
rows.mc = nil