aboutsummaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-04-20 21:18:25 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-04-20 21:19:20 +0100
commitb13a57ea4488d5d3ab95ae975ce5cf31632cb59c (patch)
tree8ddc66a70ae45eb1a515e834d5777ba8b1e7622e /client
parent6967fe9b4fd06e643124867ab8997bfe612c13c7 (diff)
Simplify this a bit
Diffstat (limited to 'client')
-rw-r--r--client/keys.go32
1 files changed, 18 insertions, 14 deletions
diff --git a/client/keys.go b/client/keys.go
index 4acfbb9..866b062 100644
--- a/client/keys.go
+++ b/client/keys.go
@@ -10,14 +10,17 @@ import (
"golang.org/x/crypto/ssh"
)
-const (
- rsaKey = "rsa"
- ecdsaKey = "ecdsa"
-)
-
type key interface{}
+type keyfunc func(int) (key, ssh.PublicKey, error)
+
+var (
+ keytypes = map[string]keyfunc{
+ "rsa": generateRSAKey,
+ "ecdsa": generateECDSAKey,
+ }
+)
-func generateRSAKey(bits int) (*rsa.PrivateKey, ssh.PublicKey, error) {
+func generateRSAKey(bits int) (key, ssh.PublicKey, error) {
k, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, err
@@ -29,7 +32,7 @@ func generateRSAKey(bits int) (*rsa.PrivateKey, ssh.PublicKey, error) {
return k, pub, nil
}
-func generateECDSAKey(bits int) (*ecdsa.PrivateKey, ssh.PublicKey, error) {
+func generateECDSAKey(bits int) (key, ssh.PublicKey, error) {
var curve elliptic.Curve
switch bits {
case 256:
@@ -53,12 +56,13 @@ func generateECDSAKey(bits int) (*ecdsa.PrivateKey, ssh.PublicKey, error) {
}
func generateKey(keytype string, bits int) (key, ssh.PublicKey, error) {
- switch keytype {
- case rsaKey:
- return generateRSAKey(bits)
- case ecdsaKey:
- return generateECDSAKey(bits)
- default:
- return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are [%s, %s]", keytype, rsaKey, ecdsaKey)
+ f, ok := keytypes[keytype]
+ if !ok {
+ var valid []string
+ for k, _ := range keytypes {
+ valid = append(valid, k)
+ }
+ return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, valid)
}
+ return f(bits)
}