aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/cashierd/handlers_test.go139
-rw-r--r--cmd/cashierd/main.go21
-rw-r--r--server/auth/testprovider/testprovider.go56
-rw-r--r--server/signer/signer.go2
4 files changed, 205 insertions, 13 deletions
diff --git a/cmd/cashierd/handlers_test.go b/cmd/cashierd/handlers_test.go
new file mode 100644
index 0000000..a214dfd
--- /dev/null
+++ b/cmd/cashierd/handlers_test.go
@@ -0,0 +1,139 @@
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/oauth2"
+
+ "github.com/gorilla/sessions"
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/auth/testprovider"
+ "github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/signer"
+ "github.com/nsheridan/cashier/server/store"
+ "github.com/nsheridan/cashier/testdata"
+)
+
+func newContext(t *testing.T) *appContext {
+ f, err := ioutil.TempFile(os.TempDir(), "signing_key_")
+ if err != nil {
+ t.Error(err)
+ }
+ defer os.Remove(f.Name())
+ f.Write(testdata.Priv)
+ f.Close()
+ signer, err := signer.New(&config.SSH{
+ SigningKey: f.Name(),
+ MaxAge: "1h",
+ })
+ if err != nil {
+ t.Error(err)
+ }
+ return &appContext{
+ cookiestore: sessions.NewCookieStore([]byte("secret")),
+ authprovider: testprovider.New(),
+ certstore: store.NewMemoryStore(),
+ authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
+ sshKeySigner: signer,
+ }
+}
+
+func TestLoginHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/auth/login", nil)
+ resp := httptest.NewRecorder()
+ loginHandler(newContext(t), resp, req)
+ if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
+ t.Error("Unexpected response")
+ }
+}
+
+func TestCallbackHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/auth/callback", nil)
+ req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ ctx.setAuthStateCookie(resp, req, "state")
+ callbackHandler(ctx, resp, req)
+ if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
+ t.Error("Unexpected response")
+ }
+}
+
+func TestRootHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/", nil)
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ tok := &oauth2.Token{
+ AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
+ Expiry: time.Now().Add(1 * time.Hour),
+ }
+ ctx.setAuthTokenCookie(resp, req, tok)
+ rootHandler(ctx, resp, req)
+ if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
+ t.Error("Unable to find token in response")
+ }
+}
+
+func TestRootHandlerNoSession(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/", nil)
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ rootHandler(ctx, resp, req)
+ if resp.Code != http.StatusSeeOther {
+ t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
+ }
+}
+
+func TestSignRevoke(t *testing.T) {
+ s, _ := json.Marshal(&lib.SignRequest{
+ Key: string(testdata.Pub),
+ })
+ req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ req.Header.Set("Authorization", "Bearer abcdef")
+ signHandler(ctx, resp, req)
+ if resp.Code != http.StatusOK {
+ t.Error("Unexpected response")
+ }
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ r := &lib.SignResponse{}
+ if err := json.Unmarshal(b, r); err != nil {
+ t.Error(err)
+ }
+ if r.Status != "ok" {
+ t.Error("Unexpected response")
+ }
+ k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response))
+ if err != nil {
+ t.Error(err)
+ }
+ cert, ok := k.(*ssh.Certificate)
+ if !ok {
+ t.Error("Did not receive a certificate")
+ }
+ // Revoke the cert and verify
+ req, _ = http.NewRequest("POST", "/revoke", nil)
+ req.Form = url.Values{"cert_id": []string{cert.KeyId}}
+ revokeCertHandler(ctx, resp, req)
+ req, _ = http.NewRequest("GET", "/revoked", nil)
+ revokedCertsHandler(ctx, resp, req)
+ revoked, _ := ioutil.ReadAll(resp.Body)
+ if string(revoked[:len(revoked)-1]) != r.Response {
+ t.Error("omg")
+ }
+}
diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go
index 1db7d30..31ba104 100644
--- a/cmd/cashierd/main.go
+++ b/cmd/cashierd/main.go
@@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// Sign the pubkey and issue the cert.
req, err := parseKey(r)
- req.Principal = a.authprovider.Username(token)
- a.authprovider.Revoke(token) // We don't need this anymore.
if err != nil {
return http.StatusInternalServerError, err
}
+ req.Principal = a.authprovider.Username(token)
+ a.authprovider.Revoke(token) // We don't need this anymore.
cert, err := a.sshKeySigner.SignUserKey(req)
if err != nil {
return http.StatusInternalServerError, err
@@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request)
}
func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if r.Method == "GET" {
- return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed))
- }
r.ParseForm()
id := r.FormValue("cert_id")
if id == "" {
@@ -268,7 +265,7 @@ func main() {
log.Fatal(err)
}
fs.Register(&config.AWS)
- signer, err := signer.New(config.SSH)
+ signer, err := signer.New(&config.SSH)
if err != nil {
log.Fatal(err)
}
@@ -304,12 +301,12 @@ func main() {
}
r := mux.NewRouter()
- r.Handle("/", appHandler{ctx, rootHandler})
- r.Handle("/auth/login", appHandler{ctx, loginHandler})
- r.Handle("/auth/callback", appHandler{ctx, callbackHandler})
- r.Handle("/sign", appHandler{ctx, signHandler})
- r.Handle("/revoked", appHandler{ctx, revokedCertsHandler})
- r.Handle("/revoke", appHandler{ctx, revokeCertHandler})
+ r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
+ r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
+ r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
+ r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
+ r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler})
+ r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler})
logfile := os.Stderr
if config.Server.HTTPLogFile != "" {
logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)
diff --git a/server/auth/testprovider/testprovider.go b/server/auth/testprovider/testprovider.go
new file mode 100644
index 0000000..3d2b13a
--- /dev/null
+++ b/server/auth/testprovider/testprovider.go
@@ -0,0 +1,56 @@
+package testprovider
+
+import (
+ "time"
+
+ "github.com/nsheridan/cashier/server/auth"
+
+ "golang.org/x/oauth2"
+)
+
+const (
+ name = "testprovider"
+)
+
+// Config is an implementation of `auth.Provider` for testing.
+type Config struct{}
+
+// New creates a new provider.
+func New() auth.Provider {
+ return &Config{}
+}
+
+// Name returns the name of the provider.
+func (c *Config) Name() string {
+ return name
+}
+
+// Valid validates the oauth token.
+func (c *Config) Valid(token *oauth2.Token) bool {
+ return true
+}
+
+// Revoke disables the access token.
+func (c *Config) Revoke(token *oauth2.Token) error {
+ return nil
+}
+
+// StartSession retrieves an authentication endpoint.
+func (c *Config) StartSession(state string) *auth.Session {
+ return &auth.Session{
+ AuthURL: "https://www.example.com/auth",
+ }
+}
+
+// Exchange authorizes the session and returns an access token.
+func (c *Config) Exchange(code string) (*oauth2.Token, error) {
+ return &oauth2.Token{
+ AccessToken: "token",
+ Expiry: time.Now().Add(1 * time.Hour),
+ }, nil
+}
+
+// Username retrieves the username portion of the user's email address.
+func (c *Config) Username(token *oauth2.Token) string {
+ return "test"
+}
diff --git a/server/signer/signer.go b/server/signer/signer.go
index a3f056a..8169c11 100644
--- a/server/signer/signer.go
+++ b/server/signer/signer.go
@@ -69,7 +69,7 @@ func makeperms(perms []string) map[string]string {
}
// New creates a new KeySigner from the supplied configuration.
-func New(conf config.SSH) (*KeySigner, error) {
+func New(conf *config.SSH) (*KeySigner, error) {
data, err := wkfs.ReadFile(conf.SigningKey)
if err != nil {
return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)