aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2018-08-20 16:41:17 +0100
committerNiall Sheridan <nsheridan@gmail.com>2018-08-22 14:52:00 +0100
commit99225736d41e86c7f47eac4db3455b18178bba24 (patch)
treec6f8530dabddf98f4832caa650b63299448c2db7
parentdbefe685912e286fce16bf9dd3773f4037cdcdf1 (diff)
Make all handlers methods of app
Merge server setup and helpers from web.go into server.go Handlers moved to handlers.go
-rw-r--r--cmd/cashierd/main.go2
-rw-r--r--server/handlers.go167
-rw-r--r--server/handlers_test.go54
-rw-r--r--server/server.go172
-rw-r--r--server/web.go358
5 files changed, 360 insertions, 393 deletions
diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go
index 5b0b390..b4f1fe7 100644
--- a/cmd/cashierd/main.go
+++ b/cmd/cashierd/main.go
@@ -40,6 +40,6 @@ func main() {
})
vaultfs.Register(conf.Vault)
- // Start the servers
+ // Start the server
server.Run(conf)
}
diff --git a/server/handlers.go b/server/handlers.go
new file mode 100644
index 0000000..b85550d
--- /dev/null
+++ b/server/handlers.go
@@ -0,0 +1,167 @@
+package server
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "io"
+ "log"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "github.com/gorilla/csrf"
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/templates"
+ "github.com/pkg/errors"
+ "golang.org/x/oauth2"
+)
+
+func (a *app) sign(w http.ResponseWriter, r *http.Request) {
+ var t string
+ if ah := r.Header.Get("Authorization"); ah != "" {
+ if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " {
+ t = ah[7:]
+ }
+ }
+
+ token := &oauth2.Token{
+ AccessToken: t,
+ }
+ if !a.authprovider.Valid(token) {
+ w.WriteHeader(http.StatusUnauthorized)
+ fmt.Fprint(w, http.StatusText(http.StatusUnauthorized))
+ return
+ }
+
+ // Sign the pubkey and issue the cert.
+ req := &lib.SignRequest{}
+ if err := json.NewDecoder(r.Body).Decode(req); err != nil {
+ fmt.Println(err)
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprint(w, http.StatusText(http.StatusBadRequest))
+ return
+ }
+
+ if a.requireReason && req.Message == "" {
+ w.Header().Add("X-Need-Reason", "required")
+ w.WriteHeader(http.StatusForbidden)
+ fmt.Fprint(w, http.StatusText(http.StatusForbidden))
+ return
+ }
+
+ username := a.authprovider.Username(token)
+ a.authprovider.Revoke(token) // We don't need this anymore.
+ cert, err := a.keysigner.SignUserKey(req, username)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "Error signing key")
+ return
+ }
+ if err := a.certstore.SetCert(cert); err != nil {
+ log.Printf("Error recording cert: %v", err)
+ }
+ if err := json.NewEncoder(w).Encode(&lib.SignResponse{
+ Status: "ok",
+ Response: string(lib.GetPublicKey(cert)),
+ }); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "Error signing key")
+ return
+ }
+}
+
+func (a *app) auth(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.EscapedPath() {
+ case "/auth/login":
+ buf := make([]byte, 32)
+ io.ReadFull(rand.Reader, buf)
+ state := hex.EncodeToString(buf)
+ a.setSessionVariable(w, r, "state", state)
+ http.Redirect(w, r, a.authprovider.StartSession(state), http.StatusFound)
+ case "/auth/callback":
+ state := a.getSessionVariable(r, "state")
+ if r.FormValue("state") != state {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte(http.StatusText(http.StatusUnauthorized)))
+ break
+ }
+ originURL := a.getSessionVariable(r, "origin_url")
+ if originURL == "" {
+ originURL = "/"
+ }
+ code := r.FormValue("code")
+ token, err := a.authprovider.Exchange(code)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte(http.StatusText(http.StatusInternalServerError)))
+ w.Write([]byte(err.Error()))
+ break
+ }
+ a.setAuthToken(w, r, token)
+ http.Redirect(w, r, originURL, http.StatusFound)
+ default:
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+}
+
+func (a *app) index(w http.ResponseWriter, r *http.Request) {
+ tok := a.getAuthToken(r)
+ page := struct {
+ Token string
+ }{tok.AccessToken}
+ page.Token = encodeString(page.Token)
+ tmpl := template.Must(template.New("token.html").Parse(templates.Token))
+ tmpl.Execute(w, page)
+}
+
+func (a *app) revoked(w http.ResponseWriter, r *http.Request) {
+ revoked, err := a.certstore.GetRevoked()
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, errors.Wrap(err, "error retrieving revoked certs").Error())
+ return
+ }
+ rl, err := a.keysigner.GenerateRevocationList(revoked)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, errors.Wrap(err, "unable to generate KRL").Error())
+ return
+ }
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Write(rl)
+}
+
+func (a *app) getAllCerts(w http.ResponseWriter, r *http.Request) {
+ tmpl := template.Must(template.New("certs.html").Parse(templates.Certs))
+ tmpl.Execute(w, map[string]interface{}{
+ csrf.TemplateTag: csrf.TemplateField(r),
+ })
+}
+
+func (a *app) getCertsJSON(w http.ResponseWriter, r *http.Request) {
+ includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all"))
+ certs, err := a.certstore.List(includeExpired)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprint(w, http.StatusText(http.StatusInternalServerError))
+ return
+ }
+ if err := json.NewEncoder(w).Encode(certs); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprint(w, http.StatusText(http.StatusInternalServerError))
+ return
+ }
+}
+
+func (a *app) revoke(w http.ResponseWriter, r *http.Request) {
+ r.ParseForm()
+ if err := a.certstore.Revoke(r.Form["cert_id"]); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("Unable to revoke certs"))
+ } else {
+ http.Redirect(w, r, "/admin/certs", http.StatusSeeOther)
+ }
+}
diff --git a/server/handlers_test.go b/server/handlers_test.go
index 6dc2236..44024ac 100644
--- a/server/handlers_test.go
+++ b/server/handlers_test.go
@@ -15,6 +15,7 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/oauth2"
+ "github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/auth/testprovider"
@@ -25,28 +26,33 @@ import (
"github.com/stripe/krl"
)
-var ctx *appContext
+var a *app
func init() {
f, _ := ioutil.TempFile(os.TempDir(), "signing_key_")
defer os.Remove(f.Name())
f.Write(testdata.Priv)
f.Close()
- keysigner, _ = signer.New(&config.SSH{
+ keysigner, _ := signer.New(&config.SSH{
SigningKey: f.Name(),
- MaxAge: "1h",
+ MaxAge: "4h",
})
- authprovider = testprovider.New()
- certstore, _ = store.New(map[string]string{"type": "mem"})
- ctx = &appContext{
- cookiestore: sessions.NewCookieStore([]byte("secret")),
+ certstore, _ := store.New(map[string]string{"type": "mem"})
+ a = &app{
+ cookiestore: sessions.NewCookieStore([]byte("secret")),
+ authprovider: testprovider.New(),
+ keysigner: keysigner,
+ certstore: certstore,
+ router: mux.NewRouter(),
+ config: &config.Server{CSRFSecret: "0123456789abcdef"},
}
+ a.routes()
}
func TestLoginHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/login", nil)
resp := httptest.NewRecorder()
- loginHandler(ctx, resp, req)
+ a.router.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
t.Error("Unexpected response")
}
@@ -56,10 +62,11 @@ func TestCallbackHandler(t *testing.T) {
req, _ := http.NewRequest("GET", "/auth/callback", nil)
req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
resp := httptest.NewRecorder()
- ctx.setAuthStateCookie(resp, req, "state")
- callbackHandler(ctx, resp, req)
+ a.setSessionVariable(resp, req, "state", "state")
+ req.Header.Add("Cookie", resp.HeaderMap["Set-Cookie"][0])
+ a.router.ServeHTTP(resp, req)
if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
- t.Error("Unexpected response")
+ t.Errorf("Response: %d\nHeaders: %v", resp.Code, resp.Header())
}
}
@@ -70,8 +77,9 @@ func TestRootHandler(t *testing.T) {
AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
Expiry: time.Now().Add(1 * time.Hour),
}
- ctx.setAuthTokenCookie(resp, req, tok)
- rootHandler(ctx, resp, req)
+ a.setAuthToken(resp, req, tok)
+ req.Header.Add("Cookie", resp.HeaderMap["Set-Cookie"][0])
+ a.router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
t.Error("Unable to find token in response")
}
@@ -80,7 +88,7 @@ func TestRootHandler(t *testing.T) {
func TestRootHandlerNoSession(t *testing.T) {
req, _ := http.NewRequest("GET", "/", nil)
resp := httptest.NewRecorder()
- rootHandler(ctx, resp, req)
+ a.router.ServeHTTP(resp, req)
if resp.Code != http.StatusSeeOther {
t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
}
@@ -89,12 +97,12 @@ func TestRootHandlerNoSession(t *testing.T) {
func TestSignRevoke(t *testing.T) {
s, _ := json.Marshal(&lib.SignRequest{
Key: string(testdata.Pub),
- ValidUntil: time.Now().UTC().Add(1 * time.Hour),
+ ValidUntil: time.Now().UTC().Add(4 * time.Hour),
})
req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
resp := httptest.NewRecorder()
req.Header.Set("Authorization", "Bearer abcdef")
- signHandler(ctx, resp, req)
+ a.router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Error("Unexpected response")
}
@@ -114,18 +122,22 @@ func TestSignRevoke(t *testing.T) {
t.Error("Did not receive a certificate")
}
// Revoke the cert and verify
- req, _ = http.NewRequest("POST", "/revoke", nil)
+ req, _ = http.NewRequest("POST", "/admin/revoke", nil)
req.Form = url.Values{"cert_id": []string{cert.KeyId}}
tok := &oauth2.Token{
AccessToken: "authenticated",
Expiry: time.Now().Add(1 * time.Hour),
}
- ctx.setAuthTokenCookie(resp, req, tok)
- revokeCertHandler(ctx, resp, req)
+ a.certstore.Revoke([]string{cert.KeyId})
+ a.setAuthToken(resp, req, tok)
+ a.router.ServeHTTP(resp, req)
req, _ = http.NewRequest("GET", "/revoked", nil)
- listRevokedCertsHandler(ctx, resp, req)
+ a.router.ServeHTTP(resp, req)
revoked, _ := ioutil.ReadAll(resp.Body)
- rl, _ := krl.ParseKRL(revoked)
+ rl, err := krl.ParseKRL(revoked)
+ if err != nil {
+ t.Fail()
+ }
if !rl.IsRevoked(cert) {
t.Errorf("cert %s was not revoked", cert.KeyId)
}
diff --git a/server/server.go b/server/server.go
index 1b8468e..2a6af15 100644
--- a/server/server.go
+++ b/server/server.go
@@ -1,17 +1,33 @@
package server
import (
+ "bytes"
"crypto/tls"
+ "encoding/base64"
+ "encoding/json"
"fmt"
"log"
"net"
+ "net/http"
+ "os"
+ "time"
+ "github.com/gorilla/csrf"
+
+ "github.com/gobuffalo/packr"
+ "github.com/gorilla/handlers"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+
+ "github.com/gorilla/mux"
+ "github.com/gorilla/sessions"
"github.com/pkg/errors"
"go4.org/wkfs"
"golang.org/x/crypto/acme/autocert"
+ "golang.org/x/oauth2"
wkfscache "github.com/nsheridan/autocert-wkfs-cache"
+ "github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/auth"
"github.com/nsheridan/cashier/server/auth/github"
"github.com/nsheridan/cashier/server/auth/gitlab"
@@ -24,12 +40,6 @@ import (
"github.com/sid77/drop"
)
-var (
- authprovider auth.Provider
- certstore store.CertStorer
- keysigner *signer.KeySigner
-)
-
func loadCerts(certFile, keyFile string) (tls.Certificate, error) {
key, err := wkfs.ReadFile(keyFile)
if err != nil {
@@ -42,13 +52,9 @@ func loadCerts(certFile, keyFile string) (tls.Certificate, error) {
return tls.X509KeyPair(cert, key)
}
-// Run the HTTP server.
+// Run the server.
func Run(conf *config.Config) {
var err error
- keysigner, err = signer.New(conf.SSH)
- if err != nil {
- log.Fatal(err)
- }
laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port)
l, err := net.Listen("tcp", laddr)
@@ -90,6 +96,7 @@ func Run(conf *config.Config) {
// Unprivileged section
metrics.Register()
+ var authprovider auth.Provider
switch conf.Auth.Provider {
case "github":
authprovider, err = github.New(conf.Auth)
@@ -106,11 +113,150 @@ func Run(conf *config.Config) {
log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider))
}
- certstore, err = store.New(conf.Server.Database)
+ keysigner, err := signer.New(conf.SSH)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ certstore, err := store.New(conf.Server.Database)
if err != nil {
log.Fatal(err)
}
+ ctx := &app{
+ cookiestore: sessions.NewCookieStore([]byte(conf.Server.CookieSecret)),
+ requireReason: conf.Server.RequireReason,
+ keysigner: keysigner,
+ certstore: certstore,
+ authprovider: authprovider,
+ config: conf.Server,
+ router: mux.NewRouter(),
+ }
+ ctx.cookiestore.Options = &sessions.Options{
+ MaxAge: 900,
+ Path: "/",
+ Secure: conf.Server.UseTLS,
+ HttpOnly: true,
+ }
+
+ logfile := os.Stderr
+ if conf.Server.HTTPLogFile != "" {
+ logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
+ if err != nil {
+ log.Printf("error opening log: %v. logging to stdout", err)
+ }
+ }
+
+ ctx.routes()
+ ctx.router.Use(mwVersion)
+ ctx.router.Use(handlers.CompressHandler)
+ ctx.router.Use(handlers.RecoveryHandler())
+ r := handlers.LoggingHandler(logfile, ctx.router)
+ s := &http.Server{
+ Handler: r,
+ ReadTimeout: 20 * time.Second,
+ WriteTimeout: 20 * time.Second,
+ IdleTimeout: 120 * time.Second,
+ }
+
log.Printf("Starting server on %s", laddr)
- runHTTPServer(conf.Server, l)
+ s.Serve(l)
+}
+
+// mwVersion is middleware to add a X-Cashier-Version header to the response.
+func mwVersion(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Cashier-Version", lib.Version)
+ next.ServeHTTP(w, r)
+ })
+}
+
+func encodeString(s string) string {
+ var buffer bytes.Buffer
+ chunkSize := 70
+ runes := []rune(base64.StdEncoding.EncodeToString([]byte(s)))
+
+ for i := 0; i < len(runes); i += chunkSize {
+ end := i + chunkSize
+ if end > len(runes) {
+ end = len(runes)
+ }
+ buffer.WriteString(string(runes[i:end]))
+ buffer.WriteString("\n")
+ }
+ buffer.WriteString(".\n")
+ return buffer.String()
+}
+
+// app contains local context - cookiestore, authsession etc.
+type app struct {
+ cookiestore *sessions.CookieStore
+ authprovider auth.Provider
+ certstore store.CertStorer
+ keysigner *signer.KeySigner
+ router *mux.Router
+ config *config.Server
+ requireReason bool
+}
+
+func (a *app) routes() {
+ // login required
+ csrfHandler := csrf.Protect([]byte(a.config.CSRFSecret), csrf.Secure(a.config.UseTLS))
+ a.router.Methods("GET").Path("/").Handler(a.authed(http.HandlerFunc(a.index)))
+ a.router.Methods("POST").Path("/admin/revoke").Handler(a.authed(csrfHandler(http.HandlerFunc(a.revoke))))
+ a.router.Methods("GET").Path("/admin/certs").Handler(a.authed(csrfHandler(http.HandlerFunc(a.getAllCerts))))
+ a.router.Methods("GET").Path("/admin/certs.json").Handler(a.authed(http.HandlerFunc(a.getCertsJSON)))
+
+ // no login required
+ a.router.Methods("GET").Path("/auth/login").HandlerFunc(a.auth)
+ a.router.Methods("GET").Path("/auth/callback").HandlerFunc(a.auth)
+ a.router.Methods("GET").Path("/revoked").HandlerFunc(a.revoked)
+ a.router.Methods("POST").Path("/sign").HandlerFunc(a.sign)
+
+ a.router.Methods("GET").Path("/healthcheck").HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, "ok")
+ })
+ a.router.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
+ box := packr.NewBox("static")
+ a.router.PathPrefix("/static/").Handler(http.StripPrefix("/static", http.FileServer(box)))
+}
+
+func (a *app) getAuthToken(r *http.Request) *oauth2.Token {
+ token := &oauth2.Token{}
+ marshalled := a.getSessionVariable(r, "token")
+ json.Unmarshal([]byte(marshalled), token)
+ return token
+}
+
+func (a *app) setAuthToken(w http.ResponseWriter, r *http.Request, token *oauth2.Token) {
+ v, _ := json.Marshal(token)
+ a.setSessionVariable(w, r, "token", string(v))
+}
+
+func (a *app) getSessionVariable(r *http.Request, key string) string {
+ session, _ := a.cookiestore.Get(r, "session")
+ v, ok := session.Values[key].(string)
+ if !ok {
+ v = ""
+ }
+ return v
+}
+
+func (a *app) setSessionVariable(w http.ResponseWriter, r *http.Request, key, value string) {
+ session, _ := a.cookiestore.Get(r, "session")
+ session.Values[key] = value
+ session.Save(r, w)
+}
+
+func (a *app) authed(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t := a.getAuthToken(r)
+ if !t.Valid() || !a.authprovider.Valid(t) {
+ a.setSessionVariable(w, r, "origin_url", r.URL.EscapedPath())
+ http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
}
diff --git a/server/web.go b/server/web.go
deleted file mode 100644
index 9114de1..0000000
--- a/server/web.go
+++ /dev/null
@@ -1,358 +0,0 @@
-package server
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/base64"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "html/template"
- "io"
- "log"
- "net"
- "net/http"
- "os"
- "strconv"
- "strings"
-
- "github.com/gobuffalo/packr"
-
- "github.com/pkg/errors"
- "github.com/prometheus/client_golang/prometheus/promhttp"
-
- "golang.org/x/oauth2"
-
- "github.com/gorilla/csrf"
- "github.com/gorilla/handlers"
- "github.com/gorilla/mux"
- "github.com/gorilla/sessions"
- "github.com/nsheridan/cashier/lib"
- "github.com/nsheridan/cashier/server/config"
- "github.com/nsheridan/cashier/server/templates"
-)
-
-// appContext contains local context - cookiestore, authsession etc.
-type appContext struct {
- cookiestore *sessions.CookieStore
- requireReason bool
-}
-
-// getAuthTokenCookie retrieves a cookie from the request.
-func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token {
- session, _ := a.cookiestore.Get(r, "session")
- t, ok := session.Values["token"]
- if !ok {
- return nil
- }
- var tok oauth2.Token
- if err := json.Unmarshal(t.([]byte), &tok); err != nil {
- return nil
- }
- if !tok.Valid() {
- return nil
- }
- return &tok
-}
-
-// setAuthTokenCookie marshals the auth token and stores it as a cookie.
-func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
- session, _ := a.cookiestore.Get(r, "session")
- val, _ := json.Marshal(t)
- session.Values["token"] = val
- session.Save(r, w)
-}
-
-// getAuthStateCookie retrieves the oauth csrf state value from the client request.
-func (a *appContext) getAuthStateCookie(r *http.Request) string {
- session, _ := a.cookiestore.Get(r, "session")
- state, ok := session.Values["state"]
- if !ok {
- return ""
- }
- return state.(string)
-}
-
-// setAuthStateCookie saves the oauth csrf state value.
-func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) {
- session, _ := a.cookiestore.Get(r, "session")
- session.Values["state"] = state
- session.Save(r, w)
-}
-
-func (a *appContext) getCurrentURL(r *http.Request) string {
- session, _ := a.cookiestore.Get(r, "session")
- path, ok := session.Values["auth_url"]
- if !ok {
- return ""
- }
- return path.(string)
-}
-
-func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) {
- session, _ := a.cookiestore.Get(r, "session")
- session.Values["auth_url"] = r.URL.Path
- session.Save(r, w)
-}
-
-func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
- tok := a.getAuthTokenCookie(r)
- if !tok.Valid() || !authprovider.Valid(tok) {
- return false
- }
- return true
-}
-
-func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) {
- a.setCurrentURL(w, r)
- http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
- return http.StatusSeeOther, nil
-}
-
-// parseKey retrieves and unmarshals the signing request.
-func extractKey(r *http.Request) (*lib.SignRequest, error) {
- var s lib.SignRequest
- if err := json.NewDecoder(r.Body).Decode(&s); err != nil {
- return nil, err
- }
- return &s, nil
-}
-
-// signHandler handles the "/sign" path.
-// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key.
-func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- var t string
- if ah := r.Header.Get("Authorization"); ah != "" {
- if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " {
- t = ah[7:]
- }
- }
- if t == "" {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- token := &oauth2.Token{
- AccessToken: t,
- }
- if !authprovider.Valid(token) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
-
- // Sign the pubkey and issue the cert.
- req, err := extractKey(r)
- if err != nil {
- return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
- }
-
- if a.requireReason && req.Message == "" {
- w.Header().Add("X-Need-Reason", "required")
- return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))
- }
-
- username := authprovider.Username(token)
- authprovider.Revoke(token) // We don't need this anymore.
- cert, err := keysigner.SignUserKey(req, username)
- if err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "error signing key")
- }
- if err := certstore.SetCert(cert); err != nil {
- log.Printf("Error recording cert: %v", err)
- }
- if err := json.NewEncoder(w).Encode(&lib.SignResponse{
- Status: "ok",
- Response: string(lib.GetPublicKey(cert)),
- }); err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "error encoding response")
- }
- return http.StatusOK, nil
-}
-
-// loginHandler starts the authentication process with the provider.
-func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- state := newState()
- a.setAuthStateCookie(w, r, state)
- http.Redirect(w, r, authprovider.StartSession(state), http.StatusFound)
- return http.StatusFound, nil
-}
-
-// callbackHandler handles retrieving the access token from the auth provider and saves it for later use.
-func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if r.FormValue("state") != a.getAuthStateCookie(r) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- code := r.FormValue("code")
- token, err := authprovider.Exchange(code)
- if err != nil {
- return http.StatusInternalServerError, err
- }
- a.setAuthTokenCookie(w, r, token)
- http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound)
- return http.StatusFound, nil
-}
-
-func encodeString(s string) string {
- var buffer bytes.Buffer
- chunkSize := 70
- runes := []rune(base64.StdEncoding.EncodeToString([]byte(s)))
-
- for i := 0; i < len(runes); i += chunkSize {
- end := i + chunkSize
- if end > len(runes) {
- end = len(runes)
- }
- buffer.WriteString(string(runes[i:end]))
- buffer.WriteString("\n")
- }
- buffer.WriteString(".\n")
- return buffer.String()
-}
-
-// rootHandler starts the auth process. If the client is authenticated it renders the token to the user.
-func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- tok := a.getAuthTokenCookie(r)
- page := struct {
- Token string
- }{tok.AccessToken}
- page.Token = encodeString(page.Token)
-
- tmpl := template.Must(template.New("token.html").Parse(templates.Token))
- tmpl.Execute(w, page)
- return http.StatusOK, nil
-}
-
-func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- revoked, err := certstore.GetRevoked()
- if err != nil {
- return http.StatusInternalServerError, err
- }
- rl, err := keysigner.GenerateRevocationList(revoked)
- if err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL")
- }
- w.Header().Set("Content-Type", "application/octet-stream")
- w.Write(rl)
- return http.StatusOK, nil
-}
-
-func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- tmpl := template.Must(template.New("certs.html").Parse(templates.Certs))
- tmpl.Execute(w, map[string]interface{}{
- csrf.TemplateTag: csrf.TemplateField(r),
- })
- return http.StatusOK, nil
-}
-
-func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all"))
- certs, err := certstore.List(includeExpired)
- if err != nil {
- return http.StatusInternalServerError, err
- }
- j, err := json.Marshal(certs)
- if err != nil {
- return http.StatusInternalServerError, err
- }
- w.Write(j)
- return http.StatusOK, nil
-}
-
-func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- r.ParseForm()
- if err := certstore.Revoke(r.Form["cert_id"]); err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke certs")
- }
- http.Redirect(w, r, "/admin/certs", http.StatusSeeOther)
- return http.StatusSeeOther, nil
-}
-
-func healthcheck(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- fmt.Fprint(w, "ok")
-}
-
-// appHandler is a handler which uses appContext to manage state.
-type appHandler struct {
- *appContext
- h func(*appContext, http.ResponseWriter, *http.Request) (int, error)
-}
-
-// ServeHTTP handles the request and writes responses.
-func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- status, err := ah.h(ah.appContext, w, r)
- if err != nil {
- http.Error(w, err.Error(), status)
- }
-}
-
-// newState generates a state identifier for the oauth process.
-func newState() string {
- k := make([]byte, 32)
- if _, err := io.ReadFull(rand.Reader, k); err != nil {
- return "unexpectedstring"
- }
- return hex.EncodeToString(k)
-}
-
-// mwVersion is middleware to add a X-Cashier-Version header to the response.
-func mwVersion(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("X-Cashier-Version", lib.Version)
- next.ServeHTTP(w, r)
- })
-}
-
-func runHTTPServer(conf *config.Server, l net.Listener) {
- var err error
- ctx := &appContext{
- cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
- requireReason: conf.RequireReason,
- }
- ctx.cookiestore.Options = &sessions.Options{
- MaxAge: 900,
- Path: "/",
- Secure: conf.UseTLS,
- HttpOnly: true,
- }
-
- logfile := os.Stderr
- if conf.HTTPLogFile != "" {
- logfile, err = os.OpenFile(conf.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
- if err != nil {
- log.Printf("unable to open %s for writing. logging to stdout", conf.HTTPLogFile)
- logfile = os.Stderr
- }
- }
-
- CSRF := csrf.Protect([]byte(conf.CSRFSecret), csrf.Secure(conf.UseTLS))
- r := mux.NewRouter()
- r.Use(mwVersion)
- r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
- r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
- r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
- r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
- r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler})
- r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler}))
- r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler}))
- r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
- r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
- r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck)
-
- box := packr.NewBox("static")
- r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box)))
- h := handlers.LoggingHandler(logfile, r)
- s := &http.Server{
- Handler: h,
- }
- s.Serve(l)
-}