aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/keys.go119
-rw-r--r--client/keys_test.go52
-rw-r--r--cmd/cashier/main.go2
3 files changed, 118 insertions, 55 deletions
diff --git a/client/keys.go b/client/keys.go
index 0ec0f1d..3d2fb31 100644
--- a/client/keys.go
+++ b/client/keys.go
@@ -1,56 +1,66 @@
package client
import (
+ "crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"fmt"
- "strings"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
// Key is a private key.
-type Key interface{}
-type keyfunc func(int) (Key, ssh.PublicKey, error)
+type Key crypto.Signer
-var (
- keytypes = map[string]keyfunc{
- "rsa": generateRSAKey,
- "ecdsa": generateECDSAKey,
- "ed25519": generateED25519Key,
- }
-)
+// Options for key generation.
+// Defaults will generate a 2048 bit RSA key.
+type options struct {
+ keytype string
+ size int
+}
-func generateED25519Key(bits int) (Key, ssh.PublicKey, error) {
- p, k, err := ed25519.GenerateKey(rand.Reader)
- if err != nil {
- return nil, nil, err
- }
- pub, err := ssh.NewPublicKey(p)
- if err != nil {
- return nil, nil, err
- }
- return &k, pub, nil
+var defaultOptions = options{
+ keytype: "rsa",
+ size: 0, // Different key types have different default sizes.
}
-func generateRSAKey(bits int) (Key, ssh.PublicKey, error) {
- k, err := rsa.GenerateKey(rand.Reader, bits)
- if err != nil {
- return nil, nil, err
+// A KeyOption is used to generate keys of different types and sizes.
+type KeyOption func(*options)
+
+// KeyType sets the type of key to generate.
+// Valid types are: "rsa", "ecdsa", "ed25519".
+// Default is "rsa"
+func KeyType(keyType string) KeyOption {
+ return func(o *options) {
+ o.keytype = keyType
}
- pub, err := ssh.NewPublicKey(&k.PublicKey)
- if err != nil {
- return nil, nil, err
+}
+
+// KeySize sets the size of the key in bits.
+// RSA keys must be a minimum of 1024 bits. The default is 2048 bits.
+// ECDSA keys must be one of 256, 384, or 521 bits. The default is 256 bits.
+// Ed25519 keys are of a fixed size. This option is ignored.
+func KeySize(size int) KeyOption {
+ return func(o *options) {
+ o.size = size
}
- return k, pub, nil
}
-func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) {
+func generateED25519Key() (Key, error) {
+ _, k, err := ed25519.GenerateKey(rand.Reader)
+ return &k, err
+}
+
+func generateRSAKey(size int) (Key, error) {
+ return rsa.GenerateKey(rand.Reader, size)
+}
+
+func generateECDSAKey(size int) (Key, error) {
var curve elliptic.Curve
- switch bits {
+ switch size {
case 256:
curve = elliptic.P256()
case 384:
@@ -58,28 +68,41 @@ func generateECDSAKey(bits int) (Key, ssh.PublicKey, error) {
case 521:
curve = elliptic.P521()
default:
- return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'")
- }
- k, err := ecdsa.GenerateKey(curve, rand.Reader)
- if err != nil {
- return nil, nil, err
- }
- pub, err := ssh.NewPublicKey(&k.PublicKey)
- if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("Unsupported key size: %d. Valid sizes are '256', '384', '521'", size)
}
- return k, pub, nil
+ return ecdsa.GenerateKey(curve, rand.Reader)
}
// GenerateKey generates a ssh key-pair according to the type and size specified.
-func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) {
- f, ok := keytypes[keytype]
- if !ok {
- var valid []string
- for k := range keytypes {
- valid = append(valid, k)
+func GenerateKey(options ...func(*options)) (Key, ssh.PublicKey, error) {
+ var privkey Key
+ var pubkey ssh.PublicKey
+ var err error
+
+ config := defaultOptions
+ for _, o := range options {
+ o(&config)
+ }
+
+ switch config.keytype {
+ case "rsa":
+ if config.size == 0 {
+ config.size = 2048
+ }
+ privkey, err = generateRSAKey(config.size)
+ case "ecdsa":
+ if config.size == 0 {
+ config.size = 256
}
- return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, strings.Join(valid, "|"))
+ privkey, err = generateECDSAKey(config.size)
+ case "ed25519":
+ privkey, err = generateED25519Key()
+ default:
+ privkey, err = generateRSAKey(config.size)
+ }
+ if err != nil {
+ return nil, nil, err
}
- return f(bits)
+ pubkey, err = ssh.NewPublicKey(privkey.Public())
+ return privkey, pubkey, err
}
diff --git a/client/keys_test.go b/client/keys_test.go
index d98a982..6a69492 100644
--- a/client/keys_test.go
+++ b/client/keys_test.go
@@ -1,23 +1,30 @@
package client
import (
+ "crypto/rsa"
"reflect"
"testing"
+
+ "golang.org/x/crypto/ed25519"
)
func TestGenerateKeys(t *testing.T) {
var tests = []struct {
- key string
- size int
- want string
+ keytype string
+ keysize int
+ want string
}{
- {"ecdsa", 256, "*ecdsa.PrivateKey"},
{"rsa", 1024, "*rsa.PrivateKey"},
- {"ed25519", 256, "*ed25519.PrivateKey"},
+ {"rsa", 0, "*rsa.PrivateKey"},
+ {"ecdsa", 0, "*ecdsa.PrivateKey"},
+ {"ecdsa", 384, "*ecdsa.PrivateKey"},
+ {"ed25519", 0, "*ed25519.PrivateKey"},
}
for _, tst := range tests {
- k, _, err := GenerateKey(tst.key, tst.size)
+ var k Key
+ var err error
+ k, _, err = GenerateKey(KeyType(tst.keytype), KeySize(tst.keysize))
if err != nil {
t.Error(err)
}
@@ -26,3 +33,36 @@ func TestGenerateKeys(t *testing.T) {
}
}
}
+
+func TestDefaultOptions(t *testing.T) {
+ k, _, err := GenerateKey()
+ if err != nil {
+ t.Error(err)
+ }
+ _, ok := k.(*rsa.PrivateKey)
+ if !ok {
+ t.Errorf("Unexpected key type %T, wanted *rsa.PrivateKey", k)
+ }
+}
+
+func TestGenerateKeyType(t *testing.T) {
+ k, _, err := GenerateKey(KeyType("ed25519"))
+ if err != nil {
+ t.Error(err)
+ }
+ _, ok := k.(*ed25519.PrivateKey)
+ if !ok {
+ t.Errorf("Unexpected key type %T, wanted *ed25519.PrivateKey", k)
+ }
+}
+
+func TestGenerateKeySize(t *testing.T) {
+ k, _, err := GenerateKey(KeySize(1024))
+ if err != nil {
+ t.Error(err)
+ }
+ _, ok := k.(*rsa.PrivateKey)
+ if !ok {
+ t.Errorf("Unexpected key type %T, wanted *rsa.PrivateKey", k)
+ }
+}
diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go
index b25e36a..26c6cbf 100644
--- a/cmd/cashier/main.go
+++ b/cmd/cashier/main.go
@@ -36,7 +36,7 @@ func main() {
fmt.Println("Error launching web browser. Go to the link in your web browser")
}
fmt.Println("Generating new key pair")
- priv, pub, err := client.GenerateKey(c.Keytype, c.Keysize)
+ priv, pub, err := client.GenerateKey(client.KeyType(c.Keytype), client.KeySize(c.Keysize))
if err != nil {
log.Fatalln("Error generating key pair: ", err)
}