diff options
39 files changed, 11870 insertions, 14 deletions
diff --git a/.travis.yml b/.travis.yml index 38fa579..c5cd158 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,10 @@ language: go services: - mysql + - mongodb env: - - GO15VENDOREXPERIMENT=1 MYSQL_TEST_CONFIG="mysql:user:passwd:localhost" + - GO15VENDOREXPERIMENT=1 MYSQL_TEST_CONFIG="mysql:user:passwd:localhost" MONGO_TEST_CONFIG="mongo:user:passwd:localhost" go: - 1.5.4 @@ -19,6 +20,7 @@ before_install: before_script: - go run ./cmd/dbinit/dbinit.go -db_user user -db_password passwd + - go run ./cmd/dbinit/dbinit.go -db_type mongo -admin_user '' -db_user user -db_password passwd sudo: false script: @@ -87,10 +87,11 @@ Configuration is divided into different sections: `server`, `auth`, `ssh`, and ` #### Datastore Datastores contain a record of issued certificates for audit and revocation purposes. The connection string is of the form `engine:username:password:host[:port]`. -Currently two engines are supported: `mysql` and `mem`. +Supported database providers: `mysql`, `mongo` and `mem`. `mem` is an in-memory database intended for testing and takes no additional config options. -`mysql` is the MySQL database and the `username`, `password` and `host` arguments are required. `port` is assumed to be 3306 unless otherwise specified. +`mysql` is the MySQL database and accepts `username`, `password` and `host` arguments. Only `username` and `host` arguments are required. `port` is assumed to be 3306 unless otherwise specified. +`mongo` is MongoDB and accepts `username`, `password` and `host` arguments. All arguments are optional and multiple hosts can be specified using comma-separated values: `mongo:dbuser:dbpasswd:host1,host2`. If no datastore is specified the `mem` store is used. @@ -100,11 +101,14 @@ Examples: server { datastore = "mem" # use the in-memory database. datastore = "mysql:root::localhost" # mysql running on localhost with the user 'root' and no password. - datastore = "mysql:cashier:aMaZiNgPaSsWoRd:mydbprovider.example.com:5150" # mysql running on a remote host on port 5150 + datastore = "mysql:cashier:PaSsWoRd:mydbprovider.example.com:5150" # mysql running on a remote host on port 5150 + datastore = "mongo:cashier:PaSsWoRd:mydbprovider.example.com:27018" # mongo running on a remote host on port 27018 + datastore = "mongo:cashier:PaSsWoRd:server1.example.com:27018,server2.example.com:27018" # mongo running on multiple servers on port 27018 } ``` -Prior to using the MySQL datastore, you need to create the database and tables using the [dbinit tool](cmd/dbinit/dbinit.go). +Prior to using MySQL or MongoDB datastores you need to create the database and tables using the [dbinit tool](cmd/dbinit/dbinit.go). +Note that dbinit has no support for replica sets. ### auth - `provider` : string. Name of the oauth provider. At present the only valid value is "google". diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index 8295c74..67af7cf 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -298,6 +298,8 @@ func certStore(config string) (store.CertStorer, error) { switch engine { case "mysql": cstore, err = store.NewMySQLStore(config) + case "mongo": + cstore, err = store.NewMongoStore(config) case "mem": cstore = store.NewMemoryStore() default: diff --git a/cmd/dbinit/dbinit.go b/cmd/dbinit/dbinit.go index de06cbb..46484d1 100644 --- a/cmd/dbinit/dbinit.go +++ b/cmd/dbinit/dbinit.go @@ -7,6 +7,8 @@ import ( "log" "strings" + mgo "gopkg.in/mgo.v2" + "github.com/go-sql-driver/mysql" ) @@ -14,16 +16,19 @@ var ( host = flag.String("host", "localhost", "host[:port]") adminUser = flag.String("admin_user", "root", "Admin user") adminPasswd = flag.String("admin_password", "", "Admin password") - dbUser = flag.String("db_user", "root", "Database user") - dbPasswd = flag.String("db_password", "", "Admin password") + dbUser = flag.String("db_user", "user", "Database user") + dbPasswd = flag.String("db_password", "passwd", "Admin password") + dbType = flag.String("db_type", "mysql", "Database engine (\"mysql\" or \"mongo\")") + + certsDB = "certs" + issuedTable = "issued_certs" ) -func main() { - flag.Parse() +func initMySQL() { var createTableStmt = []string{ - `CREATE DATABASE IF NOT EXISTS certs DEFAULT CHARACTER SET = 'utf8' DEFAULT COLLATE 'utf8_general_ci';`, - `USE certs;`, - `CREATE TABLE IF NOT EXISTS issued_certs ( + `CREATE DATABASE IF NOT EXISTS ` + certsDB + ` DEFAULT CHARACTER SET = 'utf8' DEFAULT COLLATE 'utf8_general_ci';`, + `USE ` + certsDB + `;`, + `CREATE TABLE IF NOT EXISTS ` + issuedTable + ` ( key_id VARCHAR(255) NOT NULL, principals VARCHAR(255) NULL, created_at DATETIME NULL, @@ -59,3 +64,44 @@ func main() { } } } + +func initMongo() { + di := &mgo.DialInfo{ + Addrs: strings.Split(*host, ","), + Username: *adminUser, + Password: *adminPasswd, + } + session, err := mgo.DialWithInfo(di) + if err != nil { + log.Fatalln(err) + } + defer session.Close() + d := session.DB(certsDB) + if err := d.UpsertUser(&mgo.User{ + Username: *dbUser, + Password: *dbPasswd, + Roles: []mgo.Role{mgo.RoleReadWrite}, + }); err != nil { + log.Fatalln(err) + } + c := d.C(issuedTable) + i := mgo.Index{ + Key: []string{"keyid"}, + Unique: true, + } + if err != c.EnsureIndex(i) { + log.Fatalln(err) + } +} + +func main() { + flag.Parse() + switch *dbType { + case "mysql": + initMySQL() + case "mongo": + initMongo() + default: + log.Fatalf("Invalid database type") + } +} diff --git a/server/store/config_test.go b/server/store/config_test.go new file mode 100644 index 0000000..8e283f5 --- /dev/null +++ b/server/store/config_test.go @@ -0,0 +1,68 @@ +package store + +import ( + "reflect" + "testing" + "time" + + mgo "gopkg.in/mgo.v2" +) + +func TestMySQLConfig(t *testing.T) { + var tests = []struct { + in string + out string + }{ + {"mysql:user:passwd:localhost", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}, + {"mysql:user:passwd:localhost:13306", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}, + {"mysql:root::localhost", "root@tcp(localhost:3306)/certs?parseTime=true"}, + } + for _, tt := range tests { + result := parseMySQLConfig(tt.in) + if result != tt.out { + t.Errorf("want %s, got %s", tt.out, result) + } + } +} + +func TestMongoConfig(t *testing.T) { + var tests = []struct { + in string + out *mgo.DialInfo + }{ + {"mongo:user:passwd:host", &mgo.DialInfo{ + Username: "user", + Password: "passwd", + Addrs: []string{"host"}, + Database: "certs", + Timeout: 5 * time.Second, + }}, + {"mongo:user:passwd:host1,host2", &mgo.DialInfo{ + Username: "user", + Password: "passwd", + Addrs: []string{"host1", "host2"}, + Database: "certs", + Timeout: 5 * time.Second, + }}, + {"mongo:user:passwd:host1:27017,host2:27017", &mgo.DialInfo{ + Username: "user", + Password: "passwd", + Addrs: []string{"host1:27017", "host2:27017"}, + Database: "certs", + Timeout: 5 * time.Second, + }}, + {"mongo:user:passwd:host1,host2:27017", &mgo.DialInfo{ + Username: "user", + Password: "passwd", + Addrs: []string{"host1", "host2:27017"}, + Database: "certs", + Timeout: 5 * time.Second, + }}, + } + for _, tt := range tests { + result := parseMongoConfig(tt.in) + if !reflect.DeepEqual(result, tt.out) { + t.Errorf("want:\n%+v\ngot:\n%+v", tt.out, result) + } + } +} diff --git a/server/store/mongo.go b/server/store/mongo.go new file mode 100644 index 0000000..752d405 --- /dev/null +++ b/server/store/mongo.go @@ -0,0 +1,80 @@ +package store + +import ( + "strings" + "time" + + "golang.org/x/crypto/ssh" + + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +var ( + certsDB = "certs" + issuedTable = "issued_certs" +) + +type mongoDB struct { + conn *mgo.Collection +} + +func parseMongoConfig(config string) *mgo.DialInfo { + s := strings.SplitN(config, ":", 4) + _, user, passwd, hosts := s[0], s[1], s[2], s[3] + d := &mgo.DialInfo{ + Addrs: strings.Split(hosts, ","), + Username: user, + Password: passwd, + Database: certsDB, + Timeout: time.Second * 5, + } + return d +} + +func NewMongoStore(config string) (CertStorer, error) { + session, err := mgo.DialWithInfo(parseMongoConfig(config)) + if err != nil { + return nil, err + } + c := session.DB(certsDB).C(issuedTable) + return &mongoDB{ + conn: c, + }, nil +} + +func (m *mongoDB) Get(id string) (*CertRecord, error) { + c := &CertRecord{} + err := m.conn.Find(bson.M{"keyid": id}).One(c) + return c, err +} + +func (m *mongoDB) SetCert(cert *ssh.Certificate) error { + r := parseCertificate(cert) + return m.SetRecord(r) +} + +func (m *mongoDB) SetRecord(record *CertRecord) error { + return m.conn.Insert(record) +} + +func (m *mongoDB) List() ([]*CertRecord, error) { + var result []*CertRecord + m.conn.Find(nil).All(&result) + return result, nil +} + +func (m *mongoDB) Revoke(id string) error { + return m.conn.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) +} + +func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { + var result []*CertRecord + err := m.conn.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) + return result, err +} + +func (m *mongoDB) Close() error { + m.conn.Database.Session.Close() + return nil +} diff --git a/server/store/mysql.go b/server/store/mysql.go index a62af6b..7a0b111 100644 --- a/server/store/mysql.go +++ b/server/store/mysql.go @@ -22,7 +22,7 @@ type mysqlDB struct { revoked *sql.Stmt } -func parseConfig(config string) string { +func parseMySQLConfig(config string) string { s := strings.Split(config, ":") if len(s) == 4 { s = append(s, "3306") @@ -41,7 +41,7 @@ func parseConfig(config string) string { // NewMySQLStore returns a MySQL CertStorer. func NewMySQLStore(config string) (CertStorer, error) { - conn, err := sql.Open("mysql", parseConfig(config)) + conn, err := sql.Open("mysql", parseMySQLConfig(config)) if err != nil { return nil, fmt.Errorf("mysql: could not get a connection: %v", err) } diff --git a/server/store/store_test.go b/server/store/store_test.go index ee80241..bf16fa6 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -100,3 +100,15 @@ func TestMySQLStore(t *testing.T) { } testStore(t, db) } + +func TestMongoStore(t *testing.T) { + config := os.Getenv("MONGO_TEST_CONFIG") + if config == "" { + t.Skip("No MONGO_TEST_CONFIG environment variable") + } + db, err := NewMongoStore(config) + if err != nil { + t.Error(err) + } + testStore(t, db) +} diff --git a/vendor/gopkg.in/mgo.v2/LICENSE b/vendor/gopkg.in/mgo.v2/LICENSE new file mode 100644 index 0000000..770c767 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/LICENSE @@ -0,0 +1,25 @@ +mgo - MongoDB driver for Go + +Copyright (c) 2010-2013 - Gustavo Niemeyer <gustavo@niemeyer.net> + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/mgo.v2/Makefile b/vendor/gopkg.in/mgo.v2/Makefile new file mode 100644 index 0000000..51bee73 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/Makefile @@ -0,0 +1,5 @@ +startdb: + @testdb/setup.sh start + +stopdb: + @testdb/setup.sh stop diff --git a/vendor/gopkg.in/mgo.v2/README.md b/vendor/gopkg.in/mgo.v2/README.md new file mode 100644 index 0000000..f4e452c --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/README.md @@ -0,0 +1,4 @@ +The MongoDB driver for Go +------------------------- + +Please go to [http://labix.org/mgo](http://labix.org/mgo) for all project details. diff --git a/vendor/gopkg.in/mgo.v2/auth.go b/vendor/gopkg.in/mgo.v2/auth.go new file mode 100644 index 0000000..dc26e52 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/auth.go @@ -0,0 +1,467 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "errors" + "fmt" + "sync" + + "gopkg.in/mgo.v2/bson" + "gopkg.in/mgo.v2/internal/scram" +) + +type authCmd struct { + Authenticate int + + Nonce string + User string + Key string +} + +type startSaslCmd struct { + StartSASL int `bson:"startSasl"` +} + +type authResult struct { + ErrMsg string + Ok bool +} + +type getNonceCmd struct { + GetNonce int +} + +type getNonceResult struct { + Nonce string + Err string "$err" + Code int +} + +type logoutCmd struct { + Logout int +} + +type saslCmd struct { + Start int `bson:"saslStart,omitempty"` + Continue int `bson:"saslContinue,omitempty"` + ConversationId int `bson:"conversationId,omitempty"` + Mechanism string `bson:"mechanism,omitempty"` + Payload []byte +} + +type saslResult struct { + Ok bool `bson:"ok"` + NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?) + Done bool + + ConversationId int `bson:"conversationId"` + Payload []byte + ErrMsg string +} + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +func (socket *mongoSocket) getNonce() (nonce string, err error) { + socket.Lock() + for socket.cachedNonce == "" && socket.dead == nil { + debugf("Socket %p to %s: waiting for nonce", socket, socket.addr) + socket.gotNonce.Wait() + } + if socket.cachedNonce == "mongos" { + socket.Unlock() + return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth") + } + debugf("Socket %p to %s: got nonce", socket, socket.addr) + nonce, err = socket.cachedNonce, socket.dead + socket.cachedNonce = "" + socket.Unlock() + if err != nil { + nonce = "" + } + return +} + +func (socket *mongoSocket) resetNonce() { + debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr) + op := &queryOp{} + op.query = &getNonceCmd{GetNonce: 1} + op.collection = "admin.$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + if err != nil { + socket.kill(errors.New("getNonce: "+err.Error()), true) + return + } + result := &getNonceResult{} + err = bson.Unmarshal(docData, &result) + if err != nil { + socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true) + return + } + debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result) + if result.Code == 13390 { + // mongos doesn't yet support auth (see http://j.mp/mongos-auth) + result.Nonce = "mongos" + } else if result.Nonce == "" { + var msg string + if result.Err != "" { + msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code) + } else { + msg = "Got an empty nonce" + } + socket.kill(errors.New(msg), true) + return + } + socket.Lock() + if socket.cachedNonce != "" { + socket.Unlock() + panic("resetNonce: nonce already cached") + } + socket.cachedNonce = result.Nonce + socket.gotNonce.Signal() + socket.Unlock() + } + err := socket.Query(op) + if err != nil { + socket.kill(errors.New("resetNonce: "+err.Error()), true) + } +} + +func (socket *mongoSocket) Login(cred Credential) error { + socket.Lock() + if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 { + cred.Mechanism = "SCRAM-SHA-1" + } + for _, sockCred := range socket.creds { + if sockCred == cred { + debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username) + socket.Unlock() + return nil + } + } + if socket.dropLogout(cred) { + debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + } + socket.Unlock() + + debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username) + + var err error + switch cred.Mechanism { + case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501. + err = socket.loginClassic(cred) + case "PLAIN": + err = socket.loginPlain(cred) + case "MONGODB-X509": + err = socket.loginX509(cred) + default: + // Try SASL for everything else, if it is available. + err = socket.loginSASL(cred) + } + + if err != nil { + debugf("Socket %p to %s: login error: %s", socket, socket.addr, err) + } else { + debugf("Socket %p to %s: login successful", socket, socket.addr) + } + return err +} + +func (socket *mongoSocket) loginClassic(cred Credential) error { + // Note that this only works properly because this function is + // synchronous, which means the nonce won't get reset while we're + // using it and any other login requests will block waiting for a + // new nonce provided in the defer call below. + nonce, err := socket.getNonce() + if err != nil { + return err + } + defer socket.resetNonce() + + psum := md5.New() + psum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + + ksum := md5.New() + ksum.Write([]byte(nonce + cred.Username)) + ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil)))) + + key := hex.EncodeToString(ksum.Sum(nil)) + + cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +type authX509Cmd struct { + Authenticate int + User string + Mechanism string +} + +func (socket *mongoSocket) loginX509(cred Credential) error { + cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginPlain(cred Credential) error { + cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginSASL(cred Credential) error { + var sasl saslStepper + var err error + if cred.Mechanism == "SCRAM-SHA-1" { + // SCRAM is handled without external libraries. + sasl = saslNewScram(cred) + } else if len(cred.ServiceHost) > 0 { + sasl, err = saslNew(cred, cred.ServiceHost) + } else { + sasl, err = saslNew(cred, socket.Server().Addr) + } + if err != nil { + return err + } + defer sasl.Close() + + // The goal of this logic is to carry a locked socket until the + // local SASL step confirms the auth is valid; the socket needs to be + // locked so that concurrent action doesn't leave the socket in an + // auth state that doesn't reflect the operations that took place. + // As a simple case, imagine inverting login=>logout to logout=>login. + // + // The logic below works because the lock func isn't called concurrently. + locked := false + lock := func(b bool) { + if locked != b { + locked = b + if b { + socket.Lock() + } else { + socket.Unlock() + } + } + } + + lock(true) + defer lock(false) + + start := 1 + cmd := saslCmd{} + res := saslResult{} + for { + payload, done, err := sasl.Step(res.Payload) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + lock(false) + + cmd = saslCmd{ + Start: start, + Continue: 1 - start, + ConversationId: res.ConversationId, + Mechanism: cred.Mechanism, + Payload: payload, + } + start = 0 + err = socket.loginRun(cred.Source, &cmd, &res, func() error { + // See the comment on lock for why this is necessary. + lock(true) + if !res.Ok || res.NotOk { + return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg) + } + return nil + }) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + } + + return nil +} + +func saslNewScram(cred Credential) *saslScram { + credsum := md5.New() + credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil))) + return &saslScram{cred: cred, client: client} +} + +type saslScram struct { + cred Credential + client *scram.Client +} + +func (s *saslScram) Close() {} + +func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) { + more := s.client.Step(serverData) + return s.client.Out(), !more, s.client.Err() +} + +func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error { + var mutex sync.Mutex + var replyErr error + mutex.Lock() + + op := queryOp{} + op.query = query + op.collection = db + ".$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + defer mutex.Unlock() + + if err != nil { + replyErr = err + return + } + + err = bson.Unmarshal(docData, result) + if err != nil { + replyErr = err + } else { + // Must handle this within the read loop for the socket, so + // that concurrent login requests are properly ordered. + replyErr = f() + } + } + + err := socket.Query(&op) + if err != nil { + return err + } + mutex.Lock() // Wait. + return replyErr +} + +func (socket *mongoSocket) Logout(db string) { + socket.Lock() + cred, found := socket.dropAuth(db) + if found { + debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db) + socket.logout = append(socket.logout, cred) + } + socket.Unlock() +} + +func (socket *mongoSocket) LogoutAll() { + socket.Lock() + if l := len(socket.creds); l > 0 { + debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l) + socket.logout = append(socket.logout, socket.creds...) + socket.creds = socket.creds[0:0] + } + socket.Unlock() +} + +func (socket *mongoSocket) flushLogout() (ops []interface{}) { + socket.Lock() + if l := len(socket.logout); l > 0 { + debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l) + for i := 0; i != l; i++ { + op := queryOp{} + op.query = &logoutCmd{1} + op.collection = socket.logout[i].Source + ".$cmd" + op.limit = -1 + ops = append(ops, &op) + } + socket.logout = socket.logout[0:0] + } + socket.Unlock() + return +} + +func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) { + for i, sockCred := range socket.creds { + if sockCred.Source == db { + copy(socket.creds[i:], socket.creds[i+1:]) + socket.creds = socket.creds[:len(socket.creds)-1] + return sockCred, true + } + } + return cred, false +} + +func (socket *mongoSocket) dropLogout(cred Credential) (found bool) { + for i, sockCred := range socket.logout { + if sockCred == cred { + copy(socket.logout[i:], socket.logout[i+1:]) + socket.logout = socket.logout[:len(socket.logout)-1] + return true + } + } + return false +} diff --git a/vendor/gopkg.in/mgo.v2/bson/LICENSE b/vendor/gopkg.in/mgo.v2/bson/LICENSE new file mode 100644 index 0000000..8903260 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/mgo.v2/bson/bson.go b/vendor/gopkg.in/mgo.v2/bson/bson.go new file mode 100644 index 0000000..579aec1 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/bson.go @@ -0,0 +1,721 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. +package bson + +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// -------------------------------------------------------------------------- +// The public API. + +// A value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type Getter interface { + GetBSON() (interface{}, error) +} + +// A value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// Unlike the usual behavior, unmarshalling onto a value that implements a +// Setter interface will NOT reset the value to its zero state. This allows +// the value to decide by itself how to be unmarshalled. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.Raw) error { +// return raw.Unmarshal(s) +// } +// +type Setter interface { + SetBSON(raw Raw) error +} + +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + +// M is a convenient alias for a map[string]interface{} map, useful for +// dealing with BSON in a native way. For instance: +// +// bson.M{"a": 1, "b": true} +// +// There's no special handling for this type in addition to what's done anyway +// for an equivalent map type. Elements in the map will be dumped in an +// undefined ordered. See also the bson.D type for an ordered alternative. +type M map[string]interface{} + +// D represents a BSON document containing ordered elements. For example: +// +// bson.D{{"a", 1}, {"b", true}} +// +// In some situations, such as when creating indexes for MongoDB, the order in +// which the elements are defined is important. If the order is not important, +// using a map is generally more comfortable. See bson.M and bson.RawD. +type D []DocElem + +// DocElem is an element of the bson.D document representation. +type DocElem struct { + Name string + Value interface{} +} + +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + +// The Raw type represents raw unprocessed BSON documents and elements. +// Kind is the kind of element as defined per the BSON specification, and +// Data is the raw unprocessed data for the respective element. +// Using this type it is possible to unmarshal or marshal values partially. +// +// Relevant documentation: +// +// http://bsonspec.org/#/specification +// +type Raw struct { + Kind byte + Data []byte +} + +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw +} + +// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes +// long. MongoDB objects by default have such a property set in their "_id" +// property. +// +// http://www.mongodb.org/display/DOCS/Object+IDs +type ObjectId string + +// ObjectIdHex returns an ObjectId from the provided hex representation. +// Calling this function with an invalid hex representation will +// cause a runtime panic. See the IsObjectIdHex function. +func ObjectIdHex(s string) ObjectId { + d, err := hex.DecodeString(s) + if err != nil || len(d) != 12 { + panic(fmt.Sprintf("invalid input to ObjectIdHex: %q", s)) + } + return ObjectId(d) +} + +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + +// objectIdCounter is atomically incremented when generating a new ObjectId +// using NewObjectId() function. It's used as a counter part of an id. +var objectIdCounter uint32 = readRandomUint32() + +// readRandomUint32 returns a random objectIdCounter. +func readRandomUint32() uint32 { + var b [4]byte + _, err := io.ReadFull(rand.Reader, b[:]) + if err != nil { + panic(fmt.Errorf("cannot read random object id: %v", err)) + } + return uint32((uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)) +} + +// machineId stores machine id generated once and used in subsequent calls +// to NewObjectId function. +var machineId = readMachineId() + +// readMachineId generates and returns a machine id. +// If this function fails to get the hostname it will cause a runtime error. +func readMachineId() []byte { + var sum [3]byte + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +// NewObjectId returns a new unique ObjectId. +func NewObjectId() ObjectId { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineId[0] + b[5] = machineId[1] + b[6] = machineId[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + pid := os.Getpid() + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&objectIdCounter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return ObjectId(b[:]) +} + +// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled +// with the provided number of seconds from epoch UTC, and all other parts +// filled with zeroes. It's not safe to insert a document with an id generated +// by this method, it is useful only for queries to find documents with ids +// generated before or after the specified timestamp. +func NewObjectIdWithTime(t time.Time) ObjectId { + var b [12]byte + binary.BigEndian.PutUint32(b[:4], uint32(t.Unix())) + return ObjectId(string(b[:])) +} + +// String returns a hex string representation of the id. +// Example: ObjectIdHex("4d88e15b60f486e428412dc9"). +func (id ObjectId) String() string { + return fmt.Sprintf(`ObjectIdHex("%x")`, string(id)) +} + +// Hex returns a hex representation of the ObjectId. +func (id ObjectId) Hex() string { + return hex.EncodeToString([]byte(id)) +} + +// MarshalJSON turns a bson.ObjectId into a json.Marshaller. +func (id ObjectId) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%x"`, string(id))), nil +} + +var nullBytes = []byte("null") + +// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller. +func (id *ObjectId) UnmarshalJSON(data []byte) error { + if len(data) == 2 && data[0] == '"' && data[1] == '"' || bytes.Equal(data, nullBytes) { + *id = "" + return nil + } + if len(data) != 26 || data[0] != '"' || data[25] != '"' { + return errors.New(fmt.Sprintf("invalid ObjectId in JSON: %s", string(data))) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[1:25]) + if err != nil { + return errors.New(fmt.Sprintf("invalid ObjectId in JSON: %s (%s)", string(data), err)) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// MarshalText turns bson.ObjectId into an encoding.TextMarshaler. +func (id ObjectId) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("%x", string(id))), nil +} + +// UnmarshalText turns *bson.ObjectId into an encoding.TextUnmarshaler. +func (id *ObjectId) UnmarshalText(data []byte) error { + if len(data) == 1 && data[0] == ' ' || len(data) == 0 { + *id = "" + return nil + } + if len(data) != 24 { + return fmt.Errorf("invalid ObjectId: %s", data) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[:]) + if err != nil { + return fmt.Errorf("invalid ObjectId: %s (%s)", data, err) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// Valid returns true if id is valid. A valid id must contain exactly 12 bytes. +func (id ObjectId) Valid() bool { + return len(id) == 12 +} + +// byteSlice returns byte slice of id from start to end. +// Calling this function with an invalid id will cause a runtime panic. +func (id ObjectId) byteSlice(start, end int) []byte { + if len(id) != 12 { + panic(fmt.Sprintf("invalid ObjectId: %q", string(id))) + } + return []byte(string(id)[start:end]) +} + +// Time returns the timestamp part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Time() time.Time { + // First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch. + secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4))) + return time.Unix(secs, 0) +} + +// Machine returns the 3-byte machine id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Machine() []byte { + return id.byteSlice(4, 7) +} + +// Pid returns the process id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Pid() uint16 { + return binary.BigEndian.Uint16(id.byteSlice(7, 9)) +} + +// Counter returns the incrementing value part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Counter() int32 { + b := id.byteSlice(9, 12) + // Counter is stored as big-endian 3-byte value + return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])) +} + +// The Symbol type is similar to a string and is used in languages with a +// distinct symbol type. +type Symbol string + +// Now returns the current time with millisecond precision. MongoDB stores +// timestamps with the same precision, so a Time returned from this method +// will not change after a roundtrip to the database. That's the only reason +// why this function exists. Using the time.Now function also works fine +// otherwise. +func Now() time.Time { + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) +} + +// MongoTimestamp is a special internal type used by MongoDB that for some +// strange reason has its own datatype defined in BSON. +type MongoTimestamp int64 + +type orderKey int64 + +// MaxKey is a special value that compares higher than all other possible BSON +// values in a MongoDB database. +var MaxKey = orderKey(1<<63 - 1) + +// MinKey is a special value that compares lower than all other possible BSON +// values in a MongoDB database. +var MinKey = orderKey(-1 << 63) + +type undefined struct{} + +// Undefined represents the undefined BSON value. +var Undefined undefined + +// Binary is a representation for non-standard binary values. Any kind should +// work, but the following are known as of this writing: +// +// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}. +// 0x01 - Function (!?) +// 0x02 - Obsolete generic. +// 0x03 - UUID +// 0x05 - MD5 +// 0x80 - User defined. +// +type Binary struct { + Kind byte + Data []byte +} + +// RegEx represents a regular expression. The Options field may contain +// individual characters defining the way in which the pattern should be +// applied, and must be sorted. Valid options as of this writing are 'i' for +// case insensitive matching, 'm' for multi-line matching, 'x' for verbose +// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all +// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match +// unicode. The value of the Options parameter is not verified before being +// marshaled into the BSON format. +type RegEx struct { + Pattern string + Options string +} + +// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it +// will be marshaled as a mapping from identifiers to values that may be +// used when evaluating the provided Code. +type JavaScript struct { + Code string + Scope interface{} +} + +// DBPointer refers to a document id in a namespace. +// +// This type is deprecated in the BSON specification and should not be used +// except for backwards compatibility with ancient applications. +type DBPointer struct { + Namespace string + Id ObjectId +} + +const initialBufferSize = 64 + +func handleErr(err *error) { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } else if _, ok := r.(externalPanic); ok { + panic(r) + } else if s, ok := r.(string); ok { + *err = errors.New(s) + } else if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// Marshal serializes the in value, which may be a map or a struct value. +// In the case of struct values, only exported fields will be serialized, +// and the order of serialized fields will match that of the struct itself. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[<key>][,<flag1>[,<flag2>]]" +// +// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. +// +// Some examples: +// +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := &encoder{make([]byte, 0, initialBufferSize)} + e.addDoc(reflect.ValueOf(in)) + return e.out, nil +} + +// Unmarshal deserializes data from in into the out value. The out value +// must be a map, a pointer to a struct, or a pointer to a bson.D value. +// In the case of struct values, only exported fields will be deserialized. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[<key>][,<flag1>[,<flag2>]]" +// +// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. +// +// The target field or element types of out may not necessarily match +// the BSON values of the provided data. The following conversions are +// made automatically: +// +// - Numeric types are converted if at least the integer part of the +// value would be preserved correctly +// - Bools are converted to numeric types as 1 or 0 +// - Numeric types are converted to bools as true if not 0 or false otherwise +// - Binary and string BSON data is converted to a string, array or byte slice +// +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. +func Unmarshal(in []byte, out interface{}) (err error) { + if raw, ok := out.(*Raw); ok { + raw.Kind = 3 + raw.Data = in + return nil + } + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + fallthrough + case reflect.Map: + d := newDecoder(in) + d.readDocTo(v) + case reflect.Struct: + return errors.New("Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Unmarshal needs a map or a pointer to a struct.") + } + return nil +} + +// Unmarshal deserializes raw into the out value. If the out value type +// is not compatible with raw, a *bson.TypeError is returned. +// +// See the Unmarshal function documentation for more details on the +// unmarshalling process. +func (raw Raw) Unmarshal(out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + fallthrough + case reflect.Map: + d := newDecoder(raw.Data) + good := d.readElemTo(v, raw.Kind) + if !good { + return &TypeError{v.Type(), raw.Kind} + } + case reflect.Struct: + return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Raw Unmarshal needs a map or a valid pointer.") + } + return nil +} + +type TypeError struct { + Type reflect.Type + Kind byte +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String()) +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + InlineMap int + Zero reflect.Value +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + MinSize bool + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +type externalPanic string + +func (e externalPanic) String() string { + return string(e) +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + structMapMutex.RLock() + sinfo, found := structMap[st] + structMapMutex.RUnlock() + if found { + return sinfo, nil + } + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" && !field.Anonymous { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("bson") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "minsize": + info.MinSize = true + case "inline": + inline = true + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st) + panic(externalPanic(msg)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "Duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + sinfo = &structInfo{ + fieldsMap, + fieldsList, + inlineMap, + reflect.New(st).Elem(), + } + structMapMutex.Lock() + structMap[st] = sinfo + structMapMutex.Unlock() + return sinfo, nil +} diff --git a/vendor/gopkg.in/mgo.v2/bson/decode.go b/vendor/gopkg.in/mgo.v2/bson/decode.go new file mode 100644 index 0000000..9bd73f9 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/decode.go @@ -0,0 +1,844 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "sync" + "time" +) + +type decoder struct { + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} +} + +// -------------------------------------------------------------------------- +// Some helper functions. + +func corrupted() { + panic("Document is corrupted") +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +// -------------------------------------------------------------------------- +// Unmarshaling of documents. + +const ( + setterUnknown = iota + setterNone + setterType + setterAddr +) + +var setterStyles map[reflect.Type]int +var setterIface reflect.Type +var setterMutex sync.RWMutex + +func init() { + var iface Setter + setterIface = reflect.TypeOf(&iface).Elem() + setterStyles = make(map[reflect.Type]int) +} + +func setterStyle(outt reflect.Type) int { + setterMutex.RLock() + style := setterStyles[outt] + setterMutex.RUnlock() + if style == setterUnknown { + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + setterStyles[outt] = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + setterStyles[outt] = setterAddr + } else { + setterStyles[outt] = setterNone + } + style = setterStyles[outt] + } + return style +} + +func getSetter(outt reflect.Type, out reflect.Value) Setter { + style := setterStyle(outt) + if style == setterNone { + return nil + } + if style == setterAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(Setter) +} + +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + +func (d *decoder) readDocTo(out reflect.Value) { + var elemType reflect.Type + outt := out.Type() + outk := outt.Kind() + + for { + if outk == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + if setter := getSetter(outt, out); setter != nil { + var raw Raw + d.readDocTo(reflect.ValueOf(&raw)) + err := setter.SetBSON(raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if outk == reflect.Ptr { + out = out.Elem() + outt = out.Type() + outk = out.Kind() + continue + } + break + } + + var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value + start := d.i + + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv + } + outt = out.Type() + outk = outt.Kind() + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { + case reflect.Map: + keyType = outt.Key() + if keyType.Kind() != reflect.String { + panic("BSON map must have string keys. Got: " + outt.String()) + } + if keyType != typeString { + convertKey = true + } + elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } + if out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + } else if out.Len() > 0 { + clearMap(out) + } + case reflect.Struct: + if outt != typeRaw { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + fieldsMap = sinfo.FieldsMap + out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } + } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough + default: + panic("Unsupported document type for unmarshalling: " + out.Type().String()) + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + + switch outk { + case reflect.Map: + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) + } + case reflect.Struct: + if outt == typeRaw { + d.dropElem(kind) + } else { + if info, ok := fieldsMap[name]; ok { + if info.Inline == nil { + d.readElemTo(out.Field(info.Num), kind) + } else { + d.readElemTo(out.FieldByIndex(info.Inline), kind) + } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } + } else { + d.dropElem(kind) + } + } + case reflect.Slice: + } + + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + d.docType = docType + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) + } +} + +func (d *decoder) readArrayDocTo(out reflect.Value) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + i := 0 + l := out.Len() + for d.in[d.i] != '\x00' { + if i >= l { + panic("Length mismatch on array field") + } + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + d.readElemTo(out.Index(i), kind) + if d.i >= end { + corrupted() + } + i++ + } + if i != l { + panic("Length mismatch on array field") + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +func (d *decoder) readSliceDoc(t reflect.Type) interface{} { + tmp := make([]reflect.Value, 0, 8) + elemType := t.Elem() + if elemType == typeRawDocElem { + d.dropElem(0x04) + return reflect.Zero(t).Interface() + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + tmp = append(tmp, e) + } + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + + n := len(tmp) + slice := reflect.MakeSlice(t, n, n) + for i := 0; i != n; i++ { + slice.Index(i).Set(tmp[i]) + } + return slice.Interface() +} + +var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() + +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := DocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readDocWith(f func(kind byte, name string)) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + f(kind, name) + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +// -------------------------------------------------------------------------- +// Unmarshaling of individual elements within a document. + +var blackHole = settableValueOf(struct{}{}) + +func (d *decoder) dropElem(kind byte) { + d.readElemTo(blackHole, kind) +} + +// Attempt to decode an element from the document and put it into out. +// If the types are not compatible, the returned ok value will be +// false and out will be unchanged. +func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { + + start := d.i + + if kind == 0x03 { + // Delegate unmarshaling of documents. + outt := out.Type() + outk := out.Kind() + switch outk { + case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: + d.readDocTo(out) + return true + } + if setterStyle(outt) != setterNone { + d.readDocTo(out) + return true + } + if outk == reflect.Slice { + switch outt.Elem() { + case typeDocElem: + out.Set(d.readDocElems(outt)) + case typeRawDocElem: + out.Set(d.readRawDocElems(outt)) + default: + d.readDocTo(blackHole) + } + return true + } + d.readDocTo(blackHole) + return true + } + + var in interface{} + + switch kind { + case 0x01: // Float64 + in = d.readFloat64() + case 0x02: // UTF-8 string + in = d.readStr() + case 0x03: // Document + panic("Can't happen. Handled above.") + case 0x04: // Array + outt := out.Type() + if setterStyle(outt) != setterNone { + // Skip the value so its data is handed to the setter below. + d.dropElem(kind) + break + } + for outt.Kind() == reflect.Ptr { + outt = outt.Elem() + } + switch outt.Kind() { + case reflect.Array: + d.readArrayDocTo(out) + return true + case reflect.Slice: + in = d.readSliceDoc(outt) + default: + in = d.readSliceDoc(typeSlice) + } + case 0x05: // Binary + b := d.readBinary() + if b.Kind == 0x00 || b.Kind == 0x02 { + in = b.Data + } else { + in = b + } + case 0x06: // Undefined (obsolete, but still seen in the wild) + in = Undefined + case 0x07: // ObjectId + in = ObjectId(d.readBytes(12)) + case 0x08: // Bool + in = d.readBool() + case 0x09: // Timestamp + // MongoDB handles timestamps as milliseconds. + i := d.readInt64() + if i == -62135596800000 { + in = time.Time{} // In UTC for convenience. + } else { + in = time.Unix(i/1e3, i%1e3*1e6) + } + case 0x0A: // Nil + in = nil + case 0x0B: // RegEx + in = d.readRegEx() + case 0x0C: + in = DBPointer{Namespace: d.readStr(), Id: ObjectId(d.readBytes(12))} + case 0x0D: // JavaScript without scope + in = JavaScript{Code: d.readStr()} + case 0x0E: // Symbol + in = Symbol(d.readStr()) + case 0x0F: // JavaScript with scope + d.i += 4 // Skip length + js := JavaScript{d.readStr(), make(M)} + d.readDocTo(reflect.ValueOf(js.Scope)) + in = js + case 0x10: // Int32 + in = int(d.readInt32()) + case 0x11: // Mongo-specific timestamp + in = MongoTimestamp(d.readInt64()) + case 0x12: // Int64 + in = d.readInt64() + case 0x7F: // Max key + in = MaxKey + case 0xFF: // Min key + in = MinKey + default: + panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + } + + outt := out.Type() + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) + return true + } + + if setter := getSetter(outt, out); setter != nil { + err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + + if in == nil { + out.Set(reflect.Zero(outt)) + return true + } + + outk := outt.Kind() + + // Dereference and initialize pointer if necessary. + first := true + for outk == reflect.Ptr { + if !out.IsNil() { + out = out.Elem() + } else { + elem := reflect.New(outt.Elem()) + if first { + // Only set if value is compatible. + first = false + defer func(out, elem reflect.Value) { + if good { + out.Set(elem) + } + }(out, elem) + } else { + out.Set(elem) + } + out = elem + } + outt = out.Type() + outk = outt.Kind() + } + + inv := reflect.ValueOf(in) + if outt == inv.Type() { + out.Set(inv) + return true + } + + switch outk { + case reflect.Interface: + out.Set(inv) + return true + case reflect.String: + switch inv.Kind() { + case reflect.String: + out.SetString(inv.String()) + return true + case reflect.Slice: + if b, ok := in.([]byte); ok { + out.SetString(string(b)) + return true + } + case reflect.Int, reflect.Int64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatInt(inv.Int(), 10)) + return true + } + case reflect.Float64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatFloat(inv.Float(), 'f', -1, 64)) + return true + } + } + case reflect.Slice, reflect.Array: + // Remember, array (0x04) slices are built with the correct + // element type. If we are here, must be a cross BSON kind + // conversion (e.g. 0x05 unmarshalling on string). + if outt.Elem().Kind() != reflect.Uint8 { + break + } + switch inv.Kind() { + case reflect.String: + slice := []byte(inv.String()) + out.Set(reflect.ValueOf(slice)) + return true + case reflect.Slice: + switch outt.Kind() { + case reflect.Array: + reflect.Copy(out, inv) + case reflect.Slice: + out.SetBytes(inv.Bytes()) + } + return true + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetInt(inv.Int()) + return true + case reflect.Float32, reflect.Float64: + out.SetInt(int64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetInt(1) + } else { + out.SetInt(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("can't happen: no uint types in BSON (!?)") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetUint(uint64(inv.Int())) + return true + case reflect.Float32, reflect.Float64: + out.SetUint(uint64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetUint(1) + } else { + out.SetUint(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON.") + } + case reflect.Float32, reflect.Float64: + switch inv.Kind() { + case reflect.Float32, reflect.Float64: + out.SetFloat(inv.Float()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetFloat(float64(inv.Int())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetFloat(1) + } else { + out.SetFloat(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Bool: + switch inv.Kind() { + case reflect.Bool: + out.SetBool(inv.Bool()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetBool(inv.Int() != 0) + return true + case reflect.Float32, reflect.Float64: + out.SetBool(inv.Float() != 0) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Struct: + if outt == typeURL && inv.Kind() == reflect.String { + u, err := url.Parse(inv.String()) + if err != nil { + panic(err) + } + out.Set(reflect.ValueOf(u).Elem()) + return true + } + if outt == typeBinary { + if b, ok := in.([]byte); ok { + out.Set(reflect.ValueOf(Binary{Data: b})) + return true + } + } + } + + return false +} + +// -------------------------------------------------------------------------- +// Parsers of basic types. + +func (d *decoder) readRegEx() RegEx { + re := RegEx{} + re.Pattern = d.readCStr() + re.Options = d.readCStr() + return re +} + +func (d *decoder) readBinary() Binary { + l := d.readInt32() + b := Binary{} + b.Kind = d.readByte() + b.Data = d.readBytes(l) + if b.Kind == 0x02 && len(b.Data) >= 4 { + // Weird obsolete format with redundant length. + b.Data = b.Data[4:] + } + return b +} + +func (d *decoder) readStr() string { + l := d.readInt32() + b := d.readBytes(l - 1) + if d.readByte() != '\x00' { + corrupted() + } + return string(b) +} + +func (d *decoder) readCStr() string { + start := d.i + end := start + l := len(d.in) + for ; end != l; end++ { + if d.in[end] == '\x00' { + break + } + } + d.i = end + 1 + if d.i > l { + corrupted() + } + return string(d.in[start:end]) +} + +func (d *decoder) readBool() bool { + b := d.readByte() + if b == 0 { + return false + } + if b == 1 { + return true + } + panic(fmt.Sprintf("encoded boolean must be 1 or 0, found %d", b)) +} + +func (d *decoder) readFloat64() float64 { + return math.Float64frombits(uint64(d.readInt64())) +} + +func (d *decoder) readInt32() int32 { + b := d.readBytes(4) + return int32((uint32(b[0]) << 0) | + (uint32(b[1]) << 8) | + (uint32(b[2]) << 16) | + (uint32(b[3]) << 24)) +} + +func (d *decoder) readInt64() int64 { + b := d.readBytes(8) + return int64((uint64(b[0]) << 0) | + (uint64(b[1]) << 8) | + (uint64(b[2]) << 16) | + (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | + (uint64(b[5]) << 40) | + (uint64(b[6]) << 48) | + (uint64(b[7]) << 56)) +} + +func (d *decoder) readByte() byte { + i := d.i + d.i++ + if d.i > len(d.in) { + corrupted() + } + return d.in[i] +} + +func (d *decoder) readBytes(length int32) []byte { + if length < 0 { + corrupted() + } + start := d.i + d.i += int(length) + if d.i < start || d.i > len(d.in) { + corrupted() + } + return d.in[start : start+int(length)] +} diff --git a/vendor/gopkg.in/mgo.v2/bson/encode.go b/vendor/gopkg.in/mgo.v2/bson/encode.go new file mode 100644 index 0000000..c228e28 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/encode.go @@ -0,0 +1,509 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "encoding/json" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" +) + +// -------------------------------------------------------------------------- +// Some internal infrastructure. + +var ( + typeBinary = reflect.TypeOf(Binary{}) + typeObjectId = reflect.TypeOf(ObjectId("")) + typeDBPointer = reflect.TypeOf(DBPointer{"", ObjectId("")}) + typeSymbol = reflect.TypeOf(Symbol("")) + typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) + typeOrderKey = reflect.TypeOf(MinKey) + typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) + typeRaw = reflect.TypeOf(Raw{}) + typeURL = reflect.TypeOf(url.URL{}) + typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") + typeJSONNumber = reflect.TypeOf(json.Number("")) +) + +const itoaCacheSize = 32 + +var itoaCache []string + +func init() { + itoaCache = make([]string, itoaCacheSize) + for i := 0; i != itoaCacheSize; i++ { + itoaCache[i] = strconv.Itoa(i) + } +} + +func itoa(i int) string { + if i < itoaCacheSize { + return itoaCache[i] + } + return strconv.Itoa(i) +} + +// -------------------------------------------------------------------------- +// Marshaling of the document value itself. + +type encoder struct { + out []byte +} + +func (e *encoder) addDoc(v reflect.Value) { + for { + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + continue + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + continue + } + break + } + + if v.Type() == typeRaw { + raw := v.Interface().(Raw) + if raw.Kind != 0x03 && raw.Kind != 0x00 { + panic("Attempted to marshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document") + } + if len(raw.Data) == 0 { + panic("Attempted to marshal empty Raw document") + } + e.addBytes(raw.Data...) + return + } + + start := e.reserveInt32() + + switch v.Kind() { + case reflect.Map: + e.addMap(v) + case reflect.Struct: + e.addStruct(v) + case reflect.Array, reflect.Slice: + e.addSlice(v) + default: + panic("Can't marshal " + v.Type().String() + " as a BSON document") + } + + e.addBytes(0) + e.setInt32(start, int32(len(e.out)-start)) +} + +func (e *encoder) addMap(v reflect.Value) { + for _, k := range v.MapKeys() { + e.addElem(k.String(), v.MapIndex(k), false) + } +} + +func (e *encoder) addStruct(v reflect.Value) { + sinfo, err := getStructInfo(v.Type()) + if err != nil { + panic(err) + } + var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } + for _, info := range sinfo.FieldsList { + if info.Inline == nil { + value = v.Field(info.Num) + } else { + value = v.FieldByIndex(info.Inline) + } + if info.OmitEmpty && isZero(value) { + continue + } + e.addElem(info.Key, value, info.MinSize) + } +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return len(v.String()) == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + vt := v.Type() + if vt == typeTime { + return v.Interface().(time.Time).IsZero() + } + for i := 0; i < v.NumField(); i++ { + if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { + continue // Private field + } + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} + +func (e *encoder) addSlice(v reflect.Value) { + vi := v.Interface() + if d, ok := vi.(D); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) + } +} + +// -------------------------------------------------------------------------- +// Marshaling of elements in a document. + +func (e *encoder) addElemName(kind byte, name string) { + e.addBytes(kind) + e.addBytes([]byte(name)...) + e.addBytes(0) +} + +func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { + + if !v.IsValid() { + e.addElemName('\x0A', name) + return + } + + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + e.addElem(name, reflect.ValueOf(getv), minSize) + return + } + + switch v.Kind() { + + case reflect.Interface: + e.addElem(name, v.Elem(), minSize) + + case reflect.Ptr: + e.addElem(name, v.Elem(), minSize) + + case reflect.String: + s := v.String() + switch v.Type() { + case typeObjectId: + if len(s) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s)) + ")") + } + e.addElemName('\x07', name) + e.addBytes([]byte(s)...) + case typeSymbol: + e.addElemName('\x0E', name) + e.addStr(s) + case typeJSONNumber: + n := v.Interface().(json.Number) + if i, err := n.Int64(); err == nil { + e.addElemName('\x12', name) + e.addInt64(i) + } else if f, err := n.Float64(); err == nil { + e.addElemName('\x01', name) + e.addFloat64(f) + } else { + panic("failed to convert json.Number to a number: " + s) + } + default: + e.addElemName('\x02', name) + e.addStr(s) + } + + case reflect.Float32, reflect.Float64: + e.addElemName('\x01', name) + e.addFloat64(v.Float()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + u := v.Uint() + if int64(u) < 0 { + panic("BSON has no uint64 type, and value is too large to fit correctly in an int64") + } else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) { + e.addElemName('\x10', name) + e.addInt32(int32(u)) + } else { + e.addElemName('\x12', name) + e.addInt64(int64(u)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) + + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } + + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) + } + } + + case reflect.Bool: + e.addElemName('\x08', name) + if v.Bool() { + e.addBytes(1) + } else { + e.addBytes(0) + } + + case reflect.Map: + e.addElemName('\x03', name) + e.addDoc(v) + + case reflect.Slice: + vt := v.Type() + et := vt.Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Bytes()) + } else if et == typeDocElem || et == typeRawDocElem { + e.addElemName('\x03', name) + e.addDoc(v) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Array: + et := v.Type().Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + if v.CanAddr() { + e.addBinary('\x00', v.Slice(0, v.Len()).Interface().([]byte)) + } else { + n := v.Len() + e.addInt32(int32(n)) + e.addBytes('\x00') + for i := 0; i < n; i++ { + el := v.Index(i) + e.addBytes(byte(el.Uint())) + } + } + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Struct: + switch s := v.Interface().(type) { + + case Raw: + kind := s.Kind + if kind == 0x00 { + kind = 0x03 + } + if len(s.Data) == 0 && kind != 0x06 && kind != 0x0A && kind != 0xFF && kind != 0x7F { + panic("Attempted to marshal empty Raw document") + } + e.addElemName(kind, name) + e.addBytes(s.Data...) + + case Binary: + e.addElemName('\x05', name) + e.addBinary(s.Kind, s.Data) + + case DBPointer: + e.addElemName('\x0C', name) + e.addStr(s.Namespace) + if len(s.Id) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s.Id)) + ")") + } + e.addBytes([]byte(s.Id)...) + + case RegEx: + e.addElemName('\x0B', name) + e.addCStr(s.Pattern) + e.addCStr(s.Options) + + case JavaScript: + if s.Scope == nil { + e.addElemName('\x0D', name) + e.addStr(s.Code) + } else { + e.addElemName('\x0F', name) + start := e.reserveInt32() + e.addStr(s.Code) + e.addDoc(reflect.ValueOf(s.Scope)) + e.setInt32(start, int32(len(e.out)-start)) + } + + case time.Time: + // MongoDB handles timestamps as milliseconds. + e.addElemName('\x09', name) + e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6)) + + case url.URL: + e.addElemName('\x02', name) + e.addStr(s.String()) + + case undefined: + e.addElemName('\x06', name) + + default: + e.addElemName('\x03', name) + e.addDoc(v) + } + + default: + panic("Can't marshal " + v.Type().String() + " in a BSON document") + } +} + +// -------------------------------------------------------------------------- +// Marshaling of base types. + +func (e *encoder) addBinary(subtype byte, v []byte) { + if subtype == 0x02 { + // Wonder how that brilliant idea came to life. Obsolete, luckily. + e.addInt32(int32(len(v) + 4)) + e.addBytes(subtype) + e.addInt32(int32(len(v))) + } else { + e.addInt32(int32(len(v))) + e.addBytes(subtype) + } + e.addBytes(v...) +} + +func (e *encoder) addStr(v string) { + e.addInt32(int32(len(v) + 1)) + e.addCStr(v) +} + +func (e *encoder) addCStr(v string) { + e.addBytes([]byte(v)...) + e.addBytes(0) +} + +func (e *encoder) reserveInt32() (pos int) { + pos = len(e.out) + e.addBytes(0, 0, 0, 0) + return pos +} + +func (e *encoder) setInt32(pos int, v int32) { + e.out[pos+0] = byte(v) + e.out[pos+1] = byte(v >> 8) + e.out[pos+2] = byte(v >> 16) + e.out[pos+3] = byte(v >> 24) +} + +func (e *encoder) addInt32(v int32) { + u := uint32(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24)) +} + +func (e *encoder) addInt64(v int64) { + u := uint64(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24), + byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56)) +} + +func (e *encoder) addFloat64(v float64) { + e.addInt64(int64(math.Float64bits(v))) +} + +func (e *encoder) addBytes(v ...byte) { + e.out = append(e.out, v...) +} diff --git a/vendor/gopkg.in/mgo.v2/bulk.go b/vendor/gopkg.in/mgo.v2/bulk.go new file mode 100644 index 0000000..072a520 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bulk.go @@ -0,0 +1,351 @@ +package mgo + +import ( + "bytes" + "sort" + + "gopkg.in/mgo.v2/bson" +) + +// Bulk represents an operation that can be prepared with several +// orthogonal changes before being delivered to the server. +// +// MongoDB servers older than version 2.6 do not have proper support for bulk +// operations, so the driver attempts to map its API as much as possible into +// the functionality that works. In particular, in those releases updates and +// removals are sent individually, and inserts are sent in bulk but have +// suboptimal error reporting compared to more recent versions of the server. +// See the documentation of BulkErrorCase for details on that. +// +// Relevant documentation: +// +// http://blog.mongodb.org/post/84922794768/mongodbs-new-bulk-api +// +type Bulk struct { + c *Collection + opcount int + actions []bulkAction + ordered bool +} + +type bulkOp int + +const ( + bulkInsert bulkOp = iota + 1 + bulkUpdate + bulkUpdateAll + bulkRemove +) + +type bulkAction struct { + op bulkOp + docs []interface{} + idxs []int +} + +type bulkUpdateOp []interface{} +type bulkDeleteOp []interface{} + +// BulkResult holds the results for a bulk operation. +type BulkResult struct { + Matched int + Modified int // Available only for MongoDB 2.6+ + + // Be conservative while we understand exactly how to report these + // results in a useful and convenient way, and also how to emulate + // them with prior servers. + private bool +} + +// BulkError holds an error returned from running a Bulk operation. +// Individual errors may be obtained and inspected via the Cases method. +type BulkError struct { + ecases []BulkErrorCase +} + +func (e *BulkError) Error() string { + if len(e.ecases) == 0 { + return "invalid BulkError instance: no errors" + } + if len(e.ecases) == 1 { + return e.ecases[0].Err.Error() + } + msgs := make([]string, 0, len(e.ecases)) + seen := make(map[string]bool) + for _, ecase := range e.ecases { + msg := ecase.Err.Error() + if !seen[msg] { + seen[msg] = true + msgs = append(msgs, msg) + } + } + if len(msgs) == 1 { + return msgs[0] + } + var buf bytes.Buffer + buf.WriteString("multiple errors in bulk operation:\n") + for _, msg := range msgs { + buf.WriteString(" - ") + buf.WriteString(msg) + buf.WriteByte('\n') + } + return buf.String() +} + +type bulkErrorCases []BulkErrorCase + +func (slice bulkErrorCases) Len() int { return len(slice) } +func (slice bulkErrorCases) Less(i, j int) bool { return slice[i].Index < slice[j].Index } +func (slice bulkErrorCases) Swap(i, j int) { slice[i], slice[j] = slice[j], slice[i] } + +// BulkErrorCase holds an individual error found while attempting a single change +// within a bulk operation, and the position in which it was enqueued. +// +// MongoDB servers older than version 2.6 do not have proper support for bulk +// operations, so the driver attempts to map its API as much as possible into +// the functionality that works. In particular, only the last error is reported +// for bulk inserts and without any positional information, so the Index +// field is set to -1 in these cases. +type BulkErrorCase struct { + Index int // Position of operation that failed, or -1 if unknown. + Err error +} + +// Cases returns all individual errors found while attempting the requested changes. +// +// See the documentation of BulkErrorCase for limitations in older MongoDB releases. +func (e *BulkError) Cases() []BulkErrorCase { + return e.ecases +} + +// Bulk returns a value to prepare the execution of a bulk operation. +func (c *Collection) Bulk() *Bulk { + return &Bulk{c: c, ordered: true} +} + +// Unordered puts the bulk operation in unordered mode. +// +// In unordered mode the indvidual operations may be sent +// out of order, which means latter operations may proceed +// even if prior ones have failed. +func (b *Bulk) Unordered() { + b.ordered = false +} + +func (b *Bulk) action(op bulkOp, opcount int) *bulkAction { + var action *bulkAction + if len(b.actions) > 0 && b.actions[len(b.actions)-1].op == op { + action = &b.actions[len(b.actions)-1] + } else if !b.ordered { + for i := range b.actions { + if b.actions[i].op == op { + action = &b.actions[i] + break + } + } + } + if action == nil { + b.actions = append(b.actions, bulkAction{op: op}) + action = &b.actions[len(b.actions)-1] + } + for i := 0; i < opcount; i++ { + action.idxs = append(action.idxs, b.opcount) + b.opcount++ + } + return action +} + +// Insert queues up the provided documents for insertion. +func (b *Bulk) Insert(docs ...interface{}) { + action := b.action(bulkInsert, len(docs)) + action.docs = append(action.docs, docs...) +} + +// Remove queues up the provided selectors for removing matching documents. +// Each selector will remove only a single matching document. +func (b *Bulk) Remove(selectors ...interface{}) { + action := b.action(bulkRemove, len(selectors)) + for _, selector := range selectors { + if selector == nil { + selector = bson.D{} + } + action.docs = append(action.docs, &deleteOp{ + Collection: b.c.FullName, + Selector: selector, + Flags: 1, + Limit: 1, + }) + } +} + +// RemoveAll queues up the provided selectors for removing all matching documents. +// Each selector will remove all matching documents. +func (b *Bulk) RemoveAll(selectors ...interface{}) { + action := b.action(bulkRemove, len(selectors)) + for _, selector := range selectors { + if selector == nil { + selector = bson.D{} + } + action.docs = append(action.docs, &deleteOp{ + Collection: b.c.FullName, + Selector: selector, + Flags: 0, + Limit: 0, + }) + } +} + +// Update queues up the provided pairs of updating instructions. +// The first element of each pair selects which documents must be +// updated, and the second element defines how to update it. +// Each pair matches exactly one document for updating at most. +func (b *Bulk) Update(pairs ...interface{}) { + if len(pairs)%2 != 0 { + panic("Bulk.Update requires an even number of parameters") + } + action := b.action(bulkUpdate, len(pairs)/2) + for i := 0; i < len(pairs); i += 2 { + selector := pairs[i] + if selector == nil { + selector = bson.D{} + } + action.docs = append(action.docs, &updateOp{ + Collection: b.c.FullName, + Selector: selector, + Update: pairs[i+1], + }) + } +} + +// UpdateAll queues up the provided pairs of updating instructions. +// The first element of each pair selects which documents must be +// updated, and the second element defines how to update it. +// Each pair updates all documents matching the selector. +func (b *Bulk) UpdateAll(pairs ...interface{}) { + if len(pairs)%2 != 0 { + panic("Bulk.UpdateAll requires an even number of parameters") + } + action := b.action(bulkUpdate, len(pairs)/2) + for i := 0; i < len(pairs); i += 2 { + selector := pairs[i] + if selector == nil { + selector = bson.D{} + } + action.docs = append(action.docs, &updateOp{ + Collection: b.c.FullName, + Selector: selector, + Update: pairs[i+1], + Flags: 2, + Multi: true, + }) + } +} + +// Upsert queues up the provided pairs of upserting instructions. +// The first element of each pair selects which documents must be +// updated, and the second element defines how to update it. +// Each pair matches exactly one document for updating at most. +func (b *Bulk) Upsert(pairs ...interface{}) { + if len(pairs)%2 != 0 { + panic("Bulk.Update requires an even number of parameters") + } + action := b.action(bulkUpdate, len(pairs)/2) + for i := 0; i < len(pairs); i += 2 { + selector := pairs[i] + if selector == nil { + selector = bson.D{} + } + action.docs = append(action.docs, &updateOp{ + Collection: b.c.FullName, + Selector: selector, + Update: pairs[i+1], + Flags: 1, + Upsert: true, + }) + } +} + +// Run runs all the operations queued up. +// +// If an error is reported on an unordered bulk operation, the error value may +// be an aggregation of all issues observed. As an exception to that, Insert +// operations running on MongoDB versions prior to 2.6 will report the last +// error only due to a limitation in the wire protocol. +func (b *Bulk) Run() (*BulkResult, error) { + var result BulkResult + var berr BulkError + var failed bool + for i := range b.actions { + action := &b.actions[i] + var ok bool + switch action.op { + case bulkInsert: + ok = b.runInsert(action, &result, &berr) + case bulkUpdate: + ok = b.runUpdate(action, &result, &berr) + case bulkRemove: + ok = b.runRemove(action, &result, &berr) + default: + panic("unknown bulk operation") + } + if !ok { + failed = true + if b.ordered { + break + } + } + } + if failed { + sort.Sort(bulkErrorCases(berr.ecases)) + return nil, &berr + } + return &result, nil +} + +func (b *Bulk) runInsert(action *bulkAction, result *BulkResult, berr *BulkError) bool { + op := &insertOp{b.c.FullName, action.docs, 0} + if !b.ordered { + op.flags = 1 // ContinueOnError + } + lerr, err := b.c.writeOp(op, b.ordered) + return b.checkSuccess(action, berr, lerr, err) +} + +func (b *Bulk) runUpdate(action *bulkAction, result *BulkResult, berr *BulkError) bool { + lerr, err := b.c.writeOp(bulkUpdateOp(action.docs), b.ordered) + if lerr != nil { + result.Matched += lerr.N + result.Modified += lerr.modified + } + return b.checkSuccess(action, berr, lerr, err) +} + +func (b *Bulk) runRemove(action *bulkAction, result *BulkResult, berr *BulkError) bool { + lerr, err := b.c.writeOp(bulkDeleteOp(action.docs), b.ordered) + if lerr != nil { + result.Matched += lerr.N + result.Modified += lerr.modified + } + return b.checkSuccess(action, berr, lerr, err) +} + +func (b *Bulk) checkSuccess(action *bulkAction, berr *BulkError, lerr *LastError, err error) bool { + if lerr != nil && len(lerr.ecases) > 0 { + for i := 0; i < len(lerr.ecases); i++ { + // Map back from the local error index into the visible one. + ecase := lerr.ecases[i] + idx := ecase.Index + if idx >= 0 { + idx = action.idxs[idx] + } + berr.ecases = append(berr.ecases, BulkErrorCase{idx, ecase.Err}) + } + return false + } else if err != nil { + for i := 0; i < len(action.idxs); i++ { + berr.ecases = append(berr.ecases, BulkErrorCase{action.idxs[i], err}) + } + return false + } + return true +} diff --git a/vendor/gopkg.in/mgo.v2/cluster.go b/vendor/gopkg.in/mgo.v2/cluster.go new file mode 100644 index 0000000..e28af5b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/cluster.go @@ -0,0 +1,679 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo cluster encapsulation. +// +// A cluster enables the communication with one or more servers participating +// in a mongo cluster. This works with individual servers, a replica set, +// a replica pair, one or multiple mongos routers, etc. + +type mongoCluster struct { + sync.RWMutex + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + direct bool + failFast bool + syncCount uint + setName string + cachedIndex map[string]bool + sync chan bool + dial dialer +} + +func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string) *mongoCluster { + cluster := &mongoCluster{ + userSeeds: userSeeds, + references: 1, + direct: direct, + failFast: failFast, + dial: dial, + setName: setName, + } + cluster.serverSynced.L = cluster.RWMutex.RLocker() + cluster.sync = make(chan bool, 1) + stats.cluster(+1) + go cluster.syncServersLoop() + return cluster +} + +// Acquire increases the reference count for the cluster. +func (cluster *mongoCluster) Acquire() { + cluster.Lock() + cluster.references++ + debugf("Cluster %p acquired (refs=%d)", cluster, cluster.references) + cluster.Unlock() +} + +// Release decreases the reference count for the cluster. Once +// it reaches zero, all servers will be closed. +func (cluster *mongoCluster) Release() { + cluster.Lock() + if cluster.references == 0 { + panic("cluster.Release() with references == 0") + } + cluster.references-- + debugf("Cluster %p released (refs=%d)", cluster, cluster.references) + if cluster.references == 0 { + for _, server := range cluster.servers.Slice() { + server.Close() + } + // Wake up the sync loop so it can die. + cluster.syncServers() + stats.cluster(-1) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) LiveServers() (servers []string) { + cluster.RLock() + for _, serv := range cluster.servers.Slice() { + servers = append(servers, serv.Addr) + } + cluster.RUnlock() + return servers +} + +func (cluster *mongoCluster) removeServer(server *mongoServer) { + cluster.Lock() + cluster.masters.Remove(server) + other := cluster.servers.Remove(server) + cluster.Unlock() + if other != nil { + other.Close() + log("Removed server ", server.Addr, " from cluster.") + } + server.Close() +} + +type isMasterResult struct { + IsMaster bool + Secondary bool + Primary string + Hosts []string + Passives []string + Tags bson.D + Msg string + SetName string `bson:"setName"` + MaxWireVersion int `bson:"maxWireVersion"` +} + +func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { + // Monotonic let's it talk to a slave and still hold the socket. + session := newSession(Monotonic, cluster, 10*time.Second) + session.setSocket(socket) + err := session.Run("ismaster", result) + session.Close() + return err +} + +type possibleTimeout interface { + Timeout() bool +} + +var syncSocketTimeout = 5 * time.Second + +func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { + var syncTimeout time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + syncTimeout = syncSocketTimeout + globalMutex.Unlock() + } else { + syncTimeout = syncSocketTimeout + } + + addr := server.Addr + log("SYNC Processing ", addr, "...") + + // Retry a few times to avoid knocking a server down for a hiccup. + var result isMasterResult + var tryerr error + for retry := 0; ; retry++ { + if retry == 3 || retry == 1 && cluster.failFast { + return nil, nil, tryerr + } + if retry > 0 { + // Don't abuse the server needlessly if there's something actually wrong. + if err, ok := tryerr.(possibleTimeout); ok && err.Timeout() { + // Give a chance for waiters to timeout as well. + cluster.serverSynced.Broadcast() + } + time.Sleep(syncShortDelay) + } + + // It's not clear what would be a good timeout here. Is it + // better to wait longer or to retry? + socket, _, err := server.AcquireSocket(0, syncTimeout) + if err != nil { + tryerr = err + logf("SYNC Failed to get socket to %s: %v", addr, err) + continue + } + err = cluster.isMaster(socket, &result) + socket.Release() + if err != nil { + tryerr = err + logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) + continue + } + debugf("SYNC Result of 'ismaster' from %s: %#v", addr, result) + break + } + + if cluster.setName != "" && result.SetName != cluster.setName { + logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName) + return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName) + } + + if result.IsMaster { + debugf("SYNC %s is a master.", addr) + if !server.info.Master { + // Made an incorrect assumption above, so fix stats. + stats.conn(-1, false) + stats.conn(+1, true) + } + } else if result.Secondary { + debugf("SYNC %s is a slave.", addr) + } else if cluster.direct { + logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) + } else { + logf("SYNC %s is neither a master nor a slave.", addr) + // Let stats track it as whatever was known before. + return nil, nil, errors.New(addr + " is not a master nor slave") + } + + info = &mongoServerInfo{ + Master: result.IsMaster, + Mongos: result.Msg == "isdbgrid", + Tags: result.Tags, + SetName: result.SetName, + MaxWireVersion: result.MaxWireVersion, + } + + hosts = make([]string, 0, 1+len(result.Hosts)+len(result.Passives)) + if result.Primary != "" { + // First in the list to speed up master discovery. + hosts = append(hosts, result.Primary) + } + hosts = append(hosts, result.Hosts...) + hosts = append(hosts, result.Passives...) + + debugf("SYNC %s knows about the following peers: %#v", addr, hosts) + return info, hosts, nil +} + +type syncKind bool + +const ( + completeSync syncKind = true + partialSync syncKind = false +) + +func (cluster *mongoCluster) addServer(server *mongoServer, info *mongoServerInfo, syncKind syncKind) { + cluster.Lock() + current := cluster.servers.Search(server.ResolvedAddr) + if current == nil { + if syncKind == partialSync { + cluster.Unlock() + server.Close() + log("SYNC Discarding unknown server ", server.Addr, " due to partial sync.") + return + } + cluster.servers.Add(server) + if info.Master { + cluster.masters.Add(server) + log("SYNC Adding ", server.Addr, " to cluster as a master.") + } else { + log("SYNC Adding ", server.Addr, " to cluster as a slave.") + } + } else { + if server != current { + panic("addServer attempting to add duplicated server") + } + if server.Info().Master != info.Master { + if info.Master { + log("SYNC Server ", server.Addr, " is now a master.") + cluster.masters.Add(server) + } else { + log("SYNC Server ", server.Addr, " is now a slave.") + cluster.masters.Remove(server) + } + } + } + server.SetInfo(info) + debugf("SYNC Broadcasting availability of server %s", server.Addr) + cluster.serverSynced.Broadcast() + cluster.Unlock() +} + +func (cluster *mongoCluster) getKnownAddrs() []string { + cluster.RLock() + max := len(cluster.userSeeds) + len(cluster.dynaSeeds) + cluster.servers.Len() + seen := make(map[string]bool, max) + known := make([]string, 0, max) + + add := func(addr string) { + if _, found := seen[addr]; !found { + seen[addr] = true + known = append(known, addr) + } + } + + for _, addr := range cluster.userSeeds { + add(addr) + } + for _, addr := range cluster.dynaSeeds { + add(addr) + } + for _, serv := range cluster.servers.Slice() { + add(serv.Addr) + } + cluster.RUnlock() + + return known +} + +// syncServers injects a value into the cluster.sync channel to force +// an iteration of the syncServersLoop function. +func (cluster *mongoCluster) syncServers() { + select { + case cluster.sync <- true: + default: + } +} + +// How long to wait for a checkup of the cluster topology if nothing +// else kicks a synchronization before that. +const syncServersDelay = 30 * time.Second +const syncShortDelay = 500 * time.Millisecond + +// syncServersLoop loops while the cluster is alive to keep its idea of +// the server topology up-to-date. It must be called just once from +// newCluster. The loop iterates once syncServersDelay has passed, or +// if somebody injects a value into the cluster.sync channel to force a +// synchronization. A loop iteration will contact all servers in +// parallel, ask them about known peers and their own role within the +// cluster, and then attempt to do the same with all the peers +// retrieved. +func (cluster *mongoCluster) syncServersLoop() { + for { + debugf("SYNC Cluster %p is starting a sync loop iteration.", cluster) + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.references++ // Keep alive while syncing. + direct := cluster.direct + cluster.Unlock() + + cluster.syncServersIteration(direct) + + // We just synchronized, so consume any outstanding requests. + select { + case <-cluster.sync: + default: + } + + cluster.Release() + + // Hold off before allowing another sync. No point in + // burning CPU looking for down servers. + if !cluster.failFast { + time.Sleep(syncShortDelay) + } + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.syncCount++ + // Poke all waiters so they have a chance to timeout or + // restart syncing if they wish to. + cluster.serverSynced.Broadcast() + // Check if we have to restart immediately either way. + restart := !direct && cluster.masters.Empty() || cluster.servers.Empty() + cluster.Unlock() + + if restart { + log("SYNC No masters found. Will synchronize again.") + time.Sleep(syncShortDelay) + continue + } + + debugf("SYNC Cluster %p waiting for next requested or scheduled sync.", cluster) + + // Hold off until somebody explicitly requests a synchronization + // or it's time to check for a cluster topology change again. + select { + case <-cluster.sync: + case <-time.After(syncServersDelay): + } + } + debugf("SYNC Cluster %p is stopping its sync loop.", cluster) +} + +func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { + cluster.RLock() + server := cluster.servers.Search(tcpaddr.String()) + cluster.RUnlock() + if server != nil { + return server + } + return newServer(addr, tcpaddr, cluster.sync, cluster.dial) +} + +func resolveAddr(addr string) (*net.TCPAddr, error) { + // Simple cases that do not need actual resolution. Works with IPv4 and v6. + if host, port, err := net.SplitHostPort(addr); err == nil { + if port, _ := strconv.Atoi(port); port > 0 { + zone := "" + if i := strings.LastIndex(host, "%"); i >= 0 { + zone = host[i+1:] + host = host[:i] + } + ip := net.ParseIP(host) + if ip != nil { + return &net.TCPAddr{IP: ip, Port: port, Zone: zone}, nil + } + } + } + + // Attempt to resolve IPv4 and v6 concurrently. + addrChan := make(chan *net.TCPAddr, 2) + for _, network := range []string{"udp4", "udp6"} { + network := network + go func() { + // The unfortunate UDP dialing hack allows having a timeout on address resolution. + conn, err := net.DialTimeout(network, addr, 10*time.Second) + if err != nil { + addrChan <- nil + } else { + addrChan <- (*net.TCPAddr)(conn.RemoteAddr().(*net.UDPAddr)) + conn.Close() + } + }() + } + + // Wait for the result of IPv4 and v6 resolution. Use IPv4 if available. + tcpaddr := <-addrChan + if tcpaddr == nil || len(tcpaddr.IP) != 4 { + var timeout <-chan time.Time + if tcpaddr != nil { + // Don't wait too long if an IPv6 address is known. + timeout = time.After(50 * time.Millisecond) + } + select { + case <-timeout: + case tcpaddr2 := <-addrChan: + if tcpaddr == nil || tcpaddr2 != nil { + // It's an IPv4 address or the only known address. Use it. + tcpaddr = tcpaddr2 + } + } + } + + if tcpaddr == nil { + log("SYNC Failed to resolve server address: ", addr) + return nil, errors.New("failed to resolve server address: " + addr) + } + if tcpaddr.String() != addr { + debug("SYNC Address ", addr, " resolved as ", tcpaddr.String()) + } + return tcpaddr, nil +} + +type pendingAdd struct { + server *mongoServer + info *mongoServerInfo +} + +func (cluster *mongoCluster) syncServersIteration(direct bool) { + log("SYNC Starting full topology synchronization...") + + var wg sync.WaitGroup + var m sync.Mutex + notYetAdded := make(map[string]pendingAdd) + addIfFound := make(map[string]bool) + seen := make(map[string]bool) + syncKind := partialSync + + var spawnSync func(addr string, byMaster bool) + spawnSync = func(addr string, byMaster bool) { + wg.Add(1) + go func() { + defer wg.Done() + + tcpaddr, err := resolveAddr(addr) + if err != nil { + log("SYNC Failed to start sync of ", addr, ": ", err.Error()) + return + } + resolvedAddr := tcpaddr.String() + + m.Lock() + if byMaster { + if pending, ok := notYetAdded[resolvedAddr]; ok { + delete(notYetAdded, resolvedAddr) + m.Unlock() + cluster.addServer(pending.server, pending.info, completeSync) + return + } + addIfFound[resolvedAddr] = true + } + if seen[resolvedAddr] { + m.Unlock() + return + } + seen[resolvedAddr] = true + m.Unlock() + + server := cluster.server(addr, tcpaddr) + info, hosts, err := cluster.syncServer(server) + if err != nil { + cluster.removeServer(server) + return + } + + m.Lock() + add := direct || info.Master || addIfFound[resolvedAddr] + if add { + syncKind = completeSync + } else { + notYetAdded[resolvedAddr] = pendingAdd{server, info} + } + m.Unlock() + if add { + cluster.addServer(server, info, completeSync) + } + if !direct { + for _, addr := range hosts { + spawnSync(addr, info.Master) + } + } + }() + } + + knownAddrs := cluster.getKnownAddrs() + for _, addr := range knownAddrs { + spawnSync(addr, false) + } + wg.Wait() + + if syncKind == completeSync { + logf("SYNC Synchronization was complete (got data from primary).") + for _, pending := range notYetAdded { + cluster.removeServer(pending.server) + } + } else { + logf("SYNC Synchronization was partial (cannot talk to primary).") + for _, pending := range notYetAdded { + cluster.addServer(pending.server, pending.info, partialSync) + } + } + + cluster.Lock() + mastersLen := cluster.masters.Len() + logf("SYNC Synchronization completed: %d master(s) and %d slave(s) alive.", mastersLen, cluster.servers.Len()-mastersLen) + + // Update dynamic seeds, but only if we have any good servers. Otherwise, + // leave them alone for better chances of a successful sync in the future. + if syncKind == completeSync { + dynaSeeds := make([]string, cluster.servers.Len()) + for i, server := range cluster.servers.Slice() { + dynaSeeds[i] = server.Addr + } + cluster.dynaSeeds = dynaSeeds + debugf("SYNC New dynamic seeds: %#v\n", dynaSeeds) + } + cluster.Unlock() +} + +// AcquireSocket returns a socket to a server in the cluster. If slaveOk is +// true, it will attempt to return a socket to a slave server. If it is +// false, the socket will necessarily be to a master server. +func (cluster *mongoCluster) AcquireSocket(mode Mode, slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { + var started time.Time + var syncCount uint + warnedLimit := false + for { + cluster.RLock() + for { + mastersLen := cluster.masters.Len() + slavesLen := cluster.servers.Len() - mastersLen + debugf("Cluster has %d known masters and %d known slaves.", mastersLen, slavesLen) + if !(slaveOk && mode == Secondary) && mastersLen > 0 || slaveOk && slavesLen > 0 { + break + } + if started.IsZero() { + // Initialize after fast path above. + started = time.Now() + syncCount = cluster.syncCount + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + cluster.RUnlock() + return nil, errors.New("no reachable servers") + } + log("Waiting for servers to synchronize...") + cluster.syncServers() + + // Remember: this will release and reacquire the lock. + cluster.serverSynced.Wait() + } + + var server *mongoServer + if slaveOk { + server = cluster.servers.BestFit(mode, serverTags) + } else { + server = cluster.masters.BestFit(mode, nil) + } + cluster.RUnlock() + + if server == nil { + // Must have failed the requested tags. Sleep to avoid spinning. + time.Sleep(1e8) + continue + } + + s, abended, err := server.AcquireSocket(poolLimit, socketTimeout) + if err == errPoolLimit { + if !warnedLimit { + warnedLimit = true + log("WARNING: Per-server connection limit reached.") + } + time.Sleep(100 * time.Millisecond) + continue + } + if err != nil { + cluster.removeServer(server) + cluster.syncServers() + continue + } + if abended && !slaveOk { + var result isMasterResult + err := cluster.isMaster(s, &result) + if err != nil || !result.IsMaster { + logf("Cannot confirm server %s as master (%v)", server.Addr, err) + s.Release() + cluster.syncServers() + time.Sleep(100 * time.Millisecond) + continue + } + } + return s, nil + } + panic("unreached") +} + +func (cluster *mongoCluster) CacheIndex(cacheKey string, exists bool) { + cluster.Lock() + if cluster.cachedIndex == nil { + cluster.cachedIndex = make(map[string]bool) + } + if exists { + cluster.cachedIndex[cacheKey] = true + } else { + delete(cluster.cachedIndex, cacheKey) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) HasCachedIndex(cacheKey string) (result bool) { + cluster.RLock() + if cluster.cachedIndex != nil { + result = cluster.cachedIndex[cacheKey] + } + cluster.RUnlock() + return +} + +func (cluster *mongoCluster) ResetIndexCache() { + cluster.Lock() + cluster.cachedIndex = make(map[string]bool) + cluster.Unlock() +} diff --git a/vendor/gopkg.in/mgo.v2/doc.go b/vendor/gopkg.in/mgo.v2/doc.go new file mode 100644 index 0000000..859fd9b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/doc.go @@ -0,0 +1,31 @@ +// Package mgo offers a rich MongoDB driver for Go. +// +// Details about the mgo project (pronounced as "mango") are found +// in its web page: +// +// http://labix.org/mgo +// +// Usage of the driver revolves around the concept of sessions. To +// get started, obtain a session using the Dial function: +// +// session, err := mgo.Dial(url) +// +// This will establish one or more connections with the cluster of +// servers defined by the url parameter. From then on, the cluster +// may be queried with multiple consistency rules (see SetMode) and +// documents retrieved with statements such as: +// +// c := session.DB(database).C(collection) +// err := c.Find(query).One(&result) +// +// New sessions are typically created by calling session.Copy on the +// initial session obtained at dial time. These new sessions will share +// the same cluster information and connection pool, and may be easily +// handed into other methods and functions for organizing logic. +// Every session created must have its Close method called at the end +// of its life time, so its resources may be put back in the pool or +// collected, depending on the case. +// +// For more details, see the documentation for the types and methods. +// +package mgo diff --git a/vendor/gopkg.in/mgo.v2/gridfs.go b/vendor/gopkg.in/mgo.v2/gridfs.go new file mode 100644 index 0000000..2ac4ff5 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/gridfs.go @@ -0,0 +1,761 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "hash" + "io" + "os" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type GridFS struct { + Files *Collection + Chunks *Collection +} + +type gfsFileMode int + +const ( + gfsClosed gfsFileMode = 0 + gfsReading gfsFileMode = 1 + gfsWriting gfsFileMode = 2 +) + +type GridFile struct { + m sync.Mutex + c sync.Cond + gfs *GridFS + mode gfsFileMode + err error + + chunk int + offset int64 + + wpending int + wbuf []byte + wsum hash.Hash + + rbuf []byte + rcache *gfsCachedChunk + + doc gfsFile +} + +type gfsFile struct { + Id interface{} "_id" + ChunkSize int "chunkSize" + UploadDate time.Time "uploadDate" + Length int64 ",minsize" + MD5 string + Filename string ",omitempty" + ContentType string "contentType,omitempty" + Metadata *bson.Raw ",omitempty" +} + +type gfsChunk struct { + Id interface{} "_id" + FilesId interface{} "files_id" + N int + Data []byte +} + +type gfsCachedChunk struct { + wait sync.Mutex + n int + data []byte + err error +} + +func newGridFS(db *Database, prefix string) *GridFS { + return &GridFS{db.C(prefix + ".files"), db.C(prefix + ".chunks")} +} + +func (gfs *GridFS) newFile() *GridFile { + file := &GridFile{gfs: gfs} + file.c.L = &file.m + //runtime.SetFinalizer(file, finalizeFile) + return file +} + +func finalizeFile(file *GridFile) { + file.Close() +} + +// Create creates a new file with the provided name in the GridFS. If the file +// name already exists, a new version will be inserted with an up-to-date +// uploadDate that will cause it to be atomically visible to the Open and +// OpenId methods. If the file name is not important, an empty name may be +// provided and the file Id used instead. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// A simple example inserting a new file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// n, err := file.Write([]byte("Hello world!")) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes written\n", n) +// +// The io.Writer interface is implemented by *GridFile and may be used to +// help on the file creation. For example: +// +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// messages, err := os.Open("/var/log/messages") +// check(err) +// defer messages.Close() +// err = io.Copy(file, messages) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Create(name string) (file *GridFile, err error) { + file = gfs.newFile() + file.mode = gfsWriting + file.wsum = md5.New() + file.doc = gfsFile{Id: bson.NewObjectId(), ChunkSize: 255 * 1024, Filename: name} + return +} + +// OpenId returns the file with the provided id, for reading. +// If the file isn't found, err will be set to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) OpenId(id interface{}) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"_id": id}).One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// Open returns the most recently uploaded file with the provided +// name, for reading. If the file isn't found, err will be set +// to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Open(name string) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"filename": name}).Sort("-uploadDate").One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// OpenNext opens the next file from iter for reading, sets *file to it, +// and returns true on the success case. If no more documents are available +// on iter or an error occurred, *file is set to nil and the result is false. +// Errors will be available via iter.Err(). +// +// The iter parameter must be an iterator on the GridFS files collection. +// Using the GridFS.Find method is an easy way to obtain such an iterator, +// but any iterator on the collection will work. +// +// If the provided *file is non-nil, OpenNext will close it before attempting +// to iterate to the next element. This means that in a loop one only +// has to worry about closing files when breaking out of the loop early +// (break, return, or panic). +// +// For example: +// +// gfs := db.GridFS("fs") +// query := gfs.Find(nil).Sort("filename") +// iter := query.Iter() +// var f *mgo.GridFile +// for gfs.OpenNext(iter, &f) { +// fmt.Printf("Filename: %s\n", f.Name()) +// } +// if iter.Close() != nil { +// panic(iter.Close()) +// } +// +func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { + if *file != nil { + // Ignoring the error here shouldn't be a big deal + // as we're reading the file and the loop iteration + // for this file is finished. + _ = (*file).Close() + } + var doc gfsFile + if !iter.Next(&doc) { + *file = nil + return false + } + f := gfs.newFile() + f.mode = gfsReading + f.doc = doc + *file = f + return true +} + +// Find runs query on GridFS's files collection and returns +// the resulting Query. +// +// This logic: +// +// gfs := db.GridFS("fs") +// iter := gfs.Find(nil).Iter() +// +// Is equivalent to: +// +// files := db.C("fs" + ".files") +// iter := files.Find(nil).Iter() +// +func (gfs *GridFS) Find(query interface{}) *Query { + return gfs.Files.Find(query) +} + +// RemoveId deletes the file with the provided id from the GridFS. +func (gfs *GridFS) RemoveId(id interface{}) error { + err := gfs.Files.Remove(bson.M{"_id": id}) + if err != nil { + return err + } + _, err = gfs.Chunks.RemoveAll(bson.D{{"files_id", id}}) + return err +} + +type gfsDocId struct { + Id interface{} "_id" +} + +// Remove deletes all files with the provided name from the GridFS. +func (gfs *GridFS) Remove(name string) (err error) { + iter := gfs.Files.Find(bson.M{"filename": name}).Select(bson.M{"_id": 1}).Iter() + var doc gfsDocId + for iter.Next(&doc) { + if e := gfs.RemoveId(doc.Id); e != nil { + err = e + } + } + if err == nil { + err = iter.Close() + } + return err +} + +func (file *GridFile) assertMode(mode gfsFileMode) { + switch file.mode { + case mode: + return + case gfsWriting: + panic("GridFile is open for writing") + case gfsReading: + panic("GridFile is open for reading") + case gfsClosed: + panic("GridFile is closed") + default: + panic("internal error: missing GridFile mode") + } +} + +// SetChunkSize sets size of saved chunks. Once the file is written to, it +// will be split in blocks of that size and each block saved into an +// independent chunk document. The default chunk size is 256kb. +// +// It is a runtime error to call this function once the file has started +// being written to. +func (file *GridFile) SetChunkSize(bytes int) { + file.assertMode(gfsWriting) + debugf("GridFile %p: setting chunk size to %d", file, bytes) + file.m.Lock() + file.doc.ChunkSize = bytes + file.m.Unlock() +} + +// Id returns the current file Id. +func (file *GridFile) Id() interface{} { + return file.doc.Id +} + +// SetId changes the current file Id. +// +// It is a runtime error to call this function once the file has started +// being written to, or when the file is not open for writing. +func (file *GridFile) SetId(id interface{}) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Id = id + file.m.Unlock() +} + +// Name returns the optional file name. An empty string will be returned +// in case it is unset. +func (file *GridFile) Name() string { + return file.doc.Filename +} + +// SetName changes the optional file name. An empty string may be used to +// unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetName(name string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Filename = name + file.m.Unlock() +} + +// ContentType returns the optional file content type. An empty string will be +// returned in case it is unset. +func (file *GridFile) ContentType() string { + return file.doc.ContentType +} + +// ContentType changes the optional file content type. An empty string may be +// used to unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetContentType(ctype string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.ContentType = ctype + file.m.Unlock() +} + +// GetMeta unmarshals the optional "metadata" field associated with the +// file into the result parameter. The meaning of keys under that field +// is user-defined. For example: +// +// result := struct{ INode int }{} +// err = file.GetMeta(&result) +// if err != nil { +// panic(err.String()) +// } +// fmt.Printf("inode: %d\n", result.INode) +// +func (file *GridFile) GetMeta(result interface{}) (err error) { + file.m.Lock() + if file.doc.Metadata != nil { + err = bson.Unmarshal(file.doc.Metadata.Data, result) + } + file.m.Unlock() + return +} + +// SetMeta changes the optional "metadata" field associated with the +// file. The meaning of keys under that field is user-defined. +// For example: +// +// file.SetMeta(bson.M{"inode": inode}) +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetMeta(metadata interface{}) { + file.assertMode(gfsWriting) + data, err := bson.Marshal(metadata) + file.m.Lock() + if err != nil && file.err == nil { + file.err = err + } else { + file.doc.Metadata = &bson.Raw{Data: data} + } + file.m.Unlock() +} + +// Size returns the file size in bytes. +func (file *GridFile) Size() (bytes int64) { + file.m.Lock() + bytes = file.doc.Length + file.m.Unlock() + return +} + +// MD5 returns the file MD5 as a hex-encoded string. +func (file *GridFile) MD5() (md5 string) { + return file.doc.MD5 +} + +// UploadDate returns the file upload time. +func (file *GridFile) UploadDate() time.Time { + return file.doc.UploadDate +} + +// SetUploadDate changes the file upload time. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetUploadDate(t time.Time) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.UploadDate = t + file.m.Unlock() +} + +// Close flushes any pending changes in case the file is being written +// to, waits for any background operations to finish, and closes the file. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +func (file *GridFile) Close() (err error) { + file.m.Lock() + defer file.m.Unlock() + if file.mode == gfsWriting { + if len(file.wbuf) > 0 && file.err == nil { + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + file.completeWrite() + } else if file.mode == gfsReading && file.rcache != nil { + file.rcache.wait.Lock() + file.rcache = nil + } + file.mode = gfsClosed + debugf("GridFile %p: closed", file) + return file.err +} + +func (file *GridFile) completeWrite() { + for file.wpending > 0 { + debugf("GridFile %p: waiting for %d pending chunks to complete file write", file, file.wpending) + file.c.Wait() + } + if file.err == nil { + hexsum := hex.EncodeToString(file.wsum.Sum(nil)) + if file.doc.UploadDate.IsZero() { + file.doc.UploadDate = bson.Now() + } + file.doc.MD5 = hexsum + file.err = file.gfs.Files.Insert(file.doc) + } + if file.err != nil { + file.gfs.Chunks.RemoveAll(bson.D{{"files_id", file.doc.Id}}) + } + if file.err == nil { + index := Index{ + Key: []string{"files_id", "n"}, + Unique: true, + } + file.err = file.gfs.Chunks.EnsureIndex(index) + } +} + +// Abort cancels an in-progress write, preventing the file from being +// automically created and ensuring previously written chunks are +// removed when the file is closed. +// +// It is a runtime error to call Abort when the file was not opened +// for writing. +func (file *GridFile) Abort() { + if file.mode != gfsWriting { + panic("file.Abort must be called on file opened for writing") + } + file.err = errors.New("write aborted") +} + +// Write writes the provided data to the file and returns the +// number of bytes written and an error in case something +// wrong happened. +// +// The file will internally cache the data so that all but the last +// chunk sent to the database have the size defined by SetChunkSize. +// This also means that errors may be deferred until a future call +// to Write or Close. +// +// The parameters and behavior of this function turn the file +// into an io.Writer. +func (file *GridFile) Write(data []byte) (n int, err error) { + file.assertMode(gfsWriting) + file.m.Lock() + debugf("GridFile %p: writing %d bytes", file, len(data)) + defer file.m.Unlock() + + if file.err != nil { + return 0, file.err + } + + n = len(data) + file.doc.Length += int64(n) + chunkSize := file.doc.ChunkSize + + if len(file.wbuf)+len(data) < chunkSize { + file.wbuf = append(file.wbuf, data...) + return + } + + // First, flush file.wbuf complementing with data. + if len(file.wbuf) > 0 { + missing := chunkSize - len(file.wbuf) + if missing > len(data) { + missing = len(data) + } + file.wbuf = append(file.wbuf, data[:missing]...) + data = data[missing:] + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + + // Then, flush all chunks from data without copying. + for len(data) > chunkSize { + size := chunkSize + if size > len(data) { + size = len(data) + } + file.insertChunk(data[:size]) + data = data[size:] + } + + // And append the rest for a future call. + file.wbuf = append(file.wbuf, data...) + + return n, file.err +} + +func (file *GridFile) insertChunk(data []byte) { + n := file.chunk + file.chunk++ + debugf("GridFile %p: adding to checksum: %q", file, string(data)) + file.wsum.Write(data) + + for file.doc.ChunkSize*file.wpending >= 1024*1024 { + // Hold on.. we got a MB pending. + file.c.Wait() + if file.err != nil { + return + } + } + + file.wpending++ + + debugf("GridFile %p: inserting chunk %d with %d bytes", file, n, len(data)) + + // We may not own the memory of data, so rather than + // simply copying it, we'll marshal the document ahead of time. + data, err := bson.Marshal(gfsChunk{bson.NewObjectId(), file.doc.Id, n, data}) + if err != nil { + file.err = err + return + } + + go func() { + err := file.gfs.Chunks.Insert(bson.Raw{Data: data}) + file.m.Lock() + file.wpending-- + if err != nil && file.err == nil { + file.err = err + } + file.c.Broadcast() + file.m.Unlock() + }() +} + +// Seek sets the offset for the next Read or Write on file to +// offset, interpreted according to whence: 0 means relative to +// the origin of the file, 1 means relative to the current offset, +// and 2 means relative to the end. It returns the new offset and +// an error, if any. +func (file *GridFile) Seek(offset int64, whence int) (pos int64, err error) { + file.m.Lock() + debugf("GridFile %p: seeking for %s (whence=%d)", file, offset, whence) + defer file.m.Unlock() + switch whence { + case os.SEEK_SET: + case os.SEEK_CUR: + offset += file.offset + case os.SEEK_END: + offset += file.doc.Length + default: + panic("unsupported whence value") + } + if offset > file.doc.Length { + return file.offset, errors.New("seek past end of file") + } + if offset == file.doc.Length { + // If we're seeking to the end of the file, + // no need to read anything. This enables + // a client to find the size of the file using only the + // io.ReadSeeker interface with low overhead. + file.offset = offset + return file.offset, nil + } + chunk := int(offset / int64(file.doc.ChunkSize)) + if chunk+1 == file.chunk && offset >= file.offset { + file.rbuf = file.rbuf[int(offset-file.offset):] + file.offset = offset + return file.offset, nil + } + file.offset = offset + file.chunk = chunk + file.rbuf = nil + file.rbuf, err = file.getChunk() + if err == nil { + file.rbuf = file.rbuf[int(file.offset-int64(chunk)*int64(file.doc.ChunkSize)):] + } + return file.offset, err +} + +// Read reads into b the next available data from the file and +// returns the number of bytes written and an error in case +// something wrong happened. At the end of the file, n will +// be zero and err will be set to io.EOF. +// +// The parameters and behavior of this function turn the file +// into an io.Reader. +func (file *GridFile) Read(b []byte) (n int, err error) { + file.assertMode(gfsReading) + file.m.Lock() + debugf("GridFile %p: reading at offset %d into buffer of length %d", file, file.offset, len(b)) + defer file.m.Unlock() + if file.offset == file.doc.Length { + return 0, io.EOF + } + for err == nil { + i := copy(b, file.rbuf) + n += i + file.offset += int64(i) + file.rbuf = file.rbuf[i:] + if i == len(b) || file.offset == file.doc.Length { + break + } + b = b[i:] + file.rbuf, err = file.getChunk() + } + return n, err +} + +func (file *GridFile) getChunk() (data []byte, err error) { + cache := file.rcache + file.rcache = nil + if cache != nil && cache.n == file.chunk { + debugf("GridFile %p: Getting chunk %d from cache", file, file.chunk) + cache.wait.Lock() + data, err = cache.data, cache.err + } else { + debugf("GridFile %p: Fetching chunk %d", file, file.chunk) + var doc gfsChunk + err = file.gfs.Chunks.Find(bson.D{{"files_id", file.doc.Id}, {"n", file.chunk}}).One(&doc) + data = doc.Data + } + file.chunk++ + if int64(file.chunk)*int64(file.doc.ChunkSize) < file.doc.Length { + // Read the next one in background. + cache = &gfsCachedChunk{n: file.chunk} + cache.wait.Lock() + debugf("GridFile %p: Scheduling chunk %d for background caching", file, file.chunk) + // Clone the session to avoid having it closed in between. + chunks := file.gfs.Chunks + session := chunks.Database.Session.Clone() + go func(id interface{}, n int) { + defer session.Close() + chunks = chunks.With(session) + var doc gfsChunk + cache.err = chunks.Find(bson.D{{"files_id", id}, {"n", n}}).One(&doc) + cache.data = doc.Data + cache.wait.Unlock() + }(file.doc.Id, file.chunk) + file.rcache = cache + } + debugf("Returning err: %#v", err) + return +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c new file mode 100644 index 0000000..8be0bc4 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c @@ -0,0 +1,77 @@ +// +build !windows + +#include <stdlib.h> +#include <stdio.h> +#include <string.h> +#include <sasl/sasl.h> + +static int mgo_sasl_simple(void *context, int id, const char **result, unsigned int *len) +{ + if (!result) { + return SASL_BADPARAM; + } + switch (id) { + case SASL_CB_USER: + *result = (char *)context; + break; + case SASL_CB_AUTHNAME: + *result = (char *)context; + break; + case SASL_CB_LANGUAGE: + *result = NULL; + break; + default: + return SASL_BADPARAM; + } + if (len) { + *len = *result ? strlen(*result) : 0; + } + return SASL_OK; +} + +typedef int (*callback)(void); + +static int mgo_sasl_secret(sasl_conn_t *conn, void *context, int id, sasl_secret_t **result) +{ + if (!conn || !result || id != SASL_CB_PASS) { + return SASL_BADPARAM; + } + *result = (sasl_secret_t *)context; + return SASL_OK; +} + +sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password) +{ + sasl_callback_t *cb = malloc(4 * sizeof(sasl_callback_t)); + int n = 0; + + size_t len = strlen(password); + sasl_secret_t *secret = (sasl_secret_t*)malloc(sizeof(sasl_secret_t) + len); + if (!secret) { + free(cb); + return NULL; + } + strcpy((char *)secret->data, password); + secret->len = len; + + cb[n].id = SASL_CB_PASS; + cb[n].proc = (callback)&mgo_sasl_secret; + cb[n].context = secret; + n++; + + cb[n].id = SASL_CB_USER; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_AUTHNAME; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_LIST_END; + cb[n].proc = NULL; + cb[n].context = NULL; + + return cb; +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go new file mode 100644 index 0000000..8375ddd --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go @@ -0,0 +1,138 @@ +// Package sasl is an implementation detail of the mgo package. +// +// This package is not meant to be used by itself. +// + +// +build !windows + +package sasl + +// #cgo LDFLAGS: -lsasl2 +// +// struct sasl_conn {}; +// +// #include <stdlib.h> +// #include <sasl/sasl.h> +// +// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password); +// +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + conn *C.sasl_conn_t + step int + mech string + + cstrings []*C.char + callbacks *C.sasl_callback_t +} + +var initError error +var initOnce sync.Once + +func initSASL() { + rc := C.sasl_client_init(nil) + if rc != C.SASL_OK { + initError = saslError(rc, nil, "cannot initialize SASL library") + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSASL) + if initError != nil { + return nil, initError + } + + ss := &saslSession{mech: mechanism} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password)) + rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn) + if rc != C.SASL_OK { + ss.Close() + return nil, saslError(rc, nil, "cannot create new SASL client") + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.cstrings = append(ss.cstrings, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.cstrings { + C.free(unsafe.Pointer(cstr)) + } + ss.cstrings = nil + + if ss.callbacks != nil { + C.free(unsafe.Pointer(ss.callbacks)) + } + + // The documentation of SASL dispose makes it clear that this should only + // be done when the connection is done, not when the authentication phase + // is done, because an encryption layer may have been negotiated. + // Even then, we'll do this for now, because it's simpler and prevents + // keeping track of this state for every socket. If it breaks, we'll fix it. + C.sasl_dispose(&ss.conn) +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SASL steps without authentication") + } + var cclientData *C.char + var cclientDataLen C.uint + var rc C.int + if ss.step == 1 { + var mechanism *C.char // ignored - must match cred + rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism) + } else { + var cserverData *C.char + var cserverDataLen C.uint + if len(serverData) > 0 { + cserverData = (*C.char)(unsafe.Pointer(&serverData[0])) + cserverDataLen = C.uint(len(serverData)) + } + rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen) + } + if cclientData != nil && cclientDataLen > 0 { + clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen)) + } + if rc == C.SASL_OK { + return clientData, true, nil + } + if rc == C.SASL_CONTINUE { + return clientData, false, nil + } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") +} + +func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error { + var detail string + if conn == nil { + detail = C.GoString(C.sasl_errstring(rc, nil, nil)) + } else { + detail = C.GoString(C.sasl_errdetail(conn)) + } + return fmt.Errorf(msg + ": " + detail) +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c new file mode 100644 index 0000000..dd6a88a --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c @@ -0,0 +1,118 @@ +#include "sasl_windows.h" + +static const LPSTR SSPI_PACKAGE_NAME = "kerberos"; + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle *cred_handle, char *username, char *password, char *domain) +{ + SEC_WINNT_AUTH_IDENTITY auth_identity; + SECURITY_INTEGER ignored; + + auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI; + auth_identity.User = (LPSTR) username; + auth_identity.UserLength = strlen(username); + auth_identity.Password = (LPSTR) password; + auth_identity.PasswordLength = strlen(password); + auth_identity.Domain = (LPSTR) domain; + auth_identity.DomainLength = strlen(domain); + return call_sspi_acquire_credentials_handle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, cred_handle, &ignored); +} + +int sspi_step(CredHandle *cred_handle, int has_context, CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *target) +{ + SecBufferDesc inbuf; + SecBuffer in_bufs[1]; + SecBufferDesc outbuf; + SecBuffer out_bufs[1]; + + if (has_context > 0) { + // If we already have a context, we now have data to send. + // Put this data in an inbuf. + inbuf.ulVersion = SECBUFFER_VERSION; + inbuf.cBuffers = 1; + inbuf.pBuffers = in_bufs; + in_bufs[0].pvBuffer = *buffer; + in_bufs[0].cbBuffer = *buffer_length; + in_bufs[0].BufferType = SECBUFFER_TOKEN; + } + + outbuf.ulVersion = SECBUFFER_VERSION; + outbuf.cBuffers = 1; + outbuf.pBuffers = out_bufs; + out_bufs[0].pvBuffer = NULL; + out_bufs[0].cbBuffer = 0; + out_bufs[0].BufferType = SECBUFFER_TOKEN; + + ULONG context_attr = 0; + + int ret = call_sspi_initialize_security_context(cred_handle, + has_context > 0 ? context : NULL, + (LPSTR) target, + ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH, + 0, + SECURITY_NETWORK_DREP, + has_context > 0 ? &inbuf : NULL, + 0, + context, + &outbuf, + &context_attr, + NULL); + + *buffer = malloc(out_bufs[0].cbBuffer); + *buffer_length = out_bufs[0].cbBuffer; + memcpy(*buffer, out_bufs[0].pvBuffer, *buffer_length); + + return ret; +} + +int sspi_send_client_authz_id(CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *user_plus_realm) +{ + SecPkgContext_Sizes sizes; + SECURITY_STATUS status = call_sspi_query_context_attributes(context, SECPKG_ATTR_SIZES, &sizes); + + if (status != SEC_E_OK) { + return status; + } + + size_t user_plus_realm_length = strlen(user_plus_realm); + int msgSize = 4 + user_plus_realm_length; + char *msg = malloc((sizes.cbSecurityTrailer + msgSize + sizes.cbBlockSize) * sizeof(char)); + msg[sizes.cbSecurityTrailer + 0] = 1; + msg[sizes.cbSecurityTrailer + 1] = 0; + msg[sizes.cbSecurityTrailer + 2] = 0; + msg[sizes.cbSecurityTrailer + 3] = 0; + memcpy(&msg[sizes.cbSecurityTrailer + 4], user_plus_realm, user_plus_realm_length); + + SecBuffer wrapBufs[3]; + SecBufferDesc wrapBufDesc; + wrapBufDesc.cBuffers = 3; + wrapBufDesc.pBuffers = wrapBufs; + wrapBufDesc.ulVersion = SECBUFFER_VERSION; + + wrapBufs[0].cbBuffer = sizes.cbSecurityTrailer; + wrapBufs[0].BufferType = SECBUFFER_TOKEN; + wrapBufs[0].pvBuffer = msg; + + wrapBufs[1].cbBuffer = msgSize; + wrapBufs[1].BufferType = SECBUFFER_DATA; + wrapBufs[1].pvBuffer = msg + sizes.cbSecurityTrailer; + + wrapBufs[2].cbBuffer = sizes.cbBlockSize; + wrapBufs[2].BufferType = SECBUFFER_PADDING; + wrapBufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + msgSize; + + status = call_sspi_encrypt_message(context, SECQOP_WRAP_NO_ENCRYPT, &wrapBufDesc, 0); + if (status != SEC_E_OK) { + free(msg); + return status; + } + + *buffer_length = wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer + wrapBufs[2].cbBuffer; + *buffer = malloc(*buffer_length); + + memcpy(*buffer, wrapBufs[0].pvBuffer, wrapBufs[0].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer, wrapBufs[1].pvBuffer, wrapBufs[1].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer, wrapBufs[2].pvBuffer, wrapBufs[2].cbBuffer); + + free(msg); + return SEC_E_OK; +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go new file mode 100644 index 0000000..3302cfe --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go @@ -0,0 +1,140 @@ +package sasl + +// #include "sasl_windows.h" +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + // Credentials + mech string + service string + host string + userPlusRealm string + target string + domain string + + // Internal state + authComplete bool + errored bool + step int + + // C internal state + credHandle C.CredHandle + context C.CtxtHandle + hasContext C.int + + // Keep track of pointers we need to explicitly free + stringsToFree []*C.char +} + +var initError error +var initOnce sync.Once + +func initSSPI() { + rc := C.load_secur32_dll() + if rc != 0 { + initError = fmt.Errorf("Error loading libraries: %v", rc) + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSSPI) + ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.service = service + ss.host = host + + usernameComponents := strings.Split(username, "@") + if len(usernameComponents) < 2 { + return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username) + } + user := usernameComponents[0] + ss.domain = usernameComponents[1] + ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host) + + var status C.SECURITY_STATUS + // Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle + if len(password) > 0 { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain)) + } else { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain)) + } + if status != C.SEC_E_OK { + ss.errored = true + return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status) + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.stringsToFree = append(ss.stringsToFree, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.stringsToFree { + C.free(unsafe.Pointer(cstr)) + } +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SSPI steps without authentication") + } + var buffer C.PVOID + var bufferLength C.ULONG + if len(serverData) > 0 { + buffer = (C.PVOID)(unsafe.Pointer(&serverData[0])) + bufferLength = C.ULONG(len(serverData)) + } + var status C.int + if ss.authComplete { + // Step 3: last bit of magic to use the correct server credentials + status = C.sspi_send_client_authz_id(&ss.context, &buffer, &bufferLength, ss.cstr(ss.userPlusRealm)) + } else { + // Step 1 + Step 2: set up security context with the server and TGT + status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, &buffer, &bufferLength, ss.cstr(ss.target)) + } + if buffer != C.PVOID(nil) { + defer C.free(unsafe.Pointer(buffer)) + } + if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED { + ss.errored = true + return nil, false, ss.handleSSPIErrorCode(status) + } + + clientData = C.GoBytes(unsafe.Pointer(buffer), C.int(bufferLength)) + if status == C.SEC_E_OK { + ss.authComplete = true + return clientData, true, nil + } else { + ss.hasContext = 1 + return clientData, false, nil + } +} + +func (ss *saslSession) handleSSPIErrorCode(code C.int) error { + switch { + case code == C.SEC_E_TARGET_UNKNOWN: + return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain) + } + return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code) +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h new file mode 100644 index 0000000..94321b2 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h @@ -0,0 +1,7 @@ +#include <windows.h> + +#include "sspi_windows.h" + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle* cred_handle, char* username, char* password, char* domain); +int sspi_step(CredHandle* cred_handle, int has_context, CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* target); +int sspi_send_client_authz_id(CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* user_plus_realm); diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c new file mode 100644 index 0000000..63f9a6f --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c @@ -0,0 +1,96 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.c +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#include <stdlib.h> + +#include "sspi_windows.h" + +static HINSTANCE sspi_secur32_dll = NULL; + +int load_secur32_dll() +{ + sspi_secur32_dll = LoadLibrary("secur32.dll"); + if (sspi_secur32_dll == NULL) { + return GetLastError(); + } + return 0; +} + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + encryptMessage_fn pfn_encryptMessage = (encryptMessage_fn) GetProcAddress(sspi_secur32_dll, "EncryptMessage"); + if (!pfn_encryptMessage) { + return -2; + } + return (*pfn_encryptMessage)(phContext, fQOP, pMessage, MessageSeqNo); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + acquireCredentialsHandle_fn pfn_acquireCredentialsHandle; +#ifdef _UNICODE + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleW"); +#else + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleA"); +#endif + if (!pfn_acquireCredentialsHandle) { + return -2; + } + return (*pfn_acquireCredentialsHandle)( + pszPrincipal, pszPackage, fCredentialUse, pvLogonId, pAuthData, + pGetKeyFn, pvGetKeyArgument, phCredential, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, + unsigned long fContextReq, unsigned long Reserved1, unsigned long TargetDataRep, + PSecBufferDesc pInput, unsigned long Reserved2, PCtxtHandle phNewContext, + PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + initializeSecurityContext_fn pfn_initializeSecurityContext; +#ifdef _UNICODE + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextW"); +#else + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextA"); +#endif + if (!pfn_initializeSecurityContext) { + return -2; + } + return (*pfn_initializeSecurityContext)( + phCredential, phContext, pszTargetName, fContextReq, Reserved1, TargetDataRep, + pInput, Reserved2, phNewContext, pOutput, pfContextAttr, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes(PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + queryContextAttributes_fn pfn_queryContextAttributes; +#ifdef _UNICODE + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesW"); +#else + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesA"); +#endif + if (!pfn_queryContextAttributes) { + return -2; + } + return (*pfn_queryContextAttributes)(phContext, ulAttribute, pBuffer); +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h new file mode 100644 index 0000000..d283270 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h @@ -0,0 +1,70 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.h +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#ifndef SSPI_WINDOWS_H +#define SSPI_WINDOWS_H + +#define SECURITY_WIN32 1 + +#include <windows.h> +#include <sspi.h> + +int load_secur32_dll(); + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo); + +typedef DWORD (WINAPI *encryptMessage_fn)(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo); + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, // Name of principal + LPSTR pszPackage, // Name of package + unsigned long fCredentialUse, // Flags indicating use + void *pvLogonId, // Pointer to logon ID + void *pAuthData, // Package specific data + SEC_GET_KEY_FN pGetKeyFn, // Pointer to GetKey() func + void *pvGetKeyArgument, // Value to pass to GetKey() + PCredHandle phCredential, // (out) Cred Handle + PTimeStamp ptsExpiry // (out) Lifetime (optional) +); + +typedef DWORD (WINAPI *acquireCredentialsHandle_fn)( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry +); + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, // Cred to base context + PCtxtHandle phContext, // Existing context (OPT) + LPSTR pszTargetName, // Name of target + unsigned long fContextReq, // Context Requirements + unsigned long Reserved1, // Reserved, MBZ + unsigned long TargetDataRep, // Data rep of target + PSecBufferDesc pInput, // Input Buffers + unsigned long Reserved2, // Reserved, MBZ + PCtxtHandle phNewContext, // (out) New Context handle + PSecBufferDesc pOutput, // (inout) Output Buffers + unsigned long *pfContextAttr, // (out) Context attrs + PTimeStamp ptsExpiry // (out) Life span (OPT) +); + +typedef DWORD (WINAPI *initializeSecurityContext_fn)( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, unsigned long fContextReq, + unsigned long Reserved1, unsigned long TargetDataRep, PSecBufferDesc pInput, unsigned long Reserved2, + PCtxtHandle phNewContext, PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry); + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes( + PCtxtHandle phContext, // Context to query + unsigned long ulAttribute, // Attribute to query + void *pBuffer // Buffer for attributes +); + +typedef DWORD (WINAPI *queryContextAttributes_fn)( + PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer); + +#endif // SSPI_WINDOWS_H diff --git a/vendor/gopkg.in/mgo.v2/internal/scram/scram.go b/vendor/gopkg.in/mgo.v2/internal/scram/scram.go new file mode 100644 index 0000000..80cda91 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/scram/scram.go @@ -0,0 +1,266 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2014 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-1, for example, use: +// +// client := scram.NewClient(sha1.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that ocurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 6 + buf := make([]byte, nonceLen + b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/vendor/gopkg.in/mgo.v2/log.go b/vendor/gopkg.in/mgo.v2/log.go new file mode 100644 index 0000000..53eb423 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/log.go @@ -0,0 +1,133 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "fmt" + "sync" +) + +// --------------------------------------------------------------------------- +// Logging integration. + +// Avoid importing the log type information unnecessarily. There's a small cost +// associated with using an interface rather than the type. Depending on how +// often the logger is plugged in, it would be worth using the type instead. +type log_Logger interface { + Output(calldepth int, s string) error +} + +var ( + globalLogger log_Logger + globalDebug bool + globalMutex sync.Mutex +) + +// RACE WARNING: There are known data races when logging, which are manually +// silenced when the race detector is in use. These data races won't be +// observed in typical use, because logging is supposed to be set up once when +// the application starts. Having raceDetector as a constant, the compiler +// should elide the locks altogether in actual use. + +// Specify the *log.Logger object where log messages should be sent to. +func SetLogger(logger log_Logger) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalLogger = logger +} + +// Enable the delivery of debug messages to the logger. Only meaningful +// if a logger is also set. +func SetDebug(debug bool) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalDebug = debug +} + +func log(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func logln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func logf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} + +func debug(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func debugln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func debugf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} diff --git a/vendor/gopkg.in/mgo.v2/queue.go b/vendor/gopkg.in/mgo.v2/queue.go new file mode 100644 index 0000000..e9245de --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/queue.go @@ -0,0 +1,91 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +type queue struct { + elems []interface{} + nelems, popi, pushi int +} + +func (q *queue) Len() int { + return q.nelems +} + +func (q *queue) Push(elem interface{}) { + //debugf("Pushing(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + if q.nelems == len(q.elems) { + q.expand() + } + q.elems[q.pushi] = elem + q.nelems++ + q.pushi = (q.pushi + 1) % len(q.elems) + //debugf(" Pushed(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) +} + +func (q *queue) Pop() (elem interface{}) { + //debugf("Popping(pushi=%d popi=%d cap=%d)\n", + // q.pushi, q.popi, len(q.elems)) + if q.nelems == 0 { + return nil + } + elem = q.elems[q.popi] + q.elems[q.popi] = nil // Help GC. + q.nelems-- + q.popi = (q.popi + 1) % len(q.elems) + //debugf(" Popped(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + return elem +} + +func (q *queue) expand() { + curcap := len(q.elems) + var newcap int + if curcap == 0 { + newcap = 8 + } else if curcap < 1024 { + newcap = curcap * 2 + } else { + newcap = curcap + (curcap / 4) + } + elems := make([]interface{}, newcap) + + if q.popi == 0 { + copy(elems, q.elems) + q.pushi = curcap + } else { + newpopi := newcap - (curcap - q.popi) + copy(elems, q.elems[:q.popi]) + copy(elems[newpopi:], q.elems[q.popi:]) + q.popi = newpopi + } + for i := range q.elems { + q.elems[i] = nil // Help GC. + } + q.elems = elems +} diff --git a/vendor/gopkg.in/mgo.v2/raceoff.go b/vendor/gopkg.in/mgo.v2/raceoff.go new file mode 100644 index 0000000..e60b141 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/raceoff.go @@ -0,0 +1,5 @@ +// +build !race + +package mgo + +const raceDetector = false diff --git a/vendor/gopkg.in/mgo.v2/raceon.go b/vendor/gopkg.in/mgo.v2/raceon.go new file mode 100644 index 0000000..737b08e --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/raceon.go @@ -0,0 +1,5 @@ +// +build race + +package mgo + +const raceDetector = true diff --git a/vendor/gopkg.in/mgo.v2/saslimpl.go b/vendor/gopkg.in/mgo.v2/saslimpl.go new file mode 100644 index 0000000..0d25f25 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/saslimpl.go @@ -0,0 +1,11 @@ +//+build sasl + +package mgo + +import ( + "gopkg.in/mgo.v2/internal/sasl" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return sasl.New(cred.Username, cred.Password, cred.Mechanism, cred.Service, host) +} diff --git a/vendor/gopkg.in/mgo.v2/saslstub.go b/vendor/gopkg.in/mgo.v2/saslstub.go new file mode 100644 index 0000000..6e9e309 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/saslstub.go @@ -0,0 +1,11 @@ +//+build !sasl + +package mgo + +import ( + "fmt" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return nil, fmt.Errorf("SASL support not enabled during build (-tags sasl)") +} diff --git a/vendor/gopkg.in/mgo.v2/server.go b/vendor/gopkg.in/mgo.v2/server.go new file mode 100644 index 0000000..f677359 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/server.go @@ -0,0 +1,452 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sort" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo server encapsulation. + +type mongoServer struct { + sync.RWMutex + Addr string + ResolvedAddr string + tcpaddr *net.TCPAddr + unusedSockets []*mongoSocket + liveSockets []*mongoSocket + closed bool + abended bool + sync chan bool + dial dialer + pingValue time.Duration + pingIndex int + pingCount uint32 + pingWindow [6]time.Duration + info *mongoServerInfo +} + +type dialer struct { + old func(addr net.Addr) (net.Conn, error) + new func(addr *ServerAddr) (net.Conn, error) +} + +func (dial dialer) isSet() bool { + return dial.old != nil || dial.new != nil +} + +type mongoServerInfo struct { + Master bool + Mongos bool + Tags bson.D + MaxWireVersion int + SetName string +} + +var defaultServerInfo mongoServerInfo + +func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer { + server := &mongoServer{ + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: sync, + dial: dial, + info: &defaultServerInfo, + pingValue: time.Hour, // Push it back before an actual ping. + } + go server.pinger(true) + return server +} + +var errPoolLimit = errors.New("per-server connection limit reached") +var errServerClosed = errors.New("server was closed") + +// AcquireSocket returns a socket for communicating with the server. +// This will attempt to reuse an old connection, if one is available. Otherwise, +// it will establish a new one. The returned socket is owned by the call site, +// and will return to the cache when the socket has its Release method called +// the same number of times as AcquireSocket + Acquire were called for it. +// If the poolLimit argument is greater than zero and the number of sockets in +// use in this server is greater than the provided limit, errPoolLimit is +// returned. +func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { + for { + server.Lock() + abended = server.abended + if server.closed { + server.Unlock() + return nil, abended, errServerClosed + } + n := len(server.unusedSockets) + if poolLimit > 0 && len(server.liveSockets)-n >= poolLimit { + server.Unlock() + return nil, false, errPoolLimit + } + if n > 0 { + socket = server.unusedSockets[n-1] + server.unusedSockets[n-1] = nil // Help GC. + server.unusedSockets = server.unusedSockets[:n-1] + info := server.info + server.Unlock() + err = socket.InitialAcquire(info, timeout) + if err != nil { + continue + } + } else { + server.Unlock() + socket, err = server.Connect(timeout) + if err == nil { + server.Lock() + // We've waited for the Connect, see if we got + // closed in the meantime + if server.closed { + server.Unlock() + socket.Release() + socket.Close() + return nil, abended, errServerClosed + } + server.liveSockets = append(server.liveSockets, socket) + server.Unlock() + } + } + return + } + panic("unreachable") +} + +// Connect establishes a new connection to the server. This should +// generally be done through server.AcquireSocket(). +func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { + server.RLock() + master := server.info.Master + dial := server.dial + server.RUnlock() + + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + var conn net.Conn + var err error + switch { + case !dial.isSet(): + // Cannot do this because it lacks timeout support. :-( + //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + if tcpconn, ok := conn.(*net.TCPConn); ok { + tcpconn.SetKeepAlive(true) + } else if err == nil { + panic("internal error: obtained TCP connection is not a *net.TCPConn!?") + } + case dial.old != nil: + conn, err = dial.old(server.tcpaddr) + case dial.new != nil: + conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr}) + default: + panic("dialer is set, but both dial.old and dial.new are nil") + } + if err != nil { + logf("Connection to %s failed: %v", server.Addr, err.Error()) + return nil, err + } + logf("Connection to %s established.", server.Addr) + + stats.conn(+1, master) + return newSocket(server, conn, timeout), nil +} + +// Close forces closing all sockets that are alive, whether +// they're currently in use or not. +func (server *mongoServer) Close() { + server.Lock() + server.closed = true + liveSockets := server.liveSockets + unusedSockets := server.unusedSockets + server.liveSockets = nil + server.unusedSockets = nil + server.Unlock() + logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets)) + for i, s := range liveSockets { + s.Close() + liveSockets[i] = nil + } + for i := range unusedSockets { + unusedSockets[i] = nil + } +} + +// RecycleSocket puts socket back into the unused cache. +func (server *mongoServer) RecycleSocket(socket *mongoSocket) { + server.Lock() + if !server.closed { + server.unusedSockets = append(server.unusedSockets, socket) + } + server.Unlock() +} + +func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket { + for i, s := range sockets { + if s == socket { + copy(sockets[i:], sockets[i+1:]) + n := len(sockets) - 1 + sockets[n] = nil + sockets = sockets[:n] + break + } + } + return sockets +} + +// AbendSocket notifies the server that the given socket has terminated +// abnormally, and thus should be discarded rather than cached. +func (server *mongoServer) AbendSocket(socket *mongoSocket) { + server.Lock() + server.abended = true + if server.closed { + server.Unlock() + return + } + server.liveSockets = removeSocket(server.liveSockets, socket) + server.unusedSockets = removeSocket(server.unusedSockets, socket) + server.Unlock() + // Maybe just a timeout, but suggest a cluster sync up just in case. + select { + case server.sync <- true: + default: + } +} + +func (server *mongoServer) SetInfo(info *mongoServerInfo) { + server.Lock() + server.info = info + server.Unlock() +} + +func (server *mongoServer) Info() *mongoServerInfo { + server.Lock() + info := server.info + server.Unlock() + return info +} + +func (server *mongoServer) hasTags(serverTags []bson.D) bool { +NextTagSet: + for _, tags := range serverTags { + NextReqTag: + for _, req := range tags { + for _, has := range server.info.Tags { + if req.Name == has.Name { + if req.Value == has.Value { + continue NextReqTag + } + continue NextTagSet + } + } + continue NextTagSet + } + return true + } + return false +} + +var pingDelay = 15 * time.Second + +func (server *mongoServer) pinger(loop bool) { + var delay time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + delay = pingDelay + globalMutex.Unlock() + } else { + delay = pingDelay + } + op := queryOp{ + collection: "admin.$cmd", + query: bson.D{{"ping", 1}}, + flags: flagSlaveOk, + limit: -1, + } + for { + if loop { + time.Sleep(delay) + } + op := op + socket, _, err := server.AcquireSocket(0, delay) + if err == nil { + start := time.Now() + _, _ = socket.SimpleQuery(&op) + delay := time.Now().Sub(start) + + server.pingWindow[server.pingIndex] = delay + server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow) + server.pingCount++ + var max time.Duration + for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ { + if server.pingWindow[i] > max { + max = server.pingWindow[i] + } + } + socket.Release() + server.Lock() + if server.closed { + loop = false + } + server.pingValue = max + server.Unlock() + logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond) + } else if err == errServerClosed { + return + } + if !loop { + return + } + } +} + +type mongoServerSlice []*mongoServer + +func (s mongoServerSlice) Len() int { + return len(s) +} + +func (s mongoServerSlice) Less(i, j int) bool { + return s[i].ResolvedAddr < s[j].ResolvedAddr +} + +func (s mongoServerSlice) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s mongoServerSlice) Sort() { + sort.Sort(s) +} + +func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) { + n := len(s) + i = sort.Search(n, func(i int) bool { + return s[i].ResolvedAddr >= resolvedAddr + }) + return i, i != n && s[i].ResolvedAddr == resolvedAddr +} + +type mongoServers struct { + slice mongoServerSlice +} + +func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) { + if i, ok := servers.slice.Search(resolvedAddr); ok { + return servers.slice[i] + } + return nil +} + +func (servers *mongoServers) Add(server *mongoServer) { + servers.slice = append(servers.slice, server) + servers.slice.Sort() +} + +func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) { + if i, found := servers.slice.Search(other.ResolvedAddr); found { + server = servers.slice[i] + copy(servers.slice[i:], servers.slice[i+1:]) + n := len(servers.slice) - 1 + servers.slice[n] = nil // Help GC. + servers.slice = servers.slice[:n] + } + return +} + +func (servers *mongoServers) Slice() []*mongoServer { + return ([]*mongoServer)(servers.slice) +} + +func (servers *mongoServers) Get(i int) *mongoServer { + return servers.slice[i] +} + +func (servers *mongoServers) Len() int { + return len(servers.slice) +} + +func (servers *mongoServers) Empty() bool { + return len(servers.slice) == 0 +} + +// BestFit returns the best guess of what would be the most interesting +// server to perform operations on at this point in time. +func (servers *mongoServers) BestFit(mode Mode, serverTags []bson.D) *mongoServer { + var best *mongoServer + for _, next := range servers.slice { + if best == nil { + best = next + best.RLock() + if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + best.RUnlock() + best = nil + } + continue + } + next.RLock() + swap := false + switch { + case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + // Must have requested tags. + case next.info.Master != best.info.Master && mode != Nearest: + // Prefer slaves, unless the mode is PrimaryPreferred. + swap = (mode == PrimaryPreferred) != best.info.Master + case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond: + // Prefer nearest server. + swap = next.pingValue < best.pingValue + case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets): + // Prefer servers with less connections. + swap = true + } + if swap { + best.RUnlock() + best = next + } else { + next.RUnlock() + } + } + if best != nil { + best.RUnlock() + } + return best +} + +func absDuration(d time.Duration) time.Duration { + if d < 0 { + return -d + } + return d +} diff --git a/vendor/gopkg.in/mgo.v2/session.go b/vendor/gopkg.in/mgo.v2/session.go new file mode 100644 index 0000000..a8ad115 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/session.go @@ -0,0 +1,4722 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "math" + "net" + "net/url" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type Mode int + +const ( + // Relevant documentation on read preference modes: + // + // http://docs.mongodb.org/manual/reference/read-preference/ + // + Primary Mode = 2 // Default mode. All operations read from the current replica set primary. + PrimaryPreferred Mode = 3 // Read from the primary if available. Read from the secondary otherwise. + Secondary Mode = 4 // Read from one of the nearest secondary members of the replica set. + SecondaryPreferred Mode = 5 // Read from one of the nearest secondaries if available. Read from primary otherwise. + Nearest Mode = 6 // Read from one of the nearest members, irrespective of it being primary or secondary. + + // Read preference modes are specific to mgo: + Eventual Mode = 0 // Same as Nearest, but may change servers between reads. + Monotonic Mode = 1 // Same as SecondaryPreferred before first write. Same as Primary after first write. + Strong Mode = 2 // Same as Primary. +) + +// mgo.v3: Drop Strong mode, suffix all modes with "Mode". + +// When changing the Session type, check if newSession and copySession +// need to be updated too. + +// Session represents a communication session with the database. +// +// All Session methods are concurrency-safe and may be called from multiple +// goroutines. In all session modes but Eventual, using the session from +// multiple goroutines will cause them to share the same underlying socket. +// See the documentation on Session.SetMode for more details. +type Session struct { + m sync.RWMutex + cluster_ *mongoCluster + slaveSocket *mongoSocket + masterSocket *mongoSocket + slaveOk bool + consistency Mode + queryConfig query + safeOp *queryOp + syncTimeout time.Duration + sockTimeout time.Duration + defaultdb string + sourcedb string + dialCred *Credential + creds []Credential + poolLimit int + bypassValidation bool +} + +type Database struct { + Session *Session + Name string +} + +type Collection struct { + Database *Database + Name string // "collection" + FullName string // "db.collection" +} + +type Query struct { + m sync.Mutex + session *Session + query // Enables default settings in session. +} + +type query struct { + op queryOp + prefetch float64 + limit int32 +} + +type getLastError struct { + CmdName int "getLastError,omitempty" + W interface{} "w,omitempty" + WTimeout int "wtimeout,omitempty" + FSync bool "fsync,omitempty" + J bool "j,omitempty" +} + +type Iter struct { + m sync.Mutex + gotReply sync.Cond + session *Session + server *mongoServer + docData queue + err error + op getMoreOp + prefetch float64 + limit int32 + docsToReceive int + docsBeforeMore int + timeout time.Duration + timedout bool + findCmd bool +} + +var ( + ErrNotFound = errors.New("not found") + ErrCursor = errors.New("invalid cursor") +) + +const defaultPrefetch = 0.25 + +// Dial establishes a new session to the cluster identified by the given seed +// server(s). The session will enable communication with all of the servers in +// the cluster, so the seed servers are used only to find out about the cluster +// topology. +// +// Dial will timeout after 10 seconds if a server isn't reached. The returned +// session will timeout operations after one minute by default if servers +// aren't available. To customize the timeout, see DialWithTimeout, +// SetSyncTimeout, and SetSocketTimeout. +// +// This method is generally called just once for a given cluster. Further +// sessions to the same cluster are then established using the New or Copy +// methods on the obtained session. This will make them share the underlying +// cluster, and manage the pool of connections appropriately. +// +// Once the session is not useful anymore, Close must be called to release the +// resources appropriately. +// +// The seed servers must be provided in the following format: +// +// [mongodb://][user:pass@]host1[:port1][,host2[:port2],...][/database][?options] +// +// For example, it may be as simple as: +// +// localhost +// +// Or more involved like: +// +// mongodb://myuser:mypass@localhost:40001,otherhost:40001/mydb +// +// If the port number is not provided for a server, it defaults to 27017. +// +// The username and password provided in the URL will be used to authenticate +// into the database named after the slash at the end of the host names, or +// into the "admin" database if none is provided. The authentication information +// will persist in sessions obtained through the New method as well. +// +// The following connection options are supported after the question mark: +// +// connect=direct +// +// Disables the automatic replica set server discovery logic, and +// forces the use of servers provided only (even if secondaries). +// Note that to talk to a secondary the consistency requirements +// must be relaxed to Monotonic or Eventual via SetMode. +// +// +// connect=replicaSet +// +// Discover replica sets automatically. Default connection behavior. +// +// +// replicaSet=<setname> +// +// If specified will prevent the obtained session from communicating +// with any server which is not part of a replica set with the given name. +// The default is to communicate with any server specified or discovered +// via the servers contacted. +// +// +// authSource=<db> +// +// Informs the database used to establish credentials and privileges +// with a MongoDB server. Defaults to the database name provided via +// the URL path, and "admin" if that's unset. +// +// +// authMechanism=<mechanism> +// +// Defines the protocol for credential negotiation. Defaults to "MONGODB-CR", +// which is the default username/password challenge-response mechanism. +// +// +// gssapiServiceName=<name> +// +// Defines the service name to use when authenticating with the GSSAPI +// mechanism. Defaults to "mongodb". +// +// +// maxPoolSize=<limit> +// +// Defines the per-server socket pool limit. Defaults to 4096. +// See Session.SetPoolLimit for details. +// +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/connection-string/ +// +func Dial(url string) (*Session, error) { + session, err := DialWithTimeout(url, 10*time.Second) + if err == nil { + session.SetSyncTimeout(1 * time.Minute) + session.SetSocketTimeout(1 * time.Minute) + } + return session, err +} + +// DialWithTimeout works like Dial, but uses timeout as the amount of time to +// wait for a server to respond when first connecting and also on follow up +// operations in the session. If timeout is zero, the call may block +// forever waiting for a connection to be made. +// +// See SetSyncTimeout for customizing the timeout for the session. +func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { + info, err := ParseURL(url) + if err != nil { + return nil, err + } + info.Timeout = timeout + return DialWithInfo(info) +} + +// ParseURL parses a MongoDB URL as accepted by the Dial function and returns +// a value suitable for providing into DialWithInfo. +// +// See Dial for more details on the format of url. +func ParseURL(url string) (*DialInfo, error) { + uinfo, err := extractURL(url) + if err != nil { + return nil, err + } + direct := false + mechanism := "" + service := "" + source := "" + setName := "" + poolLimit := 0 + for k, v := range uinfo.options { + switch k { + case "authSource": + source = v + case "authMechanism": + mechanism = v + case "gssapiServiceName": + service = v + case "replicaSet": + setName = v + case "maxPoolSize": + poolLimit, err = strconv.Atoi(v) + if err != nil { + return nil, errors.New("bad value for maxPoolSize: " + v) + } + case "connect": + if v == "direct" { + direct = true + break + } + if v == "replicaSet" { + break + } + fallthrough + default: + return nil, errors.New("unsupported connection URL option: " + k + "=" + v) + } + } + info := DialInfo{ + Addrs: uinfo.addrs, + Direct: direct, + Database: uinfo.db, + Username: uinfo.user, + Password: uinfo.pass, + Mechanism: mechanism, + Service: service, + Source: source, + PoolLimit: poolLimit, + ReplicaSetName: setName, + } + return &info, nil +} + +// DialInfo holds options for establishing a session with a MongoDB cluster. +// To use a URL, see the Dial function. +type DialInfo struct { + // Addrs holds the addresses for the seed servers. + Addrs []string + + // Direct informs whether to establish connections only with the + // specified seed servers, or to obtain information for the whole + // cluster and establish connections with further servers too. + Direct bool + + // Timeout is the amount of time to wait for a server to respond when + // first connecting and on follow up operations in the session. If + // timeout is zero, the call may block forever waiting for a connection + // to be established. Timeout does not affect logic in DialServer. + Timeout time.Duration + + // FailFast will cause connection and query attempts to fail faster when + // the server is unavailable, instead of retrying until the configured + // timeout period. Note that an unavailable server may silently drop + // packets instead of rejecting them, in which case it's impossible to + // distinguish it from a slow server, so the timeout stays relevant. + FailFast bool + + // Database is the default database name used when the Session.DB method + // is called with an empty name, and is also used during the initial + // authentication if Source is unset. + Database string + + // ReplicaSetName, if specified, will prevent the obtained session from + // communicating with any server which is not part of a replica set + // with the given name. The default is to communicate with any server + // specified or discovered via the servers contacted. + ReplicaSetName string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the value of Database, if that is + // set, or "admin" otherwise. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string + + // Username and Password inform the credentials for the initial authentication + // done on the database defined by the Source field. See Session.Login. + Username string + Password string + + // PoolLimit defines the per-server socket pool limit. Defaults to 4096. + // See Session.SetPoolLimit for details. + PoolLimit int + + // DialServer optionally specifies the dial function for establishing + // connections with the MongoDB servers. + DialServer func(addr *ServerAddr) (net.Conn, error) + + // WARNING: This field is obsolete. See DialServer above. + Dial func(addr net.Addr) (net.Conn, error) +} + +// mgo.v3: Drop DialInfo.Dial. + +// ServerAddr represents the address for establishing a connection to an +// individual MongoDB server. +type ServerAddr struct { + str string + tcp *net.TCPAddr +} + +// String returns the address that was provided for the server before resolution. +func (addr *ServerAddr) String() string { + return addr.str +} + +// TCPAddr returns the resolved TCP address for the server. +func (addr *ServerAddr) TCPAddr() *net.TCPAddr { + return addr.tcp +} + +// DialWithInfo establishes a new session to the cluster identified by info. +func DialWithInfo(info *DialInfo) (*Session, error) { + addrs := make([]string, len(info.Addrs)) + for i, addr := range info.Addrs { + p := strings.LastIndexAny(addr, "]:") + if p == -1 || addr[p] != ':' { + // XXX This is untested. The test suite doesn't use the standard port. + addr += ":27017" + } + addrs[i] = addr + } + cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName) + session := newSession(Eventual, cluster, info.Timeout) + session.defaultdb = info.Database + if session.defaultdb == "" { + session.defaultdb = "test" + } + session.sourcedb = info.Source + if session.sourcedb == "" { + session.sourcedb = info.Database + if session.sourcedb == "" { + session.sourcedb = "admin" + } + } + if info.Username != "" { + source := session.sourcedb + if info.Source == "" && + (info.Mechanism == "GSSAPI" || info.Mechanism == "PLAIN" || info.Mechanism == "MONGODB-X509") { + source = "$external" + } + session.dialCred = &Credential{ + Username: info.Username, + Password: info.Password, + Mechanism: info.Mechanism, + Service: info.Service, + ServiceHost: info.ServiceHost, + Source: source, + } + session.creds = []Credential{*session.dialCred} + } + if info.PoolLimit > 0 { + session.poolLimit = info.PoolLimit + } + cluster.Release() + + // People get confused when we return a session that is not actually + // established to any servers yet (e.g. what if url was wrong). So, + // ping the server to ensure there's someone there, and abort if it + // fails. + if err := session.Ping(); err != nil { + session.Close() + return nil, err + } + session.SetMode(Strong, true) + return session, nil +} + +func isOptSep(c rune) bool { + return c == ';' || c == '&' +} + +type urlInfo struct { + addrs []string + user string + pass string + db string + options map[string]string +} + +func extractURL(s string) (*urlInfo, error) { + if strings.HasPrefix(s, "mongodb://") { + s = s[10:] + } + info := &urlInfo{options: make(map[string]string)} + if c := strings.Index(s, "?"); c != -1 { + for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { + l := strings.SplitN(pair, "=", 2) + if len(l) != 2 || l[0] == "" || l[1] == "" { + return nil, errors.New("connection option must be key=value: " + pair) + } + info.options[l[0]] = l[1] + } + s = s[:c] + } + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return nil, errors.New("credentials must be provided as user:pass@host") + } + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("cannot unescape password in URL") + } + } + s = s[c+1:] + } + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] + } + info.addrs = strings.Split(s, ",") + return info, nil +} + +func newSession(consistency Mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { + cluster.Acquire() + session = &Session{ + cluster_: cluster, + syncTimeout: timeout, + sockTimeout: timeout, + poolLimit: 4096, + } + debugf("New session %p on cluster %p", session, cluster) + session.SetMode(consistency, true) + session.SetSafe(&Safe{}) + session.queryConfig.prefetch = defaultPrefetch + return session +} + +func copySession(session *Session, keepCreds bool) (s *Session) { + cluster := session.cluster() + cluster.Acquire() + if session.masterSocket != nil { + session.masterSocket.Acquire() + } + if session.slaveSocket != nil { + session.slaveSocket.Acquire() + } + var creds []Credential + if keepCreds { + creds = make([]Credential, len(session.creds)) + copy(creds, session.creds) + } else if session.dialCred != nil { + creds = []Credential{*session.dialCred} + } + scopy := *session + scopy.m = sync.RWMutex{} + scopy.creds = creds + s = &scopy + debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) + return s +} + +// LiveServers returns a list of server addresses which are +// currently known to be alive. +func (s *Session) LiveServers() (addrs []string) { + s.m.RLock() + addrs = s.cluster().LiveServers() + s.m.RUnlock() + return addrs +} + +// DB returns a value representing the named database. If name +// is empty, the database name provided in the dialed URL is +// used instead. If that is also empty, "test" is used as a +// fallback in a way equivalent to the mongo shell. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (s *Session) DB(name string) *Database { + if name == "" { + name = s.defaultdb + } + return &Database{s, name} +} + +// C returns a value representing the named collection. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (db *Database) C(name string) *Collection { + return &Collection{db, name, db.Name + "." + name} +} + +// With returns a copy of db that uses session s. +func (db *Database) With(s *Session) *Database { + newdb := *db + newdb.Session = s + return &newdb +} + +// With returns a copy of c that uses session s. +func (c *Collection) With(s *Session) *Collection { + newdb := *c.Database + newdb.Session = s + newc := *c + newc.Database = &newdb + return &newc +} + +// GridFS returns a GridFS value representing collections in db that +// follow the standard GridFS specification. +// The provided prefix (sometimes known as root) will determine which +// collections to use, and is usually set to "fs" when there is a +// single GridFS in the database. +// +// See the GridFS Create, Open, and OpenId methods for more details. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/GridFS +// http://www.mongodb.org/display/DOCS/GridFS+Tools +// http://www.mongodb.org/display/DOCS/GridFS+Specification +// +func (db *Database) GridFS(prefix string) *GridFS { + return newGridFS(db, prefix) +} + +// Run issues the provided command on the db database and unmarshals +// its result in the respective argument. The cmd argument may be either +// a string with the command name itself, in which case an empty document of +// the form bson.M{cmd: 1} will be used, or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For privilleged commands typically run on the "admin" database, see +// the Run method in the Session type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (db *Database) Run(cmd interface{}, result interface{}) error { + socket, err := db.Session.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + // This is an optimized form of db.C("$cmd").Find(cmd).One(result). + return db.run(socket, cmd, result) +} + +// Credential holds details to authenticate with a MongoDB server. +type Credential struct { + // Username and Password hold the basic details for authentication. + // Password is optional with some authentication mechanisms. + Username string + Password string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the default database provided + // during dial, or "admin" if that was unset. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (db *Database) Login(user, pass string) error { + return db.Session.Login(&Credential{Username: user, Password: pass, Source: db.Name}) +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (s *Session) Login(cred *Credential) error { + socket, err := s.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + credCopy := *cred + if cred.Source == "" { + if cred.Mechanism == "GSSAPI" { + credCopy.Source = "$external" + } else { + credCopy.Source = s.sourcedb + } + } + err = socket.Login(credCopy) + if err != nil { + return err + } + + s.m.Lock() + s.creds = append(s.creds, credCopy) + s.m.Unlock() + return nil +} + +func (s *Session) socketLogin(socket *mongoSocket) error { + for _, cred := range s.creds { + if err := socket.Login(cred); err != nil { + return err + } + } + return nil +} + +// Logout removes any established authentication credentials for the database. +func (db *Database) Logout() { + session := db.Session + dbname := db.Name + session.m.Lock() + found := false + for i, cred := range session.creds { + if cred.Source == dbname { + copy(session.creds[i:], session.creds[i+1:]) + session.creds = session.creds[:len(session.creds)-1] + found = true + break + } + } + if found { + if session.masterSocket != nil { + session.masterSocket.Logout(dbname) + } + if session.slaveSocket != nil { + session.slaveSocket.Logout(dbname) + } + } + session.m.Unlock() +} + +// LogoutAll removes all established authentication credentials for the session. +func (s *Session) LogoutAll() { + s.m.Lock() + for _, cred := range s.creds { + if s.masterSocket != nil { + s.masterSocket.Logout(cred.Source) + } + if s.slaveSocket != nil { + s.slaveSocket.Logout(cred.Source) + } + } + s.creds = s.creds[0:0] + s.m.Unlock() +} + +// User represents a MongoDB user. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// http://docs.mongodb.org/manual/reference/user-privileges/ +// +type User struct { + // Username is how the user identifies itself to the system. + Username string `bson:"user"` + + // Password is the plaintext password for the user. If set, + // the UpsertUser method will hash it into PasswordHash and + // unset it before the user is added to the database. + Password string `bson:",omitempty"` + + // PasswordHash is the MD5 hash of Username+":mongo:"+Password. + PasswordHash string `bson:"pwd,omitempty"` + + // CustomData holds arbitrary data admins decide to associate + // with this user, such as the full name or employee id. + CustomData interface{} `bson:"customData,omitempty"` + + // Roles indicates the set of roles the user will be provided. + // See the Role constants. + Roles []Role `bson:"roles"` + + // OtherDBRoles allows assigning roles in other databases from + // user documents inserted in the admin database. This field + // only works in the admin database. + OtherDBRoles map[string][]Role `bson:"otherDBRoles,omitempty"` + + // UserSource indicates where to look for this user's credentials. + // It may be set to a database name, or to "$external" for + // consulting an external resource such as Kerberos. UserSource + // must not be set if Password or PasswordHash are present. + // + // WARNING: This setting was only ever supported in MongoDB 2.4, + // and is now obsolete. + UserSource string `bson:"userSource,omitempty"` +} + +type Role string + +const ( + // Relevant documentation: + // + // http://docs.mongodb.org/manual/reference/user-privileges/ + // + RoleRoot Role = "root" + RoleRead Role = "read" + RoleReadAny Role = "readAnyDatabase" + RoleReadWrite Role = "readWrite" + RoleReadWriteAny Role = "readWriteAnyDatabase" + RoleDBAdmin Role = "dbAdmin" + RoleDBAdminAny Role = "dbAdminAnyDatabase" + RoleUserAdmin Role = "userAdmin" + RoleUserAdminAny Role = "userAdminAnyDatabase" + RoleClusterAdmin Role = "clusterAdmin" +) + +// UpsertUser updates the authentication credentials and the roles for +// a MongoDB user within the db database. If the named user doesn't exist +// it will be created. +// +// This method should only be used from MongoDB 2.4 and on. For older +// MongoDB releases, use the obsolete AddUser method instead. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/user-privileges/ +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// +func (db *Database) UpsertUser(user *User) error { + if user.Username == "" { + return fmt.Errorf("user has no Username") + } + if (user.Password != "" || user.PasswordHash != "") && user.UserSource != "" { + return fmt.Errorf("user has both Password/PasswordHash and UserSource set") + } + if len(user.OtherDBRoles) > 0 && db.Name != "admin" && db.Name != "$external" { + return fmt.Errorf("user with OtherDBRoles is only supported in the admin or $external databases") + } + + // Attempt to run this using 2.6+ commands. + rundb := db + if user.UserSource != "" { + // Compatibility logic for the userSource field of MongoDB <= 2.4.X + rundb = db.Session.DB(user.UserSource) + } + err := rundb.runUserCmd("updateUser", user) + // retry with createUser when isAuthError in order to enable the "localhost exception" + if isNotFound(err) || isAuthError(err) { + return rundb.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command does not exist. Fallback to pre-2.6 behavior. + var set, unset bson.D + if user.Password != "" { + psum := md5.New() + psum.Write([]byte(user.Username + ":mongo:" + user.Password)) + set = append(set, bson.DocElem{"pwd", hex.EncodeToString(psum.Sum(nil))}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } else if user.PasswordHash != "" { + set = append(set, bson.DocElem{"pwd", user.PasswordHash}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } + if user.UserSource != "" { + set = append(set, bson.DocElem{"userSource", user.UserSource}) + unset = append(unset, bson.DocElem{"pwd", 1}) + } + if user.Roles != nil || user.OtherDBRoles != nil { + set = append(set, bson.DocElem{"roles", user.Roles}) + if len(user.OtherDBRoles) > 0 { + set = append(set, bson.DocElem{"otherDBRoles", user.OtherDBRoles}) + } else { + unset = append(unset, bson.DocElem{"otherDBRoles", 1}) + } + } + users := db.C("system.users") + err = users.Update(bson.D{{"user", user.Username}}, bson.D{{"$unset", unset}, {"$set", set}}) + if err == ErrNotFound { + set = append(set, bson.DocElem{"user", user.Username}) + if user.Roles == nil && user.OtherDBRoles == nil { + // Roles must be sent, as it's the way MongoDB distinguishes + // old-style documents from new-style documents in pre-2.6. + set = append(set, bson.DocElem{"roles", user.Roles}) + } + err = users.Insert(set) + } + return err +} + +func isNoCmd(err error) bool { + e, ok := err.(*QueryError) + return ok && (e.Code == 59 || e.Code == 13390 || strings.HasPrefix(e.Message, "no such cmd:")) +} + +func isNotFound(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 11 +} + +func isAuthError(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 13 +} + +func (db *Database) runUserCmd(cmdName string, user *User) error { + cmd := make(bson.D, 0, 16) + cmd = append(cmd, bson.DocElem{cmdName, user.Username}) + if user.Password != "" { + cmd = append(cmd, bson.DocElem{"pwd", user.Password}) + } + var roles []interface{} + for _, role := range user.Roles { + roles = append(roles, role) + } + for db, dbroles := range user.OtherDBRoles { + for _, role := range dbroles { + roles = append(roles, bson.D{{"role", role}, {"db", db}}) + } + } + if roles != nil || user.Roles != nil || cmdName == "createUser" { + cmd = append(cmd, bson.DocElem{"roles", roles}) + } + err := db.Run(cmd, nil) + if !isNoCmd(err) && user.UserSource != "" && (user.UserSource != "$external" || db.Name != "$external") { + return fmt.Errorf("MongoDB 2.6+ does not support the UserSource setting") + } + return err +} + +// AddUser creates or updates the authentication credentials of user within +// the db database. +// +// WARNING: This method is obsolete and should only be used with MongoDB 2.2 +// or earlier. For MongoDB 2.4 and on, use UpsertUser instead. +func (db *Database) AddUser(username, password string, readOnly bool) error { + // Try to emulate the old behavior on 2.6+ + user := &User{Username: username, Password: password} + if db.Name == "admin" { + if readOnly { + user.Roles = []Role{RoleReadAny} + } else { + user.Roles = []Role{RoleReadWriteAny} + } + } else { + if readOnly { + user.Roles = []Role{RoleRead} + } else { + user.Roles = []Role{RoleReadWrite} + } + } + err := db.runUserCmd("updateUser", user) + if isNotFound(err) { + return db.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command doesn't exist. Fallback to pre-2.6 behavior. + psum := md5.New() + psum.Write([]byte(username + ":mongo:" + password)) + digest := hex.EncodeToString(psum.Sum(nil)) + c := db.C("system.users") + _, err = c.Upsert(bson.M{"user": username}, bson.M{"$set": bson.M{"user": username, "pwd": digest, "readOnly": readOnly}}) + return err +} + +// RemoveUser removes the authentication credentials of user from the database. +func (db *Database) RemoveUser(user string) error { + err := db.Run(bson.D{{"dropUser", user}}, nil) + if isNoCmd(err) { + users := db.C("system.users") + return users.Remove(bson.M{"user": user}) + } + if isNotFound(err) { + return ErrNotFound + } + return err +} + +type indexSpec struct { + Name, NS string + Key bson.D + Unique bool ",omitempty" + DropDups bool "dropDups,omitempty" + Background bool ",omitempty" + Sparse bool ",omitempty" + Bits int ",omitempty" + Min, Max float64 ",omitempty" + BucketSize float64 "bucketSize,omitempty" + ExpireAfter int "expireAfterSeconds,omitempty" + Weights bson.D ",omitempty" + DefaultLanguage string "default_language,omitempty" + LanguageOverride string "language_override,omitempty" + TextIndexVersion int "textIndexVersion,omitempty" +} + +type Index struct { + Key []string // Index key fields; prefix name with dash (-) for descending order + Unique bool // Prevent two documents from having the same index key + DropDups bool // Drop documents with the same index key as a previously indexed one + Background bool // Build index in background and return immediately + Sparse bool // Only index documents containing the Key fields + + // If ExpireAfter is defined the server will periodically delete + // documents with indexed time.Time older than the provided delta. + ExpireAfter time.Duration + + // Name holds the stored index name. On creation if this field is unset it is + // computed by EnsureIndex based on the index key. + Name string + + // Properties for spatial indexes. + // + // Min and Max were improperly typed as int when they should have been + // floats. To preserve backwards compatibility they are still typed as + // int and the following two fields enable reading and writing the same + // fields as float numbers. In mgo.v3, these fields will be dropped and + // Min/Max will become floats. + Min, Max int + Minf, Maxf float64 + BucketSize float64 + Bits int + + // Properties for text indexes. + DefaultLanguage string + LanguageOverride string + + // Weights defines the significance of provided fields relative to other + // fields in a text index. The score for a given word in a document is derived + // from the weighted sum of the frequency for each of the indexed fields in + // that document. The default field weight is 1. + Weights map[string]int +} + +// mgo.v3: Drop Minf and Maxf and transform Min and Max to floats. +// mgo.v3: Drop DropDups as it's unsupported past 2.8. + +type indexKeyInfo struct { + name string + key bson.D + weights bson.D +} + +func parseIndexKey(key []string) (*indexKeyInfo, error) { + var keyInfo indexKeyInfo + isText := false + var order interface{} + for _, field := range key { + raw := field + if keyInfo.name != "" { + keyInfo.name += "_" + } + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + keyInfo.name += field + "_" + kind + } else { + field = "\x00" + } + } + switch field[0] { + case 0: + // Logic above failed. Reset and error. + field = "" + case '@': + order = "2d" + field = field[1:] + // The shell used to render this field as key_ instead of key_2d, + // and mgo followed suit. This has been fixed in recent server + // releases, and mgo followed as well. + keyInfo.name += field + "_2d" + case '-': + order = -1 + field = field[1:] + keyInfo.name += field + "_-1" + case '+': + field = field[1:] + fallthrough + default: + if kind == "" { + order = 1 + keyInfo.name += field + "_1" + } else { + order = kind + } + } + } + if field == "" || kind != "" && order != kind { + return nil, fmt.Errorf(`invalid index key: want "[$<kind>:][-]<field name>", got %q`, raw) + } + if kind == "text" { + if !isText { + keyInfo.key = append(keyInfo.key, bson.DocElem{"_fts", "text"}, bson.DocElem{"_ftsx", 1}) + isText = true + } + keyInfo.weights = append(keyInfo.weights, bson.DocElem{field, 1}) + } else { + keyInfo.key = append(keyInfo.key, bson.DocElem{field, order}) + } + } + if keyInfo.name == "" { + return nil, errors.New("invalid index key: no fields provided") + } + return &keyInfo, nil +} + +// EnsureIndexKey ensures an index with the given key exists, creating it +// if necessary. +// +// This example: +// +// err := collection.EnsureIndexKey("a", "b") +// +// Is equivalent to: +// +// err := collection.EnsureIndex(mgo.Index{Key: []string{"a", "b"}}) +// +// See the EnsureIndex method for more details. +func (c *Collection) EnsureIndexKey(key ...string) error { + return c.EnsureIndex(Index{Key: key}) +} + +// EnsureIndex ensures an index with the given key exists, creating it with +// the provided parameters if necessary. EnsureIndex does not modify a previously +// existent index with a matching key. The old index must be dropped first instead. +// +// Once EnsureIndex returns successfully, following requests for the same index +// will not contact the server unless Collection.DropIndex is used to drop the +// same index, or Session.ResetIndexCache is called. +// +// For example: +// +// index := Index{ +// Key: []string{"lastname", "firstname"}, +// Unique: true, +// DropDups: true, +// Background: true, // See notes. +// Sparse: true, +// } +// err := collection.EnsureIndex(index) +// +// The Key value determines which fields compose the index. The index ordering +// will be ascending by default. To obtain an index with a descending order, +// the field name should be prefixed by a dash (e.g. []string{"-time"}). It can +// also be optionally prefixed by an index kind, as in "$text:summary" or +// "$2d:-point". The key string format is: +// +// [$<kind>:][-]<field name> +// +// If the Unique field is true, the index must necessarily contain only a single +// document per Key. With DropDups set to true, documents with the same key +// as a previously indexed one will be dropped rather than an error returned. +// +// If Background is true, other connections will be allowed to proceed using +// the collection without the index while it's being built. Note that the +// session executing EnsureIndex will be blocked for as long as it takes for +// the index to be built. +// +// If Sparse is true, only documents containing the provided Key fields will be +// included in the index. When using a sparse index for sorting, only indexed +// documents will be returned. +// +// If ExpireAfter is non-zero, the server will periodically scan the collection +// and remove documents containing an indexed time.Time field with a value +// older than ExpireAfter. See the documentation for details: +// +// http://docs.mongodb.org/manual/tutorial/expire-data +// +// Other kinds of indexes are also supported through that API. Here is an example: +// +// index := Index{ +// Key: []string{"$2d:loc"}, +// Bits: 26, +// } +// err := collection.EnsureIndex(index) +// +// The example above requests the creation of a "2d" index for the "loc" field. +// +// The 2D index bounds may be changed using the Min and Max attributes of the +// Index value. The default bound setting of (-180, 180) is suitable for +// latitude/longitude pairs. +// +// The Bits parameter sets the precision of the 2D geohash values. If not +// provided, 26 bits are used, which is roughly equivalent to 1 foot of +// precision for the default (-180, 180) index bounds. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Indexes +// http://www.mongodb.org/display/DOCS/Indexing+Advice+and+FAQ +// http://www.mongodb.org/display/DOCS/Indexing+as+a+Background+Operation +// http://www.mongodb.org/display/DOCS/Geospatial+Indexing +// http://www.mongodb.org/display/DOCS/Multikeys +// +func (c *Collection) EnsureIndex(index Index) error { + keyInfo, err := parseIndexKey(index.Key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + if session.cluster().HasCachedIndex(cacheKey) { + return nil + } + + spec := indexSpec{ + Name: keyInfo.name, + NS: c.FullName, + Key: keyInfo.key, + Unique: index.Unique, + DropDups: index.DropDups, + Background: index.Background, + Sparse: index.Sparse, + Bits: index.Bits, + Min: index.Minf, + Max: index.Maxf, + BucketSize: index.BucketSize, + ExpireAfter: int(index.ExpireAfter / time.Second), + Weights: keyInfo.weights, + DefaultLanguage: index.DefaultLanguage, + LanguageOverride: index.LanguageOverride, + } + + if spec.Min == 0 && spec.Max == 0 { + spec.Min = float64(index.Min) + spec.Max = float64(index.Max) + } + + if index.Name != "" { + spec.Name = index.Name + } + +NextField: + for name, weight := range index.Weights { + for i, elem := range spec.Weights { + if elem.Name == name { + spec.Weights[i].Value = weight + continue NextField + } + } + panic("weight provided for field that is not part of index key: " + name) + } + + cloned := session.Clone() + defer cloned.Close() + cloned.SetMode(Strong, false) + cloned.EnsureSafe(&Safe{}) + db := c.Database.With(cloned) + + // Try with a command first. + err = db.Run(bson.D{{"createIndexes", c.Name}, {"indexes", []indexSpec{spec}}}, nil) + if isNoCmd(err) { + // Command not yet supported. Insert into the indexes collection instead. + err = db.C("system.indexes").Insert(&spec) + } + if err == nil { + session.cluster().CacheIndex(cacheKey, true) + } + return err +} + +// DropIndex drops the index with the provided key from the c collection. +// +// See EnsureIndex for details on the accepted key variants. +// +// For example: +// +// err1 := collection.DropIndex("firstField", "-secondField") +// err2 := collection.DropIndex("customIndexName") +// +func (c *Collection) DropIndex(key ...string) error { + keyInfo, err := parseIndexKey(key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + session.cluster().CacheIndex(cacheKey, false) + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + db := c.Database.With(session) + result := struct { + ErrMsg string + Ok bool + }{} + err = db.Run(bson.D{{"dropIndexes", c.Name}, {"index", keyInfo.name}}, &result) + if err != nil { + return err + } + if !result.Ok { + return errors.New(result.ErrMsg) + } + return nil +} + +// DropIndexName removes the index with the provided index name. +// +// For example: +// +// err := collection.DropIndex("customIndexName") +// +func (c *Collection) DropIndexName(name string) error { + session := c.Database.Session + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + c = c.With(session) + + indexes, err := c.Indexes() + if err != nil { + return err + } + + var index Index + for _, idx := range indexes { + if idx.Name == name { + index = idx + break + } + } + + if index.Name != "" { + keyInfo, err := parseIndexKey(index.Key) + if err != nil { + return err + } + + cacheKey := c.FullName + "\x00" + keyInfo.name + session.cluster().CacheIndex(cacheKey, false) + } + + result := struct { + ErrMsg string + Ok bool + }{} + err = c.Database.Run(bson.D{{"dropIndexes", c.Name}, {"index", name}}, &result) + if err != nil { + return err + } + if !result.Ok { + return errors.New(result.ErrMsg) + } + return nil +} + +// nonEventual returns a clone of session and ensures it is not Eventual. +// This guarantees that the server that is used for queries may be reused +// afterwards when a cursor is received. +func (session *Session) nonEventual() *Session { + cloned := session.Clone() + if cloned.consistency == Eventual { + cloned.SetMode(Monotonic, false) + } + return cloned +} + +// Indexes returns a list of all indexes for the collection. +// +// For example, this snippet would drop all available indexes: +// +// indexes, err := collection.Indexes() +// if err != nil { +// return err +// } +// for _, index := range indexes { +// err = collection.DropIndex(index.Key...) +// if err != nil { +// return err +// } +// } +// +// See the EnsureIndex method for more details on indexes. +func (c *Collection) Indexes() (indexes []Index, err error) { + cloned := c.Database.Session.nonEventual() + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + // Try with a command. + var result struct { + Indexes []bson.Raw + Cursor cursorData + } + var iter *Iter + err = c.Database.With(cloned).Run(bson.D{{"listIndexes", c.Name}, {"cursor", bson.D{{"batchSize", batchSize}}}}, &result) + if err == nil { + firstBatch := result.Indexes + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + ns := strings.SplitN(result.Cursor.NS, ".", 2) + if len(ns) < 2 { + iter = c.With(cloned).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } else { + iter = cloned.DB(ns[0]).C(ns[1]).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } + } else if isNoCmd(err) { + // Command not yet supported. Query the database instead. + iter = c.Database.C("system.indexes").Find(bson.M{"ns": c.FullName}).Iter() + } else { + return nil, err + } + + var spec indexSpec + for iter.Next(&spec) { + indexes = append(indexes, indexFromSpec(spec)) + } + if err = iter.Close(); err != nil { + return nil, err + } + sort.Sort(indexSlice(indexes)) + return indexes, nil +} + +func indexFromSpec(spec indexSpec) Index { + index := Index{ + Name: spec.Name, + Key: simpleIndexKey(spec.Key), + Unique: spec.Unique, + DropDups: spec.DropDups, + Background: spec.Background, + Sparse: spec.Sparse, + Minf: spec.Min, + Maxf: spec.Max, + Bits: spec.Bits, + BucketSize: spec.BucketSize, + DefaultLanguage: spec.DefaultLanguage, + LanguageOverride: spec.LanguageOverride, + ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, + } + if float64(int(spec.Min)) == spec.Min && float64(int(spec.Max)) == spec.Max { + index.Min = int(spec.Min) + index.Max = int(spec.Max) + } + if spec.TextIndexVersion > 0 { + index.Key = make([]string, len(spec.Weights)) + index.Weights = make(map[string]int) + for i, elem := range spec.Weights { + index.Key[i] = "$text:" + elem.Name + if w, ok := elem.Value.(int); ok { + index.Weights[elem.Name] = w + } + } + } + return index +} + +type indexSlice []Index + +func (idxs indexSlice) Len() int { return len(idxs) } +func (idxs indexSlice) Less(i, j int) bool { return idxs[i].Name < idxs[j].Name } +func (idxs indexSlice) Swap(i, j int) { idxs[i], idxs[j] = idxs[j], idxs[i] } + +func simpleIndexKey(realKey bson.D) (key []string) { + for i := range realKey { + field := realKey[i].Name + vi, ok := realKey[i].Value.(int) + if !ok { + vf, _ := realKey[i].Value.(float64) + vi = int(vf) + } + if vi == 1 { + key = append(key, field) + continue + } + if vi == -1 { + key = append(key, "-"+field) + continue + } + if vs, ok := realKey[i].Value.(string); ok { + key = append(key, "$"+vs+":"+field) + continue + } + panic("Got unknown index key type for field " + field) + } + return +} + +// ResetIndexCache() clears the cache of previously ensured indexes. +// Following requests to EnsureIndex will contact the server. +func (s *Session) ResetIndexCache() { + s.cluster().ResetIndexCache() +} + +// New creates a new session with the same parameters as the original +// session, including consistency, batch size, prefetching, safety mode, +// etc. The returned session will use sockets from the pool, so there's +// a chance that writes just performed in another session may not yet +// be visible. +// +// Login information from the original session will not be copied over +// into the new session unless it was provided through the initial URL +// for the Dial function. +// +// See the Copy and Clone methods. +// +func (s *Session) New() *Session { + s.m.Lock() + scopy := copySession(s, false) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Copy works just like New, but preserves the exact authentication +// information from the original session. +func (s *Session) Copy() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Clone works just like Copy, but also reuses the same socket as the original +// session, in case it had already reserved one due to its consistency +// guarantees. This behavior ensures that writes performed in the old session +// are necessarily observed when using the new session, as long as it was a +// strong or monotonic session. That said, it also means that long operations +// may cause other goroutines using the original session to wait. +func (s *Session) Clone() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + return scopy +} + +// Close terminates the session. It's a runtime error to use a session +// after it has been closed. +func (s *Session) Close() { + s.m.Lock() + if s.cluster_ != nil { + debugf("Closing session %p", s) + s.unsetSocket() + s.cluster_.Release() + s.cluster_ = nil + } + s.m.Unlock() +} + +func (s *Session) cluster() *mongoCluster { + if s.cluster_ == nil { + panic("Session already closed") + } + return s.cluster_ +} + +// Refresh puts back any reserved sockets in use and restarts the consistency +// guarantees according to the current consistency setting for the session. +func (s *Session) Refresh() { + s.m.Lock() + s.slaveOk = s.consistency != Strong + s.unsetSocket() + s.m.Unlock() +} + +// SetMode changes the consistency mode for the session. +// +// In the Strong consistency mode reads and writes will always be made to +// the primary server using a unique connection so that reads and writes are +// fully consistent, ordered, and observing the most up-to-date data. +// This offers the least benefits in terms of distributing load, but the +// most guarantees. See also Monotonic and Eventual. +// +// In the Monotonic consistency mode reads may not be entirely up-to-date, +// but they will always see the history of changes moving forward, the data +// read will be consistent across sequential queries in the same session, +// and modifications made within the session will be observed in following +// queries (read-your-writes). +// +// In practice, the Monotonic mode is obtained by performing initial reads +// on a unique connection to an arbitrary secondary, if one is available, +// and once the first write happens, the session connection is switched over +// to the primary server. This manages to distribute some of the reading +// load with secondaries, while maintaining some useful guarantees. +// +// In the Eventual consistency mode reads will be made to any secondary in the +// cluster, if one is available, and sequential reads will not necessarily +// be made with the same connection. This means that data may be observed +// out of order. Writes will of course be issued to the primary, but +// independent writes in the same Eventual session may also be made with +// independent connections, so there are also no guarantees in terms of +// write ordering (no read-your-writes guarantees either). +// +// The Eventual mode is the fastest and most resource-friendly, but is +// also the one offering the least guarantees about ordering of the data +// read and written. +// +// If refresh is true, in addition to ensuring the session is in the given +// consistency mode, the consistency guarantees will also be reset (e.g. +// a Monotonic session will be allowed to read from secondaries again). +// This is equivalent to calling the Refresh function. +// +// Shifting between Monotonic and Strong modes will keep a previously +// reserved connection for the session unless refresh is true or the +// connection is unsuitable (to a secondary server in a Strong session). +func (s *Session) SetMode(consistency Mode, refresh bool) { + s.m.Lock() + debugf("Session %p: setting mode %d with refresh=%v (master=%p, slave=%p)", s, consistency, refresh, s.masterSocket, s.slaveSocket) + s.consistency = consistency + if refresh { + s.slaveOk = s.consistency != Strong + s.unsetSocket() + } else if s.consistency == Strong { + s.slaveOk = false + } else if s.masterSocket == nil { + s.slaveOk = true + } + s.m.Unlock() +} + +// Mode returns the current consistency mode for the session. +func (s *Session) Mode() Mode { + s.m.RLock() + mode := s.consistency + s.m.RUnlock() + return mode +} + +// SetSyncTimeout sets the amount of time an operation with this session +// will wait before returning an error in case a connection to a usable +// server can't be established. Set it to zero to wait forever. The +// default value is 7 seconds. +func (s *Session) SetSyncTimeout(d time.Duration) { + s.m.Lock() + s.syncTimeout = d + s.m.Unlock() +} + +// SetSocketTimeout sets the amount of time to wait for a non-responding +// socket to the database before it is forcefully closed. +func (s *Session) SetSocketTimeout(d time.Duration) { + s.m.Lock() + s.sockTimeout = d + if s.masterSocket != nil { + s.masterSocket.SetTimeout(d) + } + if s.slaveSocket != nil { + s.slaveSocket.SetTimeout(d) + } + s.m.Unlock() +} + +// SetCursorTimeout changes the standard timeout period that the server +// enforces on created cursors. The only supported value right now is +// 0, which disables the timeout. The standard server timeout is 10 minutes. +func (s *Session) SetCursorTimeout(d time.Duration) { + s.m.Lock() + if d == 0 { + s.queryConfig.op.flags |= flagNoCursorTimeout + } else { + panic("SetCursorTimeout: only 0 (disable timeout) supported for now") + } + s.m.Unlock() +} + +// SetPoolLimit sets the maximum number of sockets in use in a single server +// before this session will block waiting for a socket to be available. +// The default limit is 4096. +// +// This limit must be set to cover more than any expected workload of the +// application. It is a bad practice and an unsupported use case to use the +// database driver to define the concurrency limit of an application. Prevent +// such concurrency "at the door" instead, by properly restricting the amount +// of used resources and number of goroutines before they are created. +func (s *Session) SetPoolLimit(limit int) { + s.m.Lock() + s.poolLimit = limit + s.m.Unlock() +} + +// SetBypassValidation sets whether the server should bypass the registered +// validation expressions executed when documents are inserted or modified, +// in the interest of preserving invariants in the collection being modified. +// The default is to not bypass, and thus to perform the validation +// expressions registered for modified collections. +// +// Document validation was introuced in MongoDB 3.2. +// +// Relevant documentation: +// +// https://docs.mongodb.org/manual/release-notes/3.2/#bypass-validation +// +func (s *Session) SetBypassValidation(bypass bool) { + s.m.Lock() + s.bypassValidation = bypass + s.m.Unlock() +} + +// SetBatch sets the default batch size used when fetching documents from the +// database. It's possible to change this setting on a per-query basis as +// well, using the Query.Batch method. +// +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (s *Session) SetBatch(n int) { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + s.m.Lock() + s.queryConfig.op.limit = int32(n) + s.m.Unlock() +} + +// SetPrefetch sets the default point at which the next batch of results will be +// requested. When there are p*batch_size remaining documents cached in an +// Iter, the next batch will be requested in background. For instance, when +// using this: +// +// session.SetBatch(200) +// session.SetPrefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-query basis as well, using the Prefetch method of Query. +// +// The default prefetch value is 0.25. +func (s *Session) SetPrefetch(p float64) { + s.m.Lock() + s.queryConfig.prefetch = p + s.m.Unlock() +} + +// See SetSafe for details on the Safe type. +type Safe struct { + W int // Min # of servers to ack before success + WMode string // Write mode for MongoDB 2.0+ (e.g. "majority") + WTimeout int // Milliseconds to wait for W before timing out + FSync bool // Sync via the journal if present, or via data files sync otherwise + J bool // Sync via the journal if present +} + +// Safe returns the current safety mode for the session. +func (s *Session) Safe() (safe *Safe) { + s.m.Lock() + defer s.m.Unlock() + if s.safeOp != nil { + cmd := s.safeOp.query.(*getLastError) + safe = &Safe{WTimeout: cmd.WTimeout, FSync: cmd.FSync, J: cmd.J} + switch w := cmd.W.(type) { + case string: + safe.WMode = w + case int: + safe.W = w + } + } + return +} + +// SetSafe changes the session safety mode. +// +// If the safe parameter is nil, the session is put in unsafe mode, and writes +// become fire-and-forget, without error checking. The unsafe mode is faster +// since operations won't hold on waiting for a confirmation. +// +// If the safe parameter is not nil, any changing query (insert, update, ...) +// will be followed by a getLastError command with the specified parameters, +// to ensure the request was correctly processed. +// +// The safe.W parameter determines how many servers should confirm a write +// before the operation is considered successful. If set to 0 or 1, the +// command will return as soon as the primary is done with the request. +// If safe.WTimeout is greater than zero, it determines how many milliseconds +// to wait for the safe.W servers to respond before returning an error. +// +// Starting with MongoDB 2.0.0 the safe.WMode parameter can be used instead +// of W to request for richer semantics. If set to "majority" the server will +// wait for a majority of members from the replica set to respond before +// returning. Custom modes may also be defined within the server to create +// very detailed placement schemas. See the data awareness documentation in +// the links below for more details (note that MongoDB internally reuses the +// "w" field name for WMode). +// +// If safe.J is true, servers will block until write operations have been +// committed to the journal. Cannot be used in combination with FSync. Prior +// to MongoDB 2.6 this option was ignored if the server was running without +// journaling. Starting with MongoDB 2.6 write operations will fail with an +// exception if this option is used when the server is running without +// journaling. +// +// If safe.FSync is true and the server is running without journaling, blocks +// until the server has synced all data files to disk. If the server is running +// with journaling, this acts the same as the J option, blocking until write +// operations have been committed to the journal. Cannot be used in +// combination with J. +// +// Since MongoDB 2.0.0, the safe.J option can also be used instead of FSync +// to force the server to wait for a group commit in case journaling is +// enabled. The option has no effect if the server has journaling disabled. +// +// For example, the following statement will make the session check for +// errors, without imposing further constraints: +// +// session.SetSafe(&mgo.Safe{}) +// +// The following statement will force the server to wait for a majority of +// members of a replica set to return (MongoDB 2.0+ only): +// +// session.SetSafe(&mgo.Safe{WMode: "majority"}) +// +// The following statement, on the other hand, ensures that at least two +// servers have flushed the change to disk before confirming the success +// of operations: +// +// session.EnsureSafe(&mgo.Safe{W: 2, FSync: true}) +// +// The following statement, on the other hand, disables the verification +// of errors entirely: +// +// session.SetSafe(nil) +// +// See also the EnsureSafe method. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) SetSafe(safe *Safe) { + s.m.Lock() + s.safeOp = nil + s.ensureSafe(safe) + s.m.Unlock() +} + +// EnsureSafe compares the provided safety parameters with the ones +// currently in use by the session and picks the most conservative +// choice for each setting. +// +// That is: +// +// - safe.WMode is always used if set. +// - safe.W is used if larger than the current W and WMode is empty. +// - safe.FSync is always used if true. +// - safe.J is used if FSync is false. +// - safe.WTimeout is used if set and smaller than the current WTimeout. +// +// For example, the following statement will ensure the session is +// at least checking for errors, without enforcing further constraints. +// If a more conservative SetSafe or EnsureSafe call was previously done, +// the following call will be ignored. +// +// session.EnsureSafe(&mgo.Safe{}) +// +// See also the SetSafe method for details on what each option means. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) EnsureSafe(safe *Safe) { + s.m.Lock() + s.ensureSafe(safe) + s.m.Unlock() +} + +func (s *Session) ensureSafe(safe *Safe) { + if safe == nil { + return + } + + var w interface{} + if safe.WMode != "" { + w = safe.WMode + } else if safe.W > 0 { + w = safe.W + } + + var cmd getLastError + if s.safeOp == nil { + cmd = getLastError{1, w, safe.WTimeout, safe.FSync, safe.J} + } else { + // Copy. We don't want to mutate the existing query. + cmd = *(s.safeOp.query.(*getLastError)) + if cmd.W == nil { + cmd.W = w + } else if safe.WMode != "" { + cmd.W = safe.WMode + } else if i, ok := cmd.W.(int); ok && safe.W > i { + cmd.W = safe.W + } + if safe.WTimeout > 0 && safe.WTimeout < cmd.WTimeout { + cmd.WTimeout = safe.WTimeout + } + if safe.FSync { + cmd.FSync = true + cmd.J = false + } else if safe.J && !cmd.FSync { + cmd.J = true + } + } + s.safeOp = &queryOp{ + query: &cmd, + collection: "admin.$cmd", + limit: -1, + } +} + +// Run issues the provided command on the "admin" database and +// and unmarshals its result in the respective argument. The cmd +// argument may be either a string with the command name itself, in +// which case an empty document of the form bson.M{cmd: 1} will be used, +// or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For commands on arbitrary databases, see the Run method in +// the Database type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (s *Session) Run(cmd interface{}, result interface{}) error { + return s.DB("admin").Run(cmd, result) +} + +// SelectServers restricts communication to servers configured with the +// given tags. For example, the following statement restricts servers +// used for reading operations to those with both tag "disk" set to +// "ssd" and tag "rack" set to 1: +// +// session.SelectServers(bson.D{{"disk", "ssd"}, {"rack", 1}}) +// +// Multiple sets of tags may be provided, in which case the used server +// must match all tags within any one set. +// +// If a connection was previously assigned to the session due to the +// current session mode (see Session.SetMode), the tag selection will +// only be enforced after the session is refreshed. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/tutorial/configure-replica-set-tag-sets +// +func (s *Session) SelectServers(tags ...bson.D) { + s.m.Lock() + s.queryConfig.op.serverTags = tags + s.m.Unlock() +} + +// Ping runs a trivial ping command just to get in touch with the server. +func (s *Session) Ping() error { + return s.Run("ping", nil) +} + +// Fsync flushes in-memory writes to disk on the server the session +// is established with. If async is true, the call returns immediately, +// otherwise it returns after the flush has been made. +func (s *Session) Fsync(async bool) error { + return s.Run(bson.D{{"fsync", 1}, {"async", async}}, nil) +} + +// FsyncLock locks all writes in the specific server the session is +// established with and returns. Any writes attempted to the server +// after it is successfully locked will block until FsyncUnlock is +// called for the same server. +// +// This method works on secondaries as well, preventing the oplog from +// being flushed while the server is locked, but since only the server +// connected to is locked, for locking specific secondaries it may be +// necessary to establish a connection directly to the secondary (see +// Dial's connect=direct option). +// +// As an important caveat, note that once a write is attempted and +// blocks, follow up reads will block as well due to the way the +// lock is internally implemented in the server. More details at: +// +// https://jira.mongodb.org/browse/SERVER-4243 +// +// FsyncLock is often used for performing consistent backups of +// the database files on disk. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/fsync+Command +// http://www.mongodb.org/display/DOCS/Backups +// +func (s *Session) FsyncLock() error { + return s.Run(bson.D{{"fsync", 1}, {"lock", true}}, nil) +} + +// FsyncUnlock releases the server for writes. See FsyncLock for details. +func (s *Session) FsyncUnlock() error { + err := s.Run(bson.D{{"fsyncUnlock", 1}}, nil) + if isNoCmd(err) { + err = s.DB("admin").C("$cmd.sys.unlock").Find(nil).One(nil) // WTF? + } + return err +} + +// Find prepares a query using the provided document. The document may be a +// map or a struct value capable of being marshalled with bson. The map +// may be a generic one using interface{} for its key and/or values, such as +// bson.M, or it may be a properly typed map. Providing nil as the document +// is equivalent to providing an empty document such as bson.M{}. +// +// Further details of the query may be tweaked using the resulting Query value, +// and then executed to retrieve results using methods such as One, For, +// Iter, or Tail. +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Querying +// http://www.mongodb.org/display/DOCS/Advanced+Queries +// +func (c *Collection) Find(query interface{}) *Query { + session := c.Database.Session + session.m.RLock() + q := &Query{session: session, query: session.queryConfig} + session.m.RUnlock() + q.op.query = query + q.op.collection = c.FullName + return q +} + +type repairCmd struct { + RepairCursor string `bson:"repairCursor"` + Cursor *repairCmdCursor ",omitempty" +} + +type repairCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Repair returns an iterator that goes over all recovered documents in the +// collection, in a best-effort manner. This is most useful when there are +// damaged data files. Multiple copies of the same document may be returned +// by the iterator. +// +// Repair is supported in MongoDB 2.7.8 and later. +func (c *Collection) Repair() *Iter { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + session := c.Database.Session + cloned := session.nonEventual() + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + var result struct{ Cursor cursorData } + + cmd := repairCmd{ + RepairCursor: c.Name, + Cursor: &repairCmdCursor{batchSize}, + } + + clonedc := c.With(cloned) + err := clonedc.Database.Run(cmd, &result) + return clonedc.NewIter(session, result.Cursor.FirstBatch, result.Cursor.Id, err) +} + +// FindId is a convenience helper equivalent to: +// +// query := collection.Find(bson.M{"_id": id}) +// +// See the Find method for more details. +func (c *Collection) FindId(id interface{}) *Query { + return c.Find(bson.D{{"_id", id}}) +} + +type Pipe struct { + session *Session + collection *Collection + pipeline interface{} + allowDisk bool + batchSize int +} + +type pipeCmd struct { + Aggregate string + Pipeline interface{} + Cursor *pipeCmdCursor ",omitempty" + Explain bool ",omitempty" + AllowDisk bool "allowDiskUse,omitempty" +} + +type pipeCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Pipe prepares a pipeline to aggregate. The pipeline document +// must be a slice built in terms of the aggregation framework language. +// +// For example: +// +// pipe := collection.Pipe([]bson.M{{"$match": bson.M{"name": "Otavio"}}}) +// iter := pipe.Iter() +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/aggregation +// http://docs.mongodb.org/manual/applications/aggregation +// http://docs.mongodb.org/manual/tutorial/aggregation-examples +// +func (c *Collection) Pipe(pipeline interface{}) *Pipe { + session := c.Database.Session + session.m.RLock() + batchSize := int(session.queryConfig.op.limit) + session.m.RUnlock() + return &Pipe{ + session: session, + collection: c, + pipeline: pipeline, + batchSize: batchSize, + } +} + +// Iter executes the pipeline and returns an iterator capable of going +// over all the generated results. +func (p *Pipe) Iter() *Iter { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + cloned := p.session.nonEventual() + defer cloned.Close() + c := p.collection.With(cloned) + + var result struct { + Result []bson.Raw // 2.4, no cursors. + Cursor cursorData // 2.6+, with cursors. + } + + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Cursor: &pipeCmdCursor{p.batchSize}, + } + err := c.Database.Run(cmd, &result) + if e, ok := err.(*QueryError); ok && e.Message == `unrecognized field "cursor` { + cmd.Cursor = nil + cmd.AllowDisk = false + err = c.Database.Run(cmd, &result) + } + firstBatch := result.Result + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + return c.NewIter(p.session, firstBatch, result.Cursor.Id, err) +} + +// NewIter returns a newly created iterator with the provided parameters. +// Using this method is not recommended unless the desired functionality +// is not yet exposed via a more convenient interface (Find, Pipe, etc). +// +// The optional session parameter associates the lifetime of the returned +// iterator to an arbitrary session. If nil, the iterator will be bound to +// c's session. +// +// Documents in firstBatch will be individually provided by the returned +// iterator before documents from cursorId are made available. If cursorId +// is zero, only the documents in firstBatch are provided. +// +// If err is not nil, the iterator's Err method will report it after +// exhausting documents in firstBatch. +// +// NewIter must be called right after the cursor id is obtained, and must not +// be called on a collection in Eventual mode, because the cursor id is +// associated with the specific server that returned it. The provided session +// parameter may be in any mode or state, though. +// +func (c *Collection) NewIter(session *Session, firstBatch []bson.Raw, cursorId int64, err error) *Iter { + var server *mongoServer + csession := c.Database.Session + csession.m.RLock() + socket := csession.masterSocket + if socket == nil { + socket = csession.slaveSocket + } + if socket != nil { + server = socket.Server() + } + csession.m.RUnlock() + + if server == nil { + if csession.Mode() == Eventual { + panic("Collection.NewIter called in Eventual mode") + } + if err == nil { + err = errors.New("server not available") + } + } + + if session == nil { + session = csession + } + + iter := &Iter{ + session: session, + server: server, + timeout: -1, + err: err, + } + iter.gotReply.L = &iter.m + for _, doc := range firstBatch { + iter.docData.Push(doc.Data) + } + if cursorId != 0 { + iter.op.cursorId = cursorId + iter.op.collection = c.FullName + iter.op.replyFunc = iter.replyFunc() + } + return iter +} + +// All works like Iter.All. +func (p *Pipe) All(result interface{}) error { + return p.Iter().All(result) +} + +// One executes the pipeline and unmarshals the first item from the +// result set into the result parameter. +// It returns ErrNotFound if no items are generated by the pipeline. +func (p *Pipe) One(result interface{}) error { + iter := p.Iter() + if iter.Next(result) { + return nil + } + if err := iter.Err(); err != nil { + return err + } + return ErrNotFound +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested pipeline, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// var m bson.M +// err := collection.Pipe(pipeline).Explain(&m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +func (p *Pipe) Explain(result interface{}) error { + c := p.collection + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Explain: true, + } + return c.Database.Run(cmd, result) +} + +// AllowDiskUse enables writing to the "<dbpath>/_tmp" server directory so +// that aggregation pipelines do not have to be held entirely in memory. +func (p *Pipe) AllowDiskUse() *Pipe { + p.allowDisk = true + return p +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. +// +// The default batch size is defined by the database server. +func (p *Pipe) Batch(n int) *Pipe { + p.batchSize = n + return p +} + +// mgo.v3: Use a single user-visible error type. + +type LastError struct { + Err string + Code, N, Waited int + FSyncFiles int `bson:"fsyncFiles"` + WTimeout bool + UpdatedExisting bool `bson:"updatedExisting"` + UpsertedId interface{} `bson:"upserted"` + + modified int + ecases []BulkErrorCase +} + +func (err *LastError) Error() string { + return err.Err +} + +type queryError struct { + Err string "$err" + ErrMsg string + Assertion string + Code int + AssertionCode int "assertionCode" + LastError *LastError "lastErrorObject" +} + +type QueryError struct { + Code int + Message string + Assertion bool +} + +func (err *QueryError) Error() string { + return err.Message +} + +// IsDup returns whether err informs of a duplicate key error because +// a primary key index or a secondary unique index already has an entry +// with the given value. +func IsDup(err error) bool { + // Besides being handy, helps with MongoDB bugs SERVER-7164 and SERVER-11493. + // What follows makes me sad. Hopefully conventions will be more clear over time. + switch e := err.(type) { + case *LastError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 || e.Code == 16460 && strings.Contains(e.Err, " E11000 ") + case *QueryError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 + case *BulkError: + for _, ecase := range e.ecases { + if !IsDup(ecase.Err) { + return false + } + } + return true + } + return false +} + +// Insert inserts one or more documents in the respective collection. In +// case the session is in safe mode (see the SetSafe method) and an error +// happens while inserting the provided documents, the returned error will +// be of type *LastError. +func (c *Collection) Insert(docs ...interface{}) error { + _, err := c.writeOp(&insertOp{c.FullName, docs, 0}, true) + return err +} + +// Update finds a single document matching the provided selector document +// and modifies it according to the update document. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Update(selector interface{}, update interface{}) error { + if selector == nil { + selector = bson.D{} + } + op := updateOp{ + Collection: c.FullName, + Selector: selector, + Update: update, + } + lerr, err := c.writeOp(&op, true) + if err == nil && lerr != nil && !lerr.UpdatedExisting { + return ErrNotFound + } + return err +} + +// UpdateId is a convenience helper equivalent to: +// +// err := collection.Update(bson.M{"_id": id}, update) +// +// See the Update method for more details. +func (c *Collection) UpdateId(id interface{}, update interface{}) error { + return c.Update(bson.D{{"_id", id}}, update) +} + +// ChangeInfo holds details about the outcome of an update operation. +type ChangeInfo struct { + // Updated reports the number of existing documents modified. + // Due to server limitations, this reports the same value as the Matched field when + // talking to MongoDB <= 2.4 and on Upsert and Apply (findAndModify) operations. + Updated int + Removed int // Number of documents removed + Matched int // Number of documents matched but not necessarily changed + UpsertedId interface{} // Upserted _id field, when not explicitly provided +} + +// UpdateAll finds all documents matching the provided selector document +// and modifies them according to the update document. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info or an error of type *LastError when +// some problem is detected. It is not an error for the update to not be +// applied on any documents because the selector doesn't match. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) UpdateAll(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + if selector == nil { + selector = bson.D{} + } + op := updateOp{ + Collection: c.FullName, + Selector: selector, + Update: update, + Flags: 2, + Multi: true, + } + lerr, err := c.writeOp(&op, true) + if err == nil && lerr != nil { + info = &ChangeInfo{Updated: lerr.modified, Matched: lerr.N} + } + return info, err +} + +// Upsert finds a single document matching the provided selector document +// and modifies it according to the update document. If no document matching +// the selector is found, the update document is applied to the selector +// document and the result is inserted in the collection. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info, or an error of type *LastError when +// some problem is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Upsert(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + if selector == nil { + selector = bson.D{} + } + op := updateOp{ + Collection: c.FullName, + Selector: selector, + Update: update, + Flags: 1, + Upsert: true, + } + lerr, err := c.writeOp(&op, true) + if err == nil && lerr != nil { + info = &ChangeInfo{} + if lerr.UpdatedExisting { + info.Matched = lerr.N + info.Updated = lerr.modified + } else { + info.UpsertedId = lerr.UpsertedId + } + } + return info, err +} + +// UpsertId is a convenience helper equivalent to: +// +// info, err := collection.Upsert(bson.M{"_id": id}, update) +// +// See the Upsert method for more details. +func (c *Collection) UpsertId(id interface{}, update interface{}) (info *ChangeInfo, err error) { + return c.Upsert(bson.D{{"_id", id}}, update) +} + +// Remove finds a single document matching the provided selector document +// and removes it from the database. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) Remove(selector interface{}) error { + if selector == nil { + selector = bson.D{} + } + lerr, err := c.writeOp(&deleteOp{c.FullName, selector, 1, 1}, true) + if err == nil && lerr != nil && lerr.N == 0 { + return ErrNotFound + } + return err +} + +// RemoveId is a convenience helper equivalent to: +// +// err := collection.Remove(bson.M{"_id": id}) +// +// See the Remove method for more details. +func (c *Collection) RemoveId(id interface{}) error { + return c.Remove(bson.D{{"_id", id}}) +} + +// RemoveAll finds all documents matching the provided selector document +// and removes them from the database. In case the session is in safe mode +// (see the SetSafe method) and an error happens when attempting the change, +// the returned error will be of type *LastError. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) RemoveAll(selector interface{}) (info *ChangeInfo, err error) { + if selector == nil { + selector = bson.D{} + } + lerr, err := c.writeOp(&deleteOp{c.FullName, selector, 0, 0}, true) + if err == nil && lerr != nil { + info = &ChangeInfo{Removed: lerr.N, Matched: lerr.N} + } + return info, err +} + +// DropDatabase removes the entire database including all of its collections. +func (db *Database) DropDatabase() error { + return db.Run(bson.D{{"dropDatabase", 1}}, nil) +} + +// DropCollection removes the entire collection including all of its documents. +func (c *Collection) DropCollection() error { + return c.Database.Run(bson.D{{"drop", c.Name}}, nil) +} + +// The CollectionInfo type holds metadata about a collection. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +type CollectionInfo struct { + // DisableIdIndex prevents the automatic creation of the index + // on the _id field for the collection. + DisableIdIndex bool + + // ForceIdIndex enforces the automatic creation of the index + // on the _id field for the collection. Capped collections, + // for example, do not have such an index by default. + ForceIdIndex bool + + // If Capped is true new documents will replace old ones when + // the collection is full. MaxBytes must necessarily be set + // to define the size when the collection wraps around. + // MaxDocs optionally defines the number of documents when it + // wraps, but MaxBytes still needs to be set. + Capped bool + MaxBytes int + MaxDocs int + + // Validator contains a validation expression that defines which + // documents should be considered valid for this collection. + Validator interface{} + + // ValidationLevel may be set to "strict" (the default) to force + // MongoDB to validate all documents on inserts and updates, to + // "moderate" to apply the validation rules only to documents + // that already fulfill the validation criteria, or to "off" for + // disabling validation entirely. + ValidationLevel string + + // ValidationAction determines how MongoDB handles documents that + // violate the validation rules. It may be set to "error" (the default) + // to reject inserts or updates that violate the rules, or to "warn" + // to log invalid operations but allow them to proceed. + ValidationAction string + + // StorageEngine allows specifying collection options for the + // storage engine in use. The map keys must hold the storage engine + // name for which options are being specified. + StorageEngine interface{} +} + +// Create explicitly creates the c collection with details of info. +// MongoDB creates collections automatically on use, so this method +// is only necessary when creating collection with non-default +// characteristics, such as capped collections. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +func (c *Collection) Create(info *CollectionInfo) error { + cmd := make(bson.D, 0, 4) + cmd = append(cmd, bson.DocElem{"create", c.Name}) + if info.Capped { + if info.MaxBytes < 1 { + return fmt.Errorf("Collection.Create: with Capped, MaxBytes must also be set") + } + cmd = append(cmd, bson.DocElem{"capped", true}) + cmd = append(cmd, bson.DocElem{"size", info.MaxBytes}) + if info.MaxDocs > 0 { + cmd = append(cmd, bson.DocElem{"max", info.MaxDocs}) + } + } + if info.DisableIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", false}) + } + if info.ForceIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", true}) + } + if info.Validator != nil { + cmd = append(cmd, bson.DocElem{"validator", info.Validator}) + } + if info.ValidationLevel != "" { + cmd = append(cmd, bson.DocElem{"validationLevel", info.ValidationLevel}) + } + if info.ValidationAction != "" { + cmd = append(cmd, bson.DocElem{"validationAction", info.ValidationAction}) + } + if info.StorageEngine != nil { + cmd = append(cmd, bson.DocElem{"storageEngine", info.StorageEngine}) + } + return c.Database.Run(cmd, nil) +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. + +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (q *Query) Batch(n int) *Query { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + q.m.Lock() + q.op.limit = int32(n) + q.m.Unlock() + return q +} + +// Prefetch sets the point at which the next batch of results will be requested. +// When there are p*batch_size remaining documents cached in an Iter, the next +// batch will be requested in background. For instance, when using this: +// +// query.Batch(200).Prefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-session basis as well, using the SetPrefetch method of Session. +// +// The default prefetch value is 0.25. +func (q *Query) Prefetch(p float64) *Query { + q.m.Lock() + q.prefetch = p + q.m.Unlock() + return q +} + +// Skip skips over the n initial documents from the query results. Note that +// this only makes sense with capped collections where documents are naturally +// ordered by insertion time, or with sorted results. +func (q *Query) Skip(n int) *Query { + q.m.Lock() + q.op.skip = int32(n) + q.m.Unlock() + return q +} + +// Limit restricts the maximum number of documents retrieved to n, and also +// changes the batch size to the same value. Once n documents have been +// returned by Next, the following call will return ErrNotFound. +func (q *Query) Limit(n int) *Query { + q.m.Lock() + switch { + case n == 1: + q.limit = 1 + q.op.limit = -1 + case n == math.MinInt32: // -MinInt32 == -MinInt32 + q.limit = math.MaxInt32 + q.op.limit = math.MinInt32 + 1 + case n < 0: + q.limit = int32(-n) + q.op.limit = int32(n) + default: + q.limit = int32(n) + q.op.limit = int32(n) + } + q.m.Unlock() + return q +} + +// Select enables selecting which fields should be retrieved for the results +// found. For example, the following query would only retrieve the name field: +// +// err := collection.Find(nil).Select(bson.M{"name": 1}).One(&result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Retrieving+a+Subset+of+Fields +// +func (q *Query) Select(selector interface{}) *Query { + q.m.Lock() + q.op.selector = selector + q.m.Unlock() + return q +} + +// Sort asks the database to order returned documents according to the +// provided field names. A field name may be prefixed by - (minus) for +// it to be sorted in reverse order. +// +// For example: +// +// query1 := collection.Find(nil).Sort("firstname", "lastname") +// query2 := collection.Find(nil).Sort("-age") +// query3 := collection.Find(nil).Sort("$natural") +// query4 := collection.Find(nil).Select(bson.M{"score": bson.M{"$meta": "textScore"}}).Sort("$textScore:score") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Sort(fields ...string) *Query { + q.m.Lock() + var order bson.D + for _, field := range fields { + n := 1 + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + } + } + switch field[0] { + case '+': + field = field[1:] + case '-': + n = -1 + field = field[1:] + } + } + if field == "" { + panic("Sort: empty field name") + } + if kind == "textScore" { + order = append(order, bson.DocElem{field, bson.M{"$meta": kind}}) + } else { + order = append(order, bson.DocElem{field, n}) + } + } + q.op.options.OrderBy = order + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested query, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// m := bson.M{} +// err := collection.Find(bson.M{"filename": name}).Explain(m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Explain(result interface{}) error { + q.m.Lock() + clone := &Query{session: q.session, query: q.query} + q.m.Unlock() + clone.op.options.Explain = true + clone.op.hasOptions = true + if clone.op.limit > 0 { + clone.op.limit = -q.op.limit + } + iter := clone.Iter() + if iter.Next(result) { + return nil + } + return iter.Close() +} + +// TODO: Add Collection.Explain. See https://goo.gl/1MDlvz. + +// Hint will include an explicit "hint" in the query to force the server +// to use a specified index, potentially improving performance in some +// situations. The provided parameters are the fields that compose the +// key of the index to be used. For details on how the indexKey may be +// built, see the EnsureIndex method. +// +// For example: +// +// query := collection.Find(bson.M{"firstname": "Joe", "lastname": "Winter"}) +// query.Hint("lastname", "firstname") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Hint(indexKey ...string) *Query { + q.m.Lock() + keyInfo, err := parseIndexKey(indexKey) + q.op.options.Hint = keyInfo.key + q.op.hasOptions = true + q.m.Unlock() + if err != nil { + panic(err) + } + return q +} + +// SetMaxScan constrains the query to stop after scanning the specified +// number of documents. +// +// This modifier is generally used to prevent potentially long running +// queries from disrupting performance by scanning through too much data. +func (q *Query) SetMaxScan(n int) *Query { + q.m.Lock() + q.op.options.MaxScan = n + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// SetMaxTime constrains the query to stop after running for the specified time. +// +// When the time limit is reached MongoDB automatically cancels the query. +// This can be used to efficiently prevent and identify unexpectedly slow queries. +// +// A few important notes about the mechanism enforcing this limit: +// +// - Requests can block behind locking operations on the server, and that blocking +// time is not accounted for. In other words, the timer starts ticking only after +// the actual start of the query when it initially acquires the appropriate lock; +// +// - Operations are interrupted only at interrupt points where an operation can be +// safely aborted – the total execution time may exceed the specified value; +// +// - The limit can be applied to both CRUD operations and commands, but not all +// commands are interruptible; +// +// - While iterating over results, computing follow up batches is included in the +// total time and the iteration continues until the alloted time is over, but +// network roundtrips are not taken into account for the limit. +// +// - This limit does not override the inactive cursor timeout for idle cursors +// (default is 10 min). +// +// This mechanism was introduced in MongoDB 2.6. +// +// Relevant documentation: +// +// http://blog.mongodb.org/post/83621787773/maxtimems-and-query-optimizer-introspection-in +// +func (q *Query) SetMaxTime(d time.Duration) *Query { + q.m.Lock() + q.op.options.MaxTimeMS = int(d / time.Millisecond) + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Snapshot will force the performed query to make use of an available +// index on the _id field to prevent the same document from being returned +// more than once in a single iteration. This might happen without this +// setting in situations when the document changes in size and thus has to +// be moved while the iteration is running. +// +// Because snapshot mode traverses the _id index, it may not be used with +// sorting or explicit hints. It also cannot use any other index for the +// query. +// +// Even with snapshot mode, items inserted or deleted during the query may +// or may not be returned; that is, this mode is not a true point-in-time +// snapshot. +// +// The same effect of Snapshot may be obtained by using any unique index on +// field(s) that will not be modified (best to use Hint explicitly too). +// A non-unique index (such as creation time) may be made unique by +// appending _id to the index when creating it. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/How+to+do+Snapshotted+Queries+in+the+Mongo+Database +// +func (q *Query) Snapshot() *Query { + q.m.Lock() + q.op.options.Snapshot = true + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Comment adds a comment to the query to identify it in the database profiler output. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/operator/meta/comment +// http://docs.mongodb.org/manual/reference/command/profile +// http://docs.mongodb.org/manual/administration/analyzing-mongodb-performance/#database-profiling +// +func (q *Query) Comment(comment string) *Query { + q.m.Lock() + q.op.options.Comment = comment + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// LogReplay enables an option that optimizes queries that are typically +// made on the MongoDB oplog for replaying it. This is an internal +// implementation aspect and most likely uninteresting for other uses. +// It has seen at least one use case, though, so it's exposed via the API. +func (q *Query) LogReplay() *Query { + q.m.Lock() + q.op.flags |= flagLogReplay + q.m.Unlock() + return q +} + +func checkQueryError(fullname string, d []byte) error { + l := len(d) + if l < 16 { + return nil + } + if d[5] == '$' && d[6] == 'e' && d[7] == 'r' && d[8] == 'r' && d[9] == '\x00' && d[4] == '\x02' { + goto Error + } + if len(fullname) < 5 || fullname[len(fullname)-5:] != ".$cmd" { + return nil + } + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + goto Error + } + } + return nil + +Error: + result := &queryError{} + bson.Unmarshal(d, result) + if result.LastError != nil { + return result.LastError + } + if result.Err == "" && result.ErrMsg == "" { + return nil + } + if result.AssertionCode != 0 && result.Assertion != "" { + return &QueryError{Code: result.AssertionCode, Message: result.Assertion, Assertion: true} + } + if result.Err != "" { + return &QueryError{Code: result.Code, Message: result.Err} + } + return &QueryError{Code: result.Code, Message: result.ErrMsg} +} + +// One executes the query and unmarshals the first obtained document into the +// result argument. The result must be a struct or map value capable of being +// unmarshalled into by gobson. This function blocks until either a result +// is available or an error happens. For example: +// +// err := collection.Find(bson.M{"a": 1}).One(&result) +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +func (q *Query) One(result interface{}) (err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + socket, err := session.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + op.limit = -1 + + session.prepareQuery(&op) + + expectFindReply := prepareFindOp(socket, &op, 1) + + data, err := socket.SimpleQuery(&op) + if err != nil { + return err + } + if data == nil { + return ErrNotFound + } + if expectFindReply { + var findReply struct { + Ok bool + Code int + Errmsg string + Cursor cursorData + } + err = bson.Unmarshal(data, &findReply) + if err != nil { + return err + } + if !findReply.Ok && findReply.Errmsg != "" { + return &QueryError{Code: findReply.Code, Message: findReply.Errmsg} + } + if len(findReply.Cursor.FirstBatch) == 0 { + return ErrNotFound + } + data = findReply.Cursor.FirstBatch[0].Data + } + if result != nil { + err = bson.Unmarshal(data, result) + if err == nil { + debugf("Query %p document unmarshaled: %#v", q, result) + } else { + debugf("Query %p document unmarshaling failed: %#v", q, err) + return err + } + } + return checkQueryError(op.collection, data) +} + +// prepareFindOp translates op from being an old-style wire protocol query into +// a new-style find command if that's supported by the MongoDB server (3.2+). +// It returns whether to expect a find command result or not. Note op may be +// translated into an explain command, in which case the function returns false. +func prepareFindOp(socket *mongoSocket, op *queryOp, limit int32) bool { + if socket.ServerInfo().MaxWireVersion < 4 || op.collection == "admin.$cmd" { + return false + } + + nameDot := strings.Index(op.collection, ".") + if nameDot < 0 { + panic("invalid query collection name: " + op.collection) + } + + find := findCmd{ + Collection: op.collection[nameDot+1:], + Filter: op.query, + Projection: op.selector, + Sort: op.options.OrderBy, + Skip: op.skip, + Limit: limit, + MaxTimeMS: op.options.MaxTimeMS, + MaxScan: op.options.MaxScan, + Hint: op.options.Hint, + Comment: op.options.Comment, + Snapshot: op.options.Snapshot, + OplogReplay: op.flags&flagLogReplay != 0, + } + if op.limit < 0 { + find.BatchSize = -op.limit + find.SingleBatch = true + } else { + find.BatchSize = op.limit + } + + explain := op.options.Explain + + op.collection = op.collection[:nameDot] + ".$cmd" + op.query = &find + op.skip = 0 + op.limit = -1 + op.options = queryWrapper{} + op.hasOptions = false + + if explain { + op.query = bson.D{{"explain", op.query}} + return false + } + return true +} + +type cursorData struct { + FirstBatch []bson.Raw "firstBatch" + NextBatch []bson.Raw "nextBatch" + NS string + Id int64 +} + +// findCmd holds the command used for performing queries on MongoDB 3.2+. +// +// Relevant documentation: +// +// https://docs.mongodb.org/master/reference/command/find/#dbcmd.find +// +type findCmd struct { + Collection string `bson:"find"` + Filter interface{} `bson:"filter,omitempty"` + Sort interface{} `bson:"sort,omitempty"` + Projection interface{} `bson:"projection,omitempty"` + Hint interface{} `bson:"hint,omitempty"` + Skip interface{} `bson:"skip,omitempty"` + Limit int32 `bson:"limit,omitempty"` + BatchSize int32 `bson:"batchSize,omitempty"` + SingleBatch bool `bson:"singleBatch,omitempty"` + Comment string `bson:"comment,omitempty"` + MaxScan int `bson:"maxScan,omitempty"` + MaxTimeMS int `bson:"maxTimeMS,omitempty"` + ReadConcern interface{} `bson:"readConcern,omitempty"` + Max interface{} `bson:"max,omitempty"` + Min interface{} `bson:"min,omitempty"` + ReturnKey bool `bson:"returnKey,omitempty"` + ShowRecordId bool `bson:"showRecordId,omitempty"` + Snapshot bool `bson:"snapshot,omitempty"` + Tailable bool `bson:"tailable,omitempty"` + AwaitData bool `bson:"awaitData,omitempty"` + OplogReplay bool `bson:"oplogReplay,omitempty"` + NoCursorTimeout bool `bson:"noCursorTimeout,omitempty"` + AllowPartialResults bool `bson:"allowPartialResults,omitempty"` +} + +// getMoreCmd holds the command used for requesting more query results on MongoDB 3.2+. +// +// Relevant documentation: +// +// https://docs.mongodb.org/master/reference/command/getMore/#dbcmd.getMore +// +type getMoreCmd struct { + CursorId int64 `bson:"getMore"` + Collection string `bson:"collection"` + BatchSize int32 `bson:"batchSize,omitempty"` + MaxTimeMS int64 `bson:"maxTimeMS,omitempty"` +} + +// run duplicates the behavior of collection.Find(query).One(&result) +// as performed by Database.Run, specializing the logic for running +// database commands on a given socket. +func (db *Database) run(socket *mongoSocket, cmd, result interface{}) (err error) { + // Database.Run: + if name, ok := cmd.(string); ok { + cmd = bson.D{{name, 1}} + } + + // Collection.Find: + session := db.Session + session.m.RLock() + op := session.queryConfig.op // Copy. + session.m.RUnlock() + op.query = cmd + op.collection = db.Name + ".$cmd" + + // Query.One: + session.prepareQuery(&op) + op.limit = -1 + + data, err := socket.SimpleQuery(&op) + if err != nil { + return err + } + if data == nil { + return ErrNotFound + } + if result != nil { + err = bson.Unmarshal(data, result) + if err == nil { + var res bson.M + bson.Unmarshal(data, &res) + debugf("Run command unmarshaled: %#v, result: %#v", op, res) + } else { + debugf("Run command unmarshaling failed: %#v", op, err) + return err + } + } + return checkQueryError(op.collection, data) +} + +// The DBRef type implements support for the database reference MongoDB +// convention as supported by multiple drivers. This convention enables +// cross-referencing documents between collections and databases using +// a structure which includes a collection name, a document id, and +// optionally a database name. +// +// See the FindRef methods on Session and on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +type DBRef struct { + Collection string `bson:"$ref"` + Id interface{} `bson:"$id"` + Database string `bson:"$db,omitempty"` +} + +// NOTE: Order of fields for DBRef above does matter, per documentation. + +// FindRef returns a query that looks for the document in the provided +// reference. If the reference includes the DB field, the document will +// be retrieved from the respective database. +// +// See also the DBRef type and the FindRef method on Session. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (db *Database) FindRef(ref *DBRef) *Query { + var c *Collection + if ref.Database == "" { + c = db.C(ref.Collection) + } else { + c = db.Session.DB(ref.Database).C(ref.Collection) + } + return c.FindId(ref.Id) +} + +// FindRef returns a query that looks for the document in the provided +// reference. For a DBRef to be resolved correctly at the session level +// it must necessarily have the optional DB field defined. +// +// See also the DBRef type and the FindRef method on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (s *Session) FindRef(ref *DBRef) *Query { + if ref.Database == "" { + panic(errors.New(fmt.Sprintf("Can't resolve database for %#v", ref))) + } + c := s.DB(ref.Database).C(ref.Collection) + return c.FindId(ref.Id) +} + +// CollectionNames returns the collection names present in the db database. +func (db *Database) CollectionNames() (names []string, err error) { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + cloned := db.Session.nonEventual() + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + // Try with a command. + var result struct { + Collections []bson.Raw + Cursor cursorData + } + err = db.With(cloned).Run(bson.D{{"listCollections", 1}, {"cursor", bson.D{{"batchSize", batchSize}}}}, &result) + if err == nil { + firstBatch := result.Collections + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + var iter *Iter + ns := strings.SplitN(result.Cursor.NS, ".", 2) + if len(ns) < 2 { + iter = db.With(cloned).C("").NewIter(nil, firstBatch, result.Cursor.Id, nil) + } else { + iter = cloned.DB(ns[0]).C(ns[1]).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } + var coll struct{ Name string } + for iter.Next(&coll) { + names = append(names, coll.Name) + } + if err := iter.Close(); err != nil { + return nil, err + } + sort.Strings(names) + return names, err + } + if err != nil && !isNoCmd(err) { + return nil, err + } + + // Command not yet supported. Query the database instead. + nameIndex := len(db.Name) + 1 + iter := db.C("system.namespaces").Find(nil).Iter() + var coll struct{ Name string } + for iter.Next(&coll) { + if strings.Index(coll.Name, "$") < 0 || strings.Index(coll.Name, ".oplog.$") >= 0 { + names = append(names, coll.Name[nameIndex:]) + } + } + if err := iter.Close(); err != nil { + return nil, err + } + sort.Strings(names) + return names, nil +} + +type dbNames struct { + Databases []struct { + Name string + Empty bool + } +} + +// DatabaseNames returns the names of non-empty databases present in the cluster. +func (s *Session) DatabaseNames() (names []string, err error) { + var result dbNames + err = s.Run("listDatabases", &result) + if err != nil { + return nil, err + } + for _, db := range result.Databases { + if !db.Empty { + names = append(names, db.Name) + } + } + sort.Strings(names) + return names, nil +} + +// Iter executes the query and returns an iterator capable of going over all +// the results. Results will be returned in batches of configurable +// size (see the Batch method) and more documents will be requested when a +// configurable number of documents is iterated over (see the Prefetch method). +func (q *Query) Iter() *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + limit := q.limit + q.m.Unlock() + + iter := &Iter{ + session: session, + prefetch: prefetch, + limit: limit, + timeout: -1, + } + iter.gotReply.L = &iter.m + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + return iter + } + defer socket.Release() + + session.prepareQuery(&op) + op.replyFunc = iter.op.replyFunc + + if prepareFindOp(socket, &op, limit) { + iter.findCmd = true + } + + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query is already out and it may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + + return iter +} + +// Tail returns a tailable iterator. Unlike a normal iterator, a +// tailable iterator may wait for new values to be inserted in the +// collection once the end of the current result set is reached, +// A tailable iterator may only be used with capped collections. +// +// The timeout parameter indicates how long Next will block waiting +// for a result before timing out. If set to -1, Next will not +// timeout, and will continue waiting for a result for as long as +// the cursor is valid and the session is not closed. If set to 0, +// Next times out as soon as it reaches the end of the result set. +// Otherwise, Next will wait for at least the given number of +// seconds for a new document to be available before timing out. +// +// On timeouts, Next will unblock and return false, and the Timeout +// method will return true if called. In these cases, Next may still +// be called again on the same iterator to check if a new value is +// available at the current cursor position, and again it will block +// according to the specified timeoutSecs. If the cursor becomes +// invalid, though, both Next and Timeout will return false and +// the query must be restarted. +// +// The following example demonstrates timeout handling and query +// restarting: +// +// iter := collection.Find(nil).Sort("$natural").Tail(5 * time.Second) +// for { +// for iter.Next(&result) { +// fmt.Println(result.Id) +// lastId = result.Id +// } +// if iter.Err() != nil { +// return iter.Close() +// } +// if iter.Timeout() { +// continue +// } +// query := collection.Find(bson.M{"_id": bson.M{"$gt": lastId}}) +// iter = query.Sort("$natural").Tail(5 * time.Second) +// } +// iter.Close() +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Tailable+Cursors +// http://www.mongodb.org/display/DOCS/Capped+Collections +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Tail(timeout time.Duration) *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + q.m.Unlock() + + iter := &Iter{session: session, prefetch: prefetch} + iter.gotReply.L = &iter.m + iter.timeout = timeout + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + session.prepareQuery(&op) + op.replyFunc = iter.op.replyFunc + op.flags |= flagTailable | flagAwaitData + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + } else { + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query is already out and it may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + socket.Release() + } + return iter +} + +func (s *Session) prepareQuery(op *queryOp) { + s.m.RLock() + op.mode = s.consistency + if s.slaveOk { + op.flags |= flagSlaveOk + } + s.m.RUnlock() + return +} + +// Err returns nil if no errors happened during iteration, or the actual +// error otherwise. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type, and includes the Err message and the Code. +func (iter *Iter) Err() error { + iter.m.Lock() + err := iter.err + iter.m.Unlock() + if err == ErrNotFound { + return nil + } + return err +} + +// Close kills the server cursor used by the iterator, if any, and returns +// nil if no errors happened during iteration, or the actual error otherwise. +// +// Server cursors are automatically closed at the end of an iteration, which +// means close will do nothing unless the iteration was interrupted before +// the server finished sending results to the driver. If Close is not called +// in such a situation, the cursor will remain available at the server until +// the default cursor timeout period is reached. No further problems arise. +// +// Close is idempotent. That means it can be called repeatedly and will +// return the same result every time. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type. +func (iter *Iter) Close() error { + iter.m.Lock() + cursorId := iter.op.cursorId + iter.op.cursorId = 0 + err := iter.err + iter.m.Unlock() + if cursorId == 0 { + if err == ErrNotFound { + return nil + } + return err + } + socket, err := iter.acquireSocket() + if err == nil { + // TODO Batch kills. + err = socket.Query(&killCursorsOp{[]int64{cursorId}}) + socket.Release() + } + + iter.m.Lock() + if err != nil && (iter.err == nil || iter.err == ErrNotFound) { + iter.err = err + } else if iter.err != ErrNotFound { + err = iter.err + } + iter.m.Unlock() + return err +} + +// Timeout returns true if Next returned false due to a timeout of +// a tailable cursor. In those cases, Next may be called again to continue +// the iteration at the previous cursor position. +func (iter *Iter) Timeout() bool { + iter.m.Lock() + result := iter.timedout + iter.m.Unlock() + return result +} + +// Next retrieves the next document from the result set, blocking if necessary. +// This method will also automatically retrieve another batch of documents from +// the server when the current one is exhausted, or before that in background +// if pre-fetching is enabled (see the Query.Prefetch and Session.SetPrefetch +// methods). +// +// Next returns true if a document was successfully unmarshalled onto result, +// and false at the end of the result set or if an error happened. +// When Next returns false, the Err method should be called to verify if +// there was an error during iteration. +// +// For example: +// +// iter := collection.Find(nil).Iter() +// for iter.Next(&result) { +// fmt.Printf("Result: %v\n", result.Id) +// } +// if err := iter.Close(); err != nil { +// return err +// } +// +func (iter *Iter) Next(result interface{}) bool { + iter.m.Lock() + iter.timedout = false + timeout := time.Time{} + for iter.err == nil && iter.docData.Len() == 0 && (iter.docsToReceive > 0 || iter.op.cursorId != 0) { + if iter.docsToReceive == 0 { + if iter.timeout >= 0 { + if timeout.IsZero() { + timeout = time.Now().Add(iter.timeout) + } + if time.Now().After(timeout) { + iter.timedout = true + iter.m.Unlock() + return false + } + } + iter.getMore() + if iter.err != nil { + break + } + } + iter.gotReply.Wait() + } + + // Exhaust available data before reporting any errors. + if docData, ok := iter.docData.Pop().([]byte); ok { + close := false + if iter.limit > 0 { + iter.limit-- + if iter.limit == 0 { + if iter.docData.Len() > 0 { + iter.m.Unlock() + panic(fmt.Errorf("data remains after limit exhausted: %d", iter.docData.Len())) + } + iter.err = ErrNotFound + close = true + } + } + if iter.op.cursorId != 0 && iter.err == nil { + iter.docsBeforeMore-- + if iter.docsBeforeMore == -1 { + iter.getMore() + } + } + iter.m.Unlock() + + if close { + iter.Close() + } + err := bson.Unmarshal(docData, result) + if err != nil { + debugf("Iter %p document unmarshaling failed: %#v", iter, err) + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + debugf("Iter %p document unmarshaled: %#v", iter, result) + // XXX Only have to check first document for a query error? + err = checkQueryError(iter.op.collection, docData) + if err != nil { + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + return true + } else if iter.err != nil { + debugf("Iter %p returning false: %s", iter, iter.err) + iter.m.Unlock() + return false + } else if iter.op.cursorId == 0 { + iter.err = ErrNotFound + debugf("Iter %p exhausted with cursor=0", iter) + iter.m.Unlock() + return false + } + + panic("unreachable") +} + +// All retrieves all documents from the result set into the provided slice +// and closes the iterator. +// +// The result argument must necessarily be the address for a slice. The slice +// may be nil or previously allocated. +// +// WARNING: Obviously, All must not be used with result sets that may be +// potentially large, since it may consume all memory until the system +// crashes. Consider building the query with a Limit clause to ensure the +// result size is bounded. +// +// For instance: +// +// var result []struct{ Value int } +// iter := collection.Find(nil).Limit(100).Iter() +// err := iter.All(&result) +// if err != nil { +// return err +// } +// +func (iter *Iter) All(result interface{}) error { + resultv := reflect.ValueOf(result) + if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice { + panic("result argument must be a slice address") + } + slicev := resultv.Elem() + slicev = slicev.Slice(0, slicev.Cap()) + elemt := slicev.Type().Elem() + i := 0 + for { + if slicev.Len() == i { + elemp := reflect.New(elemt) + if !iter.Next(elemp.Interface()) { + break + } + slicev = reflect.Append(slicev, elemp.Elem()) + slicev = slicev.Slice(0, slicev.Cap()) + } else { + if !iter.Next(slicev.Index(i).Addr().Interface()) { + break + } + } + i++ + } + resultv.Elem().Set(slicev.Slice(0, i)) + return iter.Close() +} + +// All works like Iter.All. +func (q *Query) All(result interface{}) error { + return q.Iter().All(result) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (q *Query) For(result interface{}, f func() error) error { + return q.Iter().For(result, f) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (iter *Iter) For(result interface{}, f func() error) (err error) { + valid := false + v := reflect.ValueOf(result) + if v.Kind() == reflect.Ptr { + v = v.Elem() + switch v.Kind() { + case reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice: + valid = v.IsNil() + } + } + if !valid { + panic("For needs a pointer to nil reference value. See the documentation.") + } + zero := reflect.Zero(v.Type()) + for { + v.Set(zero) + if !iter.Next(result) { + break + } + err = f() + if err != nil { + return err + } + } + return iter.Err() +} + +// acquireSocket acquires a socket from the same server that the iterator +// cursor was obtained from. +// +// WARNING: This method must not be called with iter.m locked. Acquiring the +// socket depends on the cluster sync loop, and the cluster sync loop might +// attempt actions which cause replyFunc to be called, inducing a deadlock. +func (iter *Iter) acquireSocket() (*mongoSocket, error) { + socket, err := iter.session.acquireSocket(true) + if err != nil { + return nil, err + } + if socket.Server() != iter.server { + // Socket server changed during iteration. This may happen + // with Eventual sessions, if a Refresh is done, or if a + // monotonic session gets a write and shifts from secondary + // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() + sockTimeout := iter.session.sockTimeout + iter.session.m.Unlock() + socket.Release() + socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + if err != nil { + return nil, err + } + err := iter.session.socketLogin(socket) + if err != nil { + socket.Release() + return nil, err + } + } + return socket, nil +} + +func (iter *Iter) getMore() { + // Increment now so that unlocking the iterator won't cause a + // different goroutine to get here as well. + iter.docsToReceive++ + iter.m.Unlock() + socket, err := iter.acquireSocket() + iter.m.Lock() + if err != nil { + iter.err = err + return + } + defer socket.Release() + + debugf("Iter %p requesting more documents", iter) + if iter.limit > 0 { + // The -1 below accounts for the fact docsToReceive was incremented above. + limit := iter.limit - int32(iter.docsToReceive-1) - int32(iter.docData.Len()) + if limit < iter.op.limit { + iter.op.limit = limit + } + } + var op interface{} + if iter.findCmd { + op = iter.getMoreCmd() + } else { + op = &iter.op + } + if err := socket.Query(op); err != nil { + iter.docsToReceive-- + iter.err = err + } +} + +func (iter *Iter) getMoreCmd() *queryOp { + // TODO: Define the query statically in the Iter type, next to getMoreOp. + nameDot := strings.Index(iter.op.collection, ".") + if nameDot < 0 { + panic("invalid query collection name: " + iter.op.collection) + } + + getMore := getMoreCmd{ + CursorId: iter.op.cursorId, + Collection: iter.op.collection[nameDot+1:], + BatchSize: iter.op.limit, + } + + var op queryOp + op.collection = iter.op.collection[:nameDot] + ".$cmd" + op.query = &getMore + op.limit = -1 + op.replyFunc = iter.op.replyFunc + return &op +} + +type countCmd struct { + Count string + Query interface{} + Limit int32 ",omitempty" + Skip int32 ",omitempty" +} + +// Count returns the total number of documents in the result set. +func (q *Query) Count() (n int, err error) { + q.m.Lock() + session := q.session + op := q.op + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return 0, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + query := op.query + if query == nil { + query = bson.D{} + } + result := struct{ N int }{} + err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip}, &result) + return result.N, err +} + +// Count returns the total number of documents in the collection. +func (c *Collection) Count() (n int, err error) { + return c.Find(nil).Count() +} + +type distinctCmd struct { + Collection string "distinct" + Key string + Query interface{} ",omitempty" +} + +// Distinct unmarshals into result the list of distinct values for the given key. +// +// For example: +// +// var result []int +// err := collection.Find(bson.M{"gender": "F"}).Distinct("age", &result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Aggregation +// +func (q *Query) Distinct(key string, result interface{}) error { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + var doc struct{ Values bson.Raw } + err := session.DB(dbname).Run(distinctCmd{cname, key, op.query}, &doc) + if err != nil { + return err + } + return doc.Values.Unmarshal(result) +} + +type mapReduceCmd struct { + Collection string "mapreduce" + Map string ",omitempty" + Reduce string ",omitempty" + Finalize string ",omitempty" + Limit int32 ",omitempty" + Out interface{} + Query interface{} ",omitempty" + Sort interface{} ",omitempty" + Scope interface{} ",omitempty" + Verbose bool ",omitempty" +} + +type mapReduceResult struct { + Results bson.Raw + Result bson.Raw + TimeMillis int64 "timeMillis" + Counts struct{ Input, Emit, Output int } + Ok bool + Err string + Timing *MapReduceTime +} + +type MapReduce struct { + Map string // Map Javascript function code (required) + Reduce string // Reduce Javascript function code (required) + Finalize string // Finalize Javascript function code (optional) + Out interface{} // Output collection name or document. If nil, results are inlined into the result parameter. + Scope interface{} // Optional global scope for Javascript functions + Verbose bool +} + +type MapReduceInfo struct { + InputCount int // Number of documents mapped + EmitCount int // Number of times reduce called emit + OutputCount int // Number of documents in resulting collection + Database string // Output database, if results are not inlined + Collection string // Output collection, if results are not inlined + Time int64 // Time to run the job, in nanoseconds + VerboseTime *MapReduceTime // Only defined if Verbose was true +} + +type MapReduceTime struct { + Total int64 // Total time, in nanoseconds + Map int64 "mapTime" // Time within map function, in nanoseconds + EmitLoop int64 "emitLoop" // Time within the emit/map loop, in nanoseconds +} + +// MapReduce executes a map/reduce job for documents covered by the query. +// That kind of job is suitable for very flexible bulk aggregation of data +// performed at the server side via Javascript functions. +// +// Results from the job may be returned as a result of the query itself +// through the result parameter in case they'll certainly fit in memory +// and in a single document. If there's the possibility that the amount +// of data might be too large, results must be stored back in an alternative +// collection or even a separate database, by setting the Out field of the +// provided MapReduce job. In that case, provide nil as the result parameter. +// +// These are some of the ways to set Out: +// +// nil +// Inline results into the result parameter. +// +// bson.M{"replace": "mycollection"} +// The output will be inserted into a collection which replaces any +// existing collection with the same name. +// +// bson.M{"merge": "mycollection"} +// This option will merge new data into the old output collection. In +// other words, if the same key exists in both the result set and the +// old collection, the new key will overwrite the old one. +// +// bson.M{"reduce": "mycollection"} +// If documents exist for a given key in the result set and in the old +// collection, then a reduce operation (using the specified reduce +// function) will be performed on the two values and the result will be +// written to the output collection. If a finalize function was +// provided, this will be run after the reduce as well. +// +// bson.M{...., "db": "mydb"} +// Any of the above options can have the "db" key included for doing +// the respective action in a separate database. +// +// The following is a trivial example which will count the number of +// occurrences of a field named n on each document in a collection, and +// will return results inline: +// +// job := &mgo.MapReduce{ +// Map: "function() { emit(this.n, 1) }", +// Reduce: "function(key, values) { return Array.sum(values) }", +// } +// var result []struct { Id int "_id"; Value int } +// _, err := collection.Find(nil).MapReduce(job, &result) +// if err != nil { +// return err +// } +// for _, item := range result { +// fmt.Println(item.Value) +// } +// +// This function is compatible with MongoDB 1.7.4+. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/MapReduce +// +func (q *Query) MapReduce(job *MapReduce, result interface{}) (info *MapReduceInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := mapReduceCmd{ + Collection: cname, + Map: job.Map, + Reduce: job.Reduce, + Finalize: job.Finalize, + Out: fixMROut(job.Out), + Scope: job.Scope, + Verbose: job.Verbose, + Query: op.query, + Sort: op.options.OrderBy, + Limit: limit, + } + + if cmd.Out == nil { + cmd.Out = bson.D{{"inline", 1}} + } + + var doc mapReduceResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + return nil, err + } + if doc.Err != "" { + return nil, errors.New(doc.Err) + } + + info = &MapReduceInfo{ + InputCount: doc.Counts.Input, + EmitCount: doc.Counts.Emit, + OutputCount: doc.Counts.Output, + Time: doc.TimeMillis * 1e6, + } + + if doc.Result.Kind == 0x02 { + err = doc.Result.Unmarshal(&info.Collection) + info.Database = dbname + } else if doc.Result.Kind == 0x03 { + var v struct{ Collection, Db string } + err = doc.Result.Unmarshal(&v) + info.Collection = v.Collection + info.Database = v.Db + } + + if doc.Timing != nil { + info.VerboseTime = doc.Timing + info.VerboseTime.Total *= 1e6 + info.VerboseTime.Map *= 1e6 + info.VerboseTime.EmitLoop *= 1e6 + } + + if err != nil { + return nil, err + } + if result != nil { + return info, doc.Results.Unmarshal(result) + } + return info, nil +} + +// The "out" option in the MapReduce command must be ordered. This was +// found after the implementation was accepting maps for a long time, +// so rather than breaking the API, we'll fix the order if necessary. +// Details about the order requirement may be seen in MongoDB's code: +// +// http://goo.gl/L8jwJX +// +func fixMROut(out interface{}) interface{} { + outv := reflect.ValueOf(out) + if outv.Kind() != reflect.Map || outv.Type().Key() != reflect.TypeOf("") { + return out + } + outs := make(bson.D, outv.Len()) + + outTypeIndex := -1 + for i, k := range outv.MapKeys() { + ks := k.String() + outs[i].Name = ks + outs[i].Value = outv.MapIndex(k).Interface() + switch ks { + case "normal", "replace", "merge", "reduce", "inline": + outTypeIndex = i + } + } + if outTypeIndex > 0 { + outs[0], outs[outTypeIndex] = outs[outTypeIndex], outs[0] + } + return outs +} + +// Change holds fields for running a findAndModify MongoDB command via +// the Query.Apply method. +type Change struct { + Update interface{} // The update document + Upsert bool // Whether to insert in case the document isn't found + Remove bool // Whether to remove the document found rather than updating + ReturnNew bool // Should the modified document be returned rather than the old one +} + +type findModifyCmd struct { + Collection string "findAndModify" + Query, Update, Sort, Fields interface{} ",omitempty" + Upsert, Remove, New bool ",omitempty" +} + +type valueResult struct { + Value bson.Raw + LastError LastError "lastErrorObject" +} + +// Apply runs the findAndModify MongoDB command, which allows updating, upserting +// or removing a document matching a query and atomically returning either the old +// version (the default) or the new version of the document (when ReturnNew is true). +// If no objects are found Apply returns ErrNotFound. +// +// The Sort and Select query methods affect the result of Apply. In case +// multiple documents match the query, Sort enables selecting which document to +// act upon by ordering it first. Select enables retrieving only a selection +// of fields of the new or old document. +// +// This simple example increments a counter and prints its new value: +// +// change := mgo.Change{ +// Update: bson.M{"$inc": bson.M{"n": 1}}, +// ReturnNew: true, +// } +// info, err = col.Find(M{"_id": id}).Apply(change, &doc) +// fmt.Println(doc.N) +// +// This method depends on MongoDB >= 2.0 to work properly. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/findAndModify+Command +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := findModifyCmd{ + Collection: cname, + Update: change.Update, + Upsert: change.Upsert, + Remove: change.Remove, + New: change.ReturnNew, + Query: op.query, + Sort: op.options.OrderBy, + Fields: op.selector, + } + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + var doc valueResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + if qerr, ok := err.(*QueryError); ok && qerr.Message == "No matching object found" { + return nil, ErrNotFound + } + return nil, err + } + if doc.LastError.N == 0 { + return nil, ErrNotFound + } + if doc.Value.Kind != 0x0A && result != nil { + err = doc.Value.Unmarshal(result) + if err != nil { + return nil, err + } + } + info = &ChangeInfo{} + lerr := &doc.LastError + if lerr.UpdatedExisting { + info.Updated = lerr.N + info.Matched = lerr.N + } else if change.Remove { + info.Removed = lerr.N + info.Matched = lerr.N + } else if change.Upsert { + info.UpsertedId = lerr.UpsertedId + } + return info, nil +} + +// The BuildInfo type encapsulates details about the running MongoDB server. +// +// Note that the VersionArray field was introduced in MongoDB 2.0+, but it is +// internally assembled from the Version information for previous versions. +// In both cases, VersionArray is guaranteed to have at least 4 entries. +type BuildInfo struct { + Version string + VersionArray []int `bson:"versionArray"` // On MongoDB 2.0+; assembled from Version otherwise + GitVersion string `bson:"gitVersion"` + OpenSSLVersion string `bson:"OpenSSLVersion"` + SysInfo string `bson:"sysInfo"` // Deprecated and empty on MongoDB 3.2+. + Bits int + Debug bool + MaxObjectSize int `bson:"maxBsonObjectSize"` +} + +// VersionAtLeast returns whether the BuildInfo version is greater than or +// equal to the provided version number. If more than one number is +// provided, numbers will be considered as major, minor, and so on. +func (bi *BuildInfo) VersionAtLeast(version ...int) bool { + for i := range version { + if i == len(bi.VersionArray) { + return false + } + if bi.VersionArray[i] < version[i] { + return false + } + } + return true +} + +// BuildInfo retrieves the version and other details about the +// running MongoDB server. +func (s *Session) BuildInfo() (info BuildInfo, err error) { + err = s.Run(bson.D{{"buildInfo", "1"}}, &info) + if len(info.VersionArray) == 0 { + for _, a := range strings.Split(info.Version, ".") { + i, err := strconv.Atoi(a) + if err != nil { + break + } + info.VersionArray = append(info.VersionArray, i) + } + } + for len(info.VersionArray) < 4 { + info.VersionArray = append(info.VersionArray, 0) + } + if i := strings.IndexByte(info.GitVersion, ' '); i >= 0 { + // Strip off the " modules: enterprise" suffix. This is a _git version_. + // That information may be moved to another field if people need it. + info.GitVersion = info.GitVersion[:i] + } + if info.SysInfo == "deprecated" { + info.SysInfo = "" + } + return +} + +// --------------------------------------------------------------------------- +// Internal session handling helpers. + +func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { + + // Read-only lock to check for previously reserved socket. + s.m.RLock() + // If there is a slave socket reserved and its use is acceptable, take it as long + // as there isn't a master socket which would be preferred by the read preference mode. + if s.slaveSocket != nil && s.slaveOk && slaveOk && (s.masterSocket == nil || s.consistency != PrimaryPreferred && s.consistency != Monotonic) { + socket := s.slaveSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + if s.masterSocket != nil { + socket := s.masterSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + s.m.RUnlock() + + // No go. We may have to request a new socket and change the session, + // so try again but with an exclusive lock now. + s.m.Lock() + defer s.m.Unlock() + + if s.slaveSocket != nil && s.slaveOk && slaveOk && (s.masterSocket == nil || s.consistency != PrimaryPreferred && s.consistency != Monotonic) { + s.slaveSocket.Acquire() + return s.slaveSocket, nil + } + if s.masterSocket != nil { + s.masterSocket.Acquire() + return s.masterSocket, nil + } + + // Still not good. We need a new socket. + sock, err := s.cluster().AcquireSocket(s.consistency, slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit) + if err != nil { + return nil, err + } + + // Authenticate the new socket. + if err = s.socketLogin(sock); err != nil { + sock.Release() + return nil, err + } + + // Keep track of the new socket, if necessary. + // Note that, as a special case, if the Eventual session was + // not refreshed (s.slaveSocket != nil), it means the developer + // asked to preserve an existing reserved socket, so we'll + // keep a master one around too before a Refresh happens. + if s.consistency != Eventual || s.slaveSocket != nil { + s.setSocket(sock) + } + + // Switch over a Monotonic session to the master. + if !slaveOk && s.consistency == Monotonic { + s.slaveOk = false + } + + return sock, nil +} + +// setSocket binds socket to this section. +func (s *Session) setSocket(socket *mongoSocket) { + info := socket.Acquire() + if info.Master { + if s.masterSocket != nil { + panic("setSocket(master) with existing master socket reserved") + } + s.masterSocket = socket + } else { + if s.slaveSocket != nil { + panic("setSocket(slave) with existing slave socket reserved") + } + s.slaveSocket = socket + } +} + +// unsetSocket releases any slave and/or master sockets reserved. +func (s *Session) unsetSocket() { + if s.masterSocket != nil { + s.masterSocket.Release() + } + if s.slaveSocket != nil { + s.slaveSocket.Release() + } + s.masterSocket = nil + s.slaveSocket = nil +} + +func (iter *Iter) replyFunc() replyFunc { + return func(err error, op *replyOp, docNum int, docData []byte) { + iter.m.Lock() + iter.docsToReceive-- + if err != nil { + iter.err = err + debugf("Iter %p received an error: %s", iter, err.Error()) + } else if docNum == -1 { + debugf("Iter %p received no documents (cursor=%d).", iter, op.cursorId) + if op != nil && op.cursorId != 0 { + // It's a tailable cursor. + iter.op.cursorId = op.cursorId + } else if op != nil && op.cursorId == 0 && op.flags&1 == 1 { + // Cursor likely timed out. + iter.err = ErrCursor + } else { + iter.err = ErrNotFound + } + } else if iter.findCmd { + debugf("Iter %p received reply document %d/%d (cursor=%d)", iter, docNum+1, int(op.replyDocs), op.cursorId) + var findReply struct { + Ok bool + Code int + Errmsg string + Cursor cursorData + } + if err := bson.Unmarshal(docData, &findReply); err != nil { + iter.err = err + } else if !findReply.Ok && findReply.Errmsg != "" { + iter.err = &QueryError{Code: findReply.Code, Message: findReply.Errmsg} + } else if len(findReply.Cursor.FirstBatch) == 0 && len(findReply.Cursor.NextBatch) == 0 { + iter.err = ErrNotFound + } else { + batch := findReply.Cursor.FirstBatch + if len(batch) == 0 { + batch = findReply.Cursor.NextBatch + } + rdocs := len(batch) + for _, raw := range batch { + iter.docData.Push(raw.Data) + } + iter.docsToReceive = 0 + docsToProcess := iter.docData.Len() + if iter.limit == 0 || int32(docsToProcess) < iter.limit { + iter.docsBeforeMore = docsToProcess - int(iter.prefetch*float64(rdocs)) + } else { + iter.docsBeforeMore = -1 + } + iter.op.cursorId = findReply.Cursor.Id + } + } else { + rdocs := int(op.replyDocs) + if docNum == 0 { + iter.docsToReceive += rdocs - 1 + docsToProcess := iter.docData.Len() + rdocs + if iter.limit == 0 || int32(docsToProcess) < iter.limit { + iter.docsBeforeMore = docsToProcess - int(iter.prefetch*float64(rdocs)) + } else { + iter.docsBeforeMore = -1 + } + iter.op.cursorId = op.cursorId + } + debugf("Iter %p received reply document %d/%d (cursor=%d)", iter, docNum+1, rdocs, op.cursorId) + iter.docData.Push(docData) + } + iter.gotReply.Broadcast() + iter.m.Unlock() + } +} + +type writeCmdResult struct { + Ok bool + N int + NModified int `bson:"nModified"` + Upserted []struct { + Index int + Id interface{} `_id` + } + ConcernError writeConcernError `bson:"writeConcernError"` + Errors []writeCmdError `bson:"writeErrors"` +} + +type writeConcernError struct { + Code int + ErrMsg string +} + +type writeCmdError struct { + Index int + Code int + ErrMsg string +} + +func (r *writeCmdResult) BulkErrorCases() []BulkErrorCase { + ecases := make([]BulkErrorCase, len(r.Errors)) + for i, err := range r.Errors { + ecases[i] = BulkErrorCase{err.Index, &QueryError{Code: err.Code, Message: err.ErrMsg}} + } + return ecases +} + +// writeOp runs the given modifying operation, potentially followed up +// by a getLastError command in case the session is in safe mode. The +// LastError result is made available in lerr, and if lerr.Err is set it +// will also be returned as err. +func (c *Collection) writeOp(op interface{}, ordered bool) (lerr *LastError, err error) { + s := c.Database.Session + socket, err := s.acquireSocket(c.Database.Name == "local") + if err != nil { + return nil, err + } + defer socket.Release() + + s.m.RLock() + safeOp := s.safeOp + bypassValidation := s.bypassValidation + s.m.RUnlock() + + if socket.ServerInfo().MaxWireVersion >= 2 { + // Servers with a more recent write protocol benefit from write commands. + if op, ok := op.(*insertOp); ok && len(op.documents) > 1000 { + var lerr LastError + + // Maximum batch size is 1000. Must split out in separate operations for compatibility. + all := op.documents + for i := 0; i < len(all); i += 1000 { + l := i + 1000 + if l > len(all) { + l = len(all) + } + op.documents = all[i:l] + oplerr, err := c.writeOpCommand(socket, safeOp, op, ordered, bypassValidation) + lerr.N += oplerr.N + lerr.modified += oplerr.modified + if err != nil { + for ei := range lerr.ecases { + oplerr.ecases[ei].Index += i + } + lerr.ecases = append(lerr.ecases, oplerr.ecases...) + if op.flags&1 == 0 { + return &lerr, err + } + } + } + if len(lerr.ecases) != 0 { + return &lerr, lerr.ecases[0].Err + } + return &lerr, nil + } + return c.writeOpCommand(socket, safeOp, op, ordered, bypassValidation) + } else if updateOps, ok := op.(bulkUpdateOp); ok { + var lerr LastError + for i, updateOp := range updateOps { + oplerr, err := c.writeOpQuery(socket, safeOp, updateOp, ordered) + lerr.N += oplerr.N + lerr.modified += oplerr.modified + if err != nil { + lerr.ecases = append(lerr.ecases, BulkErrorCase{i, err}) + if ordered { + break + } + } + } + if len(lerr.ecases) != 0 { + return &lerr, lerr.ecases[0].Err + } + return &lerr, nil + } else if deleteOps, ok := op.(bulkDeleteOp); ok { + var lerr LastError + for i, deleteOp := range deleteOps { + oplerr, err := c.writeOpQuery(socket, safeOp, deleteOp, ordered) + lerr.N += oplerr.N + lerr.modified += oplerr.modified + if err != nil { + lerr.ecases = append(lerr.ecases, BulkErrorCase{i, err}) + if ordered { + break + } + } + } + if len(lerr.ecases) != 0 { + return &lerr, lerr.ecases[0].Err + } + return &lerr, nil + } + return c.writeOpQuery(socket, safeOp, op, ordered) +} + +func (c *Collection) writeOpQuery(socket *mongoSocket, safeOp *queryOp, op interface{}, ordered bool) (lerr *LastError, err error) { + if safeOp == nil { + return nil, socket.Query(op) + } + + var mutex sync.Mutex + var replyData []byte + var replyErr error + mutex.Lock() + query := *safeOp // Copy the data. + query.collection = c.Database.Name + ".$cmd" + query.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + replyData = docData + replyErr = err + mutex.Unlock() + } + err = socket.Query(op, &query) + if err != nil { + return nil, err + } + mutex.Lock() // Wait. + if replyErr != nil { + return nil, replyErr // XXX TESTME + } + if hasErrMsg(replyData) { + // Looks like getLastError itself failed. + err = checkQueryError(query.collection, replyData) + if err != nil { + return nil, err + } + } + result := &LastError{} + bson.Unmarshal(replyData, &result) + debugf("Result from writing query: %#v", result) + if result.Err != "" { + result.ecases = []BulkErrorCase{{Index: 0, Err: result}} + if insert, ok := op.(*insertOp); ok && len(insert.documents) > 1 { + result.ecases[0].Index = -1 + } + return result, result + } + // With MongoDB <2.6 we don't know how many actually changed, so make it the same as matched. + result.modified = result.N + return result, nil +} + +func (c *Collection) writeOpCommand(socket *mongoSocket, safeOp *queryOp, op interface{}, ordered, bypassValidation bool) (lerr *LastError, err error) { + var writeConcern interface{} + if safeOp == nil { + writeConcern = bson.D{{"w", 0}} + } else { + writeConcern = safeOp.query.(*getLastError) + } + + var cmd bson.D + switch op := op.(type) { + case *insertOp: + // http://docs.mongodb.org/manual/reference/command/insert + cmd = bson.D{ + {"insert", c.Name}, + {"documents", op.documents}, + {"writeConcern", writeConcern}, + {"ordered", op.flags&1 == 0}, + } + case *updateOp: + // http://docs.mongodb.org/manual/reference/command/update + cmd = bson.D{ + {"update", c.Name}, + {"updates", []interface{}{op}}, + {"writeConcern", writeConcern}, + {"ordered", ordered}, + } + case bulkUpdateOp: + // http://docs.mongodb.org/manual/reference/command/update + cmd = bson.D{ + {"update", c.Name}, + {"updates", op}, + {"writeConcern", writeConcern}, + {"ordered", ordered}, + } + case *deleteOp: + // http://docs.mongodb.org/manual/reference/command/delete + cmd = bson.D{ + {"delete", c.Name}, + {"deletes", []interface{}{op}}, + {"writeConcern", writeConcern}, + {"ordered", ordered}, + } + case bulkDeleteOp: + // http://docs.mongodb.org/manual/reference/command/delete + cmd = bson.D{ + {"delete", c.Name}, + {"deletes", op}, + {"writeConcern", writeConcern}, + {"ordered", ordered}, + } + } + if bypassValidation { + cmd = append(cmd, bson.DocElem{"bypassDocumentValidation", true}) + } + + var result writeCmdResult + err = c.Database.run(socket, cmd, &result) + debugf("Write command result: %#v (err=%v)", result, err) + ecases := result.BulkErrorCases() + lerr = &LastError{ + UpdatedExisting: result.N > 0 && len(result.Upserted) == 0, + N: result.N, + + modified: result.NModified, + ecases: ecases, + } + if len(result.Upserted) > 0 { + lerr.UpsertedId = result.Upserted[0].Id + } + if len(result.Errors) > 0 { + e := result.Errors[0] + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = lerr + } else if result.ConcernError.Code != 0 { + e := result.ConcernError + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = lerr + } + + if err == nil && safeOp == nil { + return nil, nil + } + return lerr, err +} + +func hasErrMsg(d []byte) bool { + l := len(d) + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + return true + } + } + return false +} diff --git a/vendor/gopkg.in/mgo.v2/socket.go b/vendor/gopkg.in/mgo.v2/socket.go new file mode 100644 index 0000000..8891dd5 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/socket.go @@ -0,0 +1,707 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) + +type mongoSocket struct { + sync.Mutex + server *mongoServer // nil when cached + conn net.Conn + timeout time.Duration + addr string // For debugging only. + nextRequestId uint32 + replyFuncs map[uint32]replyFunc + references int + creds []Credential + logout []Credential + cachedNonce string + gotNonce sync.Cond + dead error + serverInfo *mongoServerInfo +} + +type queryOpFlags uint32 + +const ( + _ queryOpFlags = 1 << iota + flagTailable + flagSlaveOk + flagLogReplay + flagNoCursorTimeout + flagAwaitData +) + +type queryOp struct { + collection string + query interface{} + skip int32 + limit int32 + selector interface{} + flags queryOpFlags + replyFunc replyFunc + + mode Mode + options queryWrapper + hasOptions bool + serverTags []bson.D +} + +type queryWrapper struct { + Query interface{} "$query" + OrderBy interface{} "$orderby,omitempty" + Hint interface{} "$hint,omitempty" + Explain bool "$explain,omitempty" + Snapshot bool "$snapshot,omitempty" + ReadPreference bson.D "$readPreference,omitempty" + MaxScan int "$maxScan,omitempty" + MaxTimeMS int "$maxTimeMS,omitempty" + Comment string "$comment,omitempty" +} + +func (op *queryOp) finalQuery(socket *mongoSocket) interface{} { + if op.flags&flagSlaveOk != 0 && socket.ServerInfo().Mongos { + var modeName string + switch op.mode { + case Strong: + modeName = "primary" + case Monotonic, Eventual: + modeName = "secondaryPreferred" + case PrimaryPreferred: + modeName = "primaryPreferred" + case Secondary: + modeName = "secondary" + case SecondaryPreferred: + modeName = "secondaryPreferred" + case Nearest: + modeName = "nearest" + default: + panic(fmt.Sprintf("unsupported read mode: %d", op.mode)) + } + op.hasOptions = true + op.options.ReadPreference = make(bson.D, 0, 2) + op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"mode", modeName}) + if len(op.serverTags) > 0 { + op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"tags", op.serverTags}) + } + } + if op.hasOptions { + if op.query == nil { + var empty bson.D + op.options.Query = empty + } else { + op.options.Query = op.query + } + debugf("final query is %#v\n", &op.options) + return &op.options + } + return op.query +} + +type getMoreOp struct { + collection string + limit int32 + cursorId int64 + replyFunc replyFunc +} + +type replyOp struct { + flags uint32 + cursorId int64 + firstDoc int32 + replyDocs int32 +} + +type insertOp struct { + collection string // "database.collection" + documents []interface{} // One or more documents to insert + flags uint32 +} + +type updateOp struct { + Collection string `bson:"-"` // "database.collection" + Selector interface{} `bson:"q"` + Update interface{} `bson:"u"` + Flags uint32 `bson:"-"` + Multi bool `bson:"multi,omitempty"` + Upsert bool `bson:"upsert,omitempty"` +} + +type deleteOp struct { + Collection string `bson:"-"` // "database.collection" + Selector interface{} `bson:"q"` + Flags uint32 `bson:"-"` + Limit int `bson:"limit"` +} + +type killCursorsOp struct { + cursorIds []int64 +} + +type requestInfo struct { + bufferPos int + replyFunc replyFunc +} + +func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { + socket := &mongoSocket{ + conn: conn, + addr: server.Addr, + server: server, + replyFuncs: make(map[uint32]replyFunc), + } + socket.gotNonce.L = &socket.Mutex + if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + panic("newSocket: InitialAcquire returned error: " + err.Error()) + } + stats.socketsAlive(+1) + debugf("Socket %p to %s: initialized", socket, socket.addr) + socket.resetNonce() + go socket.readLoop() + return socket +} + +// Server returns the server that the socket is associated with. +// It returns nil while the socket is cached in its respective server. +func (socket *mongoSocket) Server() *mongoServer { + socket.Lock() + server := socket.server + socket.Unlock() + return server +} + +// ServerInfo returns details for the server at the time the socket +// was initially acquired. +func (socket *mongoSocket) ServerInfo() *mongoServerInfo { + socket.Lock() + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// InitialAcquire obtains the first reference to the socket, either +// right after the connection is made or once a recycled socket is +// being put back in use. +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { + socket.Lock() + if socket.references > 0 { + panic("Socket acquired out of cache with references") + } + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + return dead + } + socket.references++ + socket.serverInfo = serverInfo + socket.timeout = timeout + stats.socketsInUse(+1) + stats.socketRefs(+1) + socket.Unlock() + return nil +} + +// Acquire obtains an additional reference to the socket. +// The socket will only be recycled when it's released as many +// times as it's been acquired. +func (socket *mongoSocket) Acquire() (info *mongoServerInfo) { + socket.Lock() + if socket.references == 0 { + panic("Socket got non-initial acquire with references == 0") + } + // We'll track references to dead sockets as well. + // Caller is still supposed to release the socket. + socket.references++ + stats.socketRefs(+1) + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// Release decrements a socket reference. The socket will be +// recycled once its released as many times as it's been acquired. +func (socket *mongoSocket) Release() { + socket.Lock() + if socket.references == 0 { + panic("socket.Release() with references == 0") + } + socket.references-- + stats.socketRefs(-1) + if socket.references == 0 { + stats.socketsInUse(-1) + server := socket.server + socket.Unlock() + socket.LogoutAll() + // If the socket is dead server is nil. + if server != nil { + server.RecycleSocket(socket) + } + } else { + socket.Unlock() + } +} + +// SetTimeout changes the timeout used on socket operations. +func (socket *mongoSocket) SetTimeout(d time.Duration) { + socket.Lock() + socket.timeout = d + socket.Unlock() +} + +type deadlineType int + +const ( + readDeadline deadlineType = 1 + writeDeadline deadlineType = 2 +) + +func (socket *mongoSocket) updateDeadline(which deadlineType) { + var when time.Time + if socket.timeout > 0 { + when = time.Now().Add(socket.timeout) + } + whichstr := "" + switch which { + case readDeadline | writeDeadline: + whichstr = "read/write" + socket.conn.SetDeadline(when) + case readDeadline: + whichstr = "read" + socket.conn.SetReadDeadline(when) + case writeDeadline: + whichstr = "write" + socket.conn.SetWriteDeadline(when) + default: + panic("invalid parameter to updateDeadline") + } + debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) +} + +// Close terminates the socket use. +func (socket *mongoSocket) Close() { + socket.kill(errors.New("Closed explicitly"), false) +} + +func (socket *mongoSocket) kill(err error, abend bool) { + socket.Lock() + if socket.dead != nil { + debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error()) + socket.Unlock() + return + } + logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend) + socket.dead = err + socket.conn.Close() + stats.socketsAlive(-1) + replyFuncs := socket.replyFuncs + socket.replyFuncs = make(map[uint32]replyFunc) + server := socket.server + socket.server = nil + socket.gotNonce.Broadcast() + socket.Unlock() + for _, replyFunc := range replyFuncs { + logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error()) + replyFunc(err, nil, -1, nil) + } + if abend { + server.AbendSocket(socket) + } +} + +func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) { + var wait, change sync.Mutex + var replyDone bool + var replyData []byte + var replyErr error + wait.Lock() + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + change.Lock() + if !replyDone { + replyDone = true + replyErr = err + if err == nil { + replyData = docData + } + } + change.Unlock() + wait.Unlock() + } + err = socket.Query(op) + if err != nil { + return nil, err + } + wait.Lock() + change.Lock() + data = replyData + err = replyErr + change.Unlock() + return data, err +} + +func (socket *mongoSocket) Query(ops ...interface{}) (err error) { + + if lops := socket.flushLogout(); len(lops) > 0 { + ops = append(lops, ops...) + } + + buf := make([]byte, 0, 256) + + // Serialize operations synchronously to avoid interrupting + // other goroutines while we can't really be sending data. + // Also, record id positions so that we can compute request + // ids at once later with the lock already held. + requests := make([]requestInfo, len(ops)) + requestCount := 0 + + for _, op := range ops { + debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op) + if qop, ok := op.(*queryOp); ok { + if cmd, ok := qop.query.(*findCmd); ok { + debugf("Socket %p to %s: find command: %#v", socket, socket.addr, cmd) + } + } + start := len(buf) + var replyFunc replyFunc + switch op := op.(type) { + + case *updateOp: + buf = addHeader(buf, 2001) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.Collection) + buf = addInt32(buf, int32(op.Flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector) + buf, err = addBSON(buf, op.Selector) + if err != nil { + return err + } + debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.Update) + buf, err = addBSON(buf, op.Update) + if err != nil { + return err + } + + case *insertOp: + buf = addHeader(buf, 2002) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + for _, doc := range op.documents { + debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc) + buf, err = addBSON(buf, doc) + if err != nil { + return err + } + } + + case *queryOp: + buf = addHeader(buf, 2004) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.skip) + buf = addInt32(buf, op.limit) + buf, err = addBSON(buf, op.finalQuery(socket)) + if err != nil { + return err + } + if op.selector != nil { + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + } + replyFunc = op.replyFunc + + case *getMoreOp: + buf = addHeader(buf, 2005) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.limit) + buf = addInt64(buf, op.cursorId) + replyFunc = op.replyFunc + + case *deleteOp: + buf = addHeader(buf, 2006) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.Collection) + buf = addInt32(buf, int32(op.Flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector) + buf, err = addBSON(buf, op.Selector) + if err != nil { + return err + } + + case *killCursorsOp: + buf = addHeader(buf, 2007) + buf = addInt32(buf, 0) // Reserved + buf = addInt32(buf, int32(len(op.cursorIds))) + for _, cursorId := range op.cursorIds { + buf = addInt64(buf, cursorId) + } + + default: + panic("internal error: unknown operation type") + } + + setInt32(buf, start, int32(len(buf)-start)) + + if replyFunc != nil { + request := &requests[requestCount] + request.replyFunc = replyFunc + request.bufferPos = start + requestCount++ + } + } + + // Buffer is ready for the pipe. Lock, allocate ids, and enqueue. + + socket.Lock() + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) + // XXX This seems necessary in case the session is closed concurrently + // with a query being performed, but it's not yet tested: + for i := 0; i != requestCount; i++ { + request := &requests[i] + if request.replyFunc != nil { + request.replyFunc(dead, nil, -1, nil) + } + } + return dead + } + + wasWaiting := len(socket.replyFuncs) > 0 + + // Reserve id 0 for requests which should have no responses. + requestId := socket.nextRequestId + 1 + if requestId == 0 { + requestId++ + } + socket.nextRequestId = requestId + uint32(requestCount) + for i := 0; i != requestCount; i++ { + request := &requests[i] + setInt32(buf, request.bufferPos+4, int32(requestId)) + socket.replyFuncs[requestId] = request.replyFunc + requestId++ + } + + debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf)) + stats.sentOps(len(ops)) + + socket.updateDeadline(writeDeadline) + _, err = socket.conn.Write(buf) + if !wasWaiting && requestCount > 0 { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + return err +} + +func fill(r net.Conn, b []byte) error { + l := len(b) + n, err := r.Read(b) + for n != l && err == nil { + var ni int + ni, err = r.Read(b[n:]) + n += ni + } + return err +} + +// Estimated minimum cost per socket: 1 goroutine + memory for the largest +// document ever seen. +func (socket *mongoSocket) readLoop() { + p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields + s := make([]byte, 4) + conn := socket.conn // No locking, conn never changes. + for { + err := fill(conn, p) + if err != nil { + socket.kill(err, true) + return + } + + totalLen := getInt32(p, 0) + responseTo := getInt32(p, 8) + opCode := getInt32(p, 12) + + // Don't use socket.server.Addr here. socket is not + // locked and socket.server may go away. + debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen) + + _ = totalLen + + if opCode != 1 { + socket.kill(errors.New("opcode != 1, corrupted data?"), true) + return + } + + reply := replyOp{ + flags: uint32(getInt32(p, 16)), + cursorId: getInt64(p, 20), + firstDoc: getInt32(p, 28), + replyDocs: getInt32(p, 32), + } + + stats.receivedOps(+1) + stats.receivedDocs(int(reply.replyDocs)) + + socket.Lock() + replyFunc, ok := socket.replyFuncs[uint32(responseTo)] + if ok { + delete(socket.replyFuncs, uint32(responseTo)) + } + socket.Unlock() + + if replyFunc != nil && reply.replyDocs == 0 { + replyFunc(nil, &reply, -1, nil) + } else { + for i := 0; i != int(reply.replyDocs); i++ { + err := fill(conn, s) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + b := make([]byte, int(getInt32(s, 0))) + + // copy(b, s) in an efficient way. + b[0] = s[0] + b[1] = s[1] + b[2] = s[2] + b[3] = s[3] + + err = fill(conn, b[4:]) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + if globalDebug && globalLogger != nil { + m := bson.M{} + if err := bson.Unmarshal(b, m); err == nil { + debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m) + } + } + + if replyFunc != nil { + replyFunc(nil, &reply, i, b) + } + + // XXX Do bound checking against totalLen. + } + } + + socket.Lock() + if len(socket.replyFuncs) == 0 { + // Nothing else to read for now. Disable deadline. + socket.conn.SetReadDeadline(time.Time{}) + } else { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + + // XXX Do bound checking against totalLen. + } +} + +var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +func addHeader(b []byte, opcode int) []byte { + i := len(b) + b = append(b, emptyHeader...) + // Enough for current opcodes. + b[i+12] = byte(opcode) + b[i+13] = byte(opcode >> 8) + return b +} + +func addInt32(b []byte, i int32) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24)) +} + +func addInt64(b []byte, i int64) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), + byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) +} + +func addCString(b []byte, s string) []byte { + b = append(b, []byte(s)...) + b = append(b, 0) + return b +} + +func addBSON(b []byte, doc interface{}) ([]byte, error) { + if doc == nil { + return append(b, 5, 0, 0, 0, 0), nil + } + data, err := bson.Marshal(doc) + if err != nil { + return b, err + } + return append(b, data...), nil +} + +func setInt32(b []byte, pos int, i int32) { + b[pos] = byte(i) + b[pos+1] = byte(i >> 8) + b[pos+2] = byte(i >> 16) + b[pos+3] = byte(i >> 24) +} + +func getInt32(b []byte, pos int) int32 { + return (int32(b[pos+0])) | + (int32(b[pos+1]) << 8) | + (int32(b[pos+2]) << 16) | + (int32(b[pos+3]) << 24) +} + +func getInt64(b []byte, pos int) int64 { + return (int64(b[pos+0])) | + (int64(b[pos+1]) << 8) | + (int64(b[pos+2]) << 16) | + (int64(b[pos+3]) << 24) | + (int64(b[pos+4]) << 32) | + (int64(b[pos+5]) << 40) | + (int64(b[pos+6]) << 48) | + (int64(b[pos+7]) << 56) +} diff --git a/vendor/gopkg.in/mgo.v2/stats.go b/vendor/gopkg.in/mgo.v2/stats.go new file mode 100644 index 0000000..59723e6 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/stats.go @@ -0,0 +1,147 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "sync" +) + +var stats *Stats +var statsMutex sync.Mutex + +func SetStats(enabled bool) { + statsMutex.Lock() + if enabled { + if stats == nil { + stats = &Stats{} + } + } else { + stats = nil + } + statsMutex.Unlock() +} + +func GetStats() (snapshot Stats) { + statsMutex.Lock() + snapshot = *stats + statsMutex.Unlock() + return +} + +func ResetStats() { + statsMutex.Lock() + debug("Resetting stats") + old := stats + stats = &Stats{} + // These are absolute values: + stats.Clusters = old.Clusters + stats.SocketsInUse = old.SocketsInUse + stats.SocketsAlive = old.SocketsAlive + stats.SocketRefs = old.SocketRefs + statsMutex.Unlock() + return +} + +type Stats struct { + Clusters int + MasterConns int + SlaveConns int + SentOps int + ReceivedOps int + ReceivedDocs int + SocketsAlive int + SocketsInUse int + SocketRefs int +} + +func (stats *Stats) cluster(delta int) { + if stats != nil { + statsMutex.Lock() + stats.Clusters += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) conn(delta int, master bool) { + if stats != nil { + statsMutex.Lock() + if master { + stats.MasterConns += delta + } else { + stats.SlaveConns += delta + } + statsMutex.Unlock() + } +} + +func (stats *Stats) sentOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SentOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedDocs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedDocs += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsInUse(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsInUse += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsAlive(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsAlive += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketRefs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketRefs += delta + statsMutex.Unlock() + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json index a687a53..7480eb9 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -719,6 +719,30 @@ "revisionTime": "2016-07-14T22:09:51Z" }, { + "checksumSHA1": "o+ZqT87RLprKCGKMboo5vCxdAvA=", + "path": "gopkg.in/mgo.v2", + "revision": "29cc868a5ca65f401ff318143f9408d02f4799cc", + "revisionTime": "2016-06-09T18:00:28Z" + }, + { + "checksumSHA1": "zTjQOFGy4XTv4L33Kd2FhrV+mbM=", + "path": "gopkg.in/mgo.v2/bson", + "revision": "29cc868a5ca65f401ff318143f9408d02f4799cc", + "revisionTime": "2016-06-09T18:00:28Z" + }, + { + "checksumSHA1": "CkCDTsyN2Lbj1cL8+oaYKoPpI9w=", + "path": "gopkg.in/mgo.v2/internal/sasl", + "revision": "29cc868a5ca65f401ff318143f9408d02f4799cc", + "revisionTime": "2016-06-09T18:00:28Z" + }, + { + "checksumSHA1": "+1WDRPaOphSCmRMxVPIPBV4aubc=", + "path": "gopkg.in/mgo.v2/internal/scram", + "revision": "29cc868a5ca65f401ff318143f9408d02f4799cc", + "revisionTime": "2016-06-09T18:00:28Z" + }, + { "checksumSHA1": "yV+2de12Q/t09hBGYGKEOFr1/zc=", "path": "gopkg.in/yaml.v2", "revision": "e4d366fc3c7938e2958e662b4258c7a89e1f0e3e", |