aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/go-sql-driver/mysql/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/connection.go')
-rw-r--r--vendor/github.com/go-sql-driver/mysql/connection.go372
1 files changed, 372 insertions, 0 deletions
diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go
new file mode 100644
index 0000000..c3899de
--- /dev/null
+++ b/vendor/github.com/go-sql-driver/mysql/connection.go
@@ -0,0 +1,372 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+ "database/sql/driver"
+ "net"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type mysqlConn struct {
+ buf buffer
+ netConn net.Conn
+ affectedRows uint64
+ insertId uint64
+ cfg *Config
+ maxPacketAllowed int
+ maxWriteSize int
+ writeTimeout time.Duration
+ flags clientFlag
+ status statusFlag
+ sequence uint8
+ parseTime bool
+ strict bool
+}
+
+// Handles parameters set in DSN after the connection is established
+func (mc *mysqlConn) handleParams() (err error) {
+ for param, val := range mc.cfg.Params {
+ switch param {
+ // Charset
+ case "charset":
+ charsets := strings.Split(val, ",")
+ for i := range charsets {
+ // ignore errors here - a charset may not exist
+ err = mc.exec("SET NAMES " + charsets[i])
+ if err == nil {
+ break
+ }
+ }
+ if err != nil {
+ return
+ }
+
+ // System Vars
+ default:
+ err = mc.exec("SET " + param + "=" + val + "")
+ if err != nil {
+ return
+ }
+ }
+ }
+
+ return
+}
+
+func (mc *mysqlConn) Begin() (driver.Tx, error) {
+ if mc.netConn == nil {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ err := mc.exec("START TRANSACTION")
+ if err == nil {
+ return &mysqlTx{mc}, err
+ }
+
+ return nil, err
+}
+
+func (mc *mysqlConn) Close() (err error) {
+ // Makes Close idempotent
+ if mc.netConn != nil {
+ err = mc.writeCommandPacket(comQuit)
+ }
+
+ mc.cleanup()
+
+ return
+}
+
+// Closes the network connection and unsets internal variables. Do not call this
+// function after successfully authentication, call Close instead. This function
+// is called before auth or on auth failure because MySQL will have already
+// closed the network connection.
+func (mc *mysqlConn) cleanup() {
+ // Makes cleanup idempotent
+ if mc.netConn != nil {
+ if err := mc.netConn.Close(); err != nil {
+ errLog.Print(err)
+ }
+ mc.netConn = nil
+ }
+ mc.cfg = nil
+ mc.buf.nc = nil
+}
+
+func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+ if mc.netConn == nil {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ // Send command
+ err := mc.writeCommandPacketStr(comStmtPrepare, query)
+ if err != nil {
+ return nil, err
+ }
+
+ stmt := &mysqlStmt{
+ mc: mc,
+ }
+
+ // Read Result
+ columnCount, err := stmt.readPrepareResultPacket()
+ if err == nil {
+ if stmt.paramCount > 0 {
+ if err = mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+ }
+
+ if columnCount > 0 {
+ err = mc.readUntilEOF()
+ }
+ }
+
+ return stmt, err
+}
+
+func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
+ buf := mc.buf.takeCompleteBuffer()
+ if buf == nil {
+ // can not take the buffer. Something must be wrong with the connection
+ errLog.Print(ErrBusyBuffer)
+ return "", driver.ErrBadConn
+ }
+ buf = buf[:0]
+ argPos := 0
+
+ for i := 0; i < len(query); i++ {
+ q := strings.IndexByte(query[i:], '?')
+ if q == -1 {
+ buf = append(buf, query[i:]...)
+ break
+ }
+ buf = append(buf, query[i:i+q]...)
+ i += q
+
+ arg := args[argPos]
+ argPos++
+
+ if arg == nil {
+ buf = append(buf, "NULL"...)
+ continue
+ }
+
+ switch v := arg.(type) {
+ case int64:
+ buf = strconv.AppendInt(buf, v, 10)
+ case float64:
+ buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+ case bool:
+ if v {
+ buf = append(buf, '1')
+ } else {
+ buf = append(buf, '0')
+ }
+ case time.Time:
+ if v.IsZero() {
+ buf = append(buf, "'0000-00-00'"...)
+ } else {
+ v := v.In(mc.cfg.Loc)
+ v = v.Add(time.Nanosecond * 500) // To round under microsecond
+ year := v.Year()
+ year100 := year / 100
+ year1 := year % 100
+ month := v.Month()
+ day := v.Day()
+ hour := v.Hour()
+ minute := v.Minute()
+ second := v.Second()
+ micro := v.Nanosecond() / 1000
+
+ buf = append(buf, []byte{
+ '\'',
+ digits10[year100], digits01[year100],
+ digits10[year1], digits01[year1],
+ '-',
+ digits10[month], digits01[month],
+ '-',
+ digits10[day], digits01[day],
+ ' ',
+ digits10[hour], digits01[hour],
+ ':',
+ digits10[minute], digits01[minute],
+ ':',
+ digits10[second], digits01[second],
+ }...)
+
+ if micro != 0 {
+ micro10000 := micro / 10000
+ micro100 := micro / 100 % 100
+ micro1 := micro % 100
+ buf = append(buf, []byte{
+ '.',
+ digits10[micro10000], digits01[micro10000],
+ digits10[micro100], digits01[micro100],
+ digits10[micro1], digits01[micro1],
+ }...)
+ }
+ buf = append(buf, '\'')
+ }
+ case []byte:
+ if v == nil {
+ buf = append(buf, "NULL"...)
+ } else {
+ buf = append(buf, "_binary'"...)
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeBytesBackslash(buf, v)
+ } else {
+ buf = escapeBytesQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ }
+ case string:
+ buf = append(buf, '\'')
+ if mc.status&statusNoBackslashEscapes == 0 {
+ buf = escapeStringBackslash(buf, v)
+ } else {
+ buf = escapeStringQuotes(buf, v)
+ }
+ buf = append(buf, '\'')
+ default:
+ return "", driver.ErrSkip
+ }
+
+ if len(buf)+4 > mc.maxPacketAllowed {
+ return "", driver.ErrSkip
+ }
+ }
+ if argPos != len(args) {
+ return "", driver.ErrSkip
+ }
+ return string(buf), nil
+}
+
+func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ if mc.netConn == nil {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
+ }
+ // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ args = nil
+ }
+ mc.affectedRows = 0
+ mc.insertId = 0
+
+ err := mc.exec(query)
+ if err == nil {
+ return &mysqlResult{
+ affectedRows: int64(mc.affectedRows),
+ insertId: int64(mc.insertId),
+ }, err
+ }
+ return nil, err
+}
+
+// Internal function to execute commands
+func (mc *mysqlConn) exec(query string) error {
+ // Send command
+ err := mc.writeCommandPacketStr(comQuery, query)
+ if err != nil {
+ return err
+ }
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err == nil && resLen > 0 {
+ if err = mc.readUntilEOF(); err != nil {
+ return err
+ }
+
+ err = mc.readUntilEOF()
+ }
+
+ return err
+}
+
+func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ if mc.netConn == nil {
+ errLog.Print(ErrInvalidConn)
+ return nil, driver.ErrBadConn
+ }
+ if len(args) != 0 {
+ if !mc.cfg.InterpolateParams {
+ return nil, driver.ErrSkip
+ }
+ // try client-side prepare to reduce roundtrip
+ prepared, err := mc.interpolateParams(query, args)
+ if err != nil {
+ return nil, err
+ }
+ query = prepared
+ args = nil
+ }
+ // Send command
+ err := mc.writeCommandPacketStr(comQuery, query)
+ if err == nil {
+ // Read Result
+ var resLen int
+ resLen, err = mc.readResultSetHeaderPacket()
+ if err == nil {
+ rows := new(textRows)
+ rows.mc = mc
+
+ if resLen == 0 {
+ // no columns, no more data
+ return emptyRows{}, nil
+ }
+ // Columns
+ rows.columns, err = mc.readColumns(resLen)
+ return rows, err
+ }
+ }
+ return nil, err
+}
+
+// Gets the value of the given MySQL System Variable
+// The returned byte slice is only valid until the next read
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
+ // Send command
+ if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+ return nil, err
+ }
+
+ // Read Result
+ resLen, err := mc.readResultSetHeaderPacket()
+ if err == nil {
+ rows := new(textRows)
+ rows.mc = mc
+ rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
+
+ if resLen > 0 {
+ // Columns
+ if err := mc.readUntilEOF(); err != nil {
+ return nil, err
+ }
+ }
+
+ dest := make([]driver.Value, resLen)
+ if err = rows.readRow(dest); err == nil {
+ return dest[0].([]byte), mc.readUntilEOF()
+ }
+ }
+ return nil, err
+}