aboutsummaryrefslogtreecommitdiff
path: root/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go
blob: 3302cfe05d6838f0e8dac12e6ae9a0975a3ef929 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package sasl

// #include "sasl_windows.h"
import "C"

import (
	"fmt"
	"strings"
	"sync"
	"unsafe"
)

type saslStepper interface {
	Step(serverData []byte) (clientData []byte, done bool, err error)
	Close()
}

type saslSession struct {
	// Credentials
	mech          string
	service       string
	host          string
	userPlusRealm string
	target        string
	domain        string

	// Internal state
	authComplete bool
	errored      bool
	step         int

	// C internal state
	credHandle C.CredHandle
	context    C.CtxtHandle
	hasContext C.int

	// Keep track of pointers we need to explicitly free
	stringsToFree []*C.char
}

var initError error
var initOnce sync.Once

func initSSPI() {
	rc := C.load_secur32_dll()
	if rc != 0 {
		initError = fmt.Errorf("Error loading libraries: %v", rc)
	}
}

func New(username, password, mechanism, service, host string) (saslStepper, error) {
	initOnce.Do(initSSPI)
	ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username}
	if service == "" {
		service = "mongodb"
	}
	if i := strings.Index(host, ":"); i >= 0 {
		host = host[:i]
	}
	ss.service = service
	ss.host = host

	usernameComponents := strings.Split(username, "@")
	if len(usernameComponents) < 2 {
		return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username)
	}
	user := usernameComponents[0]
	ss.domain = usernameComponents[1]
	ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host)

	var status C.SECURITY_STATUS
	// Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle
	if len(password) > 0 {
		status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain))
	} else {
		status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain))
	}
	if status != C.SEC_E_OK {
		ss.errored = true
		return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status)
	}
	return ss, nil
}

func (ss *saslSession) cstr(s string) *C.char {
	cstr := C.CString(s)
	ss.stringsToFree = append(ss.stringsToFree, cstr)
	return cstr
}

func (ss *saslSession) Close() {
	for _, cstr := range ss.stringsToFree {
		C.free(unsafe.Pointer(cstr))
	}
}

func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
	ss.step++
	if ss.step > 10 {
		return nil, false, fmt.Errorf("too many SSPI steps without authentication")
	}
	var buffer C.PVOID
	var bufferLength C.ULONG
	if len(serverData) > 0 {
		buffer = (C.PVOID)(unsafe.Pointer(&serverData[0]))
		bufferLength = C.ULONG(len(serverData))
	}
	var status C.int
	if ss.authComplete {
		// Step 3: last bit of magic to use the correct server credentials
		status = C.sspi_send_client_authz_id(&ss.context, &buffer, &bufferLength, ss.cstr(ss.userPlusRealm))
	} else {
		// Step 1 + Step 2: set up security context with the server and TGT
		status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, &buffer, &bufferLength, ss.cstr(ss.target))
	}
	if buffer != C.PVOID(nil) {
		defer C.free(unsafe.Pointer(buffer))
	}
	if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED {
		ss.errored = true
		return nil, false, ss.handleSSPIErrorCode(status)
	}

	clientData = C.GoBytes(unsafe.Pointer(buffer), C.int(bufferLength))
	if status == C.SEC_E_OK {
		ss.authComplete = true
		return clientData, true, nil
	} else {
		ss.hasContext = 1
		return clientData, false, nil
	}
}

func (ss *saslSession) handleSSPIErrorCode(code C.int) error {
	switch {
	case code == C.SEC_E_TARGET_UNKNOWN:
		return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain)
	}
	return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code)
}