diff options
-rw-r--r-- | cmd/cashierd/handlers_test.go | 139 | ||||
-rw-r--r-- | cmd/cashierd/main.go | 21 | ||||
-rw-r--r-- | server/auth/testprovider/testprovider.go | 56 | ||||
-rw-r--r-- | server/signer/signer.go | 2 |
4 files changed, 205 insertions, 13 deletions
diff --git a/cmd/cashierd/handlers_test.go b/cmd/cashierd/handlers_test.go new file mode 100644 index 0000000..a214dfd --- /dev/null +++ b/cmd/cashierd/handlers_test.go @@ -0,0 +1,139 @@ +package main + +import ( + "bytes" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/oauth2" + + "github.com/gorilla/sessions" + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/server/auth" + "github.com/nsheridan/cashier/server/auth/testprovider" + "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/signer" + "github.com/nsheridan/cashier/server/store" + "github.com/nsheridan/cashier/testdata" +) + +func newContext(t *testing.T) *appContext { + f, err := ioutil.TempFile(os.TempDir(), "signing_key_") + if err != nil { + t.Error(err) + } + defer os.Remove(f.Name()) + f.Write(testdata.Priv) + f.Close() + signer, err := signer.New(&config.SSH{ + SigningKey: f.Name(), + MaxAge: "1h", + }) + if err != nil { + t.Error(err) + } + return &appContext{ + cookiestore: sessions.NewCookieStore([]byte("secret")), + authprovider: testprovider.New(), + certstore: store.NewMemoryStore(), + authsession: &auth.Session{AuthURL: "https://www.example.com/auth"}, + sshKeySigner: signer, + } +} + +func TestLoginHandler(t *testing.T) { + req, _ := http.NewRequest("GET", "/auth/login", nil) + resp := httptest.NewRecorder() + loginHandler(newContext(t), resp, req) + if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" { + t.Error("Unexpected response") + } +} + +func TestCallbackHandler(t *testing.T) { + req, _ := http.NewRequest("GET", "/auth/callback", nil) + req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}} + resp := httptest.NewRecorder() + ctx := newContext(t) + ctx.setAuthStateCookie(resp, req, "state") + callbackHandler(ctx, resp, req) + if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" { + t.Error("Unexpected response") + } +} + +func TestRootHandler(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + ctx := newContext(t) + tok := &oauth2.Token{ + AccessToken: "XXX_TEST_TOKEN_STRING_XXX", + Expiry: time.Now().Add(1 * time.Hour), + } + ctx.setAuthTokenCookie(resp, req, tok) + rootHandler(ctx, resp, req) + if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") { + t.Error("Unable to find token in response") + } +} + +func TestRootHandlerNoSession(t *testing.T) { + req, _ := http.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + ctx := newContext(t) + rootHandler(ctx, resp, req) + if resp.Code != http.StatusSeeOther { + t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther)) + } +} + +func TestSignRevoke(t *testing.T) { + s, _ := json.Marshal(&lib.SignRequest{ + Key: string(testdata.Pub), + }) + req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s)) + resp := httptest.NewRecorder() + ctx := newContext(t) + req.Header.Set("Authorization", "Bearer abcdef") + signHandler(ctx, resp, req) + if resp.Code != http.StatusOK { + t.Error("Unexpected response") + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + r := &lib.SignResponse{} + if err := json.Unmarshal(b, r); err != nil { + t.Error(err) + } + if r.Status != "ok" { + t.Error("Unexpected response") + } + k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response)) + if err != nil { + t.Error(err) + } + cert, ok := k.(*ssh.Certificate) + if !ok { + t.Error("Did not receive a certificate") + } + // Revoke the cert and verify + req, _ = http.NewRequest("POST", "/revoke", nil) + req.Form = url.Values{"cert_id": []string{cert.KeyId}} + revokeCertHandler(ctx, resp, req) + req, _ = http.NewRequest("GET", "/revoked", nil) + revokedCertsHandler(ctx, resp, req) + revoked, _ := ioutil.ReadAll(resp.Body) + if string(revoked[:len(revoked)-1]) != r.Response { + t.Error("omg") + } +} diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go index 1db7d30..31ba104 100644 --- a/cmd/cashierd/main.go +++ b/cmd/cashierd/main.go @@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er // Sign the pubkey and issue the cert. req, err := parseKey(r) - req.Principal = a.authprovider.Username(token) - a.authprovider.Revoke(token) // We don't need this anymore. if err != nil { return http.StatusInternalServerError, err } + req.Principal = a.authprovider.Username(token) + a.authprovider.Revoke(token) // We don't need this anymore. cert, err := a.sshKeySigner.SignUserKey(req) if err != nil { return http.StatusInternalServerError, err @@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request) } func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { - if r.Method == "GET" { - return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed)) - } r.ParseForm() id := r.FormValue("cert_id") if id == "" { @@ -268,7 +265,7 @@ func main() { log.Fatal(err) } fs.Register(&config.AWS) - signer, err := signer.New(config.SSH) + signer, err := signer.New(&config.SSH) if err != nil { log.Fatal(err) } @@ -304,12 +301,12 @@ func main() { } r := mux.NewRouter() - r.Handle("/", appHandler{ctx, rootHandler}) - r.Handle("/auth/login", appHandler{ctx, loginHandler}) - r.Handle("/auth/callback", appHandler{ctx, callbackHandler}) - r.Handle("/sign", appHandler{ctx, signHandler}) - r.Handle("/revoked", appHandler{ctx, revokedCertsHandler}) - r.Handle("/revoke", appHandler{ctx, revokeCertHandler}) + r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler}) + r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler}) + r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler}) + r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler}) + r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler}) + r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler}) logfile := os.Stderr if config.Server.HTTPLogFile != "" { logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660) diff --git a/server/auth/testprovider/testprovider.go b/server/auth/testprovider/testprovider.go new file mode 100644 index 0000000..3d2b13a --- /dev/null +++ b/server/auth/testprovider/testprovider.go @@ -0,0 +1,56 @@ +package testprovider + +import ( + "time" + + "github.com/nsheridan/cashier/server/auth" + + "golang.org/x/oauth2" +) + +const ( + name = "testprovider" +) + +// Config is an implementation of `auth.Provider` for testing. +type Config struct{} + +// New creates a new provider. +func New() auth.Provider { + return &Config{} +} + +// Name returns the name of the provider. +func (c *Config) Name() string { + return name +} + +// Valid validates the oauth token. +func (c *Config) Valid(token *oauth2.Token) bool { + return true +} + +// Revoke disables the access token. +func (c *Config) Revoke(token *oauth2.Token) error { + return nil +} + +// StartSession retrieves an authentication endpoint. +func (c *Config) StartSession(state string) *auth.Session { + return &auth.Session{ + AuthURL: "https://www.example.com/auth", + } +} + +// Exchange authorizes the session and returns an access token. +func (c *Config) Exchange(code string) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "token", + Expiry: time.Now().Add(1 * time.Hour), + }, nil +} + +// Username retrieves the username portion of the user's email address. +func (c *Config) Username(token *oauth2.Token) string { + return "test" +} diff --git a/server/signer/signer.go b/server/signer/signer.go index a3f056a..8169c11 100644 --- a/server/signer/signer.go +++ b/server/signer/signer.go @@ -69,7 +69,7 @@ func makeperms(perms []string) map[string]string { } // New creates a new KeySigner from the supplied configuration. -func New(conf config.SSH) (*KeySigner, error) { +func New(conf *config.SSH) (*KeySigner, error) { data, err := wkfs.ReadFile(conf.SigningKey) if err != nil { return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err) |