From 99225736d41e86c7f47eac4db3455b18178bba24 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Mon, 20 Aug 2018 16:41:17 +0100 Subject: Make all handlers methods of app Merge server setup and helpers from web.go into server.go Handlers moved to handlers.go --- cmd/cashierd/main.go | 2 +- server/handlers.go | 167 ++++++++++++++++++++++ server/handlers_test.go | 54 +++++--- server/server.go | 172 +++++++++++++++++++++-- server/web.go | 358 ------------------------------------------------ 5 files changed, 360 insertions(+), 393 deletions(-) create mode 100644 server/handlers.go delete mode 100644 server/web.go diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index 5b0b390..b4f1fe7 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -40,6 +40,6 @@ func main() { }) vaultfs.Register(conf.Vault) - // Start the servers + // Start the server server.Run(conf) } diff --git a/server/handlers.go b/server/handlers.go new file mode 100644 index 0000000..b85550d --- /dev/null +++ b/server/handlers.go @@ -0,0 +1,167 @@ +package server + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "html/template" + "io" + "log" + "net/http" + "strconv" + "strings" + + "github.com/gorilla/csrf" + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/server/templates" + "github.com/pkg/errors" + "golang.org/x/oauth2" +) + +func (a *app) sign(w http.ResponseWriter, r *http.Request) { + var t string + if ah := r.Header.Get("Authorization"); ah != "" { + if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { + t = ah[7:] + } + } + + token := &oauth2.Token{ + AccessToken: t, + } + if !a.authprovider.Valid(token) { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, http.StatusText(http.StatusUnauthorized)) + return + } + + // Sign the pubkey and issue the cert. + req := &lib.SignRequest{} + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, http.StatusText(http.StatusBadRequest)) + return + } + + if a.requireReason && req.Message == "" { + w.Header().Add("X-Need-Reason", "required") + w.WriteHeader(http.StatusForbidden) + fmt.Fprint(w, http.StatusText(http.StatusForbidden)) + return + } + + username := a.authprovider.Username(token) + a.authprovider.Revoke(token) // We don't need this anymore. + cert, err := a.keysigner.SignUserKey(req, username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Error signing key") + return + } + if err := a.certstore.SetCert(cert); err != nil { + log.Printf("Error recording cert: %v", err) + } + if err := json.NewEncoder(w).Encode(&lib.SignResponse{ + Status: "ok", + Response: string(lib.GetPublicKey(cert)), + }); err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Error signing key") + return + } +} + +func (a *app) auth(w http.ResponseWriter, r *http.Request) { + switch r.URL.EscapedPath() { + case "/auth/login": + buf := make([]byte, 32) + io.ReadFull(rand.Reader, buf) + state := hex.EncodeToString(buf) + a.setSessionVariable(w, r, "state", state) + http.Redirect(w, r, a.authprovider.StartSession(state), http.StatusFound) + case "/auth/callback": + state := a.getSessionVariable(r, "state") + if r.FormValue("state") != state { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(http.StatusText(http.StatusUnauthorized))) + break + } + originURL := a.getSessionVariable(r, "origin_url") + if originURL == "" { + originURL = "/" + } + code := r.FormValue("code") + token, err := a.authprovider.Exchange(code) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(http.StatusText(http.StatusInternalServerError))) + w.Write([]byte(err.Error())) + break + } + a.setAuthToken(w, r, token) + http.Redirect(w, r, originURL, http.StatusFound) + default: + w.WriteHeader(http.StatusInternalServerError) + } +} + +func (a *app) index(w http.ResponseWriter, r *http.Request) { + tok := a.getAuthToken(r) + page := struct { + Token string + }{tok.AccessToken} + page.Token = encodeString(page.Token) + tmpl := template.Must(template.New("token.html").Parse(templates.Token)) + tmpl.Execute(w, page) +} + +func (a *app) revoked(w http.ResponseWriter, r *http.Request) { + revoked, err := a.certstore.GetRevoked() + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, errors.Wrap(err, "error retrieving revoked certs").Error()) + return + } + rl, err := a.keysigner.GenerateRevocationList(revoked) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, errors.Wrap(err, "unable to generate KRL").Error()) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(rl) +} + +func (a *app) getAllCerts(w http.ResponseWriter, r *http.Request) { + tmpl := template.Must(template.New("certs.html").Parse(templates.Certs)) + tmpl.Execute(w, map[string]interface{}{ + csrf.TemplateTag: csrf.TemplateField(r), + }) +} + +func (a *app) getCertsJSON(w http.ResponseWriter, r *http.Request) { + includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all")) + certs, err := a.certstore.List(includeExpired) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, http.StatusText(http.StatusInternalServerError)) + return + } + if err := json.NewEncoder(w).Encode(certs); err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, http.StatusText(http.StatusInternalServerError)) + return + } +} + +func (a *app) revoke(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + if err := a.certstore.Revoke(r.Form["cert_id"]); err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Unable to revoke certs")) + } else { + http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) + } +} diff --git a/server/handlers_test.go b/server/handlers_test.go index 6dc2236..44024ac 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -15,6 +15,7 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/oauth2" + "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/nsheridan/cashier/lib" "github.com/nsheridan/cashier/server/auth/testprovider" @@ -25,28 +26,33 @@ import ( "github.com/stripe/krl" ) -var ctx *appContext +var a *app func init() { f, _ := ioutil.TempFile(os.TempDir(), "signing_key_") defer os.Remove(f.Name()) f.Write(testdata.Priv) f.Close() - keysigner, _ = signer.New(&config.SSH{ + keysigner, _ := signer.New(&config.SSH{ SigningKey: f.Name(), - MaxAge: "1h", + MaxAge: "4h", }) - authprovider = testprovider.New() - certstore, _ = store.New(map[string]string{"type": "mem"}) - ctx = &appContext{ - cookiestore: sessions.NewCookieStore([]byte("secret")), + certstore, _ := store.New(map[string]string{"type": "mem"}) + a = &app{ + cookiestore: sessions.NewCookieStore([]byte("secret")), + authprovider: testprovider.New(), + keysigner: keysigner, + certstore: certstore, + router: mux.NewRouter(), + config: &config.Server{CSRFSecret: "0123456789abcdef"}, } + a.routes() } func TestLoginHandler(t *testing.T) { req, _ := http.NewRequest("GET", "/auth/login", nil) resp := httptest.NewRecorder() - loginHandler(ctx, resp, req) + a.router.ServeHTTP(resp, req) if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" { t.Error("Unexpected response") } @@ -56,10 +62,11 @@ func TestCallbackHandler(t *testing.T) { req, _ := http.NewRequest("GET", "/auth/callback", nil) req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}} resp := httptest.NewRecorder() - ctx.setAuthStateCookie(resp, req, "state") - callbackHandler(ctx, resp, req) + a.setSessionVariable(resp, req, "state", "state") + req.Header.Add("Cookie", resp.HeaderMap["Set-Cookie"][0]) + a.router.ServeHTTP(resp, req) if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" { - t.Error("Unexpected response") + t.Errorf("Response: %d\nHeaders: %v", resp.Code, resp.Header()) } } @@ -70,8 +77,9 @@ func TestRootHandler(t *testing.T) { AccessToken: "XXX_TEST_TOKEN_STRING_XXX", Expiry: time.Now().Add(1 * time.Hour), } - ctx.setAuthTokenCookie(resp, req, tok) - rootHandler(ctx, resp, req) + a.setAuthToken(resp, req, tok) + req.Header.Add("Cookie", resp.HeaderMap["Set-Cookie"][0]) + a.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") { t.Error("Unable to find token in response") } @@ -80,7 +88,7 @@ func TestRootHandler(t *testing.T) { func TestRootHandlerNoSession(t *testing.T) { req, _ := http.NewRequest("GET", "/", nil) resp := httptest.NewRecorder() - rootHandler(ctx, resp, req) + a.router.ServeHTTP(resp, req) if resp.Code != http.StatusSeeOther { t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther)) } @@ -89,12 +97,12 @@ func TestRootHandlerNoSession(t *testing.T) { func TestSignRevoke(t *testing.T) { s, _ := json.Marshal(&lib.SignRequest{ Key: string(testdata.Pub), - ValidUntil: time.Now().UTC().Add(1 * time.Hour), + ValidUntil: time.Now().UTC().Add(4 * time.Hour), }) req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s)) resp := httptest.NewRecorder() req.Header.Set("Authorization", "Bearer abcdef") - signHandler(ctx, resp, req) + a.router.ServeHTTP(resp, req) if resp.Code != http.StatusOK { t.Error("Unexpected response") } @@ -114,18 +122,22 @@ func TestSignRevoke(t *testing.T) { t.Error("Did not receive a certificate") } // Revoke the cert and verify - req, _ = http.NewRequest("POST", "/revoke", nil) + req, _ = http.NewRequest("POST", "/admin/revoke", nil) req.Form = url.Values{"cert_id": []string{cert.KeyId}} tok := &oauth2.Token{ AccessToken: "authenticated", Expiry: time.Now().Add(1 * time.Hour), } - ctx.setAuthTokenCookie(resp, req, tok) - revokeCertHandler(ctx, resp, req) + a.certstore.Revoke([]string{cert.KeyId}) + a.setAuthToken(resp, req, tok) + a.router.ServeHTTP(resp, req) req, _ = http.NewRequest("GET", "/revoked", nil) - listRevokedCertsHandler(ctx, resp, req) + a.router.ServeHTTP(resp, req) revoked, _ := ioutil.ReadAll(resp.Body) - rl, _ := krl.ParseKRL(revoked) + rl, err := krl.ParseKRL(revoked) + if err != nil { + t.Fail() + } if !rl.IsRevoked(cert) { t.Errorf("cert %s was not revoked", cert.KeyId) } diff --git a/server/server.go b/server/server.go index 1b8468e..2a6af15 100644 --- a/server/server.go +++ b/server/server.go @@ -1,17 +1,33 @@ package server import ( + "bytes" "crypto/tls" + "encoding/base64" + "encoding/json" "fmt" "log" "net" + "net/http" + "os" + "time" + "github.com/gorilla/csrf" + + "github.com/gobuffalo/packr" + "github.com/gorilla/handlers" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/gorilla/mux" + "github.com/gorilla/sessions" "github.com/pkg/errors" "go4.org/wkfs" "golang.org/x/crypto/acme/autocert" + "golang.org/x/oauth2" wkfscache "github.com/nsheridan/autocert-wkfs-cache" + "github.com/nsheridan/cashier/lib" "github.com/nsheridan/cashier/server/auth" "github.com/nsheridan/cashier/server/auth/github" "github.com/nsheridan/cashier/server/auth/gitlab" @@ -24,12 +40,6 @@ import ( "github.com/sid77/drop" ) -var ( - authprovider auth.Provider - certstore store.CertStorer - keysigner *signer.KeySigner -) - func loadCerts(certFile, keyFile string) (tls.Certificate, error) { key, err := wkfs.ReadFile(keyFile) if err != nil { @@ -42,13 +52,9 @@ func loadCerts(certFile, keyFile string) (tls.Certificate, error) { return tls.X509KeyPair(cert, key) } -// Run the HTTP server. +// Run the server. func Run(conf *config.Config) { var err error - keysigner, err = signer.New(conf.SSH) - if err != nil { - log.Fatal(err) - } laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port) l, err := net.Listen("tcp", laddr) @@ -90,6 +96,7 @@ func Run(conf *config.Config) { // Unprivileged section metrics.Register() + var authprovider auth.Provider switch conf.Auth.Provider { case "github": authprovider, err = github.New(conf.Auth) @@ -106,11 +113,150 @@ func Run(conf *config.Config) { log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider)) } - certstore, err = store.New(conf.Server.Database) + keysigner, err := signer.New(conf.SSH) + if err != nil { + log.Fatal(err) + } + + certstore, err := store.New(conf.Server.Database) if err != nil { log.Fatal(err) } + ctx := &app{ + cookiestore: sessions.NewCookieStore([]byte(conf.Server.CookieSecret)), + requireReason: conf.Server.RequireReason, + keysigner: keysigner, + certstore: certstore, + authprovider: authprovider, + config: conf.Server, + router: mux.NewRouter(), + } + ctx.cookiestore.Options = &sessions.Options{ + MaxAge: 900, + Path: "/", + Secure: conf.Server.UseTLS, + HttpOnly: true, + } + + logfile := os.Stderr + if conf.Server.HTTPLogFile != "" { + logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) + if err != nil { + log.Printf("error opening log: %v. logging to stdout", err) + } + } + + ctx.routes() + ctx.router.Use(mwVersion) + ctx.router.Use(handlers.CompressHandler) + ctx.router.Use(handlers.RecoveryHandler()) + r := handlers.LoggingHandler(logfile, ctx.router) + s := &http.Server{ + Handler: r, + ReadTimeout: 20 * time.Second, + WriteTimeout: 20 * time.Second, + IdleTimeout: 120 * time.Second, + } + log.Printf("Starting server on %s", laddr) - runHTTPServer(conf.Server, l) + s.Serve(l) +} + +// mwVersion is middleware to add a X-Cashier-Version header to the response. +func mwVersion(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Cashier-Version", lib.Version) + next.ServeHTTP(w, r) + }) +} + +func encodeString(s string) string { + var buffer bytes.Buffer + chunkSize := 70 + runes := []rune(base64.StdEncoding.EncodeToString([]byte(s))) + + for i := 0; i < len(runes); i += chunkSize { + end := i + chunkSize + if end > len(runes) { + end = len(runes) + } + buffer.WriteString(string(runes[i:end])) + buffer.WriteString("\n") + } + buffer.WriteString(".\n") + return buffer.String() +} + +// app contains local context - cookiestore, authsession etc. +type app struct { + cookiestore *sessions.CookieStore + authprovider auth.Provider + certstore store.CertStorer + keysigner *signer.KeySigner + router *mux.Router + config *config.Server + requireReason bool +} + +func (a *app) routes() { + // login required + csrfHandler := csrf.Protect([]byte(a.config.CSRFSecret), csrf.Secure(a.config.UseTLS)) + a.router.Methods("GET").Path("/").Handler(a.authed(http.HandlerFunc(a.index))) + a.router.Methods("POST").Path("/admin/revoke").Handler(a.authed(csrfHandler(http.HandlerFunc(a.revoke)))) + a.router.Methods("GET").Path("/admin/certs").Handler(a.authed(csrfHandler(http.HandlerFunc(a.getAllCerts)))) + a.router.Methods("GET").Path("/admin/certs.json").Handler(a.authed(http.HandlerFunc(a.getCertsJSON))) + + // no login required + a.router.Methods("GET").Path("/auth/login").HandlerFunc(a.auth) + a.router.Methods("GET").Path("/auth/callback").HandlerFunc(a.auth) + a.router.Methods("GET").Path("/revoked").HandlerFunc(a.revoked) + a.router.Methods("POST").Path("/sign").HandlerFunc(a.sign) + + a.router.Methods("GET").Path("/healthcheck").HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "ok") + }) + a.router.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) + box := packr.NewBox("static") + a.router.PathPrefix("/static/").Handler(http.StripPrefix("/static", http.FileServer(box))) +} + +func (a *app) getAuthToken(r *http.Request) *oauth2.Token { + token := &oauth2.Token{} + marshalled := a.getSessionVariable(r, "token") + json.Unmarshal([]byte(marshalled), token) + return token +} + +func (a *app) setAuthToken(w http.ResponseWriter, r *http.Request, token *oauth2.Token) { + v, _ := json.Marshal(token) + a.setSessionVariable(w, r, "token", string(v)) +} + +func (a *app) getSessionVariable(r *http.Request, key string) string { + session, _ := a.cookiestore.Get(r, "session") + v, ok := session.Values[key].(string) + if !ok { + v = "" + } + return v +} + +func (a *app) setSessionVariable(w http.ResponseWriter, r *http.Request, key, value string) { + session, _ := a.cookiestore.Get(r, "session") + session.Values[key] = value + session.Save(r, w) +} + +func (a *app) authed(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t := a.getAuthToken(r) + if !t.Valid() || !a.authprovider.Valid(t) { + a.setSessionVariable(w, r, "origin_url", r.URL.EscapedPath()) + http.Redirect(w, r, "/auth/login", http.StatusSeeOther) + return + } + next.ServeHTTP(w, r) + }) } diff --git a/server/web.go b/server/web.go deleted file mode 100644 index 9114de1..0000000 --- a/server/web.go +++ /dev/null @@ -1,358 +0,0 @@ -package server - -import ( - "bytes" - "crypto/rand" - "encoding/base64" - "encoding/hex" - "encoding/json" - "fmt" - "html/template" - "io" - "log" - "net" - "net/http" - "os" - "strconv" - "strings" - - "github.com/gobuffalo/packr" - - "github.com/pkg/errors" - "github.com/prometheus/client_golang/prometheus/promhttp" - - "golang.org/x/oauth2" - - "github.com/gorilla/csrf" - "github.com/gorilla/handlers" - "github.com/gorilla/mux" - "github.com/gorilla/sessions" - "github.com/nsheridan/cashier/lib" - "github.com/nsheridan/cashier/server/config" - "github.com/nsheridan/cashier/server/templates" -) - -// appContext contains local context - cookiestore, authsession etc. -type appContext struct { - cookiestore *sessions.CookieStore - requireReason bool -} - -// getAuthTokenCookie retrieves a cookie from the request. -func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token { - session, _ := a.cookiestore.Get(r, "session") - t, ok := session.Values["token"] - if !ok { - return nil - } - var tok oauth2.Token - if err := json.Unmarshal(t.([]byte), &tok); err != nil { - return nil - } - if !tok.Valid() { - return nil - } - return &tok -} - -// setAuthTokenCookie marshals the auth token and stores it as a cookie. -func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { - session, _ := a.cookiestore.Get(r, "session") - val, _ := json.Marshal(t) - session.Values["token"] = val - session.Save(r, w) -} - -// getAuthStateCookie retrieves the oauth csrf state value from the client request. -func (a *appContext) getAuthStateCookie(r *http.Request) string { - session, _ := a.cookiestore.Get(r, "session") - state, ok := session.Values["state"] - if !ok { - return "" - } - return state.(string) -} - -// setAuthStateCookie saves the oauth csrf state value. -func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) { - session, _ := a.cookiestore.Get(r, "session") - session.Values["state"] = state - session.Save(r, w) -} - -func (a *appContext) getCurrentURL(r *http.Request) string { - session, _ := a.cookiestore.Get(r, "session") - path, ok := session.Values["auth_url"] - if !ok { - return "" - } - return path.(string) -} - -func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) { - session, _ := a.cookiestore.Get(r, "session") - session.Values["auth_url"] = r.URL.Path - session.Save(r, w) -} - -func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool { - tok := a.getAuthTokenCookie(r) - if !tok.Valid() || !authprovider.Valid(tok) { - return false - } - return true -} - -func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) { - a.setCurrentURL(w, r) - http.Redirect(w, r, "/auth/login", http.StatusSeeOther) - return http.StatusSeeOther, nil -} - -// parseKey retrieves and unmarshals the signing request. -func extractKey(r *http.Request) (*lib.SignRequest, error) { - var s lib.SignRequest - if err := json.NewDecoder(r.Body).Decode(&s); err != nil { - return nil, err - } - return &s, nil -} - -// signHandler handles the "/sign" path. -// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key. -func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - var t string - if ah := r.Header.Get("Authorization"); ah != "" { - if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { - t = ah[7:] - } - } - if t == "" { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - token := &oauth2.Token{ - AccessToken: t, - } - if !authprovider.Valid(token) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - - // Sign the pubkey and issue the cert. - req, err := extractKey(r) - if err != nil { - return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") - } - - if a.requireReason && req.Message == "" { - w.Header().Add("X-Need-Reason", "required") - return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden)) - } - - username := authprovider.Username(token) - authprovider.Revoke(token) // We don't need this anymore. - cert, err := keysigner.SignUserKey(req, username) - if err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "error signing key") - } - if err := certstore.SetCert(cert); err != nil { - log.Printf("Error recording cert: %v", err) - } - if err := json.NewEncoder(w).Encode(&lib.SignResponse{ - Status: "ok", - Response: string(lib.GetPublicKey(cert)), - }); err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "error encoding response") - } - return http.StatusOK, nil -} - -// loginHandler starts the authentication process with the provider. -func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - state := newState() - a.setAuthStateCookie(w, r, state) - http.Redirect(w, r, authprovider.StartSession(state), http.StatusFound) - return http.StatusFound, nil -} - -// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. -func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if r.FormValue("state") != a.getAuthStateCookie(r) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - code := r.FormValue("code") - token, err := authprovider.Exchange(code) - if err != nil { - return http.StatusInternalServerError, err - } - a.setAuthTokenCookie(w, r, token) - http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound) - return http.StatusFound, nil -} - -func encodeString(s string) string { - var buffer bytes.Buffer - chunkSize := 70 - runes := []rune(base64.StdEncoding.EncodeToString([]byte(s))) - - for i := 0; i < len(runes); i += chunkSize { - end := i + chunkSize - if end > len(runes) { - end = len(runes) - } - buffer.WriteString(string(runes[i:end])) - buffer.WriteString("\n") - } - buffer.WriteString(".\n") - return buffer.String() -} - -// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. -func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - tok := a.getAuthTokenCookie(r) - page := struct { - Token string - }{tok.AccessToken} - page.Token = encodeString(page.Token) - - tmpl := template.Must(template.New("token.html").Parse(templates.Token)) - tmpl.Execute(w, page) - return http.StatusOK, nil -} - -func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - revoked, err := certstore.GetRevoked() - if err != nil { - return http.StatusInternalServerError, err - } - rl, err := keysigner.GenerateRevocationList(revoked) - if err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL") - } - w.Header().Set("Content-Type", "application/octet-stream") - w.Write(rl) - return http.StatusOK, nil -} - -func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - tmpl := template.Must(template.New("certs.html").Parse(templates.Certs)) - tmpl.Execute(w, map[string]interface{}{ - csrf.TemplateTag: csrf.TemplateField(r), - }) - return http.StatusOK, nil -} - -func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all")) - certs, err := certstore.List(includeExpired) - if err != nil { - return http.StatusInternalServerError, err - } - j, err := json.Marshal(certs) - if err != nil { - return http.StatusInternalServerError, err - } - w.Write(j) - return http.StatusOK, nil -} - -func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - r.ParseForm() - if err := certstore.Revoke(r.Form["cert_id"]); err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke certs") - } - http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) - return http.StatusSeeOther, nil -} - -func healthcheck(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, "ok") -} - -// appHandler is a handler which uses appContext to manage state. -type appHandler struct { - *appContext - h func(*appContext, http.ResponseWriter, *http.Request) (int, error) -} - -// ServeHTTP handles the request and writes responses. -func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - status, err := ah.h(ah.appContext, w, r) - if err != nil { - http.Error(w, err.Error(), status) - } -} - -// newState generates a state identifier for the oauth process. -func newState() string { - k := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return "unexpectedstring" - } - return hex.EncodeToString(k) -} - -// mwVersion is middleware to add a X-Cashier-Version header to the response. -func mwVersion(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Cashier-Version", lib.Version) - next.ServeHTTP(w, r) - }) -} - -func runHTTPServer(conf *config.Server, l net.Listener) { - var err error - ctx := &appContext{ - cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)), - requireReason: conf.RequireReason, - } - ctx.cookiestore.Options = &sessions.Options{ - MaxAge: 900, - Path: "/", - Secure: conf.UseTLS, - HttpOnly: true, - } - - logfile := os.Stderr - if conf.HTTPLogFile != "" { - logfile, err = os.OpenFile(conf.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) - if err != nil { - log.Printf("unable to open %s for writing. logging to stdout", conf.HTTPLogFile) - logfile = os.Stderr - } - } - - CSRF := csrf.Protect([]byte(conf.CSRFSecret), csrf.Secure(conf.UseTLS)) - r := mux.NewRouter() - r.Use(mwVersion) - r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler}) - r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler}) - r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler}) - r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler}) - r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler}) - r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler})) - r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler})) - r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler}) - r.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) - r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck) - - box := packr.NewBox("static") - r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box))) - h := handlers.LoggingHandler(logfile, r) - s := &http.Server{ - Handler: h, - } - s.Serve(l) -} -- cgit v1.2.3