From a4b5776500b1250b61c3dafd17e464fdf3f3aae8 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Tue, 3 Jan 2017 23:34:47 +0000 Subject: Simplify key generation Use functions to build key generation options. Make it entirely optional. --- client/keys.go | 119 +++++++++++++++++++++++++++++++--------------------- client/keys_test.go | 52 ++++++++++++++++++++--- cmd/cashier/main.go | 2 +- 3 files changed, 118 insertions(+), 55 deletions(-) diff --git a/client/keys.go b/client/keys.go index 0ec0f1d..3d2fb31 100644 --- a/client/keys.go +++ b/client/keys.go @@ -1,56 +1,66 @@ package client import ( + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "fmt" - "strings" "golang.org/x/crypto/ed25519" "golang.org/x/crypto/ssh" ) // Key is a private key. -type Key interface{} -type keyfunc func(int) (Key, ssh.PublicKey, error) +type Key crypto.Signer -var ( - keytypes = map[string]keyfunc{ - "rsa": generateRSAKey, - "ecdsa": generateECDSAKey, - "ed25519": generateED25519Key, - } -) +// Options for key generation. +// Defaults will generate a 2048 bit RSA key. +type options struct { + keytype string + size int +} -func generateED25519Key(bits int) (Key, ssh.PublicKey, error) { - p, k, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, err - } - pub, err := ssh.NewPublicKey(p) - if err != nil { - return nil, nil, err - } - return &k, pub, nil +var defaultOptions = options{ + keytype: "rsa", + size: 0, // Different key types have different default sizes. } -func generateRSAKey(bits int) (Key, ssh.PublicKey, error) { - k, err := rsa.GenerateKey(rand.Reader, bits) - if err != nil { - return nil, nil, err +// A KeyOption is used to generate keys of different types and sizes. +type KeyOption func(*options) + +// KeyType sets the type of key to generate. +// Valid types are: "rsa", "ecdsa", "ed25519". +// Default is "rsa" +func KeyType(keyType string) KeyOption { + return func(o *options) { + o.keytype = keyType } - pub, err := ssh.NewPublicKey(&k.PublicKey) - if err != nil { - return nil, nil, err +} + +// KeySize sets the size of the key in bits. +// RSA keys must be a minimum of 1024 bits. The default is 2048 bits. +// ECDSA keys must be one of 256, 384, or 521 bits. The default is 256 bits. +// Ed25519 keys are of a fixed size. This option is ignored. +func KeySize(size int) KeyOption { + return func(o *options) { + o.size = size } - return k, pub, nil } -func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) { +func generateED25519Key() (Key, error) { + _, k, err := ed25519.GenerateKey(rand.Reader) + return &k, err +} + +func generateRSAKey(size int) (Key, error) { + return rsa.GenerateKey(rand.Reader, size) +} + +func generateECDSAKey(size int) (Key, error) { var curve elliptic.Curve - switch bits { + switch size { case 256: curve = elliptic.P256() case 384: @@ -58,28 +68,41 @@ func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) { case 521: curve = elliptic.P521() default: - return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'") - } - k, err := ecdsa.GenerateKey(curve, rand.Reader) - if err != nil { - return nil, nil, err - } - pub, err := ssh.NewPublicKey(&k.PublicKey) - if err != nil { - return nil, nil, err + return nil, fmt.Errorf("Unsupported key size: %d. Valid sizes are '256', '384', '521'", size) } - return k, pub, nil + return ecdsa.GenerateKey(curve, rand.Reader) } // GenerateKey generates a ssh key-pair according to the type and size specified. -func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) { - f, ok := keytypes[keytype] - if !ok { - var valid []string - for k := range keytypes { - valid = append(valid, k) +func GenerateKey(options ...func(*options)) (Key, ssh.PublicKey, error) { + var privkey Key + var pubkey ssh.PublicKey + var err error + + config := defaultOptions + for _, o := range options { + o(&config) + } + + switch config.keytype { + case "rsa": + if config.size == 0 { + config.size = 2048 + } + privkey, err = generateRSAKey(config.size) + case "ecdsa": + if config.size == 0 { + config.size = 256 } - return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, strings.Join(valid, "|")) + privkey, err = generateECDSAKey(config.size) + case "ed25519": + privkey, err = generateED25519Key() + default: + privkey, err = generateRSAKey(config.size) + } + if err != nil { + return nil, nil, err } - return f(bits) + pubkey, err = ssh.NewPublicKey(privkey.Public()) + return privkey, pubkey, err } diff --git a/client/keys_test.go b/client/keys_test.go index d98a982..6a69492 100644 --- a/client/keys_test.go +++ b/client/keys_test.go @@ -1,23 +1,30 @@ package client import ( + "crypto/rsa" "reflect" "testing" + + "golang.org/x/crypto/ed25519" ) func TestGenerateKeys(t *testing.T) { var tests = []struct { - key string - size int - want string + keytype string + keysize int + want string }{ - {"ecdsa", 256, "*ecdsa.PrivateKey"}, {"rsa", 1024, "*rsa.PrivateKey"}, - {"ed25519", 256, "*ed25519.PrivateKey"}, + {"rsa", 0, "*rsa.PrivateKey"}, + {"ecdsa", 0, "*ecdsa.PrivateKey"}, + {"ecdsa", 384, "*ecdsa.PrivateKey"}, + {"ed25519", 0, "*ed25519.PrivateKey"}, } for _, tst := range tests { - k, _, err := GenerateKey(tst.key, tst.size) + var k Key + var err error + k, _, err = GenerateKey(KeyType(tst.keytype), KeySize(tst.keysize)) if err != nil { t.Error(err) } @@ -26,3 +33,36 @@ func TestGenerateKeys(t *testing.T) { } } } + +func TestDefaultOptions(t *testing.T) { + k, _, err := GenerateKey() + if err != nil { + t.Error(err) + } + _, ok := k.(*rsa.PrivateKey) + if !ok { + t.Errorf("Unexpected key type %T, wanted *rsa.PrivateKey", k) + } +} + +func TestGenerateKeyType(t *testing.T) { + k, _, err := GenerateKey(KeyType("ed25519")) + if err != nil { + t.Error(err) + } + _, ok := k.(*ed25519.PrivateKey) + if !ok { + t.Errorf("Unexpected key type %T, wanted *ed25519.PrivateKey", k) + } +} + +func TestGenerateKeySize(t *testing.T) { + k, _, err := GenerateKey(KeySize(1024)) + if err != nil { + t.Error(err) + } + _, ok := k.(*rsa.PrivateKey) + if !ok { + t.Errorf("Unexpected key type %T, wanted *rsa.PrivateKey", k) + } +} diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index b25e36a..26c6cbf 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -36,7 +36,7 @@ func main() { fmt.Println("Error launching web browser. Go to the link in your web browser") } fmt.Println("Generating new key pair") - priv, pub, err := client.GenerateKey(c.Keytype, c.Keysize) + priv, pub, err := client.GenerateKey(client.KeyType(c.Keytype), client.KeySize(c.Keysize)) if err != nil { log.Fatalln("Error generating key pair: ", err) } -- cgit v1.2.3