From 30802e07b2d84fbc213b490d3402707dffe60096 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Mon, 10 Apr 2017 21:18:42 +0100 Subject: update dependencies --- vendor/github.com/go-sql-driver/mysql/rows.go | 122 +++++++++++++++++++++----- 1 file changed, 99 insertions(+), 23 deletions(-) (limited to 'vendor/github.com/go-sql-driver/mysql/rows.go') diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index c08255e..900f548 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -21,40 +21,49 @@ type mysqlField struct { decimals byte } -type mysqlRows struct { - mc *mysqlConn +type resultSet struct { columns []mysqlField + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs resultSet } type binaryRows struct { mysqlRows + // stmtCols is a pointer to the statement's cached columns for different + // result sets. + stmtCols *[][]mysqlField + // i is a number of the current result set. It is used to fetch proper + // columns from stmtCols. + i int } type textRows struct { mysqlRows } -type emptyRows struct{} - func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) + columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } return columns } -func (rows *mysqlRows) Close() error { +func (rows *mysqlRows) Close() (err error) { mc := rows.mc if mc == nil { return nil @@ -64,7 +73,9 @@ func (rows *mysqlRows) Close() error { } // Remove unread packets from stream - err := mc.readUntilEOF() + if !rows.rs.done { + err = mc.readUntilEOF() + } if err == nil { if err = mc.discardResults(); err != nil { return err @@ -75,6 +86,73 @@ func (rows *mysqlRows) Close() error { return err } +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if rows.mc.netConn == nil { + return 0, ErrInvalidConn + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + // get columns, if not cached, read them and cache them. + if rows.i >= len(*rows.stmtCols) { + rows.rs.columns, err = rows.mc.readColumns(resLen) + *rows.stmtCols = append(*rows.stmtCols, rows.rs.columns) + } else { + rows.rs.columns = (*rows.stmtCols)[rows.i] + if err := rows.mc.readUntilEOF(); err != nil { + return err + } + } + + rows.i++ + return nil +} + func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { if mc.netConn == nil { @@ -87,6 +165,16 @@ func (rows *binaryRows) Next(dest []driver.Value) error { return io.EOF } +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { if mc.netConn == nil { @@ -98,15 +186,3 @@ func (rows *textRows) Next(dest []driver.Value) error { } return io.EOF } - -func (rows emptyRows) Columns() []string { - return nil -} - -func (rows emptyRows) Close() error { - return nil -} - -func (rows emptyRows) Next(dest []driver.Value) error { - return io.EOF -} -- cgit v1.2.3