From f456753248612222ad9bb6f3de74b7e28771470e Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Mon, 6 Jun 2016 00:31:15 +0100 Subject: Save oauth 'state' identifier in the client --- cmd/cashierd/main.go | 39 +++++++++++++++++++++++++++++---------- server/auth/github/github.go | 1 - server/auth/github/github_test.go | 1 - server/auth/google/google.go | 1 - server/auth/google/google_test.go | 1 - server/auth/provider.go | 1 - 6 files changed, 29 insertions(+), 15 deletions(-) diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index 1b1035c..12072d6 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -40,9 +40,9 @@ type appContext struct { sshKeySigner *signer.KeySigner } -// getAuthCookie retrieves a cookie from the request and validates it. -func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { - session, _ := a.cookiestore.Get(r, "tok") +// getAuthTokenCookie retrieves a cookie from the request. +func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token { + session, _ := a.cookiestore.Get(r, "session") t, ok := session.Values["token"] if !ok { return nil @@ -57,14 +57,31 @@ func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { return &tok } -// setAuthCookie marshals the auth token and stores it as a cookie. -func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { - session, _ := a.cookiestore.Get(r, "tok") +// setAuthTokenCookie marshals the auth token and stores it as a cookie. +func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { + session, _ := a.cookiestore.Get(r, "session") val, _ := json.Marshal(t) session.Values["token"] = val session.Save(r, w) } +// getAuthStateCookie retrieves the oauth csrf state value from the client request. +func (a *appContext) getAuthStateCookie(r *http.Request) string { + session, _ := a.cookiestore.Get(r, "session") + state, ok := session.Values["state"] + if !ok { + return "" + } + return state.(string) +} + +// setAuthStateCookie saves the oauth csrf state value. +func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) { + session, _ := a.cookiestore.Get(r, "session") + session.Values["state"] = state + session.Save(r, w) +} + // parseKey retrieves and unmarshals the signing request. func parseKey(r *http.Request) (*lib.SignRequest, error) { var s lib.SignRequest @@ -118,28 +135,30 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er // loginHandler starts the authentication process with the provider. func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - a.authsession = a.authprovider.StartSession(newState()) + state := newState() + a.setAuthStateCookie(w, r, state) + a.authsession = a.authprovider.StartSession(state) http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) return http.StatusFound, nil } // callbackHandler handles retrieving the access token from the auth provider and saves it for later use. func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if r.FormValue("state") != a.authsession.State { + if r.FormValue("state") != a.getAuthStateCookie(r) { return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) } code := r.FormValue("code") if err := a.authsession.Authorize(a.authprovider, code); err != nil { return http.StatusInternalServerError, err } - a.setAuthCookie(w, r, a.authsession.Token) + a.setAuthTokenCookie(w, r, a.authsession.Token) http.Redirect(w, r, "/", http.StatusFound) return http.StatusFound, nil } // rootHandler starts the auth process. If the client is authenticated it renders the token to the user. func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - tok := a.getAuthCookie(r) + tok := a.getAuthTokenCookie(r) if !tok.Valid() || !a.authprovider.Valid(tok) { http.Redirect(w, r, "/auth/login", http.StatusSeeOther) return http.StatusSeeOther, nil diff --git a/server/auth/github/github.go b/server/auth/github/github.go index d7a57af..7904e26 100644 --- a/server/auth/github/github.go +++ b/server/auth/github/github.go @@ -78,7 +78,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { func (c *Config) StartSession(state string) *auth.Session { return &auth.Session{ AuthURL: c.config.AuthCodeURL(state), - State: state, } } diff --git a/server/auth/github/github_test.go b/server/auth/github/github_test.go index f50d134..1d6b801 100644 --- a/server/auth/github/github_test.go +++ b/server/auth/github/github_test.go @@ -42,7 +42,6 @@ func TestStartSession(t *testing.T) { p, _ := newGithub() s := p.StartSession("test_state") - a.Equal(s.State, "test_state") a.Contains(s.AuthURL, "github.com/login/oauth/authorize") a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, fmt.Sprintf("client_id=%s", oauthClientID)) diff --git a/server/auth/google/google.go b/server/auth/google/google.go index 7c9b930..e2c6724 100644 --- a/server/auth/google/google.go +++ b/server/auth/google/google.go @@ -90,7 +90,6 @@ func (c *Config) Revoke(token *oauth2.Token) error { func (c *Config) StartSession(state string) *auth.Session { return &auth.Session{ AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)), - State: state, } } diff --git a/server/auth/google/google_test.go b/server/auth/google/google_test.go index 4d41986..9970c21 100644 --- a/server/auth/google/google_test.go +++ b/server/auth/google/google_test.go @@ -44,7 +44,6 @@ func TestStartSession(t *testing.T) { p, err := newGoogle() a.NoError(err) s := p.StartSession("test_state") - a.Equal(s.State, "test_state") a.Contains(s.AuthURL, "accounts.google.com/o/oauth2/auth") a.Contains(s.AuthURL, "state=test_state") a.Contains(s.AuthURL, fmt.Sprintf("hd=%s", domain)) diff --git a/server/auth/provider.go b/server/auth/provider.go index d7d5ed5..06dc1c9 100644 --- a/server/auth/provider.go +++ b/server/auth/provider.go @@ -16,7 +16,6 @@ type Provider interface { type Session struct { AuthURL string Token *oauth2.Token - State string } // Authorize obtains data from the provider and retains an access token that -- cgit v1.2.3