aboutsummaryrefslogtreecommitdiff
path: root/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go
blob: 8375dddf82a160adefc32c6dae4016a54078eb5b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Package sasl is an implementation detail of the mgo package.
//
// This package is not meant to be used by itself.
//

// +build !windows

package sasl

// #cgo LDFLAGS: -lsasl2
//
// struct sasl_conn {};
//
// #include <stdlib.h>
// #include <sasl/sasl.h>
//
// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password);
//
import "C"

import (
	"fmt"
	"strings"
	"sync"
	"unsafe"
)

type saslStepper interface {
	Step(serverData []byte) (clientData []byte, done bool, err error)
	Close()
}

type saslSession struct {
	conn *C.sasl_conn_t
	step int
	mech string

	cstrings  []*C.char
	callbacks *C.sasl_callback_t
}

var initError error
var initOnce sync.Once

func initSASL() {
	rc := C.sasl_client_init(nil)
	if rc != C.SASL_OK {
		initError = saslError(rc, nil, "cannot initialize SASL library")
	}
}

func New(username, password, mechanism, service, host string) (saslStepper, error) {
	initOnce.Do(initSASL)
	if initError != nil {
		return nil, initError
	}

	ss := &saslSession{mech: mechanism}
	if service == "" {
		service = "mongodb"
	}
	if i := strings.Index(host, ":"); i >= 0 {
		host = host[:i]
	}
	ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password))
	rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn)
	if rc != C.SASL_OK {
		ss.Close()
		return nil, saslError(rc, nil, "cannot create new SASL client")
	}
	return ss, nil
}

func (ss *saslSession) cstr(s string) *C.char {
	cstr := C.CString(s)
	ss.cstrings = append(ss.cstrings, cstr)
	return cstr
}

func (ss *saslSession) Close() {
	for _, cstr := range ss.cstrings {
		C.free(unsafe.Pointer(cstr))
	}
	ss.cstrings = nil

	if ss.callbacks != nil {
		C.free(unsafe.Pointer(ss.callbacks))
	}

	// The documentation of SASL dispose makes it clear that this should only
	// be done when the connection is done, not when the authentication phase
	// is done, because an encryption layer may have been negotiated.
	// Even then, we'll do this for now, because it's simpler and prevents
	// keeping track of this state for every socket. If it breaks, we'll fix it.
	C.sasl_dispose(&ss.conn)
}

func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
	ss.step++
	if ss.step > 10 {
		return nil, false, fmt.Errorf("too many SASL steps without authentication")
	}
	var cclientData *C.char
	var cclientDataLen C.uint
	var rc C.int
	if ss.step == 1 {
		var mechanism *C.char // ignored - must match cred
		rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism)
	} else {
		var cserverData *C.char
		var cserverDataLen C.uint
		if len(serverData) > 0 {
			cserverData = (*C.char)(unsafe.Pointer(&serverData[0]))
			cserverDataLen = C.uint(len(serverData))
		}
		rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen)
	}
	if cclientData != nil && cclientDataLen > 0 {
		clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen))
	}
	if rc == C.SASL_OK {
		return clientData, true, nil
	}
	if rc == C.SASL_CONTINUE {
		return clientData, false, nil
	}
	return nil, false, saslError(rc, ss.conn, "cannot establish SASL session")
}

func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error {
	var detail string
	if conn == nil {
		detail = C.GoString(C.sasl_errstring(rc, nil, nil))
	} else {
		detail = C.GoString(C.sasl_errdetail(conn))
	}
	return fmt.Errorf(msg + ": " + detail)
}