diff options
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/rows.go')
| -rw-r--r-- | vendor/github.com/go-sql-driver/mysql/rows.go | 111 | 
1 files changed, 71 insertions, 40 deletions
| diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index 900f548..18f4169 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -11,34 +11,24 @@ package mysql  import (  	"database/sql/driver"  	"io" +	"math" +	"reflect"  ) -type mysqlField struct { -	tableName string -	name      string -	flags     fieldFlag -	fieldType byte -	decimals  byte -} -  type resultSet struct { -	columns []mysqlField -	done    bool +	columns     []mysqlField +	columnNames []string +	done        bool  }  type mysqlRows struct { -	mc *mysqlConn -	rs resultSet +	mc     *mysqlConn +	rs     resultSet +	finish func()  }  type binaryRows struct {  	mysqlRows -	// stmtCols is a pointer to the statement's cached columns for different -	// result sets. -	stmtCols *[][]mysqlField -	// i is a number of the current result set. It is used to fetch proper -	// columns from stmtCols. -	i int  }  type textRows struct { @@ -46,6 +36,10 @@ type textRows struct {  }  func (rows *mysqlRows) Columns() []string { +	if rows.rs.columnNames != nil { +		return rows.rs.columnNames +	} +  	columns := make([]string, len(rows.rs.columns))  	if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {  		for i := range columns { @@ -60,16 +54,64 @@ func (rows *mysqlRows) Columns() []string {  			columns[i] = rows.rs.columns[i].name  		}  	} + +	rows.rs.columnNames = columns  	return columns  } +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { +	if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok { +		return name +	} +	return "" +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// 	return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { +	return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { +	column := rows.rs.columns[i] +	decimals := int64(column.decimals) + +	switch column.fieldType { +	case fieldTypeDecimal, fieldTypeNewDecimal: +		if decimals > 0 { +			return int64(column.length) - 2, decimals, true +		} +		return int64(column.length) - 1, decimals, true +	case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: +		return decimals, decimals, true +	case fieldTypeFloat, fieldTypeDouble: +		if decimals == 0x1f { +			return math.MaxInt64, math.MaxInt64, true +		} +		return math.MaxInt64, decimals, true +	} + +	return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { +	return rows.rs.columns[i].scanType() +} +  func (rows *mysqlRows) Close() (err error) { +	if f := rows.finish; f != nil { +		f() +		rows.finish = nil +	} +  	mc := rows.mc  	if mc == nil {  		return nil  	} -	if mc.netConn == nil { -		return ErrInvalidConn +	if err := mc.error(); err != nil { +		return err  	}  	// Remove unread packets from stream @@ -97,8 +139,8 @@ func (rows *mysqlRows) nextResultSet() (int, error) {  	if rows.mc == nil {  		return 0, io.EOF  	} -	if rows.mc.netConn == nil { -		return 0, ErrInvalidConn +	if err := rows.mc.error(); err != nil { +		return 0, err  	}  	// Remove unread packets from stream @@ -132,31 +174,20 @@ func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {  	}  } -func (rows *binaryRows) NextResultSet() (err error) { +func (rows *binaryRows) NextResultSet() error {  	resLen, err := rows.nextNotEmptyResultSet()  	if err != nil {  		return err  	} -	// get columns, if not cached, read them and cache them. -	if rows.i >= len(*rows.stmtCols) { -		rows.rs.columns, err = rows.mc.readColumns(resLen) -		*rows.stmtCols = append(*rows.stmtCols, rows.rs.columns) -	} else { -		rows.rs.columns = (*rows.stmtCols)[rows.i] -		if err := rows.mc.readUntilEOF(); err != nil { -			return err -		} -	} - -	rows.i++ -	return nil +	rows.rs.columns, err = rows.mc.readColumns(resLen) +	return err  }  func (rows *binaryRows) Next(dest []driver.Value) error {  	if mc := rows.mc; mc != nil { -		if mc.netConn == nil { -			return ErrInvalidConn +		if err := mc.error(); err != nil { +			return err  		}  		// Fetch next row from stream @@ -177,8 +208,8 @@ func (rows *textRows) NextResultSet() (err error) {  func (rows *textRows) Next(dest []driver.Value) error {  	if mc := rows.mc; mc != nil { -		if mc.netConn == nil { -			return ErrInvalidConn +		if err := mc.error(); err != nil { +			return err  		}  		// Fetch next row from stream | 
