diff options
-rw-r--r-- | client/client.go | 57 | ||||
-rw-r--r-- | client/client_test.go | 8 | ||||
-rw-r--r-- | cmd/cashier/main.go | 10 | ||||
-rw-r--r-- | server/config/config.go | 1 | ||||
-rw-r--r-- | server/web.go | 16 |
5 files changed, 61 insertions, 31 deletions
diff --git a/client/client.go b/client/client.go index 58cc6bb..628783a 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "bufio" "bytes" "crypto/tls" "encoding/base64" @@ -10,7 +11,9 @@ import ( "io/ioutil" "net/http" "net/url" + "os" "path" + "strings" "time" "github.com/nsheridan/cashier/lib" @@ -19,6 +22,10 @@ import ( "golang.org/x/crypto/ssh/agent" ) +var ( + errNeedsReason = errors.New("reason required") +) + // SavePublicFiles installs the public part of the cert and key. func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error { if prefix == "" { @@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { } // send the signing request to the CA. -func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { +func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { + s, err := json.Marshal(sr) + if err != nil { + return nil, errors.Wrap(err, "unable to create sign request") + } transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, } @@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes return nil, err } defer resp.Body.Close() + signResponse := &lib.SignResponse{} if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Bad response from server: %s", resp.Status) + if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") { + return signResponse, errNeedsReason + } + return signResponse, fmt.Errorf("bad response from server: %s", resp.Status) } - c := &lib.SignResponse{} - if err := json.NewDecoder(resp.Body).Decode(c); err != nil { + if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil { return nil, errors.Wrap(err, "unable to decode server response") } - return c, nil + return signResponse, nil +} + +func promptForReason() (message string) { + fmt.Print("Enter message: ") + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + message = scanner.Text() + } + return message } // Sign sends the public key to the CA to be signed. -func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) { +func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) { + var err error validity, err := time.ParseDuration(conf.Validity) if err != nil { return nil, err } - s, err := json.Marshal(&lib.SignRequest{ + s := &lib.SignRequest{ Key: string(lib.GetPublicKey(pub)), ValidUntil: time.Now().Add(validity), - Message: message, - }) - if err != nil { - return nil, errors.Wrap(err, "unable to create sign request") } - resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) - if err != nil { - return nil, errors.Wrap(err, "error sending request to CA") + resp := &lib.SignResponse{} + for { + resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate) + if err == nil { + break + } + if err != nil && err == errNeedsReason { + s.Message = promptForReason() + continue + } else if err != nil { + return nil, errors.Wrap(err, "error sending request to CA") + } } if resp.Status != "ok" { return nil, fmt.Errorf("bad response from CA: %s", resp.Response) diff --git a/client/client_test.go b/client/client_test.go index fddd543..2447db3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - _, err := send([]byte(`{}`), "token", ts.URL, true) + _, err := send(&lib.SignRequest{}, "token", ts.URL, true) if err != nil { t.Error(err) } @@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) { CA: ts.URL, Validity: "24h", } - cert, err := Sign(k, "token", "message", c) + cert, err := Sign(k, "token", c) if cert == nil && err != nil { t.Error(err) } @@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - _, err := send([]byte(`{}`), "token", ts.URL, true) + _, err := send(&lib.SignRequest{}, "token", ts.URL, true) if err != nil { t.Error(err) } @@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) { CA: ts.URL, Validity: "24h", } - cert, err := Sign(k, "token", "message", c) + cert, err := Sign(k, "token", c) if cert != nil && err == nil { t.Error(err) } diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index 7054bef..f448a25 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "fmt" "log" "net" @@ -50,14 +49,7 @@ func main() { var token string fmt.Scanln(&token) - var message string - fmt.Print("Enter message: ") - scanner := bufio.NewScanner(os.Stdin) - if scanner.Scan() { - message = scanner.Text() - } - - cert, err := client.Sign(pub, token, message, c) + cert, err := client.Sign(pub, token, c) if err != nil { log.Fatalln(err) } diff --git a/server/config/config.go b/server/config/config.go index 422a135..1985800 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -37,6 +37,7 @@ type Server struct { CSRFSecret string `hcl:"csrf_secret"` HTTPLogFile string `hcl:"http_logfile"` Database Database `hcl:"database"` + RequireReason bool `hcl:"require_reason"` } // Auth holds the configuration specific to the OAuth provider. diff --git a/server/web.go b/server/web.go index 5677429..e238150 100644 --- a/server/web.go +++ b/server/web.go @@ -33,8 +33,9 @@ import ( // appContext contains local context - cookiestore, authsession etc. type appContext struct { - cookiestore *sessions.CookieStore - authsession *auth.Session + cookiestore *sessions.CookieStore + authsession *auth.Session + requireReason bool } // getAuthTokenCookie retrieves a cookie from the request. @@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er if err != nil { return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request") } + + if a.requireReason && req.Message == "" { + w.Header().Add("X-Need-Reason", "required") + return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden)) + } + username := authprovider.Username(token) authprovider.Revoke(token) // We don't need this anymore. cert, err := keysigner.SignUserKey(req, username) @@ -266,7 +273,6 @@ type appHandler struct { func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, err := ah.h(ah.appContext, w, r) if err != nil { - log.Printf("HTTP %d: %q", status, err) http.Error(w, err.Error(), status) } } @@ -283,7 +289,8 @@ func newState() string { func runHTTPServer(conf *config.Server, l net.Listener) { var err error ctx := &appContext{ - cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)), + cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)), + requireReason: conf.RequireReason, } ctx.cookiestore.Options = &sessions.Options{ MaxAge: 900, @@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) { r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler}) r.Methods("GET").Path("/metrics").Handler(promhttp.Handler()) r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck) + box := packr.NewBox("static") r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box))) h := handlers.LoggingHandler(logfile, r) |