diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2016-09-03 19:14:13 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2016-09-03 19:14:13 +0100 |
commit | 0af43a29b7cabb6710cd1cb335785ff60dbf758f (patch) | |
tree | 28733eca29ef955254ba449504534fb3e6da0986 | |
parent | dba3de4451f29fc0b8cb6474b9bbb18ed61d9eac (diff) |
Move signing & agent logic out of the main package
-rw-r--r-- | cmd/cashier/client/client.go | 111 | ||||
-rw-r--r-- | cmd/cashier/client/client_test.go (renamed from cmd/cashier/client_test.go) | 12 | ||||
-rw-r--r-- | cmd/cashier/client/config.go (renamed from cmd/cashier/config.go) | 8 | ||||
-rw-r--r-- | cmd/cashier/client/keys.go (renamed from cmd/cashier/keys.go) | 14 | ||||
-rw-r--r-- | cmd/cashier/client/keys_test.go | 28 | ||||
-rw-r--r-- | cmd/cashier/main.go | 109 |
6 files changed, 161 insertions, 121 deletions
diff --git a/cmd/cashier/client/client.go b/cmd/cashier/client/client.go new file mode 100644 index 0000000..d8def27 --- /dev/null +++ b/cmd/cashier/client/client.go @@ -0,0 +1,111 @@ +package client + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "path" + "time" + + "github.com/nsheridan/cashier/lib" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +// InstallCert adds the private key and signed certificate to the ssh agent. +func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error { + t := time.Unix(int64(cert.ValidBefore), 0) + lifetime := t.Sub(time.Now()).Seconds() + comment := fmt.Sprintf("%s [Expires %s]", cert.KeyId, t) + pubcert := agent.AddedKey{ + PrivateKey: key, + Certificate: cert, + Comment: comment, + LifetimeSecs: uint32(lifetime), + } + if err := a.Add(pubcert); err != nil { + return fmt.Errorf("error importing certificate: %s", err) + } + privkey := agent.AddedKey{ + PrivateKey: key, + Comment: comment, + LifetimeSecs: uint32(lifetime), + } + if err := a.Add(privkey); err != nil { + return fmt.Errorf("error importing key: %s", err) + } + return nil +} + +// send the signing request to the CA. +func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, + } + client := &http.Client{Transport: transport} + u, err := url.Parse(ca) + if err != nil { + return nil, err + } + u.Path = path.Join(u.Path, "/sign") + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(s)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Bad response from server: %s", resp.Status) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + c := &lib.SignResponse{} + if err := json.Unmarshal(body, c); err != nil { + return nil, err + } + return c, nil +} + +func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) { + validity, err := time.ParseDuration(conf.Validity) + if err != nil { + return nil, err + } + marshaled := ssh.MarshalAuthorizedKey(pub) + marshaled = marshaled[:len(marshaled)-1] + s, err := json.Marshal(&lib.SignRequest{ + Key: string(marshaled), + ValidUntil: time.Now().Add(validity), + }) + if err != nil { + return nil, err + } + resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) + if err != nil { + return nil, err + } + if resp.Status != "ok" { + return nil, fmt.Errorf("error: %s", resp.Response) + } + k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) + if err != nil { + return nil, err + } + cert, ok := k.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("did not receive a certificate from server") + } + return cert, nil +} diff --git a/cmd/cashier/client_test.go b/cmd/cashier/client/client_test.go index e03b712..b7df3fd 100644 --- a/cmd/cashier/client_test.go +++ b/cmd/cashier/client/client_test.go @@ -1,4 +1,4 @@ -package main +package client import ( "bytes" @@ -36,7 +36,7 @@ func TestLoadCert(t *testing.T) { } c.SignCert(rand.Reader, signer) a := agent.NewKeyring() - if err := installCert(a, c, key); err != nil { + if err := InstallCert(a, c, key); err != nil { t.Error(err) } listedKeys, err := a.List() @@ -77,11 +77,11 @@ func TestSignGood(t *testing.T) { if err != nil { t.Error(err) } - c := &config{ + c := &Config{ CA: ts.URL, Validity: "24h", } - cert, err := sign(k, "token", c) + cert, err := Sign(k, "token", c) if cert == nil && err != nil { t.Error(err) } @@ -106,11 +106,11 @@ func TestSignBad(t *testing.T) { if err != nil { t.Error(err) } - c := &config{ + c := &Config{ CA: ts.URL, Validity: "24h", } - cert, err := sign(k, "token", c) + cert, err := Sign(k, "token", c) if cert != nil && err == nil { t.Error(err) } diff --git a/cmd/cashier/config.go b/cmd/cashier/client/config.go index eed98e1..d4defef 100644 --- a/cmd/cashier/config.go +++ b/cmd/cashier/client/config.go @@ -1,11 +1,11 @@ -package main +package client import ( "github.com/spf13/pflag" "github.com/spf13/viper" ) -type config struct { +type Config struct { CA string `mapstructure:"ca"` Keytype string `mapstructure:"key_type"` Keysize int `mapstructure:"key_size"` @@ -21,14 +21,14 @@ func setDefaults() { viper.SetDefault("validateTLSCertificate", true) } -func readConfig(path string) (*config, error) { +func ReadConfig(path string) (*Config, error) { setDefaults() viper.SetConfigFile(path) viper.SetConfigType("hcl") if err := viper.ReadInConfig(); err != nil { return nil, err } - c := &config{} + c := &Config{} if err := viper.Unmarshal(c); err != nil { return nil, err } diff --git a/cmd/cashier/keys.go b/cmd/cashier/client/keys.go index ac0a9f7..877ff42 100644 --- a/cmd/cashier/keys.go +++ b/cmd/cashier/client/keys.go @@ -1,4 +1,4 @@ -package main +package client import ( "crypto/ecdsa" @@ -11,8 +11,8 @@ import ( "golang.org/x/crypto/ssh" ) -type key interface{} -type keyfunc func(int) (key, ssh.PublicKey, error) +type Key interface{} +type keyfunc func(int) (Key, ssh.PublicKey, error) var ( keytypes = map[string]keyfunc{ @@ -22,7 +22,7 @@ var ( } ) -func generateED25519Key(bits int) (key, ssh.PublicKey, error) { +func generateED25519Key(bits int) (Key, ssh.PublicKey, error) { p, k, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, nil, err @@ -34,7 +34,7 @@ func generateED25519Key(bits int) (key, ssh.PublicKey, error) { return &k, pub, nil } -func generateRSAKey(bits int) (key, ssh.PublicKey, error) { +func generateRSAKey(bits int) (Key, ssh.PublicKey, error) { k, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { return nil, nil, err @@ -46,7 +46,7 @@ func generateRSAKey(bits int) (key, ssh.PublicKey, error) { return k, pub, nil } -func generateECDSAKey(bits int) (key, ssh.PublicKey, error) { +func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) { var curve elliptic.Curve switch bits { case 256: @@ -69,7 +69,7 @@ func generateECDSAKey(bits int) (key, ssh.PublicKey, error) { return k, pub, nil } -func generateKey(keytype string, bits int) (key, ssh.PublicKey, error) { +func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) { f, ok := keytypes[keytype] if !ok { var valid []string diff --git a/cmd/cashier/client/keys_test.go b/cmd/cashier/client/keys_test.go new file mode 100644 index 0000000..9e930d5 --- /dev/null +++ b/cmd/cashier/client/keys_test.go @@ -0,0 +1,28 @@ +package client + +import ( + "reflect" + "testing" +) + +func TestGenerateKeys(t *testing.T) { + var tests = []struct { + key string + size int + want string + }{ + {"ecdsa", 256, "*ecdsa.PrivateKey"}, + {"rsa", 1024, "*rsa.PrivateKey"}, + {"ed25519", 256, "*ed25519.PrivateKey"}, + } + + for _, tst := range tests { + k, _, err := GenerateKey(tst.key, tst.size) + if err != nil { + t.Error(err) + } + if reflect.TypeOf(k).String() != tst.want { + t.Errorf("Wrong key type returned. Got %s, wanted %s", reflect.TypeOf(k).String(), tst.want) + } + } +} diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index 72355e3..4ceaa80 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -1,24 +1,17 @@ package main import ( - "bytes" - "crypto/tls" - "encoding/json" "fmt" - "io/ioutil" "log" "net" - "net/http" - "net/url" "os" "os/user" "path" "time" - "github.com/nsheridan/cashier/lib" + "github.com/nsheridan/cashier/cmd/cashier/client" "github.com/pkg/browser" "github.com/spf13/pflag" - "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -31,102 +24,10 @@ var ( keytype = pflag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519") ) -func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { - t := time.Unix(int64(cert.ValidBefore), 0) - lifetime := t.Sub(time.Now()).Seconds() - comment := fmt.Sprintf("%s [Expires %s]", cert.KeyId, t) - pubcert := agent.AddedKey{ - PrivateKey: key, - Certificate: cert, - Comment: comment, - LifetimeSecs: uint32(lifetime), - } - if err := a.Add(pubcert); err != nil { - return fmt.Errorf("error importing certificate: %s", err) - } - privkey := agent.AddedKey{ - PrivateKey: key, - Comment: comment, - LifetimeSecs: uint32(lifetime), - } - if err := a.Add(privkey); err != nil { - return fmt.Errorf("error importing key: %s", err) - } - return nil -} - -func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, - } - client := &http.Client{Transport: transport} - u, err := url.Parse(ca) - if err != nil { - return nil, err - } - u.Path = path.Join(u.Path, "/sign") - req, err := http.NewRequest("POST", u.String(), bytes.NewReader(s)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err := client.Do(req) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Bad response from server: %s", resp.Status) - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, err - } - c := &lib.SignResponse{} - if err := json.Unmarshal(body, c); err != nil { - return nil, err - } - return c, nil -} - -func sign(pub ssh.PublicKey, token string, conf *config) (*ssh.Certificate, error) { - validity, err := time.ParseDuration(conf.Validity) - if err != nil { - return nil, err - } - marshaled := ssh.MarshalAuthorizedKey(pub) - marshaled = marshaled[:len(marshaled)-1] - s, err := json.Marshal(&lib.SignRequest{ - Key: string(marshaled), - ValidUntil: time.Now().Add(validity), - }) - if err != nil { - return nil, err - } - resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) - if err != nil { - return nil, err - } - if resp.Status != "ok" { - return nil, fmt.Errorf("error: %s", resp.Response) - } - k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) - if err != nil { - return nil, err - } - cert, ok := k.(*ssh.Certificate) - if !ok { - return nil, fmt.Errorf("did not receive a certificate from server") - } - return cert, nil -} - func main() { pflag.Parse() - c, err := readConfig(*cfg) + c, err := client.ReadConfig(*cfg) if err != nil { log.Fatalf("Error parsing config file: %v\n", err) } @@ -135,7 +36,7 @@ func main() { fmt.Println("Error launching web browser. Go to the link in your web browser") } fmt.Println("Generating new key pair") - priv, pub, err := generateKey(c.Keytype, c.Keysize) + priv, pub, err := client.GenerateKey(c.Keytype, c.Keysize) if err != nil { log.Fatalln("Error generating key pair: ", err) } @@ -144,7 +45,7 @@ func main() { var token string fmt.Scanln(&token) - cert, err := sign(pub, token, c) + cert, err := client.Sign(pub, token, c) if err != nil { log.Fatalln(err) } @@ -154,7 +55,7 @@ func main() { } defer sock.Close() a := agent.NewClient(sock) - if err := installCert(a, cert, priv); err != nil { + if err := client.InstallCert(a, cert, priv); err != nil { log.Fatalln(err) } fmt.Println("Credentials added.") |