aboutsummaryrefslogtreecommitdiff
path: root/vendor/gopkg.in/mgo.v2/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/gopkg.in/mgo.v2/socket.go')
-rw-r--r--vendor/gopkg.in/mgo.v2/socket.go707
1 files changed, 707 insertions, 0 deletions
diff --git a/vendor/gopkg.in/mgo.v2/socket.go b/vendor/gopkg.in/mgo.v2/socket.go
new file mode 100644
index 0000000..8891dd5
--- /dev/null
+++ b/vendor/gopkg.in/mgo.v2/socket.go
@@ -0,0 +1,707 @@
+// mgo - MongoDB driver for Go
+//
+// Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net>
+//
+// All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice, this
+// list of conditions and the following disclaimer.
+// 2. Redistributions in binary form must reproduce the above copyright notice,
+// this list of conditions and the following disclaimer in the documentation
+// and/or other materials provided with the distribution.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+package mgo
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "gopkg.in/mgo.v2/bson"
+)
+
+type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
+
+type mongoSocket struct {
+ sync.Mutex
+ server *mongoServer // nil when cached
+ conn net.Conn
+ timeout time.Duration
+ addr string // For debugging only.
+ nextRequestId uint32
+ replyFuncs map[uint32]replyFunc
+ references int
+ creds []Credential
+ logout []Credential
+ cachedNonce string
+ gotNonce sync.Cond
+ dead error
+ serverInfo *mongoServerInfo
+}
+
+type queryOpFlags uint32
+
+const (
+ _ queryOpFlags = 1 << iota
+ flagTailable
+ flagSlaveOk
+ flagLogReplay
+ flagNoCursorTimeout
+ flagAwaitData
+)
+
+type queryOp struct {
+ collection string
+ query interface{}
+ skip int32
+ limit int32
+ selector interface{}
+ flags queryOpFlags
+ replyFunc replyFunc
+
+ mode Mode
+ options queryWrapper
+ hasOptions bool
+ serverTags []bson.D
+}
+
+type queryWrapper struct {
+ Query interface{} "$query"
+ OrderBy interface{} "$orderby,omitempty"
+ Hint interface{} "$hint,omitempty"
+ Explain bool "$explain,omitempty"
+ Snapshot bool "$snapshot,omitempty"
+ ReadPreference bson.D "$readPreference,omitempty"
+ MaxScan int "$maxScan,omitempty"
+ MaxTimeMS int "$maxTimeMS,omitempty"
+ Comment string "$comment,omitempty"
+}
+
+func (op *queryOp) finalQuery(socket *mongoSocket) interface{} {
+ if op.flags&flagSlaveOk != 0 && socket.ServerInfo().Mongos {
+ var modeName string
+ switch op.mode {
+ case Strong:
+ modeName = "primary"
+ case Monotonic, Eventual:
+ modeName = "secondaryPreferred"
+ case PrimaryPreferred:
+ modeName = "primaryPreferred"
+ case Secondary:
+ modeName = "secondary"
+ case SecondaryPreferred:
+ modeName = "secondaryPreferred"
+ case Nearest:
+ modeName = "nearest"
+ default:
+ panic(fmt.Sprintf("unsupported read mode: %d", op.mode))
+ }
+ op.hasOptions = true
+ op.options.ReadPreference = make(bson.D, 0, 2)
+ op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"mode", modeName})
+ if len(op.serverTags) > 0 {
+ op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"tags", op.serverTags})
+ }
+ }
+ if op.hasOptions {
+ if op.query == nil {
+ var empty bson.D
+ op.options.Query = empty
+ } else {
+ op.options.Query = op.query
+ }
+ debugf("final query is %#v\n", &op.options)
+ return &op.options
+ }
+ return op.query
+}
+
+type getMoreOp struct {
+ collection string
+ limit int32
+ cursorId int64
+ replyFunc replyFunc
+}
+
+type replyOp struct {
+ flags uint32
+ cursorId int64
+ firstDoc int32
+ replyDocs int32
+}
+
+type insertOp struct {
+ collection string // "database.collection"
+ documents []interface{} // One or more documents to insert
+ flags uint32
+}
+
+type updateOp struct {
+ Collection string `bson:"-"` // "database.collection"
+ Selector interface{} `bson:"q"`
+ Update interface{} `bson:"u"`
+ Flags uint32 `bson:"-"`
+ Multi bool `bson:"multi,omitempty"`
+ Upsert bool `bson:"upsert,omitempty"`
+}
+
+type deleteOp struct {
+ Collection string `bson:"-"` // "database.collection"
+ Selector interface{} `bson:"q"`
+ Flags uint32 `bson:"-"`
+ Limit int `bson:"limit"`
+}
+
+type killCursorsOp struct {
+ cursorIds []int64
+}
+
+type requestInfo struct {
+ bufferPos int
+ replyFunc replyFunc
+}
+
+func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket {
+ socket := &mongoSocket{
+ conn: conn,
+ addr: server.Addr,
+ server: server,
+ replyFuncs: make(map[uint32]replyFunc),
+ }
+ socket.gotNonce.L = &socket.Mutex
+ if err := socket.InitialAcquire(server.Info(), timeout); err != nil {
+ panic("newSocket: InitialAcquire returned error: " + err.Error())
+ }
+ stats.socketsAlive(+1)
+ debugf("Socket %p to %s: initialized", socket, socket.addr)
+ socket.resetNonce()
+ go socket.readLoop()
+ return socket
+}
+
+// Server returns the server that the socket is associated with.
+// It returns nil while the socket is cached in its respective server.
+func (socket *mongoSocket) Server() *mongoServer {
+ socket.Lock()
+ server := socket.server
+ socket.Unlock()
+ return server
+}
+
+// ServerInfo returns details for the server at the time the socket
+// was initially acquired.
+func (socket *mongoSocket) ServerInfo() *mongoServerInfo {
+ socket.Lock()
+ serverInfo := socket.serverInfo
+ socket.Unlock()
+ return serverInfo
+}
+
+// InitialAcquire obtains the first reference to the socket, either
+// right after the connection is made or once a recycled socket is
+// being put back in use.
+func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error {
+ socket.Lock()
+ if socket.references > 0 {
+ panic("Socket acquired out of cache with references")
+ }
+ if socket.dead != nil {
+ dead := socket.dead
+ socket.Unlock()
+ return dead
+ }
+ socket.references++
+ socket.serverInfo = serverInfo
+ socket.timeout = timeout
+ stats.socketsInUse(+1)
+ stats.socketRefs(+1)
+ socket.Unlock()
+ return nil
+}
+
+// Acquire obtains an additional reference to the socket.
+// The socket will only be recycled when it's released as many
+// times as it's been acquired.
+func (socket *mongoSocket) Acquire() (info *mongoServerInfo) {
+ socket.Lock()
+ if socket.references == 0 {
+ panic("Socket got non-initial acquire with references == 0")
+ }
+ // We'll track references to dead sockets as well.
+ // Caller is still supposed to release the socket.
+ socket.references++
+ stats.socketRefs(+1)
+ serverInfo := socket.serverInfo
+ socket.Unlock()
+ return serverInfo
+}
+
+// Release decrements a socket reference. The socket will be
+// recycled once its released as many times as it's been acquired.
+func (socket *mongoSocket) Release() {
+ socket.Lock()
+ if socket.references == 0 {
+ panic("socket.Release() with references == 0")
+ }
+ socket.references--
+ stats.socketRefs(-1)
+ if socket.references == 0 {
+ stats.socketsInUse(-1)
+ server := socket.server
+ socket.Unlock()
+ socket.LogoutAll()
+ // If the socket is dead server is nil.
+ if server != nil {
+ server.RecycleSocket(socket)
+ }
+ } else {
+ socket.Unlock()
+ }
+}
+
+// SetTimeout changes the timeout used on socket operations.
+func (socket *mongoSocket) SetTimeout(d time.Duration) {
+ socket.Lock()
+ socket.timeout = d
+ socket.Unlock()
+}
+
+type deadlineType int
+
+const (
+ readDeadline deadlineType = 1
+ writeDeadline deadlineType = 2
+)
+
+func (socket *mongoSocket) updateDeadline(which deadlineType) {
+ var when time.Time
+ if socket.timeout > 0 {
+ when = time.Now().Add(socket.timeout)
+ }
+ whichstr := ""
+ switch which {
+ case readDeadline | writeDeadline:
+ whichstr = "read/write"
+ socket.conn.SetDeadline(when)
+ case readDeadline:
+ whichstr = "read"
+ socket.conn.SetReadDeadline(when)
+ case writeDeadline:
+ whichstr = "write"
+ socket.conn.SetWriteDeadline(when)
+ default:
+ panic("invalid parameter to updateDeadline")
+ }
+ debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when)
+}
+
+// Close terminates the socket use.
+func (socket *mongoSocket) Close() {
+ socket.kill(errors.New("Closed explicitly"), false)
+}
+
+func (socket *mongoSocket) kill(err error, abend bool) {
+ socket.Lock()
+ if socket.dead != nil {
+ debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error())
+ socket.Unlock()
+ return
+ }
+ logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend)
+ socket.dead = err
+ socket.conn.Close()
+ stats.socketsAlive(-1)
+ replyFuncs := socket.replyFuncs
+ socket.replyFuncs = make(map[uint32]replyFunc)
+ server := socket.server
+ socket.server = nil
+ socket.gotNonce.Broadcast()
+ socket.Unlock()
+ for _, replyFunc := range replyFuncs {
+ logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error())
+ replyFunc(err, nil, -1, nil)
+ }
+ if abend {
+ server.AbendSocket(socket)
+ }
+}
+
+func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) {
+ var wait, change sync.Mutex
+ var replyDone bool
+ var replyData []byte
+ var replyErr error
+ wait.Lock()
+ op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) {
+ change.Lock()
+ if !replyDone {
+ replyDone = true
+ replyErr = err
+ if err == nil {
+ replyData = docData
+ }
+ }
+ change.Unlock()
+ wait.Unlock()
+ }
+ err = socket.Query(op)
+ if err != nil {
+ return nil, err
+ }
+ wait.Lock()
+ change.Lock()
+ data = replyData
+ err = replyErr
+ change.Unlock()
+ return data, err
+}
+
+func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
+
+ if lops := socket.flushLogout(); len(lops) > 0 {
+ ops = append(lops, ops...)
+ }
+
+ buf := make([]byte, 0, 256)
+
+ // Serialize operations synchronously to avoid interrupting
+ // other goroutines while we can't really be sending data.
+ // Also, record id positions so that we can compute request
+ // ids at once later with the lock already held.
+ requests := make([]requestInfo, len(ops))
+ requestCount := 0
+
+ for _, op := range ops {
+ debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op)
+ if qop, ok := op.(*queryOp); ok {
+ if cmd, ok := qop.query.(*findCmd); ok {
+ debugf("Socket %p to %s: find command: %#v", socket, socket.addr, cmd)
+ }
+ }
+ start := len(buf)
+ var replyFunc replyFunc
+ switch op := op.(type) {
+
+ case *updateOp:
+ buf = addHeader(buf, 2001)
+ buf = addInt32(buf, 0) // Reserved
+ buf = addCString(buf, op.Collection)
+ buf = addInt32(buf, int32(op.Flags))
+ debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
+ buf, err = addBSON(buf, op.Selector)
+ if err != nil {
+ return err
+ }
+ debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.Update)
+ buf, err = addBSON(buf, op.Update)
+ if err != nil {
+ return err
+ }
+
+ case *insertOp:
+ buf = addHeader(buf, 2002)
+ buf = addInt32(buf, int32(op.flags))
+ buf = addCString(buf, op.collection)
+ for _, doc := range op.documents {
+ debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc)
+ buf, err = addBSON(buf, doc)
+ if err != nil {
+ return err
+ }
+ }
+
+ case *queryOp:
+ buf = addHeader(buf, 2004)
+ buf = addInt32(buf, int32(op.flags))
+ buf = addCString(buf, op.collection)
+ buf = addInt32(buf, op.skip)
+ buf = addInt32(buf, op.limit)
+ buf, err = addBSON(buf, op.finalQuery(socket))
+ if err != nil {
+ return err
+ }
+ if op.selector != nil {
+ buf, err = addBSON(buf, op.selector)
+ if err != nil {
+ return err
+ }
+ }
+ replyFunc = op.replyFunc
+
+ case *getMoreOp:
+ buf = addHeader(buf, 2005)
+ buf = addInt32(buf, 0) // Reserved
+ buf = addCString(buf, op.collection)
+ buf = addInt32(buf, op.limit)
+ buf = addInt64(buf, op.cursorId)
+ replyFunc = op.replyFunc
+
+ case *deleteOp:
+ buf = addHeader(buf, 2006)
+ buf = addInt32(buf, 0) // Reserved
+ buf = addCString(buf, op.Collection)
+ buf = addInt32(buf, int32(op.Flags))
+ debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector)
+ buf, err = addBSON(buf, op.Selector)
+ if err != nil {
+ return err
+ }
+
+ case *killCursorsOp:
+ buf = addHeader(buf, 2007)
+ buf = addInt32(buf, 0) // Reserved
+ buf = addInt32(buf, int32(len(op.cursorIds)))
+ for _, cursorId := range op.cursorIds {
+ buf = addInt64(buf, cursorId)
+ }
+
+ default:
+ panic("internal error: unknown operation type")
+ }
+
+ setInt32(buf, start, int32(len(buf)-start))
+
+ if replyFunc != nil {
+ request := &requests[requestCount]
+ request.replyFunc = replyFunc
+ request.bufferPos = start
+ requestCount++
+ }
+ }
+
+ // Buffer is ready for the pipe. Lock, allocate ids, and enqueue.
+
+ socket.Lock()
+ if socket.dead != nil {
+ dead := socket.dead
+ socket.Unlock()
+ debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error())
+ // XXX This seems necessary in case the session is closed concurrently
+ // with a query being performed, but it's not yet tested:
+ for i := 0; i != requestCount; i++ {
+ request := &requests[i]
+ if request.replyFunc != nil {
+ request.replyFunc(dead, nil, -1, nil)
+ }
+ }
+ return dead
+ }
+
+ wasWaiting := len(socket.replyFuncs) > 0
+
+ // Reserve id 0 for requests which should have no responses.
+ requestId := socket.nextRequestId + 1
+ if requestId == 0 {
+ requestId++
+ }
+ socket.nextRequestId = requestId + uint32(requestCount)
+ for i := 0; i != requestCount; i++ {
+ request := &requests[i]
+ setInt32(buf, request.bufferPos+4, int32(requestId))
+ socket.replyFuncs[requestId] = request.replyFunc
+ requestId++
+ }
+
+ debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf))
+ stats.sentOps(len(ops))
+
+ socket.updateDeadline(writeDeadline)
+ _, err = socket.conn.Write(buf)
+ if !wasWaiting && requestCount > 0 {
+ socket.updateDeadline(readDeadline)
+ }
+ socket.Unlock()
+ return err
+}
+
+func fill(r net.Conn, b []byte) error {
+ l := len(b)
+ n, err := r.Read(b)
+ for n != l && err == nil {
+ var ni int
+ ni, err = r.Read(b[n:])
+ n += ni
+ }
+ return err
+}
+
+// Estimated minimum cost per socket: 1 goroutine + memory for the largest
+// document ever seen.
+func (socket *mongoSocket) readLoop() {
+ p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields
+ s := make([]byte, 4)
+ conn := socket.conn // No locking, conn never changes.
+ for {
+ err := fill(conn, p)
+ if err != nil {
+ socket.kill(err, true)
+ return
+ }
+
+ totalLen := getInt32(p, 0)
+ responseTo := getInt32(p, 8)
+ opCode := getInt32(p, 12)
+
+ // Don't use socket.server.Addr here. socket is not
+ // locked and socket.server may go away.
+ debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen)
+
+ _ = totalLen
+
+ if opCode != 1 {
+ socket.kill(errors.New("opcode != 1, corrupted data?"), true)
+ return
+ }
+
+ reply := replyOp{
+ flags: uint32(getInt32(p, 16)),
+ cursorId: getInt64(p, 20),
+ firstDoc: getInt32(p, 28),
+ replyDocs: getInt32(p, 32),
+ }
+
+ stats.receivedOps(+1)
+ stats.receivedDocs(int(reply.replyDocs))
+
+ socket.Lock()
+ replyFunc, ok := socket.replyFuncs[uint32(responseTo)]
+ if ok {
+ delete(socket.replyFuncs, uint32(responseTo))
+ }
+ socket.Unlock()
+
+ if replyFunc != nil && reply.replyDocs == 0 {
+ replyFunc(nil, &reply, -1, nil)
+ } else {
+ for i := 0; i != int(reply.replyDocs); i++ {
+ err := fill(conn, s)
+ if err != nil {
+ if replyFunc != nil {
+ replyFunc(err, nil, -1, nil)
+ }
+ socket.kill(err, true)
+ return
+ }
+
+ b := make([]byte, int(getInt32(s, 0)))
+
+ // copy(b, s) in an efficient way.
+ b[0] = s[0]
+ b[1] = s[1]
+ b[2] = s[2]
+ b[3] = s[3]
+
+ err = fill(conn, b[4:])
+ if err != nil {
+ if replyFunc != nil {
+ replyFunc(err, nil, -1, nil)
+ }
+ socket.kill(err, true)
+ return
+ }
+
+ if globalDebug && globalLogger != nil {
+ m := bson.M{}
+ if err := bson.Unmarshal(b, m); err == nil {
+ debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m)
+ }
+ }
+
+ if replyFunc != nil {
+ replyFunc(nil, &reply, i, b)
+ }
+
+ // XXX Do bound checking against totalLen.
+ }
+ }
+
+ socket.Lock()
+ if len(socket.replyFuncs) == 0 {
+ // Nothing else to read for now. Disable deadline.
+ socket.conn.SetReadDeadline(time.Time{})
+ } else {
+ socket.updateDeadline(readDeadline)
+ }
+ socket.Unlock()
+
+ // XXX Do bound checking against totalLen.
+ }
+}
+
+var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+
+func addHeader(b []byte, opcode int) []byte {
+ i := len(b)
+ b = append(b, emptyHeader...)
+ // Enough for current opcodes.
+ b[i+12] = byte(opcode)
+ b[i+13] = byte(opcode >> 8)
+ return b
+}
+
+func addInt32(b []byte, i int32) []byte {
+ return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
+}
+
+func addInt64(b []byte, i int64) []byte {
+ return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24),
+ byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
+}
+
+func addCString(b []byte, s string) []byte {
+ b = append(b, []byte(s)...)
+ b = append(b, 0)
+ return b
+}
+
+func addBSON(b []byte, doc interface{}) ([]byte, error) {
+ if doc == nil {
+ return append(b, 5, 0, 0, 0, 0), nil
+ }
+ data, err := bson.Marshal(doc)
+ if err != nil {
+ return b, err
+ }
+ return append(b, data...), nil
+}
+
+func setInt32(b []byte, pos int, i int32) {
+ b[pos] = byte(i)
+ b[pos+1] = byte(i >> 8)
+ b[pos+2] = byte(i >> 16)
+ b[pos+3] = byte(i >> 24)
+}
+
+func getInt32(b []byte, pos int) int32 {
+ return (int32(b[pos+0])) |
+ (int32(b[pos+1]) << 8) |
+ (int32(b[pos+2]) << 16) |
+ (int32(b[pos+3]) << 24)
+}
+
+func getInt64(b []byte, pos int) int64 {
+ return (int64(b[pos+0])) |
+ (int64(b[pos+1]) << 8) |
+ (int64(b[pos+2]) << 16) |
+ (int64(b[pos+3]) << 24) |
+ (int64(b[pos+4]) << 32) |
+ (int64(b[pos+5]) << 40) |
+ (int64(b[pos+6]) << 48) |
+ (int64(b[pos+7]) << 56)
+}