diff options
| author | Niall Sheridan <nsheridan@gmail.com> | 2017-10-18 13:15:14 +0100 | 
|---|---|---|
| committer | Niall Sheridan <niall@intercom.io> | 2017-10-18 13:25:46 +0100 | 
| commit | 7b320119ba532fd409ec7dade7ad02011c309599 (patch) | |
| tree | a39860f35b55e6cc499f8f5bfa969138c5dd6b73 /vendor/github.com/go-sql-driver/mysql/dsn.go | |
| parent | 7c99874c7a3e7a89716f3ee0cdf696532e35ae35 (diff) | |
Update dependencies
Diffstat (limited to 'vendor/github.com/go-sql-driver/mysql/dsn.go')
| -rw-r--r-- | vendor/github.com/go-sql-driver/mysql/dsn.go | 119 | 
1 files changed, 79 insertions, 40 deletions
| diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index ac00dce..3ade963 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -15,6 +15,7 @@ import (  	"fmt"  	"net"  	"net/url" +	"sort"  	"strconv"  	"strings"  	"time" @@ -27,7 +28,9 @@ var (  	errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")  ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values.  type Config struct {  	User             string            // Username  	Passwd           string            // Password (requires User) @@ -53,7 +56,45 @@ type Config struct {  	InterpolateParams       bool // Interpolate placeholders into query string  	MultiStatements         bool // Allow multiple statements in one query  	ParseTime               bool // Parse time values to time.Time -	Strict                  bool // Return warnings as errors +	RejectReadOnly          bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { +	return &Config{ +		Collation:            defaultCollation, +		Loc:                  time.UTC, +		MaxAllowedPacket:     defaultMaxAllowedPacket, +		AllowNativePasswords: true, +	} +} + +func (cfg *Config) normalize() error { +	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { +		return errInvalidDSNUnsafeCollation +	} + +	// Set default network if empty +	if cfg.Net == "" { +		cfg.Net = "tcp" +	} + +	// Set default address if empty +	if cfg.Addr == "" { +		switch cfg.Net { +		case "tcp": +			cfg.Addr = "127.0.0.1:3306" +		case "unix": +			cfg.Addr = "/tmp/mysql.sock" +		default: +			return errors.New("default addr for network '" + cfg.Net + "' unknown") +		} + +	} else if cfg.Net == "tcp" { +		cfg.Addr = ensureHavePort(cfg.Addr) +	} + +	return nil  }  // FormatDSN formats the given Config into a DSN string which can be passed to @@ -102,12 +143,12 @@ func (cfg *Config) FormatDSN() string {  		}  	} -	if cfg.AllowNativePasswords { +	if !cfg.AllowNativePasswords {  		if hasParam { -			buf.WriteString("&allowNativePasswords=true") +			buf.WriteString("&allowNativePasswords=false")  		} else {  			hasParam = true -			buf.WriteString("?allowNativePasswords=true") +			buf.WriteString("?allowNativePasswords=false")  		}  	} @@ -195,12 +236,12 @@ func (cfg *Config) FormatDSN() string {  		buf.WriteString(cfg.ReadTimeout.String())  	} -	if cfg.Strict { +	if cfg.RejectReadOnly {  		if hasParam { -			buf.WriteString("&strict=true") +			buf.WriteString("&rejectReadOnly=true")  		} else {  			hasParam = true -			buf.WriteString("?strict=true") +			buf.WriteString("?rejectReadOnly=true")  		}  	} @@ -234,7 +275,7 @@ func (cfg *Config) FormatDSN() string {  		buf.WriteString(cfg.WriteTimeout.String())  	} -	if cfg.MaxAllowedPacket > 0 { +	if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {  		if hasParam {  			buf.WriteString("&maxAllowedPacket=")  		} else { @@ -247,7 +288,12 @@ func (cfg *Config) FormatDSN() string {  	// other params  	if cfg.Params != nil { -		for param, value := range cfg.Params { +		var params []string +		for param := range cfg.Params { +			params = append(params, param) +		} +		sort.Strings(params) +		for _, param := range params {  			if hasParam {  				buf.WriteByte('&')  			} else { @@ -257,7 +303,7 @@ func (cfg *Config) FormatDSN() string {  			buf.WriteString(param)  			buf.WriteByte('=') -			buf.WriteString(url.QueryEscape(value)) +			buf.WriteString(url.QueryEscape(cfg.Params[param]))  		}  	} @@ -267,10 +313,7 @@ func (cfg *Config) FormatDSN() string {  // ParseDSN parses the DSN string to a Config  func ParseDSN(dsn string) (cfg *Config, err error) {  	// New config with some default values -	cfg = &Config{ -		Loc:       time.UTC, -		Collation: defaultCollation, -	} +	cfg = NewConfig()  	// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]  	// Find the last '/' (since the password or the net addr might contain a '/') @@ -338,28 +381,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) {  		return nil, errInvalidDSNNoSlash  	} -	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { -		return nil, errInvalidDSNUnsafeCollation +	if err = cfg.normalize(); err != nil { +		return nil, err  	} - -	// Set default network if empty -	if cfg.Net == "" { -		cfg.Net = "tcp" -	} - -	// Set default address if empty -	if cfg.Addr == "" { -		switch cfg.Net { -		case "tcp": -			cfg.Addr = "127.0.0.1:3306" -		case "unix": -			cfg.Addr = "/tmp/mysql.sock" -		default: -			return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") -		} - -	} -  	return  } @@ -472,14 +496,18 @@ func parseDSNParams(cfg *Config, params string) (err error) {  				return  			} -		// Strict mode -		case "strict": +		// Reject read-only connections +		case "rejectReadOnly":  			var isBool bool -			cfg.Strict, isBool = readBool(value) +			cfg.RejectReadOnly, isBool = readBool(value)  			if !isBool {  				return errors.New("invalid bool value: " + value)  			} +		// Strict mode +		case "strict": +			panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") +  		// Dial Timeout  		case "timeout":  			cfg.Timeout, err = time.ParseDuration(value) @@ -494,6 +522,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {  				if boolValue {  					cfg.TLSConfig = "true"  					cfg.tls = &tls.Config{} +					host, _, err := net.SplitHostPort(cfg.Addr) +					if err == nil { +						cfg.tls.ServerName = host +					}  				} else {  					cfg.TLSConfig = "false"  				} @@ -506,7 +538,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {  					return fmt.Errorf("invalid value for TLS config name: %v", err)  				} -				if tlsConfig, ok := tlsConfigRegister[name]; ok { +				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {  					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {  						host, _, err := net.SplitHostPort(cfg.Addr)  						if err == nil { @@ -546,3 +578,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {  	return  } + +func ensureHavePort(addr string) string { +	if _, _, err := net.SplitHostPort(addr); err != nil { +		return net.JoinHostPort(addr, "3306") +	} +	return addr +} | 
