aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2017-02-19 23:28:33 +0000
committerNiall Sheridan <nsheridan@gmail.com>2017-02-20 22:13:56 +0000
commitfb830dc3531904be0a58e2c4dd4638b390bbdab2 (patch)
treefa9dc298dc7463be55d66ea855d82b9d111382fe
parenteb57eaf30965ba24ff669d6f9c8d11cd24951777 (diff)
Split the servers out of main
-rw-r--r--cmd/cashierd/main.go404
-rw-r--r--server/handlers_test.go (renamed from cmd/cashierd/handlers_test.go)2
-rw-r--r--server/rpc.go (renamed from cmd/cashierd/rpc.go)4
-rw-r--r--server/server.go117
-rw-r--r--server/web.go313
5 files changed, 436 insertions, 404 deletions
diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go
index d355604..2e378bc 100644
--- a/cmd/cashierd/main.go
+++ b/cmd/cashierd/main.go
@@ -1,315 +1,20 @@
package main
import (
- "crypto/rand"
- "crypto/tls"
- "encoding/hex"
- "encoding/json"
"flag"
- "fmt"
- "html/template"
- "io"
"log"
- "net"
- "net/http"
- "os"
- "strconv"
- "strings"
- "github.com/pkg/errors"
- "github.com/soheilhy/cmux"
-
- "go4.org/wkfs"
- "golang.org/x/crypto/acme/autocert"
- "golang.org/x/oauth2"
-
- "github.com/gorilla/csrf"
- "github.com/gorilla/handlers"
- "github.com/gorilla/mux"
- "github.com/gorilla/sessions"
- wkfscache "github.com/nsheridan/autocert-wkfs-cache"
- "github.com/nsheridan/cashier/lib"
- "github.com/nsheridan/cashier/server/auth"
- "github.com/nsheridan/cashier/server/auth/github"
- "github.com/nsheridan/cashier/server/auth/gitlab"
- "github.com/nsheridan/cashier/server/auth/google"
+ "github.com/nsheridan/cashier/server"
"github.com/nsheridan/cashier/server/config"
- "github.com/nsheridan/cashier/server/metrics"
- "github.com/nsheridan/cashier/server/signer"
- "github.com/nsheridan/cashier/server/static"
- "github.com/nsheridan/cashier/server/store"
- "github.com/nsheridan/cashier/server/templates"
"github.com/nsheridan/cashier/server/wkfs/vaultfs"
"github.com/nsheridan/wkfs/s3"
- "github.com/prometheus/client_golang/prometheus/promhttp"
- "github.com/sid77/drop"
)
var (
cfg = flag.String("config_file", "cashierd.conf", "Path to configuration file.")
-
- authprovider auth.Provider
- certstore store.CertStorer
- keysigner *signer.KeySigner
)
-// appContext contains local context - cookiestore, authsession etc.
-type appContext struct {
- cookiestore *sessions.CookieStore
- authsession *auth.Session
-}
-
-// getAuthTokenCookie retrieves a cookie from the request.
-func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token {
- session, _ := a.cookiestore.Get(r, "session")
- t, ok := session.Values["token"]
- if !ok {
- return nil
- }
- var tok oauth2.Token
- if err := json.Unmarshal(t.([]byte), &tok); err != nil {
- return nil
- }
- if !tok.Valid() {
- return nil
- }
- return &tok
-}
-
-// setAuthTokenCookie marshals the auth token and stores it as a cookie.
-func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
- session, _ := a.cookiestore.Get(r, "session")
- val, _ := json.Marshal(t)
- session.Values["token"] = val
- session.Save(r, w)
-}
-
-// getAuthStateCookie retrieves the oauth csrf state value from the client request.
-func (a *appContext) getAuthStateCookie(r *http.Request) string {
- session, _ := a.cookiestore.Get(r, "session")
- state, ok := session.Values["state"]
- if !ok {
- return ""
- }
- return state.(string)
-}
-
-// setAuthStateCookie saves the oauth csrf state value.
-func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) {
- session, _ := a.cookiestore.Get(r, "session")
- session.Values["state"] = state
- session.Save(r, w)
-}
-
-func (a *appContext) getCurrentURL(r *http.Request) string {
- session, _ := a.cookiestore.Get(r, "session")
- path, ok := session.Values["auth_url"]
- if !ok {
- return ""
- }
- return path.(string)
-}
-
-func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) {
- session, _ := a.cookiestore.Get(r, "session")
- session.Values["auth_url"] = r.URL.Path
- session.Save(r, w)
-}
-
-func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
- tok := a.getAuthTokenCookie(r)
- if !tok.Valid() || !authprovider.Valid(tok) {
- return false
- }
- return true
-}
-
-func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) {
- a.setCurrentURL(w, r)
- http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
- return http.StatusSeeOther, nil
-}
-
-// parseKey retrieves and unmarshals the signing request.
-func extractKey(r *http.Request) (*lib.SignRequest, error) {
- var s lib.SignRequest
- if err := json.NewDecoder(r.Body).Decode(&s); err != nil {
- return nil, err
- }
- return &s, nil
-}
-
-// signHandler handles the "/sign" path.
-// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key.
-func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- var t string
- if ah := r.Header.Get("Authorization"); ah != "" {
- if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " {
- t = ah[7:]
- }
- }
- if t == "" {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- token := &oauth2.Token{
- AccessToken: t,
- }
- if !authprovider.Valid(token) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
-
- // Sign the pubkey and issue the cert.
- req, err := extractKey(r)
- if err != nil {
- return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
- }
- username := authprovider.Username(token)
- authprovider.Revoke(token) // We don't need this anymore.
- cert, err := keysigner.SignUserKey(req, username)
- if err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "error signing key")
- }
- if err := certstore.SetCert(cert); err != nil {
- log.Printf("Error recording cert: %v", err)
- }
- if err := json.NewEncoder(w).Encode(&lib.SignResponse{
- Status: "ok",
- Response: string(lib.GetPublicKey(cert)),
- }); err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "error encoding response")
- }
- return http.StatusOK, nil
-}
-
-// loginHandler starts the authentication process with the provider.
-func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- state := newState()
- a.setAuthStateCookie(w, r, state)
- a.authsession = authprovider.StartSession(state)
- http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
- return http.StatusFound, nil
-}
-
-// callbackHandler handles retrieving the access token from the auth provider and saves it for later use.
-func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if r.FormValue("state") != a.getAuthStateCookie(r) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- code := r.FormValue("code")
- if err := a.authsession.Authorize(authprovider, code); err != nil {
- return http.StatusInternalServerError, err
- }
- a.setAuthTokenCookie(w, r, a.authsession.Token)
- http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound)
- return http.StatusFound, nil
-}
-
-// rootHandler starts the auth process. If the client is authenticated it renders the token to the user.
-func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- tok := a.getAuthTokenCookie(r)
- page := struct {
- Token string
- }{tok.AccessToken}
-
- tmpl := template.Must(template.New("token.html").Parse(templates.Token))
- tmpl.Execute(w, page)
- return http.StatusOK, nil
-}
-
-func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- revoked, err := certstore.GetRevoked()
- if err != nil {
- return http.StatusInternalServerError, err
- }
- rl, err := keysigner.GenerateRevocationList(revoked)
- if err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL")
- }
- w.Header().Set("Content-Type", "application/octet-stream")
- w.Write(rl)
- return http.StatusOK, nil
-}
-
-func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- tmpl := template.Must(template.New("certs.html").Parse(templates.Certs))
- tmpl.Execute(w, map[string]interface{}{
- csrf.TemplateTag: csrf.TemplateField(r),
- })
- return http.StatusOK, nil
-}
-
-func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
- }
- includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all"))
- certs, err := certstore.List(includeExpired)
- j, err := json.Marshal(certs)
- if err != nil {
- return http.StatusInternalServerError, errors.New(http.StatusText(http.StatusInternalServerError))
- }
- w.Write(j)
- return http.StatusOK, nil
-}
-
-func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if !a.isLoggedIn(w, r) {
- return a.login(w, r)
- }
- r.ParseForm()
- for _, id := range r.Form["cert_id"] {
- if err := certstore.Revoke(id); err != nil {
- return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke")
- }
- }
- http.Redirect(w, r, "/admin/certs", http.StatusSeeOther)
- return http.StatusSeeOther, nil
-}
-
-// appHandler is a handler which uses appContext to manage state.
-type appHandler struct {
- *appContext
- h func(*appContext, http.ResponseWriter, *http.Request) (int, error)
-}
-
-// ServeHTTP handles the request and writes responses.
-func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- status, err := ah.h(ah.appContext, w, r)
- if err != nil {
- log.Printf("HTTP %d: %q", status, err)
- http.Error(w, err.Error(), status)
- }
-}
-
-// newState generates a state identifier for the oauth process.
-func newState() string {
- k := make([]byte, 32)
- if _, err := io.ReadFull(rand.Reader, k); err != nil {
- return "unexpectedstring"
- }
- return hex.EncodeToString(k)
-}
-
-func loadCerts(certFile, keyFile string) (tls.Certificate, error) {
- key, err := wkfs.ReadFile(keyFile)
- if err != nil {
- return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key")
- }
- cert, err := wkfs.ReadFile(certFile)
- if err != nil {
- return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate")
- }
- return tls.X509KeyPair(cert, key)
-}
-
func main() {
- // Privileged section
flag.Parse()
conf, err := config.ReadConfig(*cfg)
if err != nil {
@@ -327,109 +32,6 @@ func main() {
})
vaultfs.Register(conf.Vault)
- keysigner, err = signer.New(conf.SSH)
- if err != nil {
- log.Fatal(err)
- }
-
- logfile := os.Stderr
- if conf.Server.HTTPLogFile != "" {
- logfile, err = os.OpenFile(conf.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
- if err != nil {
- log.Printf("unable to open %s for writing. logging to stdout", conf.Server.HTTPLogFile)
- logfile = os.Stderr
- }
- }
-
- laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port)
- l, err := net.Listen("tcp", laddr)
- if err != nil {
- log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port))
- }
-
- tlsConfig := &tls.Config{}
- if conf.Server.UseTLS {
- if conf.Server.LetsEncryptServername != "" {
- m := autocert.Manager{
- Prompt: autocert.AcceptTOS,
- Cache: wkfscache.Cache(conf.Server.LetsEncryptCache),
- HostPolicy: autocert.HostWhitelist(conf.Server.LetsEncryptServername),
- }
- tlsConfig.GetCertificate = m.GetCertificate
- } else {
- if conf.Server.TLSCert == "" || conf.Server.TLSKey == "" {
- log.Fatal("TLS cert or key not specified in config")
- }
- tlsConfig.Certificates = make([]tls.Certificate, 1)
- tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey)
- if err != nil {
- log.Fatal(errors.Wrap(err, "unable to create TLS listener"))
- }
- }
- l = tls.NewListener(l, tlsConfig)
- }
-
- if conf.Server.User != "" {
- log.Print("Dropping privileges...")
- if err := drop.DropPrivileges(conf.Server.User); err != nil {
- log.Fatal(errors.Wrap(err, "unable to drop privileges"))
- }
- }
-
- // Unprivileged section
- metrics.Register()
-
- switch conf.Auth.Provider {
- case "google":
- authprovider, err = google.New(conf.Auth)
- case "github":
- authprovider, err = github.New(conf.Auth)
- case "gitlab":
- authprovider, err = gitlab.New(conf.Auth)
- default:
- log.Fatalf("Unknown provider %s\n", conf.Auth.Provider)
- }
- if err != nil {
- log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider))
- }
-
- certstore, err = store.New(conf.Server.Database)
- if err != nil {
- log.Fatal(err)
- }
- ctx := &appContext{
- cookiestore: sessions.NewCookieStore([]byte(conf.Server.CookieSecret)),
- }
- ctx.cookiestore.Options = &sessions.Options{
- MaxAge: 900,
- Path: "/",
- Secure: conf.Server.UseTLS,
- HttpOnly: true,
- }
-
- CSRF := csrf.Protect([]byte(conf.Server.CSRFSecret), csrf.Secure(conf.Server.UseTLS))
- r := mux.NewRouter()
- r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
- r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
- r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
- r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
- r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler})
- r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler}))
- r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler}))
- r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
- r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
- r.PathPrefix("/").Handler(http.FileServer(static.FS(false)))
- h := handlers.LoggingHandler(logfile, r)
-
- log.Printf("Starting server on %s", laddr)
- s := &http.Server{
- Handler: h,
- }
-
- cm := cmux.New(l)
- httpl := cm.Match(cmux.HTTP1Fast())
- grpcl := cm.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
- go s.Serve(httpl)
- go newGrpcServer(grpcl)
- log.Fatal(cm.Serve())
+ // Start the servers
+ server.Run(conf)
}
diff --git a/cmd/cashierd/handlers_test.go b/server/handlers_test.go
index 934d5d0..b2646a3 100644
--- a/cmd/cashierd/handlers_test.go
+++ b/server/handlers_test.go
@@ -1,4 +1,4 @@
-package main
+package server
import (
"bytes"
diff --git a/cmd/cashierd/rpc.go b/server/rpc.go
index ad8aa5d..ce95e96 100644
--- a/cmd/cashierd/rpc.go
+++ b/server/rpc.go
@@ -1,4 +1,4 @@
-package main
+package server
import (
"log"
@@ -61,7 +61,7 @@ func authInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServe
return handler(ctx, req)
}
-func newGrpcServer(l net.Listener) {
+func runGRPCServer(l net.Listener) {
serv := grpc.NewServer(grpc.UnaryInterceptor(authInterceptor))
proto.RegisterSignerServer(serv, &rpcServer{})
serv.Serve(l)
diff --git a/server/server.go b/server/server.go
new file mode 100644
index 0000000..676a61b
--- /dev/null
+++ b/server/server.go
@@ -0,0 +1,117 @@
+package server
+
+import (
+ "crypto/tls"
+ "fmt"
+ "log"
+ "net"
+
+ "github.com/pkg/errors"
+ "github.com/soheilhy/cmux"
+
+ "go4.org/wkfs"
+ "golang.org/x/crypto/acme/autocert"
+
+ wkfscache "github.com/nsheridan/autocert-wkfs-cache"
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/auth/github"
+ "github.com/nsheridan/cashier/server/auth/gitlab"
+ "github.com/nsheridan/cashier/server/auth/google"
+ "github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/metrics"
+ "github.com/nsheridan/cashier/server/signer"
+ "github.com/nsheridan/cashier/server/store"
+ "github.com/sid77/drop"
+)
+
+var (
+ authprovider auth.Provider
+ certstore store.CertStorer
+ keysigner *signer.KeySigner
+)
+
+func loadCerts(certFile, keyFile string) (tls.Certificate, error) {
+ key, err := wkfs.ReadFile(keyFile)
+ if err != nil {
+ return tls.Certificate{}, errors.Wrap(err, "error reading TLS private key")
+ }
+ cert, err := wkfs.ReadFile(certFile)
+ if err != nil {
+ return tls.Certificate{}, errors.Wrap(err, "error reading TLS certificate")
+ }
+ return tls.X509KeyPair(cert, key)
+}
+
+// Run the HTTP and RPC servers.
+func Run(conf *config.Config) {
+ var err error
+ keysigner, err = signer.New(conf.SSH)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ laddr := fmt.Sprintf("%s:%d", conf.Server.Addr, conf.Server.Port)
+ l, err := net.Listen("tcp", laddr)
+ if err != nil {
+ log.Fatal(errors.Wrapf(err, "unable to listen on %s:%d", conf.Server.Addr, conf.Server.Port))
+ }
+
+ tlsConfig := &tls.Config{}
+ if conf.Server.UseTLS {
+ if conf.Server.LetsEncryptServername != "" {
+ m := autocert.Manager{
+ Prompt: autocert.AcceptTOS,
+ Cache: wkfscache.Cache(conf.Server.LetsEncryptCache),
+ HostPolicy: autocert.HostWhitelist(conf.Server.LetsEncryptServername),
+ }
+ tlsConfig.GetCertificate = m.GetCertificate
+ } else {
+ if conf.Server.TLSCert == "" || conf.Server.TLSKey == "" {
+ log.Fatal("TLS cert or key not specified in config")
+ }
+ tlsConfig.Certificates = make([]tls.Certificate, 1)
+ tlsConfig.Certificates[0], err = loadCerts(conf.Server.TLSCert, conf.Server.TLSKey)
+ if err != nil {
+ log.Fatal(errors.Wrap(err, "unable to create TLS listener"))
+ }
+ }
+ l = tls.NewListener(l, tlsConfig)
+ }
+
+ if conf.Server.User != "" {
+ log.Print("Dropping privileges...")
+ if err := drop.DropPrivileges(conf.Server.User); err != nil {
+ log.Fatal(errors.Wrap(err, "unable to drop privileges"))
+ }
+ }
+
+ // Unprivileged section
+ metrics.Register()
+
+ switch conf.Auth.Provider {
+ case "google":
+ authprovider, err = google.New(conf.Auth)
+ case "github":
+ authprovider, err = github.New(conf.Auth)
+ case "gitlab":
+ authprovider, err = gitlab.New(conf.Auth)
+ default:
+ log.Fatalf("Unknown provider %s\n", conf.Auth.Provider)
+ }
+ if err != nil {
+ log.Fatal(errors.Wrapf(err, "unable to use provider '%s'", conf.Auth.Provider))
+ }
+
+ certstore, err = store.New(conf.Server.Database)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ log.Printf("Starting server on %s", laddr)
+ cm := cmux.New(l)
+ httpl := cm.Match(cmux.HTTP1Fast())
+ grpcl := cm.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
+ go runHTTPServer(conf.Server, httpl)
+ go runGRPCServer(grpcl)
+ log.Fatal(cm.Serve())
+}
diff --git a/server/web.go b/server/web.go
new file mode 100644
index 0000000..65eca49
--- /dev/null
+++ b/server/web.go
@@ -0,0 +1,313 @@
+package server
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "html/template"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+
+ "github.com/pkg/errors"
+ "github.com/prometheus/client_golang/prometheus/promhttp"
+
+ "golang.org/x/oauth2"
+
+ "github.com/gorilla/csrf"
+ "github.com/gorilla/handlers"
+ "github.com/gorilla/mux"
+ "github.com/gorilla/sessions"
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/static"
+ "github.com/nsheridan/cashier/server/templates"
+)
+
+// appContext contains local context - cookiestore, authsession etc.
+type appContext struct {
+ cookiestore *sessions.CookieStore
+ authsession *auth.Session
+}
+
+// getAuthTokenCookie retrieves a cookie from the request.
+func (a *appContext) getAuthTokenCookie(r *http.Request) *oauth2.Token {
+ session, _ := a.cookiestore.Get(r, "session")
+ t, ok := session.Values["token"]
+ if !ok {
+ return nil
+ }
+ var tok oauth2.Token
+ if err := json.Unmarshal(t.([]byte), &tok); err != nil {
+ return nil
+ }
+ if !tok.Valid() {
+ return nil
+ }
+ return &tok
+}
+
+// setAuthTokenCookie marshals the auth token and stores it as a cookie.
+func (a *appContext) setAuthTokenCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
+ session, _ := a.cookiestore.Get(r, "session")
+ val, _ := json.Marshal(t)
+ session.Values["token"] = val
+ session.Save(r, w)
+}
+
+// getAuthStateCookie retrieves the oauth csrf state value from the client request.
+func (a *appContext) getAuthStateCookie(r *http.Request) string {
+ session, _ := a.cookiestore.Get(r, "session")
+ state, ok := session.Values["state"]
+ if !ok {
+ return ""
+ }
+ return state.(string)
+}
+
+// setAuthStateCookie saves the oauth csrf state value.
+func (a *appContext) setAuthStateCookie(w http.ResponseWriter, r *http.Request, state string) {
+ session, _ := a.cookiestore.Get(r, "session")
+ session.Values["state"] = state
+ session.Save(r, w)
+}
+
+func (a *appContext) getCurrentURL(r *http.Request) string {
+ session, _ := a.cookiestore.Get(r, "session")
+ path, ok := session.Values["auth_url"]
+ if !ok {
+ return ""
+ }
+ return path.(string)
+}
+
+func (a *appContext) setCurrentURL(w http.ResponseWriter, r *http.Request) {
+ session, _ := a.cookiestore.Get(r, "session")
+ session.Values["auth_url"] = r.URL.Path
+ session.Save(r, w)
+}
+
+func (a *appContext) isLoggedIn(w http.ResponseWriter, r *http.Request) bool {
+ tok := a.getAuthTokenCookie(r)
+ if !tok.Valid() || !authprovider.Valid(tok) {
+ return false
+ }
+ return true
+}
+
+func (a *appContext) login(w http.ResponseWriter, r *http.Request) (int, error) {
+ a.setCurrentURL(w, r)
+ http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
+ return http.StatusSeeOther, nil
+}
+
+// parseKey retrieves and unmarshals the signing request.
+func extractKey(r *http.Request) (*lib.SignRequest, error) {
+ var s lib.SignRequest
+ if err := json.NewDecoder(r.Body).Decode(&s); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+// signHandler handles the "/sign" path.
+// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key.
+func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ var t string
+ if ah := r.Header.Get("Authorization"); ah != "" {
+ if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " {
+ t = ah[7:]
+ }
+ }
+ if t == "" {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ token := &oauth2.Token{
+ AccessToken: t,
+ }
+ if !authprovider.Valid(token) {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+
+ // Sign the pubkey and issue the cert.
+ req, err := extractKey(r)
+ if err != nil {
+ return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
+ }
+ username := authprovider.Username(token)
+ authprovider.Revoke(token) // We don't need this anymore.
+ cert, err := keysigner.SignUserKey(req, username)
+ if err != nil {
+ return http.StatusInternalServerError, errors.Wrap(err, "error signing key")
+ }
+ if err := certstore.SetCert(cert); err != nil {
+ log.Printf("Error recording cert: %v", err)
+ }
+ if err := json.NewEncoder(w).Encode(&lib.SignResponse{
+ Status: "ok",
+ Response: string(lib.GetPublicKey(cert)),
+ }); err != nil {
+ return http.StatusInternalServerError, errors.Wrap(err, "error encoding response")
+ }
+ return http.StatusOK, nil
+}
+
+// loginHandler starts the authentication process with the provider.
+func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ state := newState()
+ a.setAuthStateCookie(w, r, state)
+ a.authsession = authprovider.StartSession(state)
+ http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
+ return http.StatusFound, nil
+}
+
+// callbackHandler handles retrieving the access token from the auth provider and saves it for later use.
+func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if r.FormValue("state") != a.getAuthStateCookie(r) {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ code := r.FormValue("code")
+ if err := a.authsession.Authorize(authprovider, code); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ a.setAuthTokenCookie(w, r, a.authsession.Token)
+ http.Redirect(w, r, a.getCurrentURL(r), http.StatusFound)
+ return http.StatusFound, nil
+}
+
+// rootHandler starts the auth process. If the client is authenticated it renders the token to the user.
+func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if !a.isLoggedIn(w, r) {
+ return a.login(w, r)
+ }
+ tok := a.getAuthTokenCookie(r)
+ page := struct {
+ Token string
+ }{tok.AccessToken}
+
+ tmpl := template.Must(template.New("token.html").Parse(templates.Token))
+ tmpl.Execute(w, page)
+ return http.StatusOK, nil
+}
+
+func listRevokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ revoked, err := certstore.GetRevoked()
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ rl, err := keysigner.GenerateRevocationList(revoked)
+ if err != nil {
+ return http.StatusInternalServerError, errors.Wrap(err, "unable to generate KRL")
+ }
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Write(rl)
+ return http.StatusOK, nil
+}
+
+func listAllCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if !a.isLoggedIn(w, r) {
+ return a.login(w, r)
+ }
+ tmpl := template.Must(template.New("certs.html").Parse(templates.Certs))
+ tmpl.Execute(w, map[string]interface{}{
+ csrf.TemplateTag: csrf.TemplateField(r),
+ })
+ return http.StatusOK, nil
+}
+
+func listCertsJSONHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if !a.isLoggedIn(w, r) {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ includeExpired, _ := strconv.ParseBool(r.URL.Query().Get("all"))
+ certs, err := certstore.List(includeExpired)
+ j, err := json.Marshal(certs)
+ if err != nil {
+ return http.StatusInternalServerError, errors.New(http.StatusText(http.StatusInternalServerError))
+ }
+ w.Write(j)
+ return http.StatusOK, nil
+}
+
+func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if !a.isLoggedIn(w, r) {
+ return a.login(w, r)
+ }
+ r.ParseForm()
+ for _, id := range r.Form["cert_id"] {
+ if err := certstore.Revoke(id); err != nil {
+ return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke")
+ }
+ }
+ http.Redirect(w, r, "/admin/certs", http.StatusSeeOther)
+ return http.StatusSeeOther, nil
+}
+
+// appHandler is a handler which uses appContext to manage state.
+type appHandler struct {
+ *appContext
+ h func(*appContext, http.ResponseWriter, *http.Request) (int, error)
+}
+
+// ServeHTTP handles the request and writes responses.
+func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ status, err := ah.h(ah.appContext, w, r)
+ if err != nil {
+ log.Printf("HTTP %d: %q", status, err)
+ http.Error(w, err.Error(), status)
+ }
+}
+
+// newState generates a state identifier for the oauth process.
+func newState() string {
+ k := make([]byte, 32)
+ if _, err := io.ReadFull(rand.Reader, k); err != nil {
+ return "unexpectedstring"
+ }
+ return hex.EncodeToString(k)
+}
+
+func runHTTPServer(conf *config.Server, l net.Listener) {
+ var err error
+ ctx := &appContext{
+ cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
+ }
+ ctx.cookiestore.Options = &sessions.Options{
+ MaxAge: 900,
+ Path: "/",
+ Secure: conf.UseTLS,
+ HttpOnly: true,
+ }
+
+ logfile := os.Stderr
+ if conf.HTTPLogFile != "" {
+ logfile, err = os.OpenFile(conf.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0640)
+ if err != nil {
+ log.Printf("unable to open %s for writing. logging to stdout", conf.HTTPLogFile)
+ logfile = os.Stderr
+ }
+ }
+
+ CSRF := csrf.Protect([]byte(conf.CSRFSecret), csrf.Secure(conf.UseTLS))
+ r := mux.NewRouter()
+ r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
+ r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
+ r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
+ r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
+ r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, listRevokedCertsHandler})
+ r.Methods("POST").Path("/admin/revoke").Handler(CSRF(appHandler{ctx, revokeCertHandler}))
+ r.Methods("GET").Path("/admin/certs").Handler(CSRF(appHandler{ctx, listAllCertsHandler}))
+ r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
+ r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
+ r.PathPrefix("/").Handler(http.FileServer(static.FS(false)))
+ h := handlers.LoggingHandler(logfile, r)
+ s := &http.Server{
+ Handler: h,
+ }
+ s.Serve(l)
+}