aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-07-10 22:35:13 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-07-17 13:58:10 +0100
commit49f40a952943f26494d6407dc608b50b2ec0df7f (patch)
treec261836adad0165642bb7ade18db78852ad6c5cb /cmd
parentdee5a19d36554a8f9a365efd65d13b134889bf63 (diff)
Add some handlers tests
Diffstat (limited to 'cmd')
-rw-r--r--cmd/cashierd/handlers_test.go139
-rw-r--r--cmd/cashierd/main.go21
2 files changed, 148 insertions, 12 deletions
diff --git a/cmd/cashierd/handlers_test.go b/cmd/cashierd/handlers_test.go
new file mode 100644
index 0000000..a214dfd
--- /dev/null
+++ b/cmd/cashierd/handlers_test.go
@@ -0,0 +1,139 @@
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "io/ioutil"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/oauth2"
+
+ "github.com/gorilla/sessions"
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/auth/testprovider"
+ "github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/signer"
+ "github.com/nsheridan/cashier/server/store"
+ "github.com/nsheridan/cashier/testdata"
+)
+
+func newContext(t *testing.T) *appContext {
+ f, err := ioutil.TempFile(os.TempDir(), "signing_key_")
+ if err != nil {
+ t.Error(err)
+ }
+ defer os.Remove(f.Name())
+ f.Write(testdata.Priv)
+ f.Close()
+ signer, err := signer.New(&config.SSH{
+ SigningKey: f.Name(),
+ MaxAge: "1h",
+ })
+ if err != nil {
+ t.Error(err)
+ }
+ return &appContext{
+ cookiestore: sessions.NewCookieStore([]byte("secret")),
+ authprovider: testprovider.New(),
+ certstore: store.NewMemoryStore(),
+ authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
+ sshKeySigner: signer,
+ }
+}
+
+func TestLoginHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/auth/login", nil)
+ resp := httptest.NewRecorder()
+ loginHandler(newContext(t), resp, req)
+ if resp.Code != http.StatusFound && resp.Header().Get("Location") != "https://www.example.com/auth" {
+ t.Error("Unexpected response")
+ }
+}
+
+func TestCallbackHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/auth/callback", nil)
+ req.Form = url.Values{"state": []string{"state"}, "code": []string{"abcdef"}}
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ ctx.setAuthStateCookie(resp, req, "state")
+ callbackHandler(ctx, resp, req)
+ if resp.Code != http.StatusFound && resp.Header().Get("Location") != "/" {
+ t.Error("Unexpected response")
+ }
+}
+
+func TestRootHandler(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/", nil)
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ tok := &oauth2.Token{
+ AccessToken: "XXX_TEST_TOKEN_STRING_XXX",
+ Expiry: time.Now().Add(1 * time.Hour),
+ }
+ ctx.setAuthTokenCookie(resp, req, tok)
+ rootHandler(ctx, resp, req)
+ if resp.Code != http.StatusOK && !strings.Contains(resp.Body.String(), "XXX_TEST_TOKEN_STRING_XXX") {
+ t.Error("Unable to find token in response")
+ }
+}
+
+func TestRootHandlerNoSession(t *testing.T) {
+ req, _ := http.NewRequest("GET", "/", nil)
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ rootHandler(ctx, resp, req)
+ if resp.Code != http.StatusSeeOther {
+ t.Errorf("Unexpected status: %s, wanted %s", http.StatusText(resp.Code), http.StatusText(http.StatusSeeOther))
+ }
+}
+
+func TestSignRevoke(t *testing.T) {
+ s, _ := json.Marshal(&lib.SignRequest{
+ Key: string(testdata.Pub),
+ })
+ req, _ := http.NewRequest("POST", "/sign", bytes.NewReader(s))
+ resp := httptest.NewRecorder()
+ ctx := newContext(t)
+ req.Header.Set("Authorization", "Bearer abcdef")
+ signHandler(ctx, resp, req)
+ if resp.Code != http.StatusOK {
+ t.Error("Unexpected response")
+ }
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ r := &lib.SignResponse{}
+ if err := json.Unmarshal(b, r); err != nil {
+ t.Error(err)
+ }
+ if r.Status != "ok" {
+ t.Error("Unexpected response")
+ }
+ k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(r.Response))
+ if err != nil {
+ t.Error(err)
+ }
+ cert, ok := k.(*ssh.Certificate)
+ if !ok {
+ t.Error("Did not receive a certificate")
+ }
+ // Revoke the cert and verify
+ req, _ = http.NewRequest("POST", "/revoke", nil)
+ req.Form = url.Values{"cert_id": []string{cert.KeyId}}
+ revokeCertHandler(ctx, resp, req)
+ req, _ = http.NewRequest("GET", "/revoked", nil)
+ revokedCertsHandler(ctx, resp, req)
+ revoked, _ := ioutil.ReadAll(resp.Body)
+ if string(revoked[:len(revoked)-1]) != r.Response {
+ t.Error("omg")
+ }
+}
diff --git a/cmd/cashierd/main.go b/cmd/cashierd/main.go
index 1db7d30..31ba104 100644
--- a/cmd/cashierd/main.go
+++ b/cmd/cashierd/main.go
@@ -123,11 +123,11 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
// Sign the pubkey and issue the cert.
req, err := parseKey(r)
- req.Principal = a.authprovider.Username(token)
- a.authprovider.Revoke(token) // We don't need this anymore.
if err != nil {
return http.StatusInternalServerError, err
}
+ req.Principal = a.authprovider.Username(token)
+ a.authprovider.Revoke(token) // We don't need this anymore.
cert, err := a.sshKeySigner.SignUserKey(req)
if err != nil {
return http.StatusInternalServerError, err
@@ -199,9 +199,6 @@ func revokedCertsHandler(a *appContext, w http.ResponseWriter, r *http.Request)
}
func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
- if r.Method == "GET" {
- return http.StatusMethodNotAllowed, errors.New(http.StatusText(http.StatusMethodNotAllowed))
- }
r.ParseForm()
id := r.FormValue("cert_id")
if id == "" {
@@ -268,7 +265,7 @@ func main() {
log.Fatal(err)
}
fs.Register(&config.AWS)
- signer, err := signer.New(config.SSH)
+ signer, err := signer.New(&config.SSH)
if err != nil {
log.Fatal(err)
}
@@ -304,12 +301,12 @@ func main() {
}
r := mux.NewRouter()
- r.Handle("/", appHandler{ctx, rootHandler})
- r.Handle("/auth/login", appHandler{ctx, loginHandler})
- r.Handle("/auth/callback", appHandler{ctx, callbackHandler})
- r.Handle("/sign", appHandler{ctx, signHandler})
- r.Handle("/revoked", appHandler{ctx, revokedCertsHandler})
- r.Handle("/revoke", appHandler{ctx, revokeCertHandler})
+ r.Methods("GET").Path("/").Handler(appHandler{ctx, rootHandler})
+ r.Methods("GET").Path("/auth/login").Handler(appHandler{ctx, loginHandler})
+ r.Methods("GET").Path("/auth/callback").Handler(appHandler{ctx, callbackHandler})
+ r.Methods("POST").Path("/sign").Handler(appHandler{ctx, signHandler})
+ r.Methods("GET").Path("/revoked").Handler(appHandler{ctx, revokedCertsHandler})
+ r.Methods("POST").Path("/revoke").Handler(appHandler{ctx, revokeCertHandler})
logfile := os.Stderr
if config.Server.HTTPLogFile != "" {
logfile, err = os.OpenFile(config.Server.HTTPLogFile, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0660)