diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/auth/microsoft/microsoft.go | 203 | ||||
-rw-r--r-- | server/auth/microsoft/microsoft_test.go | 72 | ||||
-rw-r--r-- | server/server.go | 7 | ||||
-rw-r--r-- | server/web.go | 20 |
4 files changed, 300 insertions, 2 deletions
diff --git a/server/auth/microsoft/microsoft.go b/server/auth/microsoft/microsoft.go new file mode 100644 index 0000000..49d9b82 --- /dev/null +++ b/server/auth/microsoft/microsoft.go @@ -0,0 +1,203 @@ +package microsoft + +import ( + "encoding/json" + "errors" + "net/http" + "path" + "strings" + + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/metrics" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/microsoft" +) + +const ( + name = "microsoft" +) + +// Config is an implementation of `auth.Provider` for authenticating using a +// Office 365 account. +type Config struct { + config *oauth2.Config + tenant string + groups map[string]bool + whitelist map[string]bool +} + +var _ auth.Provider = (*Config)(nil) + +// New creates a new Microsoft provider from a configuration. +func New(c *config.Auth) (*Config, error) { + whitelist := make(map[string]bool) + for _, u := range c.UsersWhitelist { + whitelist[u] = true + } + if c.ProviderOpts["tenant"] == "" && len(whitelist) == 0 { + return nil, errors.New("either Office 365 tenant or users whitelist must be specified") + } + groupMap := make(map[string]bool) + if groups, ok := c.ProviderOpts["groups"]; ok { + for _, group := range strings.Split(groups, ",") { + groupMap[strings.Trim(group, " ")] = true + } + } + + return &Config{ + config: &oauth2.Config{ + ClientID: c.OauthClientID, + ClientSecret: c.OauthClientSecret, + RedirectURL: c.OauthCallbackURL, + Endpoint: microsoft.AzureADEndpoint(c.ProviderOpts["tenant"]), + Scopes: []string{"user.Read.All", "Directory.Read.All"}, + }, + tenant: c.ProviderOpts["tenant"], + whitelist: whitelist, + groups: groupMap, + }, nil +} + +// A new oauth2 http client. +func (c *Config) newClient(token *oauth2.Token) *http.Client { + return c.config.Client(oauth2.NoContext, token) +} + +// Gets a response for an graph api call. +func (c *Config) getDocument(token *oauth2.Token, pathElements ...string) map[string]interface{} { + client := c.newClient(token) + url := "https://" + path.Join("graph.microsoft.com/v1.0", path.Join(pathElements...)) + resp, err := client.Get(url) + if err != nil { + return nil + } + defer resp.Body.Close() + var document map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&document); err != nil { + return nil + } + return document +} + +// Get info from the "/me" endpoint of the Microsoft Graph API (MSG-API). +// https://developer.microsoft.com/en-us/graph/docs/concepts/v1-overview +func (c *Config) getMe(token *oauth2.Token, item string) string { + document := c.getDocument(token, "/me") + if value, ok := document[item].(string); ok { + return value + } + return "" +} + +// Check against verified domains from "/organization" endpoint of MSG-API. +func (c *Config) verifyTenant(token *oauth2.Token) bool { + document := c.getDocument(token, "/organization") + // The domains for an organisation are in an array of structs under + // verifiedDomains, which is in a struct which is in turn an array + // of such structs under value in the document. Which in json looks + // like this: + // { "@odata.context": "https://graph.microsoft.com/v1.0/$metadata#organization", + // "value": [ { + // ... + // "verifiedDomains": [ { + // ... + // "name": "M365x214355.onmicrosoft.com", + // } ] + // } ] + //} + var value []interface{} + var ok bool + if value, ok = document["value"].([]interface{}); !ok { + return false + } + for _, valueEntry := range value { + if value, ok = valueEntry.(map[string]interface{})["verifiedDomains"].([]interface{}); !ok { + continue + } + for _, val := range value { + domain := val.(map[string]interface{})["name"].(string) + if domain == c.tenant { + return true + } + } + } + return false +} + +// Check against groups from /users/{id}/memberOf endpoint of MSG-API. +func (c *Config) verifyGroups(token *oauth2.Token) bool { + document := c.getDocument(token, "/users/me/memberOf") + var value []interface{} + var ok bool + if value, ok = document["value"].([]interface{}); !ok { + return false + } + for _, valueEntry := range value { + if group, ok := valueEntry.(map[string]interface{})["displayName"].(string); ok { + if c.groups[group] { + return true + } + } + } + return false +} + +// Name returns the name of the provider. +func (c *Config) Name() string { + return name +} + +// Valid validates the oauth token. +func (c *Config) Valid(token *oauth2.Token) bool { + if len(c.whitelist) > 0 && !c.whitelist[c.Email(token)] { + return false + } + if !token.Valid() { + return false + } + metrics.M.AuthValid.WithLabelValues("microsoft").Inc() + if c.tenant != "" { + if c.verifyTenant(token) { + if len(c.groups) > 0 { + return c.verifyGroups(token) + } + return true + } + } + return false +} + +// Revoke disables the access token. +func (c *Config) Revoke(token *oauth2.Token) error { + return nil +} + +// StartSession retrieves an authentication endpoint from Microsoft. +func (c *Config) StartSession(state string) *auth.Session { + return &auth.Session{ + AuthURL: c.config.AuthCodeURL(state, + oauth2.SetAuthURLParam("hd", c.tenant), + oauth2.SetAuthURLParam("prompt", "login")), + } +} + +// Exchange authorizes the session and returns an access token. +func (c *Config) Exchange(code string) (*oauth2.Token, error) { + t, err := c.config.Exchange(oauth2.NoContext, code) + if err == nil { + metrics.M.AuthExchange.WithLabelValues("microsoft").Inc() + } + return t, err +} + +// Email retrieves the email address of the user. +func (c *Config) Email(token *oauth2.Token) string { + return c.getMe(token, "mail") +} + +// Username retrieves the username portion of the user's email address. +func (c *Config) Username(token *oauth2.Token) string { + return strings.Split(c.Email(token), "@")[0] +} diff --git a/server/auth/microsoft/microsoft_test.go b/server/auth/microsoft/microsoft_test.go new file mode 100644 index 0000000..c2c2c17 --- /dev/null +++ b/server/auth/microsoft/microsoft_test.go @@ -0,0 +1,72 @@ +package microsoft + +import ( + "fmt" + "testing" + + "github.com/nsheridan/cashier/server/config" + "github.com/stretchr/testify/assert" +) + +var ( + oauthClientID = "id" + oauthClientSecret = "secret" + oauthCallbackURL = "url" + tenant = "example.com" + users = []string{"user"} +) + +func TestNew(t *testing.T) { + a := assert.New(t) + p, err := newMicrosoft() + a.NoError(err) + a.Equal(p.config.ClientID, oauthClientID) + a.Equal(p.config.ClientSecret, oauthClientSecret) + a.Equal(p.config.RedirectURL, oauthCallbackURL) + a.Equal(p.tenant, tenant) + a.Equal(p.whitelist, map[string]bool{"user": true}) +} + +func TestWhitelist(t *testing.T) { + c := &config.Auth{ + OauthClientID: oauthClientID, + OauthClientSecret: oauthClientSecret, + OauthCallbackURL: oauthCallbackURL, + ProviderOpts: map[string]string{"tenant": ""}, + UsersWhitelist: []string{}, + } + if _, err := New(c); err == nil { + t.Error("creating a provider without a tenant set should return an error") + } + // Set a user whitelist but no tenant + c.UsersWhitelist = users + if _, err := New(c); err != nil { + t.Error("creating a provider with users but no tenant should not return an error") + } + // Unset the user whitelist and set a tenant + c.UsersWhitelist = []string{} + c.ProviderOpts = map[string]string{"tenant": tenant} + if _, err := New(c); err != nil { + t.Error("creating a provider with a tenant set but without a user whitelist should not return an error") + } +} + +func TestStartSession(t *testing.T) { + a := assert.New(t) + + p, err := newMicrosoft() + a.NoError(err) + s := p.StartSession("test_state") + a.Contains(s.AuthURL, fmt.Sprintf("login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenant)) +} + +func newMicrosoft() (*Config, error) { + c := &config.Auth{ + OauthClientID: oauthClientID, + OauthClientSecret: oauthClientSecret, + OauthCallbackURL: oauthCallbackURL, + ProviderOpts: map[string]string{"tenant": tenant}, + UsersWhitelist: users, + } + return New(c) +} diff --git a/server/server.go b/server/server.go index 42476f3..1b8468e 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ import ( "github.com/nsheridan/cashier/server/auth/github" "github.com/nsheridan/cashier/server/auth/gitlab" "github.com/nsheridan/cashier/server/auth/google" + "github.com/nsheridan/cashier/server/auth/microsoft" "github.com/nsheridan/cashier/server/config" "github.com/nsheridan/cashier/server/metrics" "github.com/nsheridan/cashier/server/signer" @@ -90,12 +91,14 @@ func Run(conf *config.Config) { metrics.Register() switch conf.Auth.Provider { - case "google": - authprovider, err = google.New(conf.Auth) case "github": authprovider, err = github.New(conf.Auth) case "gitlab": authprovider, err = gitlab.New(conf.Auth) + case "google": + authprovider, err = google.New(conf.Auth) + case "microsoft": + authprovider, err = microsoft.New(conf.Auth) default: log.Fatalf("Unknown provider %s\n", conf.Auth.Provider) } diff --git a/server/web.go b/server/web.go index e238150..d55aa52 100644 --- a/server/web.go +++ b/server/web.go @@ -1,7 +1,9 @@ package server import ( + "bytes" "crypto/rand" + "encoding/base64" "encoding/hex" "encoding/json" "fmt" @@ -189,6 +191,23 @@ func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int return http.StatusFound, nil } +func encodeString(s string) string { + var buffer bytes.Buffer + chunkSize := 70 + runes := []rune(base64.StdEncoding.EncodeToString([]byte(s))) + + for i := 0; i < len(runes); i += chunkSize { + end := i + chunkSize + if end > len(runes) { + end = len(runes) + } + buffer.WriteString(string(runes[i:end])) + buffer.WriteString("\n") + } + buffer.WriteString(".\n") + return buffer.String() +} + // rootHandler starts the auth process. If the client is authenticated it renders the token to the user. func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { if !a.isLoggedIn(w, r) { @@ -198,6 +217,7 @@ func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er page := struct { Token string }{tok.AccessToken} + page.Token = encodeString(page.Token) tmpl := template.Must(template.New("token.html").Parse(templates.Token)) tmpl.Execute(w, page) |