From d21fac6f190c1079ca247658530d465ad5867ff5 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Thu, 9 Aug 2018 20:47:50 +0100 Subject: Only request a reason from the client if the server requires it --- client/client.go | 57 ++++++++++++++++++++++++++++++++++++++------------- client/client_test.go | 8 ++++---- 2 files changed, 47 insertions(+), 18 deletions(-) (limited to 'client') diff --git a/client/client.go b/client/client.go index 58cc6bb..628783a 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "bufio" "bytes" "crypto/tls" "encoding/base64" @@ -10,7 +11,9 @@ import ( "io/ioutil" "net/http" "net/url" + "os" "path" + "strings" "time" "github.com/nsheridan/cashier/lib" @@ -19,6 +22,10 @@ import ( "golang.org/x/crypto/ssh/agent" ) +var ( + errNeedsReason = errors.New("reason required") +) + // SavePublicFiles installs the public part of the cert and key. func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error { if prefix == "" { @@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { } // send the signing request to the CA. -func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { +func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { + s, err := json.Marshal(sr) + if err != nil { + return nil, errors.Wrap(err, "unable to create sign request") + } transport := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, } @@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes return nil, err } defer resp.Body.Close() + signResponse := &lib.SignResponse{} if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Bad response from server: %s", resp.Status) + if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") { + return signResponse, errNeedsReason + } + return signResponse, fmt.Errorf("bad response from server: %s", resp.Status) } - c := &lib.SignResponse{} - if err := json.NewDecoder(resp.Body).Decode(c); err != nil { + if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil { return nil, errors.Wrap(err, "unable to decode server response") } - return c, nil + return signResponse, nil +} + +func promptForReason() (message string) { + fmt.Print("Enter message: ") + scanner := bufio.NewScanner(os.Stdin) + if scanner.Scan() { + message = scanner.Text() + } + return message } // Sign sends the public key to the CA to be signed. -func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) { +func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) { + var err error validity, err := time.ParseDuration(conf.Validity) if err != nil { return nil, err } - s, err := json.Marshal(&lib.SignRequest{ + s := &lib.SignRequest{ Key: string(lib.GetPublicKey(pub)), ValidUntil: time.Now().Add(validity), - Message: message, - }) - if err != nil { - return nil, errors.Wrap(err, "unable to create sign request") } - resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) - if err != nil { - return nil, errors.Wrap(err, "error sending request to CA") + resp := &lib.SignResponse{} + for { + resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate) + if err == nil { + break + } + if err != nil && err == errNeedsReason { + s.Message = promptForReason() + continue + } else if err != nil { + return nil, errors.Wrap(err, "error sending request to CA") + } } if resp.Status != "ok" { return nil, fmt.Errorf("bad response from CA: %s", resp.Response) diff --git a/client/client_test.go b/client/client_test.go index fddd543..2447db3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - _, err := send([]byte(`{}`), "token", ts.URL, true) + _, err := send(&lib.SignRequest{}, "token", ts.URL, true) if err != nil { t.Error(err) } @@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) { CA: ts.URL, Validity: "24h", } - cert, err := Sign(k, "token", "message", c) + cert, err := Sign(k, "token", c) if cert == nil && err != nil { t.Error(err) } @@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - _, err := send([]byte(`{}`), "token", ts.URL, true) + _, err := send(&lib.SignRequest{}, "token", ts.URL, true) if err != nil { t.Error(err) } @@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) { CA: ts.URL, Validity: "24h", } - cert, err := Sign(k, "token", "message", c) + cert, err := Sign(k, "token", c) if cert != nil && err == nil { t.Error(err) } -- cgit v1.2.3