aboutsummaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/client.go57
-rw-r--r--client/client_test.go8
2 files changed, 47 insertions, 18 deletions
diff --git a/client/client.go b/client/client.go
index 58cc6bb..628783a 100644
--- a/client/client.go
+++ b/client/client.go
@@ -1,6 +1,7 @@
package client
import (
+ "bufio"
"bytes"
"crypto/tls"
"encoding/base64"
@@ -10,7 +11,9 @@ import (
"io/ioutil"
"net/http"
"net/url"
+ "os"
"path"
+ "strings"
"time"
"github.com/nsheridan/cashier/lib"
@@ -19,6 +22,10 @@ import (
"golang.org/x/crypto/ssh/agent"
)
+var (
+ errNeedsReason = errors.New("reason required")
+)
+
// SavePublicFiles installs the public part of the cert and key.
func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error {
if prefix == "" {
@@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error {
}
// send the signing request to the CA.
-func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
+ s, err := json.Marshal(sr)
+ if err != nil {
+ return nil, errors.Wrap(err, "unable to create sign request")
+ }
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate},
}
@@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes
return nil, err
}
defer resp.Body.Close()
+ signResponse := &lib.SignResponse{}
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("Bad response from server: %s", resp.Status)
+ if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") {
+ return signResponse, errNeedsReason
+ }
+ return signResponse, fmt.Errorf("bad response from server: %s", resp.Status)
}
- c := &lib.SignResponse{}
- if err := json.NewDecoder(resp.Body).Decode(c); err != nil {
+ if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil {
return nil, errors.Wrap(err, "unable to decode server response")
}
- return c, nil
+ return signResponse, nil
+}
+
+func promptForReason() (message string) {
+ fmt.Print("Enter message: ")
+ scanner := bufio.NewScanner(os.Stdin)
+ if scanner.Scan() {
+ message = scanner.Text()
+ }
+ return message
}
// Sign sends the public key to the CA to be signed.
-func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) {
+func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) {
+ var err error
validity, err := time.ParseDuration(conf.Validity)
if err != nil {
return nil, err
}
- s, err := json.Marshal(&lib.SignRequest{
+ s := &lib.SignRequest{
Key: string(lib.GetPublicKey(pub)),
ValidUntil: time.Now().Add(validity),
- Message: message,
- })
- if err != nil {
- return nil, errors.Wrap(err, "unable to create sign request")
}
- resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate)
- if err != nil {
- return nil, errors.Wrap(err, "error sending request to CA")
+ resp := &lib.SignResponse{}
+ for {
+ resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate)
+ if err == nil {
+ break
+ }
+ if err != nil && err == errNeedsReason {
+ s.Message = promptForReason()
+ continue
+ } else if err != nil {
+ return nil, errors.Wrap(err, "error sending request to CA")
+ }
}
if resp.Status != "ok" {
return nil, fmt.Errorf("bad response from CA: %s", resp.Response)
diff --git a/client/client_test.go b/client/client_test.go
index fddd543..2447db3 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
- _, err := send([]byte(`{}`), "token", ts.URL, true)
+ _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
@@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
- cert, err := Sign(k, "token", "message", c)
+ cert, err := Sign(k, "token", c)
if cert == nil && err != nil {
t.Error(err)
}
@@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
- _, err := send([]byte(`{}`), "token", ts.URL, true)
+ _, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
@@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
- cert, err := Sign(k, "token", "message", c)
+ cert, err := Sign(k, "token", c)
if cert != nil && err == nil {
t.Error(err)
}