From 73ef85bc5db590c22689e11be20737a3dd88168f Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Wed, 28 Dec 2016 21:18:36 +0000 Subject: Update dependencies --- vendor/github.com/go-sql-driver/mysql/packets.go | 111 ++++++++++++++++------- 1 file changed, 76 insertions(+), 35 deletions(-) (limited to 'vendor/github.com/go-sql-driver/mysql/packets.go') diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index f06752b..aafe979 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -25,9 +25,9 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { - var payload []byte + var prevData []byte for { - // Read packet header + // read packet header data, err := mc.buf.readNext(4) if err != nil { errLog.Print(err) @@ -35,16 +35,10 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { return nil, driver.ErrBadConn } - // Packet Length [24 bit] + // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if pktLen < 1 { - errLog.Print(ErrMalformPkt) - mc.Close() - return nil, driver.ErrBadConn - } - - // Check Packet Sync [8 bit] + // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul @@ -53,7 +47,20 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } mc.sequence++ - // Read packet body [pktLen bytes] + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, driver.ErrBadConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { errLog.Print(err) @@ -61,18 +68,17 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { return nil, driver.ErrBadConn } - isLastPacket := (pktLen < maxPacketSize) + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - // Zero allocations for non-splitting packets - if isLastPacket && payload == nil { - return data, nil + return append(prevData, data...), nil } - payload = append(payload, data...) - - if isLastPacket { - return payload, nil - } + prevData = append(prevData, data...) } } @@ -372,6 +378,26 @@ func (mc *mysqlConn) writeClearAuthPacket() error { return mc.writePacket(data) } +// Native password authentication method +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { + scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + + // Calculate the packet length and add a tailing 0 + pktLen := len(scrambleBuff) + data := mc.buf.takeSmallBuffer(4 + pktLen) + if data == nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(ErrBusyBuffer) + return driver.ErrBadConn + } + + // Add the scramble + copy(data[4:], scrambleBuff) + + return mc.writePacket(data) +} + /****************************************************************************** * Command Packets * ******************************************************************************/ @@ -445,36 +471,43 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { ******************************************************************************/ // Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() error { +func (mc *mysqlConn) readResultOK() ([]byte, error) { data, err := mc.readPacket() if err == nil { // packet indicator switch data[0] { case iOK: - return mc.handleOkPacket(data) + return nil, mc.handleOkPacket(data) case iEOF: if len(data) > 1 { - plugin := string(data[1:bytes.IndexByte(data, 0x00)]) + pluginEndIndex := bytes.IndexByte(data, 0x00) + plugin := string(data[1:pluginEndIndex]) + cipher := data[pluginEndIndex+1 : len(data)-1] + if plugin == "mysql_old_password" { // using old_passwords - return ErrOldPassword + return cipher, ErrOldPassword } else if plugin == "mysql_clear_password" { // using clear text password - return ErrCleartextPassword + return cipher, ErrCleartextPassword + } else if plugin == "mysql_native_password" { + // using mysql default authentication method + return cipher, ErrNativePassword } else { - return ErrUnknownPlugin + return cipher, ErrUnknownPlugin } } else { - return ErrOldPassword + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, ErrOldPassword } default: // Error otherwise - return mc.handleErrorPacket(data) + return nil, mc.handleErrorPacket(data) } } - return err + return nil, err } // Result Set Header Packet @@ -674,11 +707,15 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() } rows.mc = nil - return io.EOF + return err } if data[0] == iERR { rows.mc = nil @@ -1079,11 +1116,15 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + err = rows.mc.discardResults() + if err == nil { + err = io.EOF + } else { + // connection unusable + rows.mc.Close() } rows.mc = nil - return io.EOF + return err } rows.mc = nil -- cgit v1.2.3