aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/client.go57
-rw-r--r--client/client_test.go8
-rw-r--r--cmd/cashier/main.go10
-rw-r--r--server/config/config.go1
-rw-r--r--server/web.go16
5 files changed, 61 insertions, 31 deletions
diff --git a/client/client.go b/client/client.go
index 58cc6bb..628783a 100644
--- a/client/client.go
+++ b/client/client.go
@@ -1,6 +1,7 @@
package client
import (
+ "bufio"
"bytes"
"crypto/tls"
"encoding/base64"
@@ -10,7 +11,9 @@ import (
"io/ioutil"
"net/http"
"net/url"
+ "os"
"path"
+ "strings"
"time"
"github.com/nsheridan/cashier/lib"
@@ -19,6 +22,10 @@ import (
"golang.org/x/crypto/ssh/agent"
)
+var (
+ errNeedsReason = errors.New("reason required")
+)
+
// SavePublicFiles installs the public part of the cert and key.
func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error {
if prefix == "" {
@@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error {
}
// send the signing request to the CA.
-func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+ s, err := json.Marshal(sr)
+ if err != nil {
+ return nil, errors.Wrap(err, "unable to create sign request")
+ }
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate},
}
@@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes
return nil, err
}
defer resp.Body.Close()
+ signResponse := &lib.SignResponse{}
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("Bad response from server: %s", resp.Status)
+ if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") {
+ return signResponse, errNeedsReason
+ }
+ return signResponse, fmt.Errorf("bad response from server: %s", resp.Status)
}
- c := &lib.SignResponse{}
- if err := json.NewDecoder(resp.Body).Decode(c); err != nil {
+ if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil {
return nil, errors.Wrap(err, "unable to decode server response")
}
- return c, nil
+ return signResponse, nil
+}
+
+func promptForReason() (message string) {
+ fmt.Print("Enter message: ")
+ scanner := bufio.NewScanner(os.Stdin)
+ if scanner.Scan() {
+ message = scanner.Text()
+ }
+ return message
}
// Sign sends the public key to the CA to be signed.
-func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) {
+func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) {
+ var err error
validity, err := time.ParseDuration(conf.Validity)
if err != nil {
return nil, err
}
- s, err := json.Marshal(&lib.SignRequest{
+ s := &lib.SignRequest{
Key: string(lib.GetPublicKey(pub)),
ValidUntil: time.Now().Add(validity),
- Message: message,
- })
- if err != nil {
- return nil, errors.Wrap(err, "unable to create sign request")
}
- resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate)
- if err != nil {
- return nil, errors.Wrap(err, "error sending request to CA")
+ resp := &lib.SignResponse{}
+ for {
+ resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate)
+ if err == nil {
+ break
+ }
+ if err != nil && err == errNeedsReason {
+ s.Message = promptForReason()
+ continue
+ } else if err != nil {
+ return nil, errors.Wrap(err, "error sending request to CA")
+ }
}
if resp.Status != "ok" {
return nil, fmt.Errorf("bad response from CA: %s", resp.Response)
diff --git a/client/client_test.go b/client/client_test.go
index fddd543..2447db3 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
- _, err := send([]byte(`{}`), "token", ts.URL, true)
+ _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
@@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
- cert, err := Sign(k, "token", "message", c)
+ cert, err := Sign(k, "token", c)
if cert == nil && err != nil {
t.Error(err)
}
@@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
- _, err := send([]byte(`{}`), "token", ts.URL, true)
+ _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
@@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
- cert, err := Sign(k, "token", "message", c)
+ cert, err := Sign(k, "token", c)
if cert != nil && err == nil {
t.Error(err)
}
diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go
index 7054bef..f448a25 100644
--- a/cmd/cashier/main.go
+++ b/cmd/cashier/main.go
@@ -1,7 +1,6 @@
package main
import (
- "bufio"
"fmt"
"log"
"net"
@@ -50,14 +49,7 @@ func main() {
var token string
fmt.Scanln(&token)
- var message string
- fmt.Print("Enter message: ")
- scanner := bufio.NewScanner(os.Stdin)
- if scanner.Scan() {
- message = scanner.Text()
- }
-
- cert, err := client.Sign(pub, token, message, c)
+ cert, err := client.Sign(pub, token, c)
if err != nil {
log.Fatalln(err)
}
diff --git a/server/config/config.go b/server/config/config.go
index 422a135..1985800 100644
--- a/server/config/config.go
+++ b/server/config/config.go
@@ -37,6 +37,7 @@ type Server struct {
CSRFSecret string `hcl:"csrf_secret"`
HTTPLogFile string `hcl:"http_logfile"`
Database Database `hcl:"database"`
+ RequireReason bool `hcl:"require_reason"`
}
// Auth holds the configuration specific to the OAuth provider.
diff --git a/server/web.go b/server/web.go
index 5677429..e238150 100644
--- a/server/web.go
+++ b/server/web.go
@@ -33,8 +33,9 @@ import (
// appContext contains local context - cookiestore, authsession etc.
type appContext struct {
- cookiestore *sessions.CookieStore
- authsession *auth.Session
+ cookiestore *sessions.CookieStore
+ authsession *auth.Session
+ requireReason bool
}
// getAuthTokenCookie retrieves a cookie from the request.
@@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
if err != nil {
return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
}
+
+ if a.requireReason && req.Message == "" {
+ w.Header().Add("X-Need-Reason", "required")
+ return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))
+ }
+
username := authprovider.Username(token)
authprovider.Revoke(token) // We don't need this anymore.
cert, err := keysigner.SignUserKey(req, username)
@@ -266,7 +273,6 @@ type appHandler struct {
func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, err := ah.h(ah.appContext, w, r)
if err != nil {
- log.Printf("HTTP %d: %q", status, err)
http.Error(w, err.Error(), status)
}
}
@@ -283,7 +289,8 @@ func newState() string {
func runHTTPServer(conf *config.Server, l net.Listener) {
var err error
ctx := &appContext{
- cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
+ cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
+ requireReason: conf.RequireReason,
}
ctx.cookiestore.Options = &sessions.Options{
MaxAge: 900,
@@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) {
r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck)
+
box := packr.NewBox("static")
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box)))
h := handlers.LoggingHandler(logfile, r)