package sasl // #include "sasl_windows.h" import "C" import ( "fmt" "strings" "sync" "unsafe" ) type saslStepper interface { Step(serverData []byte) (clientData []byte, done bool, err error) Close() } type saslSession struct { // Credentials mech string service string host string userPlusRealm string target string domain string // Internal state authComplete bool errored bool step int // C internal state credHandle C.CredHandle context C.CtxtHandle hasContext C.int // Keep track of pointers we need to explicitly free stringsToFree []*C.char } var initError error var initOnce sync.Once func initSSPI() { rc := C.load_secur32_dll() if rc != 0 { initError = fmt.Errorf("Error loading libraries: %v", rc) } } func New(username, password, mechanism, service, host string) (saslStepper, error) { initOnce.Do(initSSPI) ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username} if service == "" { service = "mongodb" } if i := strings.Index(host, ":"); i >= 0 { host = host[:i] } ss.service = service ss.host = host usernameComponents := strings.Split(username, "@") if len(usernameComponents) < 2 { return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username) } user := usernameComponents[0] ss.domain = usernameComponents[1] ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host) var status C.SECURITY_STATUS // Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle if len(password) > 0 { status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain)) } else { status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain)) } if status != C.SEC_E_OK { ss.errored = true return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status) } return ss, nil } func (ss *saslSession) cstr(s string) *C.char { cstr := C.CString(s) ss.stringsToFree = append(ss.stringsToFree, cstr) return cstr } func (ss *saslSession) Close() { for _, cstr := range ss.stringsToFree { C.free(unsafe.Pointer(cstr)) } } func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { ss.step++ if ss.step > 10 { return nil, false, fmt.Errorf("too many SSPI steps without authentication") } var buffer C.PVOID var bufferLength C.ULONG var outBuffer C.PVOID var outBufferLength C.ULONG if len(serverData) > 0 { buffer = (C.PVOID)(unsafe.Pointer(&serverData[0])) bufferLength = C.ULONG(len(serverData)) } var status C.int if ss.authComplete { // Step 3: last bit of magic to use the correct server credentials status = C.sspi_send_client_authz_id(&ss.context, &outBuffer, &outBufferLength, ss.cstr(ss.userPlusRealm)) } else { // Step 1 + Step 2: set up security context with the server and TGT status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, buffer, bufferLength, &outBuffer, &outBufferLength, ss.cstr(ss.target)) } if outBuffer != C.PVOID(nil) { defer C.free(unsafe.Pointer(outBuffer)) } if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED { ss.errored = true return nil, false, ss.handleSSPIErrorCode(status) } clientData = C.GoBytes(unsafe.Pointer(outBuffer), C.int(outBufferLength)) if status == C.SEC_E_OK { ss.authComplete = true return clientData, true, nil } else { ss.hasContext = 1 return clientData, false, nil } } func (ss *saslSession) handleSSPIErrorCode(code C.int) error { switch { case code == C.SEC_E_TARGET_UNKNOWN: return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain) } return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code) }