diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2017-04-10 21:18:42 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2017-04-10 21:38:33 +0100 |
commit | 30802e07b2d84fbc213b490d3402707dffe60096 (patch) | |
tree | 934aecb8f3582325dfd1aa6652193adac87d00db /vendor/github.com/go-sql-driver/mysql/statement.go | |
parent | da7638dc112c4c106e8929601b642d2ca4596cba (diff) |
update dependencies
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/statement.go')
-rw-r--r-- | vendor/github.com/go-sql-driver/mysql/statement.go | 56 |
1 files changed, 36 insertions, 20 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index 7f9b045..b887716 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -11,6 +11,7 @@ package mysql import ( "database/sql/driver" "fmt" + "io" "reflect" "strconv" ) @@ -19,7 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int - columns []mysqlField // cached from the first query + columns [][]mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { @@ -62,26 +63,30 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - // Rows - err = mc.readUntilEOF() + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err } } - return nil, err + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -104,18 +109,29 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } rows := new(binaryRows) + rows.stmtCols = &stmt.columns if resLen > 0 { rows.mc = mc + rows.i++ // Columns // If not cached, read them and cache them - if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns + if len(stmt.columns) == 0 { + rows.rs.columns, err = mc.readColumns(resLen) + stmt.columns = append(stmt.columns, rows.rs.columns) } else { - rows.columns = stmt.columns + rows.rs.columns = stmt.columns[0] err = mc.readUntilEOF() } + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } return rows, err |