From 7b320119ba532fd409ec7dade7ad02011c309599 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Wed, 18 Oct 2017 13:15:14 +0100 Subject: Update dependencies --- .../github.com/go-sql-driver/mysql/connection.go | 114 +++++++++++++++++---- 1 file changed, 92 insertions(+), 22 deletions(-) (limited to 'vendor/github.com/go-sql-driver/mysql/connection.go') diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index 08e5fad..e570614 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -17,6 +17,16 @@ import ( "time" ) +// a copy of context.Context for Go 1.7 and earlier +type mysqlContext interface { + Done() <-chan struct{} + Err() error + + // defined in context.Context, but not used in this driver: + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} +} + type mysqlConn struct { buf buffer netConn net.Conn @@ -30,7 +40,14 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - strict bool + + // for context support (Go 1.8+) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established @@ -63,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } @@ -92,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { + if !mc.closed.TrySet(true) { + return + } + // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err } - mc.netConn = nil + return ErrInvalidConn } - mc.cfg = nil - mc.buf.nc = nil + return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ @@ -145,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -258,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -272,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, err } query = prepared - args = nil } mc.affectedRows = 0 mc.insertId = 0 @@ -284,14 +332,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { - return err + return mc.markBadConn(err) } // Read Result @@ -316,7 +364,11 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -330,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } query = prepared - args = nil } // Send command err := mc.writeCommandPacketStr(comQuery, query) @@ -352,12 +403,13 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } } + // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable @@ -389,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} -- cgit v1.2.3