From fb830dc3531904be0a58e2c4dd4638b390bbdab2 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 19 Feb 2017 23:28:33 +0000 Subject: Split the servers out of main --- cmd/cashierd/handlers_test.go | 146 --------------- cmd/cashierd/main.go | 404 +----------------------------------------- cmd/cashierd/rpc.go | 68 ------- server/handlers_test.go | 146 +++++++++++++++ server/rpc.go | 68 +++++++ server/server.go | 117 ++++++++++++ server/web.go | 313 ++++++++++++++++++++++++++++++++ 7 files changed, 647 insertions(+), 615 deletions(-) delete mode 100644 cmd/cashierd/handlers_test.go delete mode 100644 cmd/cashierd/rpc.go create mode 100644 server/handlers_test.go create mode 100644 server/rpc.go create mode 100644 server/server.go create mode 100644 server/web.go diff --git a/cmd/cashierd/handlers_test.go b/cmd/cashierd/handlers_test.go deleted file mode 100644 index 934d5d0..0000000 --- a/cmd/cashierd/handlers_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package main - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - "net/http/httptest" - "net/url" - "os" - "strings" - "testing" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/oauth2" - - "github.com/gorilla/sessions" - "github.com/nsheridan/cashier/lib" - "github.com/nsheridan/cashier/server/auth" - "github.com/nsheridan/cashier/server/auth/testprovider" - "github.com/nsheridan/cashier/server/config" - "github.com/nsheridan/cashier/server/signer" - "github.com/nsheridan/cashier/server/store" - "github.com/nsheridan/cashier/testdata" - "github.com/stripe/krl" -) - -func newContext(t *testing.T) *appContext { - f, err := ioutil.TempFile(os.TempDir(), "signing_key_") - if err != nil { - t.Error(err) - } - defer os.Remove(f.Name()) - f.Write(testdata.Priv) - f.Close() - if keysigner, err = signer.New(&config.SSH{ - SigningKey: f.Name(), - MaxAge: "1h", - }); err != nil { - t.Error(err) - } - authprovider = testprovider.New() - certstore = store.NewMemoryStore() - return &appContext{ - cookiestore: sessions.NewCookieStore([]byte("secret")), - authsession: &auth.Session{AuthURL: "https://www.example.com/auth"}, - } -} - -func TestLoginHandler(t *testing.T) { - t.Parallel() - req, _ := http.NewRequest("GET", "/auth/login", nil) - resp := httptest.NewRecorder() - loginHandler(newContext(t), resp, req) - if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" { - t.Error("Unexpected response") - } -} - -func TestCallbackHandler(t *testing.T) { - t.Parallel() - req, _ := http.NewRequest("GET", "/auth/callback", nil) - req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}} - resp := httptest.NewRecorder() - ctx := newContext(t) - ctx.setAuthStateCookie(resp, req, "state") - callbackHandler(ctx, resp, req) - if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" { - t.Error("Unexpected response") - } -} - -func TestRootHandler(t *testing.T) { - t.Parallel() - req, _ := http.NewRequest("GET", "/", nil) - resp := httptest.NewRecorder() - ctx := newContext(t) - tok := &oauth2.Token{ - AccessToken: "XXX_TEST_TOKEN_STRING_XXX", - Expiry: time.Now().Add(1 * time.Hour), - } - ctx.setAuthTokenCookie(resp, req, tok) - rootHandler(ctx, resp, req) - if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") { - t.Error("Unable to find token in response") - } -} - -func TestRootHandlerNoSession(t *testing.T) { - t.Parallel() - req, _ := http.NewRequest("GET", "/", nil) - resp := httptest.NewRecorder() - ctx := newContext(t) - rootHandler(ctx, resp, req) - if resp.Code != http.StatusSeeOther { - t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther)) - } -} - -func TestSignRevoke(t *testing.T) { - t.Parallel() - s, _ := json.Marshal(&lib.SignRequest{ - Key: string(testdata.Pub), - ValidUntil: time.Now().UTC().Add(1 * time.Hour), - }) - req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s)) - resp := httptest.NewRecorder() - ctx := newContext(t) - req.Header.Set("Authorization", "Bearer abcdef") - signHandler(ctx, resp, req) - if resp.Code != http.StatusOK { - t.Error("Unexpected response") - } - r := &lib.SignResponse{} - if err := json.NewDecoder(resp.Body).Decode(r); err != nil { - t.Error(err) - } - if r.Status != "ok" { - t.Error("Unexpected response") - } - k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response)) - if err != nil { - t.Error(err) - } - cert, ok := k.(*ssh.Certificate) - if !ok { - t.Error("Did not receive a certificate") - } - // Revoke the cert and verify - req, _ = http.NewRequest("POST", "/revoke", nil) - req.Form = url.Values{"cert_id": []string{cert.KeyId}} - tok := &oauth2.Token{ - AccessToken: "authenticated", - Expiry: time.Now().Add(1 * time.Hour), - } - ctx.setAuthTokenCookie(resp, req, tok) - revokeCertHandler(ctx, resp, req) - req, _ = http.NewRequest("GET", "/revoked", nil) - listRevokedCertsHandler(ctx, resp, req) - revoked, _ := ioutil.ReadAll(resp.Body) - rl, _ := krl.ParseKRL(revoked) - if !rl.IsRevoked(cert) { - t.Errorf("cert %s was not revoked", cert.KeyId) - } -} diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index d355604..2e378bc 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -1,315 +1,20 @@ package main import ( - "crypto/rand" - "crypto/tls" - "encoding/hex" - "encoding/json" "flag" - "fmt" - "html/template" - "io" "log" - "net" - "net/http" - "os" - "strconv" - "strings" - "github.com/pkg/errors" - "github.com/soheilhy/cmux" - - "go4.org/wkfs" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/oauth2" - - "github.com/gorilla/csrf" - "github.com/gorilla/handlers" - "github.com/gorilla/mux" - "github.com/gorilla/sessions" - wkfscache "github.com/nsheridan/autocert-wkfs-cache" - "github.com/nsheridan/cashier/lib" - "github.com/nsheridan/cashier/server/auth" - "github.com/nsheridan/cashier/server/auth/github" - "github.com/nsheridan/cashier/server/auth/gitlab" - "github.com/nsheridan/cashier/server/auth/google" + "github.com/nsheridan/cashier/server" "github.com/nsheridan/cashier/server/config" - "github.com/nsheridan/cashier/server/metrics" - "github.com/nsheridan/cashier/server/signer" - "github.com/nsheridan/cashier/server/static" - "github.com/nsheridan/cashier/server/store" - "github.com/nsheridan/cashier/server/templates" "github.com/nsheridan/cashier/server/wkfs/vaultfs" "github.com/nsheridan/wkfs/s3" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/sid77/drop" ) var ( cfg = flag.String("config_file", "cashierd.conf", "Path to configuration file.") - - authprovider auth.Provider - certstore store.CertStorer - keysigner *signer.KeySigner ) -// appContext contains local context - cookiestore, authsession etc. -type appContext struct { - cookiestore *sessions.CookieStore - authsession *auth.Session -} - -// getAuthTokenCookie retrieves a cookie from the request. -func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token { - session, _ := a.cookiestore.Get(r, "session") - t, ok := session.Values["token"] - if !ok { - return nil - } - var tok oauth2.Token - if err := json.Unmarshal(t.([]byte), &tok); err != nil { - return nil - } - if !tok.Valid() { - return nil - } - return &tok -} - -// setAuthTokenCookie marshals the auth token and stores it as a cookie. -func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { - session, _ := a.cookiestore.Get(r, "session") - val, _ := json.Marshal(t) - session.Values["token"] = val - session.Save(r, w) -} - -// getAuthStateCookie retrieves the oauth csrf state value from the client request. -func (a *appContext) getAuthStateCookie(r *http.Request) string { - session, _ := a.cookiestore.Get(r, "session") - state, ok := session.Values["state"] - if !ok { - return "" - } - return state.(string) -} - -// setAuthStateCookie saves the oauth csrf state value. -func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) { - session, _ := a.cookiestore.Get(r, "session") - session.Values["state"] = state - session.Save(r, w) -} - -func (a *appContext) getCurrentURL(r *http.Request) string { - session, _ := a.cookiestore.Get(r, "session") - path, ok := session.Values["auth_url"] - if !ok { - return "" - } - return path.(string) -} - -func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) { - session, _ := a.cookiestore.Get(r, "session") - session.Values["auth_url"] = r.URL.Path - session.Save(r, w) -} - -func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool { - tok := a.getAuthTokenCookie(r) - if !tok.Valid() || !authprovider.Valid(tok) { - return false - } - return true -} - -func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) { - a.setCurrentURL(w, r) - http.Redirect(w, r, "/auth/login", http.StatusSeeOther) - return http.StatusSeeOther, nil -} - -// parseKey retrieves and unmarshals the signing request. -func extractKey(r *http.Request) (*lib.SignRequest, error) { - var s lib.SignRequest - if err := json.NewDecoder(r.Body).Decode(&s); err != nil { - return nil, err - } - return &s, nil -} - -// signHandler handles the "/sign" path. -// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key. -func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - var t string - if ah := r.Header.Get("Authorization"); ah != "" { - if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { - t = ah[7:] - } - } - if t == "" { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - token := &oauth2.Token{ - AccessToken: t, - } - if !authprovider.Valid(token) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - - // Sign the pubkey and issue the cert. - req, err := extractKey(r) - if err != nil { - return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") - } - username := authprovider.Username(token) - authprovider.Revoke(token) // We don't need this anymore. - cert, err := keysigner.SignUserKey(req, username) - if err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "error signing key") - } - if err := certstore.SetCert(cert); err != nil { - log.Printf("Error recording cert: %v", err) - } - if err := json.NewEncoder(w).Encode(&lib.SignResponse{ - Status: "ok", - Response: string(lib.GetPublicKey(cert)), - }); err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "error encoding response") - } - return http.StatusOK, nil -} - -// loginHandler starts the authentication process with the provider. -func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - state := newState() - a.setAuthStateCookie(w, r, state) - a.authsession = authprovider.StartSession(state) - http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) - return http.StatusFound, nil -} - -// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. -func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if r.FormValue("state") != a.getAuthStateCookie(r) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - code := r.FormValue("code") - if err := a.authsession.Authorize(authprovider, code); err != nil { - return http.StatusInternalServerError, err - } - a.setAuthTokenCookie(w, r, a.authsession.Token) - http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound) - return http.StatusFound, nil -} - -// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. -func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - tok := a.getAuthTokenCookie(r) - page := struct { - Token string - }{tok.AccessToken} - - tmpl := template.Must(template.New("token.html").Parse(templates.Token)) - tmpl.Execute(w, page) - return http.StatusOK, nil -} - -func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - revoked, err := certstore.GetRevoked() - if err != nil { - return http.StatusInternalServerError, err - } - rl, err := keysigner.GenerateRevocationList(revoked) - if err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL") - } - w.Header().Set("Content-Type", "application/octet-stream") - w.Write(rl) - return http.StatusOK, nil -} - -func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - tmpl := template.Must(template.New("certs.html").Parse(templates.Certs)) - tmpl.Execute(w, map[string]interface{}{ - csrf.TemplateTag: csrf.TemplateField(r), - }) - return http.StatusOK, nil -} - -func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) - } - includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all")) - certs, err := certstore.List(includeExpired) - j, err := json.Marshal(certs) - if err != nil { - return http.StatusInternalServerError, errors.New(http.StatusText(http.StatusInternalServerError)) - } - w.Write(j) - return http.StatusOK, nil -} - -func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if !a.isLoggedIn(w, r) { - return a.login(w, r) - } - r.ParseForm() - for _, id := range r.Form["cert_id"] { - if err := certstore.Revoke(id); err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke") - } - } - http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) - return http.StatusSeeOther, nil -} - -// appHandler is a handler which uses appContext to manage state. -type appHandler struct { - *appContext - h func(*appContext, http.ResponseWriter, *http.Request) (int, error) -} - -// ServeHTTP handles the request and writes responses. -func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - status, err := ah.h(ah.appContext, w, r) - if err != nil { - log.Printf("HTTP %d: %q", status, err) - http.Error(w, err.Error(), status) - } -} - -// newState generates a state identifier for the oauth process. -func newState() string { - k := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, k); err != nil { - return "unexpectedstring" - } - return hex.EncodeToString(k) -} - -func loadCerts(certFile, keyFile string) (tls.Certificate, error) { - key, err := wkfs.ReadFile(keyFile) - if err != nil { - return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key") - } - cert, err := wkfs.ReadFile(certFile) - if err != nil { - return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate") - } - return tls.X509KeyPair(cert, key) -} - func main() { - // Privileged section flag.Parse() conf, err := config.ReadConfig(*cfg) if err != nil { @@ -327,109 +32,6 @@ func main() { }) vaultfs.Register(conf.Vault) - keysigner, err = signer.New(conf.SSH) - if err != nil { - log.Fatal(err) - } - - logfile := os.Stderr - if conf.Server.HTTPLogFile != "" { - logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) - if err != nil { - log.Printf("unable to open %s for writing. logging to stdout", conf.Server.HTTPLogFile) - logfile = os.Stderr - } - } - - laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port) - l, err := net.Listen("tcp", laddr) - if err != nil { - log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port)) - } - - tlsConfig := &tls.Config{} - if conf.Server.UseTLS { - if conf.Server.LetsEncryptServername != "" { - m := autocert.Manager{ - Prompt: autocert.AcceptTOS, - Cache: wkfscache.Cache(conf.Server.LetsEncryptCache), - HostPolicy: autocert.HostWhitelist(conf.Server.LetsEncryptServername), - } - tlsConfig.GetCertificate = m.GetCertificate - } else { - if conf.Server.TLSCert == "" || conf.Server.TLSKey == "" { - log.Fatal("TLS cert or key not specified in config") - } - tlsConfig.Certificates = make([]tls.Certificate, 1) - tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey) - if err != nil { - log.Fatal(errors.Wrap(err, "unable to create TLS listener")) - } - } - l = tls.NewListener(l, tlsConfig) - } - - if conf.Server.User != "" { - log.Print("Dropping privileges...") - if err := drop.DropPrivileges(conf.Server.User); err != nil { - log.Fatal(errors.Wrap(err, "unable to drop privileges")) - } - } - - // Unprivileged section - metrics.Register() - - switch conf.Auth.Provider { - case "google": - authprovider, err = google.New(conf.Auth) - case "github": - authprovider, err = github.New(conf.Auth) - case "gitlab": - authprovider, err = gitlab.New(conf.Auth) - default: - log.Fatalf("Unknown provider %s\n", conf.Auth.Provider) - } - if err != nil { - log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider)) - } - - certstore, err = store.New(conf.Server.Database) - if err != nil { - log.Fatal(err) - } - ctx := &appContext{ - cookiestore: sessions.NewCookieStore([]byte(conf.Server.CookieSecret)), - } - ctx.cookiestore.Options = &sessions.Options{ - MaxAge: 900, - Path: "/", - Secure: conf.Server.UseTLS, - HttpOnly: true, - } - - CSRF := csrf.Protect([]byte(conf.Server.CSRFSecret), csrf.Secure(conf.Server.UseTLS)) - r := mux.NewRouter() - r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler}) - r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler}) - r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler}) - r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler}) - r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler}) - r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler})) - r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler})) - r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler}) - r.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) - r.PathPrefix("/").Handler(http.FileServer(static.FS(false))) - h := handlers.LoggingHandler(logfile, r) - - log.Printf("Starting server on %s", laddr) - s := &http.Server{ - Handler: h, - } - - cm := cmux.New(l) - httpl := cm.Match(cmux.HTTP1Fast()) - grpcl := cm.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) - go s.Serve(httpl) - go newGrpcServer(grpcl) - log.Fatal(cm.Serve()) + // Start the servers + server.Run(conf) } diff --git a/cmd/cashierd/rpc.go b/cmd/cashierd/rpc.go deleted file mode 100644 index ad8aa5d..0000000 --- a/cmd/cashierd/rpc.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -import ( - "log" - "net" - - "golang.org/x/net/context" - - "golang.org/x/oauth2" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - - "github.com/nsheridan/cashier/lib" - "github.com/nsheridan/cashier/proto" -) - -type rpcServer struct{} - -type key int - -const usernameKey key = 0 - -func (s *rpcServer) Sign(ctx context.Context, req *proto.SignRequest) (*proto.SignResponse, error) { - username, ok := ctx.Value(usernameKey).(string) - if !ok { - return nil, grpc.Errorf(codes.InvalidArgument, "Error reading username") - } - cert, err := keysigner.SignUserKeyFromRPC(req, username) - if err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, err.Error()) - } - if err := certstore.SetCert(cert); err != nil { - log.Printf("Error recording cert: %v", err) - } - resp := &proto.SignResponse{ - Cert: lib.GetPublicKey(cert), - } - return resp, nil -} - -func authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { - md, ok := metadata.FromContext(ctx) - if !ok { - return nil, grpc.Errorf(codes.Unauthenticated, "request not authenticated") - } - switch md["security"][0] { - case "authorization": - token := &oauth2.Token{ - AccessToken: md["payload"][0], - } - if !authprovider.Valid(token) { - return nil, grpc.Errorf(codes.PermissionDenied, "access denied") - } - authprovider.Revoke(token) - ctx = context.WithValue(ctx, usernameKey, authprovider.Username(token)) - default: - return nil, grpc.Errorf(codes.InvalidArgument, "unknown argument") - } - return handler(ctx, req) -} - -func newGrpcServer(l net.Listener) { - serv := grpc.NewServer(grpc.UnaryInterceptor(authInterceptor)) - proto.RegisterSignerServer(serv, &rpcServer{}) - serv.Serve(l) -} diff --git a/server/handlers_test.go b/server/handlers_test.go new file mode 100644 index 0000000..b2646a3 --- /dev/null +++ b/server/handlers_test.go @@ -0,0 +1,146 @@ +package server + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/oauth2" + + "github.com/gorilla/sessions" + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/auth/testprovider" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/signer" + "github.com/nsheridan/cashier/server/store" + "github.com/nsheridan/cashier/testdata" + "github.com/stripe/krl" +) + +func newContext(t *testing.T) *appContext { + f, err := ioutil.TempFile(os.TempDir(), "signing_key_") + if err != nil { + t.Error(err) + } + defer os.Remove(f.Name()) + f.Write(testdata.Priv) + f.Close() + if keysigner, err = signer.New(&config.SSH{ + SigningKey: f.Name(), + MaxAge: "1h", + }); err != nil { + t.Error(err) + } + authprovider = testprovider.New() + certstore = store.NewMemoryStore() + return &appContext{ + cookiestore: sessions.NewCookieStore([]byte("secret")), + authsession: &auth.Session{AuthURL: "https://www.example.com/auth"}, + } +} + +func TestLoginHandler(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", "/auth/login", nil) + resp := httptest.NewRecorder() + loginHandler(newContext(t), resp, req) + if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" { + t.Error("Unexpected response") + } +} + +func TestCallbackHandler(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", "/auth/callback", nil) + req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}} + resp := httptest.NewRecorder() + ctx := newContext(t) + ctx.setAuthStateCookie(resp, req, "state") + callbackHandler(ctx, resp, req) + if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" { + t.Error("Unexpected response") + } +} + +func TestRootHandler(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + ctx := newContext(t) + tok := &oauth2.Token{ + AccessToken: "XXX_TEST_TOKEN_STRING_XXX", + Expiry: time.Now().Add(1 * time.Hour), + } + ctx.setAuthTokenCookie(resp, req, tok) + rootHandler(ctx, resp, req) + if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") { + t.Error("Unable to find token in response") + } +} + +func TestRootHandlerNoSession(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + ctx := newContext(t) + rootHandler(ctx, resp, req) + if resp.Code != http.StatusSeeOther { + t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther)) + } +} + +func TestSignRevoke(t *testing.T) { + t.Parallel() + s, _ := json.Marshal(&lib.SignRequest{ + Key: string(testdata.Pub), + ValidUntil: time.Now().UTC().Add(1 * time.Hour), + }) + req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s)) + resp := httptest.NewRecorder() + ctx := newContext(t) + req.Header.Set("Authorization", "Bearer abcdef") + signHandler(ctx, resp, req) + if resp.Code != http.StatusOK { + t.Error("Unexpected response") + } + r := &lib.SignResponse{} + if err := json.NewDecoder(resp.Body).Decode(r); err != nil { + t.Error(err) + } + if r.Status != "ok" { + t.Error("Unexpected response") + } + k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response)) + if err != nil { + t.Error(err) + } + cert, ok := k.(*ssh.Certificate) + if !ok { + t.Error("Did not receive a certificate") + } + // Revoke the cert and verify + req, _ = http.NewRequest("POST", "/revoke", nil) + req.Form = url.Values{"cert_id": []string{cert.KeyId}} + tok := &oauth2.Token{ + AccessToken: "authenticated", + Expiry: time.Now().Add(1 * time.Hour), + } + ctx.setAuthTokenCookie(resp, req, tok) + revokeCertHandler(ctx, resp, req) + req, _ = http.NewRequest("GET", "/revoked", nil) + listRevokedCertsHandler(ctx, resp, req) + revoked, _ := ioutil.ReadAll(resp.Body) + rl, _ := krl.ParseKRL(revoked) + if !rl.IsRevoked(cert) { + t.Errorf("cert %s was not revoked", cert.KeyId) + } +} diff --git a/server/rpc.go b/server/rpc.go new file mode 100644 index 0000000..ce95e96 --- /dev/null +++ b/server/rpc.go @@ -0,0 +1,68 @@ +package server + +import ( + "log" + "net" + + "golang.org/x/net/context" + + "golang.org/x/oauth2" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/proto" +) + +type rpcServer struct{} + +type key int + +const usernameKey key = 0 + +func (s *rpcServer) Sign(ctx context.Context, req *proto.SignRequest) (*proto.SignResponse, error) { + username, ok := ctx.Value(usernameKey).(string) + if !ok { + return nil, grpc.Errorf(codes.InvalidArgument, "Error reading username") + } + cert, err := keysigner.SignUserKeyFromRPC(req, username) + if err != nil { + return nil, grpc.Errorf(codes.InvalidArgument, err.Error()) + } + if err := certstore.SetCert(cert); err != nil { + log.Printf("Error recording cert: %v", err) + } + resp := &proto.SignResponse{ + Cert: lib.GetPublicKey(cert), + } + return resp, nil +} + +func authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + md, ok := metadata.FromContext(ctx) + if !ok { + return nil, grpc.Errorf(codes.Unauthenticated, "request not authenticated") + } + switch md["security"][0] { + case "authorization": + token := &oauth2.Token{ + AccessToken: md["payload"][0], + } + if !authprovider.Valid(token) { + return nil, grpc.Errorf(codes.PermissionDenied, "access denied") + } + authprovider.Revoke(token) + ctx = context.WithValue(ctx, usernameKey, authprovider.Username(token)) + default: + return nil, grpc.Errorf(codes.InvalidArgument, "unknown argument") + } + return handler(ctx, req) +} + +func runGRPCServer(l net.Listener) { + serv := grpc.NewServer(grpc.UnaryInterceptor(authInterceptor)) + proto.RegisterSignerServer(serv, &rpcServer{}) + serv.Serve(l) +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..676a61b --- /dev/null +++ b/server/server.go @@ -0,0 +1,117 @@ +package server + +import ( + "crypto/tls" + "fmt" + "log" + "net" + + "github.com/pkg/errors" + "github.com/soheilhy/cmux" + + "go4.org/wkfs" + "golang.org/x/crypto/acme/autocert" + + wkfscache "github.com/nsheridan/autocert-wkfs-cache" + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/auth/github" + "github.com/nsheridan/cashier/server/auth/gitlab" + "github.com/nsheridan/cashier/server/auth/google" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/metrics" + "github.com/nsheridan/cashier/server/signer" + "github.com/nsheridan/cashier/server/store" + "github.com/sid77/drop" +) + +var ( + authprovider auth.Provider + certstore store.CertStorer + keysigner *signer.KeySigner +) + +func loadCerts(certFile, keyFile string) (tls.Certificate, error) { + key, err := wkfs.ReadFile(keyFile) + if err != nil { + return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key") + } + cert, err := wkfs.ReadFile(certFile) + if err != nil { + return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate") + } + return tls.X509KeyPair(cert, key) +} + +// Run the HTTP and RPC servers. +func Run(conf *config.Config) { + var err error + keysigner, err = signer.New(conf.SSH) + if err != nil { + log.Fatal(err) + } + + laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port) + l, err := net.Listen("tcp", laddr) + if err != nil { + log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port)) + } + + tlsConfig := &tls.Config{} + if conf.Server.UseTLS { + if conf.Server.LetsEncryptServername != "" { + m := autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: wkfscache.Cache(conf.Server.LetsEncryptCache), + HostPolicy: autocert.HostWhitelist(conf.Server.LetsEncryptServername), + } + tlsConfig.GetCertificate = m.GetCertificate + } else { + if conf.Server.TLSCert == "" || conf.Server.TLSKey == "" { + log.Fatal("TLS cert or key not specified in config") + } + tlsConfig.Certificates = make([]tls.Certificate, 1) + tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey) + if err != nil { + log.Fatal(errors.Wrap(err, "unable to create TLS listener")) + } + } + l = tls.NewListener(l, tlsConfig) + } + + if conf.Server.User != "" { + log.Print("Dropping privileges...") + if err := drop.DropPrivileges(conf.Server.User); err != nil { + log.Fatal(errors.Wrap(err, "unable to drop privileges")) + } + } + + // Unprivileged section + metrics.Register() + + switch conf.Auth.Provider { + case "google": + authprovider, err = google.New(conf.Auth) + case "github": + authprovider, err = github.New(conf.Auth) + case "gitlab": + authprovider, err = gitlab.New(conf.Auth) + default: + log.Fatalf("Unknown provider %s\n", conf.Auth.Provider) + } + if err != nil { + log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider)) + } + + certstore, err = store.New(conf.Server.Database) + if err != nil { + log.Fatal(err) + } + + log.Printf("Starting server on %s", laddr) + cm := cmux.New(l) + httpl := cm.Match(cmux.HTTP1Fast()) + grpcl := cm.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + go runHTTPServer(conf.Server, httpl) + go runGRPCServer(grpcl) + log.Fatal(cm.Serve()) +} diff --git a/server/web.go b/server/web.go new file mode 100644 index 0000000..65eca49 --- /dev/null +++ b/server/web.go @@ -0,0 +1,313 @@ +package server + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "html/template" + "io" + "log" + "net" + "net/http" + "os" + "strconv" + "strings" + + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "golang.org/x/oauth2" + + "github.com/gorilla/csrf" + "github.com/gorilla/handlers" + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/static" + "github.com/nsheridan/cashier/server/templates" +) + +// appContext contains local context - cookiestore, authsession etc. +type appContext struct { + cookiestore *sessions.CookieStore + authsession *auth.Session +} + +// getAuthTokenCookie retrieves a cookie from the request. +func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token { + session, _ := a.cookiestore.Get(r, "session") + t, ok := session.Values["token"] + if !ok { + return nil + } + var tok oauth2.Token + if err := json.Unmarshal(t.([]byte), &tok); err != nil { + return nil + } + if !tok.Valid() { + return nil + } + return &tok +} + +// setAuthTokenCookie marshals the auth token and stores it as a cookie. +func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { + session, _ := a.cookiestore.Get(r, "session") + val, _ := json.Marshal(t) + session.Values["token"] = val + session.Save(r, w) +} + +// getAuthStateCookie retrieves the oauth csrf state value from the client request. +func (a *appContext) getAuthStateCookie(r *http.Request) string { + session, _ := a.cookiestore.Get(r, "session") + state, ok := session.Values["state"] + if !ok { + return "" + } + return state.(string) +} + +// setAuthStateCookie saves the oauth csrf state value. +func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) { + session, _ := a.cookiestore.Get(r, "session") + session.Values["state"] = state + session.Save(r, w) +} + +func (a *appContext) getCurrentURL(r *http.Request) string { + session, _ := a.cookiestore.Get(r, "session") + path, ok := session.Values["auth_url"] + if !ok { + return "" + } + return path.(string) +} + +func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) { + session, _ := a.cookiestore.Get(r, "session") + session.Values["auth_url"] = r.URL.Path + session.Save(r, w) +} + +func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool { + tok := a.getAuthTokenCookie(r) + if !tok.Valid() || !authprovider.Valid(tok) { + return false + } + return true +} + +func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) { + a.setCurrentURL(w, r) + http.Redirect(w, r, "/auth/login", http.StatusSeeOther) + return http.StatusSeeOther, nil +} + +// parseKey retrieves and unmarshals the signing request. +func extractKey(r *http.Request) (*lib.SignRequest, error) { + var s lib.SignRequest + if err := json.NewDecoder(r.Body).Decode(&s); err != nil { + return nil, err + } + return &s, nil +} + +// signHandler handles the "/sign" path. +// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key. +func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + var t string + if ah := r.Header.Get("Authorization"); ah != "" { + if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { + t = ah[7:] + } + } + if t == "" { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + token := &oauth2.Token{ + AccessToken: t, + } + if !authprovider.Valid(token) { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + + // Sign the pubkey and issue the cert. + req, err := extractKey(r) + if err != nil { + return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") + } + username := authprovider.Username(token) + authprovider.Revoke(token) // We don't need this anymore. + cert, err := keysigner.SignUserKey(req, username) + if err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "error signing key") + } + if err := certstore.SetCert(cert); err != nil { + log.Printf("Error recording cert: %v", err) + } + if err := json.NewEncoder(w).Encode(&lib.SignResponse{ + Status: "ok", + Response: string(lib.GetPublicKey(cert)), + }); err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "error encoding response") + } + return http.StatusOK, nil +} + +// loginHandler starts the authentication process with the provider. +func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + state := newState() + a.setAuthStateCookie(w, r, state) + a.authsession = authprovider.StartSession(state) + http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) + return http.StatusFound, nil +} + +// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. +func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if r.FormValue("state") != a.getAuthStateCookie(r) { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + code := r.FormValue("code") + if err := a.authsession.Authorize(authprovider, code); err != nil { + return http.StatusInternalServerError, err + } + a.setAuthTokenCookie(w, r, a.authsession.Token) + http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound) + return http.StatusFound, nil +} + +// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. +func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if !a.isLoggedIn(w, r) { + return a.login(w, r) + } + tok := a.getAuthTokenCookie(r) + page := struct { + Token string + }{tok.AccessToken} + + tmpl := template.Must(template.New("token.html").Parse(templates.Token)) + tmpl.Execute(w, page) + return http.StatusOK, nil +} + +func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + revoked, err := certstore.GetRevoked() + if err != nil { + return http.StatusInternalServerError, err + } + rl, err := keysigner.GenerateRevocationList(revoked) + if err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL") + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(rl) + return http.StatusOK, nil +} + +func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if !a.isLoggedIn(w, r) { + return a.login(w, r) + } + tmpl := template.Must(template.New("certs.html").Parse(templates.Certs)) + tmpl.Execute(w, map[string]interface{}{ + csrf.TemplateTag: csrf.TemplateField(r), + }) + return http.StatusOK, nil +} + +func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if !a.isLoggedIn(w, r) { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all")) + certs, err := certstore.List(includeExpired) + j, err := json.Marshal(certs) + if err != nil { + return http.StatusInternalServerError, errors.New(http.StatusText(http.StatusInternalServerError)) + } + w.Write(j) + return http.StatusOK, nil +} + +func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if !a.isLoggedIn(w, r) { + return a.login(w, r) + } + r.ParseForm() + for _, id := range r.Form["cert_id"] { + if err := certstore.Revoke(id); err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke") + } + } + http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) + return http.StatusSeeOther, nil +} + +// appHandler is a handler which uses appContext to manage state. +type appHandler struct { + *appContext + h func(*appContext, http.ResponseWriter, *http.Request) (int, error) +} + +// ServeHTTP handles the request and writes responses. +func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + status, err := ah.h(ah.appContext, w, r) + if err != nil { + log.Printf("HTTP %d: %q", status, err) + http.Error(w, err.Error(), status) + } +} + +// newState generates a state identifier for the oauth process. +func newState() string { + k := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return "unexpectedstring" + } + return hex.EncodeToString(k) +} + +func runHTTPServer(conf *config.Server, l net.Listener) { + var err error + ctx := &appContext{ + cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)), + } + ctx.cookiestore.Options = &sessions.Options{ + MaxAge: 900, + Path: "/", + Secure: conf.UseTLS, + HttpOnly: true, + } + + logfile := os.Stderr + if conf.HTTPLogFile != "" { + logfile, err = os.OpenFile(conf.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640) + if err != nil { + log.Printf("unable to open %s for writing. logging to stdout", conf.HTTPLogFile) + logfile = os.Stderr + } + } + + CSRF := csrf.Protect([]byte(conf.CSRFSecret), csrf.Secure(conf.UseTLS)) + r := mux.NewRouter() + r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler}) + r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler}) + r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler}) + r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler}) + r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler}) + r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler})) + r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler})) + r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler}) + r.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) + r.PathPrefix("/").Handler(http.FileServer(static.FS(false))) + h := handlers.LoggingHandler(logfile, r) + s := &http.Server{ + Handler: h, + } + s.Serve(l) +} -- cgit v1.2.3