aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-sql-driver/mysql/rows.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/rows.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/rows.go111
1 files changed, 71 insertions, 40 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go
index 900f548..18f4169 100644
--- a/vendor/github.com/go-sql-driver/mysql/rows.go
+++ b/vendor/github.com/go-sql-driver/mysql/rows.go
@@ -11,34 +11,24 @@ package mysql
import (
"database/sql/driver"
"io"
+ "math"
+ "reflect"
)
-type mysqlField struct {
- tableName string
- name string
- flags fieldFlag
- fieldType byte
- decimals byte
-}
-
type resultSet struct {
- columns []mysqlField
- done bool
+ columns []mysqlField
+ columnNames []string
+ done bool
}
type mysqlRows struct {
- mc *mysqlConn
- rs resultSet
+ mc *mysqlConn
+ rs resultSet
+ finish func()
}
type binaryRows struct {
mysqlRows
- // stmtCols is a pointer to the statement's cached columns for different
- // result sets.
- stmtCols *[][]mysqlField
- // i is a number of the current result set. It is used to fetch proper
- // columns from stmtCols.
- i int
}
type textRows struct {
@@ -46,6 +36,10 @@ type textRows struct {
}
func (rows *mysqlRows) Columns() []string {
+ if rows.rs.columnNames != nil {
+ return rows.rs.columnNames
+ }
+
columns := make([]string, len(rows.rs.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
for i := range columns {
@@ -60,16 +54,64 @@ func (rows *mysqlRows) Columns() []string {
columns[i] = rows.rs.columns[i].name
}
}
+
+ rows.rs.columnNames = columns
return columns
}
+func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
+ if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok {
+ return name
+ }
+ return ""
+}
+
+// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
+// return int64(rows.rs.columns[i].length), true
+// }
+
+func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
+ return rows.rs.columns[i].flags&flagNotNULL == 0, true
+}
+
+func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
+ column := rows.rs.columns[i]
+ decimals := int64(column.decimals)
+
+ switch column.fieldType {
+ case fieldTypeDecimal, fieldTypeNewDecimal:
+ if decimals > 0 {
+ return int64(column.length) - 2, decimals, true
+ }
+ return int64(column.length) - 1, decimals, true
+ case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
+ return decimals, decimals, true
+ case fieldTypeFloat, fieldTypeDouble:
+ if decimals == 0x1f {
+ return math.MaxInt64, math.MaxInt64, true
+ }
+ return math.MaxInt64, decimals, true
+ }
+
+ return 0, 0, false
+}
+
+func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
+ return rows.rs.columns[i].scanType()
+}
+
func (rows *mysqlRows) Close() (err error) {
+ if f := rows.finish; f != nil {
+ f()
+ rows.finish = nil
+ }
+
mc := rows.mc
if mc == nil {
return nil
}
- if mc.netConn == nil {
- return ErrInvalidConn
+ if err := mc.error(); err != nil {
+ return err
}
// Remove unread packets from stream
@@ -97,8 +139,8 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
if rows.mc == nil {
return 0, io.EOF
}
- if rows.mc.netConn == nil {
- return 0, ErrInvalidConn
+ if err := rows.mc.error(); err != nil {
+ return 0, err
}
// Remove unread packets from stream
@@ -132,31 +174,20 @@ func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
}
}
-func (rows *binaryRows) NextResultSet() (err error) {
+func (rows *binaryRows) NextResultSet() error {
resLen, err := rows.nextNotEmptyResultSet()
if err != nil {
return err
}
- // get columns, if not cached, read them and cache them.
- if rows.i >= len(*rows.stmtCols) {
- rows.rs.columns, err = rows.mc.readColumns(resLen)
- *rows.stmtCols = append(*rows.stmtCols, rows.rs.columns)
- } else {
- rows.rs.columns = (*rows.stmtCols)[rows.i]
- if err := rows.mc.readUntilEOF(); err != nil {
- return err
- }
- }
-
- rows.i++
- return nil
+ rows.rs.columns, err = rows.mc.readColumns(resLen)
+ return err
}
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
- if mc.netConn == nil {
- return ErrInvalidConn
+ if err := mc.error(); err != nil {
+ return err
}
// Fetch next row from stream
@@ -177,8 +208,8 @@ func (rows *textRows) NextResultSet() (err error) {
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
- if mc.netConn == nil {
- return ErrInvalidConn
+ if err := mc.error(); err != nil {
+ return err
}
// Fetch next row from stream