diff options
Diffstat (limited to 'cmd')
-rw-r--r-- | cmd/cashier/client_test.go | 99 | ||||
-rw-r--r-- | cmd/cashier/keys.go | 82 | ||||
-rw-r--r-- | cmd/cashier/main.go | 127 | ||||
-rw-r--r-- | cmd/cashierd/main.go | 246 |
4 files changed, 554 insertions, 0 deletions
diff --git a/cmd/cashier/client_test.go b/cmd/cashier/client_test.go new file mode 100644 index 0000000..492f4fc --- /dev/null +++ b/cmd/cashier/client_test.go @@ -0,0 +1,99 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/testdata" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func TestLoadCert(t *testing.T) { + priv, _ := ssh.ParseRawPrivateKey(testdata.Priv) + key := priv.(*rsa.PrivateKey) + pub, _ := ssh.NewPublicKey(&key.PublicKey) + c := &ssh.Certificate{ + Key: pub, + CertType: ssh.UserCert, + ValidBefore: ssh.CertTimeInfinity, + ValidAfter: 0, + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Fatal(err) + } + c.SignCert(rand.Reader, signer) + a := agent.NewKeyring() + if err := installCert(a, c, key); err != nil { + t.Fatal(err) + } + listedKeys, err := a.List() + if err != nil { + t.Fatalf("Error reading from agent: %v", err) + } + if len(listedKeys) != 1 { + t.Fatalf("Expected 1 key, got %d", len(listedKeys)) + } + if !bytes.Equal(listedKeys[0].Marshal(), c.Marshal()) { + t.Fatal("Certs not equal") + } +} + +func TestSignGood(t *testing.T) { + res := &lib.SignResponse{ + Status: "ok", + Response: string(testdata.Cert), + } + j, _ := json.Marshal(res) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, string(j)) + })) + defer ts.Close() + *ca = ts.URL + _, err := send([]byte(`{}`), "token") + if err != nil { + t.Fatal(err) + } + k, _, _, _, err := ssh.ParseAuthorizedKey(testdata.Pub) + if err != nil { + t.Fatal(err) + } + cert, err := sign(k, "token") + if cert == nil && err != nil { + t.Fatal(err) + } +} + +func TestSignBad(t *testing.T) { + res := &lib.SignResponse{ + Status: "error", + Response: `{"response": "error"}`, + } + j, _ := json.Marshal(res) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, string(j)) + })) + defer ts.Close() + *ca = ts.URL + _, err := send([]byte(`{}`), "token") + if err != nil { + t.Fatal(err) + } + k, _, _, _, err := ssh.ParseAuthorizedKey(testdata.Pub) + if err != nil { + t.Fatal(err) + } + cert, err := sign(k, "token") + if cert != nil && err == nil { + t.Fatal(err) + } +} diff --git a/cmd/cashier/keys.go b/cmd/cashier/keys.go new file mode 100644 index 0000000..a2f95e9 --- /dev/null +++ b/cmd/cashier/keys.go @@ -0,0 +1,82 @@ +package main + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "fmt" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" +) + +type key interface{} +type keyfunc func(int) (key, ssh.PublicKey, error) + +var ( + keytypes = map[string]keyfunc{ + "rsa": generateRSAKey, + "ecdsa": generateECDSAKey, + "ed25519": generateED25519Key, + } +) + +func generateED25519Key(bits int) (key, ssh.PublicKey, error) { + p, k, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(p) + if err != nil { + return nil, nil, err + } + return k, pub, nil +} + +func generateRSAKey(bits int) (key, ssh.PublicKey, error) { + k, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(&k.PublicKey) + if err != nil { + return nil, nil, err + } + return k, pub, nil +} + +func generateECDSAKey(bits int) (key, ssh.PublicKey, error) { + var curve elliptic.Curve + switch bits { + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + case 521: + curve = elliptic.P521() + default: + return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'") + } + k, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(&k.PublicKey) + if err != nil { + return nil, nil, err + } + return k, pub, nil +} + +func generateKey(keytype string, bits int) (key, ssh.PublicKey, error) { + f, ok := keytypes[keytype] + if !ok { + var valid []string + for k := range keytypes { + valid = append(valid, k) + } + return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, valid) + } + return f(bits) +} diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go new file mode 100644 index 0000000..8bcc3e7 --- /dev/null +++ b/cmd/cashier/main.go @@ -0,0 +1,127 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "os" + "time" + + "github.com/nsheridan/cashier/lib" + "github.com/pkg/browser" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +var ( + ca = flag.String("ca", "http://localhost:10000", "CA server") + keybits = flag.Int("bits", 2048, "Key size. Ignored for ed25519 keys") + validity = flag.Duration("validity", time.Hour*24, "Key validity") + keytype = flag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519") +) + +func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { + pubcert := agent.AddedKey{ + PrivateKey: key, + Certificate: cert, + Comment: cert.KeyId, + } + if err := a.Add(pubcert); err != nil { + return fmt.Errorf("error importing certificate: %s", err) + } + return nil +} + +func send(s []byte, token string) (*lib.SignResponse, error) { + req, err := http.NewRequest("POST", *ca+"/sign", bytes.NewReader(s)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Bad response from server: %s", resp.Status) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + c := &lib.SignResponse{} + if err := json.Unmarshal(body, c); err != nil { + return nil, err + } + return c, nil +} + +func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) { + marshaled := ssh.MarshalAuthorizedKey(pub) + marshaled = marshaled[:len(marshaled)-1] + s, err := json.Marshal(&lib.SignRequest{ + Key: string(marshaled), + ValidUntil: time.Now().Add(*validity), + }) + if err != nil { + return nil, err + } + resp, err := send(s, token) + if err != nil { + return nil, err + } + if resp.Status != "ok" { + return nil, fmt.Errorf("error: %s", resp.Response) + } + k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) + if err != nil { + return nil, err + } + cert, ok := k.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("did not receive a certificate from server") + } + return cert, nil +} + +func main() { + flag.Parse() + + fmt.Printf("Your browser has been opened to visit %s\n", *ca) + if err := browser.OpenURL(*ca); err != nil { + fmt.Println("Error launching web browser. Go to the link in your web browser") + } + fmt.Println("Generating new key pair") + priv, pub, err := generateKey(*keytype, *keybits) + if err != nil { + log.Fatalln("Error generating key pair: ", err) + } + + fmt.Print("Enter token: ") + var token string + fmt.Scanln(&token) + + cert, err := sign(pub, token) + if err != nil { + log.Fatalln(err) + } + sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err != nil { + log.Fatalln("Error connecting to agent: %s", err) + } + defer sock.Close() + a := agent.NewClient(sock) + if err := installCert(a, cert, priv); err != nil { + log.Fatalln(err) + } + fmt.Println("Certificate added.") +} diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go new file mode 100644 index 0000000..bc460da --- /dev/null +++ b/cmd/cashierd/main.go @@ -0,0 +1,246 @@ +package main + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "flag" + "fmt" + "html/template" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "strings" + "time" + + "golang.org/x/oauth2" + + "github.com/gorilla/mux" + "github.com/gorilla/sessions" + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/auth/github" + "github.com/nsheridan/cashier/server/auth/google" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/signer" +) + +var ( + cfg = flag.String("config_file", "config.json", "Path to configuration file.") +) + +// appContext contains local context - cookiestore, authprovider, authsession, templates etc. +type appContext struct { + cookiestore *sessions.CookieStore + authprovider auth.Provider + authsession *auth.Session + views *template.Template + sshKeySigner *signer.KeySigner +} + +// getAuthCookie retrieves a cookie from the request and validates it. +func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { + session, _ := a.cookiestore.Get(r, "tok") + t, ok := session.Values["token"] + if !ok { + return nil + } + var tok oauth2.Token + if err := json.Unmarshal(t.([]byte), &tok); err != nil { + return nil + } + if !tok.Valid() { + return nil + } + return &tok +} + +// setAuthCookie marshals the auth token and stores it as a cookie. +func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { + session, _ := a.cookiestore.Get(r, "tok") + val, _ := json.Marshal(t) + session.Values["token"] = val + session.Save(r, w) +} + +// parseKey retrieves and unmarshals the signing request. +func parseKey(r *http.Request) (*lib.SignRequest, error) { + var s lib.SignRequest + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + if err := json.Unmarshal(body, &s); err != nil { + return nil, err + } + return &s, nil +} + +// signHandler handles the "/sign" path. +// It unmarshals the client token to an oauth token, validates it and signs the provided public ssh key. +func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + var t string + if ah := r.Header.Get("Authorization"); ah != "" { + if len(ah) > 6 && strings.ToUpper(ah[0:7]) == "BEARER " { + t = ah[7:] + } + } + if t == "" { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + token := &oauth2.Token{ + AccessToken: t, + } + ok := a.authprovider.Valid(token) + if !ok { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + + // Sign the pubkey and issue the cert. + req, err := parseKey(r) + req.Principal = a.authprovider.Username(token) + a.authprovider.Revoke(token) // We don't need this anymore. + if err != nil { + return http.StatusInternalServerError, err + } + signed, err := a.sshKeySigner.SignUserKey(req) + if err != nil { + return http.StatusInternalServerError, err + } + json.NewEncoder(w).Encode(&lib.SignResponse{ + Status: "ok", + Response: signed, + }) + return http.StatusOK, nil +} + +// loginHandler starts the authentication process with the provider. +func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + a.authsession = a.authprovider.StartSession(newState()) + http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) + return http.StatusFound, nil +} + +// callbackHandler handles retrieving the access token from the auth provider and saves it for later use. +func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + if r.FormValue("state") != a.authsession.State { + return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) + } + code := r.FormValue("code") + if err := a.authsession.Authorize(a.authprovider, code); err != nil { + return http.StatusInternalServerError, err + } + // Github tokens don't have an expiry. Set one so that the session expires + // after a period. + if a.authsession.Token.Expiry.Unix() <= 0 { + a.authsession.Token.Expiry = time.Now().Add(1 * time.Hour) + } + a.setAuthCookie(w, r, a.authsession.Token) + http.Redirect(w, r, "/", http.StatusFound) + return http.StatusFound, nil +} + +// rootHandler starts the auth process. If the client is authenticated it renders the token to the user. +func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { + tok := a.getAuthCookie(r) + if !tok.Valid() || !a.authprovider.Valid(tok) { + http.Redirect(w, r, "/auth/login", http.StatusSeeOther) + return http.StatusSeeOther, nil + } + page := struct { + Token string + }{tok.AccessToken} + a.views.ExecuteTemplate(w, "token.html", page) + return http.StatusOK, nil +} + +// appHandler is a handler which uses appContext to manage state. +type appHandler struct { + *appContext + h func(*appContext, http.ResponseWriter, *http.Request) (int, error) +} + +// ServeHTTP handles the request and writes responses. +func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + status, err := ah.h(ah.appContext, w, r) + if err != nil { + log.Printf("HTTP %d: %q", status, err) + switch status { + case http.StatusNotFound: + http.NotFound(w, r) + case http.StatusInternalServerError: + http.Error(w, http.StatusText(status), status) + default: + http.Error(w, http.StatusText(status), status) + } + } +} + +// newState generates a state identifier for the oauth process. +func newState() string { + k := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, k); err != nil { + return "unexpectedstring" + } + return hex.EncodeToString(k) +} + +func readConfig(filename string) (*config.Config, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + return config.ReadConfig(f) +} + +func main() { + flag.Parse() + config, err := readConfig(*cfg) + if err != nil { + log.Fatal(err) + } + signer, err := signer.New(config.SSH) + if err != nil { + log.Fatal(err) + } + + var authprovider auth.Provider + switch config.Auth.Provider { + case "google": + authprovider = google.New(&config.Auth) + case "github": + authprovider = github.New(&config.Auth) + default: + log.Fatalln("Unknown provider %s", config.Auth.Provider) + } + + ctx := &appContext{ + cookiestore: sessions.NewCookieStore([]byte(config.Server.CookieSecret)), + authprovider: authprovider, + views: template.Must(template.ParseGlob("templates/*")), + sshKeySigner: signer, + } + ctx.cookiestore.Options = &sessions.Options{ + MaxAge: 900, + Path: "/", + Secure: config.Server.UseTLS, + HttpOnly: true, + } + + m := mux.NewRouter() + m.Handle("/", appHandler{ctx, rootHandler}) + m.Handle("/auth/login", appHandler{ctx, loginHandler}) + m.Handle("/auth/callback", appHandler{ctx, callbackHandler}) + m.Handle("/sign", appHandler{ctx, signHandler}) + + fmt.Println("Starting server...") + l := fmt.Sprintf(":%d", config.Server.Port) + if config.Server.UseTLS { + log.Fatal(http.ListenAndServeTLS(l, config.Server.TLSCert, config.Server.TLSKey, m)) + } + log.Fatal(http.ListenAndServe(l, m)) +} |