From a30d6403f723765b8f9b7609e7eb3ade0f5434a0 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 10 Sep 2016 17:40:23 +0100 Subject: Make client a top-level package for consistency --- client/client.go | 113 ++++++++++++++++++++++++++++++++++++++++++++++++ client/client_test.go | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++ client/config.go | 38 ++++++++++++++++ client/keys.go | 84 ++++++++++++++++++++++++++++++++++++ client/keys_test.go | 28 ++++++++++++ 5 files changed, 380 insertions(+) create mode 100644 client/client.go create mode 100644 client/client_test.go create mode 100644 client/config.go create mode 100644 client/keys.go create mode 100644 client/keys_test.go (limited to 'client') diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..ba5b900 --- /dev/null +++ b/client/client.go @@ -0,0 +1,113 @@ +package client + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "path" + "time" + + "github.com/nsheridan/cashier/lib" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// InstallCert adds the private key and signed certificate to the ssh agent. +func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { + t := time.Unix(int64(cert.ValidBefore), 0) + lifetime := t.Sub(time.Now()).Seconds() + comment := fmt.Sprintf("%s [Expires %s]", cert.KeyId, t) + pubcert := agent.AddedKey{ + PrivateKey: key, + Certificate: cert, + Comment: comment, + LifetimeSecs: uint32(lifetime), + } + if err := a.Add(pubcert); err != nil { + return fmt.Errorf("error importing certificate: %s", err) + } + privkey := agent.AddedKey{ + PrivateKey: key, + Comment: comment, + LifetimeSecs: uint32(lifetime), + } + if err := a.Add(privkey); err != nil { + return fmt.Errorf("error importing key: %s", err) + } + return nil +} + +// send the signing request to the CA. +func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, + } + client := &http.Client{Transport: transport} + u, err := url.Parse(ca) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, "/sign") + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(s)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Bad response from server: %s", resp.Status) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + c := &lib.SignResponse{} + if err := json.Unmarshal(body, c); err != nil { + return nil, err + } + return c, nil +} + +// Sign sends the public key to the CA to be signed. +func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) { + validity, err := time.ParseDuration(conf.Validity) + if err != nil { + return nil, err + } + marshaled := ssh.MarshalAuthorizedKey(pub) + // Remove the trailing newline. + marshaled = marshaled[:len(marshaled)-1] + s, err := json.Marshal(&lib.SignRequest{ + Key: string(marshaled), + ValidUntil: time.Now().Add(validity), + }) + if err != nil { + return nil, err + } + resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) + if err != nil { + return nil, err + } + if resp.Status != "ok" { + return nil, fmt.Errorf("error: %s", resp.Response) + } + k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) + if err != nil { + return nil, err + } + cert, ok := k.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("did not receive a certificate from server") + } + return cert, nil +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 0000000..b7df3fd --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,117 @@ +package client + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/testdata" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func TestLoadCert(t *testing.T) { + t.Parallel() + priv, _ := ssh.ParseRawPrivateKey(testdata.Priv) + key := priv.(*rsa.PrivateKey) + pub, _ := ssh.NewPublicKey(&key.PublicKey) + c := &ssh.Certificate{ + KeyId: "test_key_12345", + Key: pub, + CertType: ssh.UserCert, + ValidBefore: ssh.CertTimeInfinity, + ValidAfter: 0, + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + t.Error(err) + } + c.SignCert(rand.Reader, signer) + a := agent.NewKeyring() + if err := InstallCert(a, c, key); err != nil { + t.Error(err) + } + listedKeys, err := a.List() + if err != nil { + t.Errorf("Error reading from agent: %v", err) + } + if len(listedKeys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(listedKeys)) + } + if !bytes.Equal(listedKeys[0].Marshal(), c.Marshal()) { + t.Error("Certs not equal") + } + for _, k := range listedKeys { + exp := time.Unix(int64(c.ValidBefore), 0).String() + want := fmt.Sprintf("%s [Expires %s]", c.KeyId, exp) + if k.Comment != want { + t.Errorf("key comment:\nwanted:%s\ngot: %s", want, k.Comment) + } + } +} + +func TestSignGood(t *testing.T) { + t.Parallel() + res := &lib.SignResponse{ + Status: "ok", + Response: string(testdata.Cert), + } + j, _ := json.Marshal(res) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, string(j)) + })) + defer ts.Close() + _, err := send([]byte(`{}`), "token", ts.URL, true) + if err != nil { + t.Error(err) + } + k, _, _, _, err := ssh.ParseAuthorizedKey(testdata.Pub) + if err != nil { + t.Error(err) + } + c := &Config{ + CA: ts.URL, + Validity: "24h", + } + cert, err := Sign(k, "token", c) + if cert == nil && err != nil { + t.Error(err) + } +} + +func TestSignBad(t *testing.T) { + t.Parallel() + res := &lib.SignResponse{ + Status: "error", + Response: `{"response": "error"}`, + } + j, _ := json.Marshal(res) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, string(j)) + })) + defer ts.Close() + _, err := send([]byte(`{}`), "token", ts.URL, true) + if err != nil { + t.Error(err) + } + k, _, _, _, err := ssh.ParseAuthorizedKey(testdata.Pub) + if err != nil { + t.Error(err) + } + c := &Config{ + CA: ts.URL, + Validity: "24h", + } + cert, err := Sign(k, "token", c) + if cert != nil && err == nil { + t.Error(err) + } +} diff --git a/client/config.go b/client/config.go new file mode 100644 index 0000000..1cc9401 --- /dev/null +++ b/client/config.go @@ -0,0 +1,38 @@ +package client + +import ( + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +// Config holds the client configuration. +type Config struct { + CA string `mapstructure:"ca"` + Keytype string `mapstructure:"key_type"` + Keysize int `mapstructure:"key_size"` + Validity string `mapstructure:"validity"` + ValidateTLSCertificate bool `mapstructure:"validate_tls_certificate"` +} + +func setDefaults() { + viper.BindPFlag("ca", pflag.Lookup("ca")) + viper.BindPFlag("key_type", pflag.Lookup("key_type")) + viper.BindPFlag("key_size", pflag.Lookup("key_size")) + viper.BindPFlag("validity", pflag.Lookup("validity")) + viper.SetDefault("validateTLSCertificate", true) +} + +// ReadConfig reads the client configuration from a file into a Config struct. +func ReadConfig(path string) (*Config, error) { + setDefaults() + viper.SetConfigFile(path) + viper.SetConfigType("hcl") + if err := viper.ReadInConfig(); err != nil { + return nil, err + } + c := &Config{} + if err := viper.Unmarshal(c); err != nil { + return nil, err + } + return c, nil +} diff --git a/client/keys.go b/client/keys.go new file mode 100644 index 0000000..4b3b69e --- /dev/null +++ b/client/keys.go @@ -0,0 +1,84 @@ +package client + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "fmt" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh" +) + +// Key is a private key. +type Key interface{} +type keyfunc func(int) (Key, ssh.PublicKey, error) + +var ( + keytypes = map[string]keyfunc{ + "rsa": generateRSAKey, + "ecdsa": generateECDSAKey, + "ed25519": generateED25519Key, + } +) + +func generateED25519Key(bits int) (Key, ssh.PublicKey, error) { + p, k, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(p) + if err != nil { + return nil, nil, err + } + return &k, pub, nil +} + +func generateRSAKey(bits int) (Key, ssh.PublicKey, error) { + k, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(&k.PublicKey) + if err != nil { + return nil, nil, err + } + return k, pub, nil +} + +func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) { + var curve elliptic.Curve + switch bits { + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + case 521: + curve = elliptic.P521() + default: + return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'") + } + k, err := ecdsa.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, nil, err + } + pub, err := ssh.NewPublicKey(&k.PublicKey) + if err != nil { + return nil, nil, err + } + return k, pub, nil +} + +// GenerateKey generates a ssh key-pair according to the type and size specified. +func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) { + f, ok := keytypes[keytype] + if !ok { + var valid []string + for k := range keytypes { + valid = append(valid, k) + } + return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, valid) + } + return f(bits) +} diff --git a/client/keys_test.go b/client/keys_test.go new file mode 100644 index 0000000..9e930d5 --- /dev/null +++ b/client/keys_test.go @@ -0,0 +1,28 @@ +package client + +import ( + "reflect" + "testing" +) + +func TestGenerateKeys(t *testing.T) { + var tests = []struct { + key string + size int + want string + }{ + {"ecdsa", 256, "*ecdsa.PrivateKey"}, + {"rsa", 1024, "*rsa.PrivateKey"}, + {"ed25519", 256, "*ed25519.PrivateKey"}, + } + + for _, tst := range tests { + k, _, err := GenerateKey(tst.key, tst.size) + if err != nil { + t.Error(err) + } + if reflect.TypeOf(k).String() != tst.want { + t.Errorf("Wrong key type returned. Got %s, wanted %s", reflect.TypeOf(k).String(), tst.want) + } + } +} -- cgit v1.2.3