From dee5a19d36554a8f9a365efd65d13b134889bf63 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 19 Jun 2016 23:44:25 +0100 Subject: first pass at a certificate store --- server/certutil/util.go | 10 +++ server/certutil/util_test.go | 15 ++++ server/config/config.go | 1 + server/signer/signer.go | 11 ++- server/signer/signer_test.go | 7 +- server/store/mem.go | 86 +++++++++++++++++++++++ server/store/mysql.go | 160 +++++++++++++++++++++++++++++++++++++++++++ server/store/store.go | 39 +++++++++++ server/store/store_test.go | 102 +++++++++++++++++++++++++++ 9 files changed, 418 insertions(+), 13 deletions(-) create mode 100644 server/certutil/util.go create mode 100644 server/certutil/util_test.go create mode 100644 server/store/mem.go create mode 100644 server/store/mysql.go create mode 100644 server/store/store.go create mode 100644 server/store/store_test.go (limited to 'server') diff --git a/server/certutil/util.go b/server/certutil/util.go new file mode 100644 index 0000000..eb1900b --- /dev/null +++ b/server/certutil/util.go @@ -0,0 +1,10 @@ +package certutil + +import "golang.org/x/crypto/ssh" + +// GetPublicKey marshals a ssh certificate to a string. +func GetPublicKey(cert *ssh.Certificate) string { + marshaled := ssh.MarshalAuthorizedKey(cert) + // Strip trailing newline + return string(marshaled[:len(marshaled)-1]) +} diff --git a/server/certutil/util_test.go b/server/certutil/util_test.go new file mode 100644 index 0000000..abb8f10 --- /dev/null +++ b/server/certutil/util_test.go @@ -0,0 +1,15 @@ +package certutil + +import ( + "testing" + + "github.com/nsheridan/cashier/testdata" + "golang.org/x/crypto/ssh" +) + +func TestGetPublicKey(t *testing.T) { + c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert) + if GetPublicKey(c.(*ssh.Certificate)) != string(testdata.Cert) { + t.Fail() + } +} diff --git a/server/config/config.go b/server/config/config.go index 0ef417f..674ceee 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -32,6 +32,7 @@ type Server struct { Port int `mapstructure:"port"` CookieSecret string `mapstructure:"cookie_secret"` HTTPLogFile string `mapstructure:"http_logfile"` + Datastore string `mapstructure:"datastore"` } // Auth holds the configuration specific to the OAuth provider. diff --git a/server/signer/signer.go b/server/signer/signer.go index 1be6d75..a3f056a 100644 --- a/server/signer/signer.go +++ b/server/signer/signer.go @@ -25,10 +25,10 @@ type KeySigner struct { } // SignUserKey returns a signed ssh certificate. -func (s *KeySigner) SignUserKey(req *lib.SignRequest) (string, error) { +func (s *KeySigner) SignUserKey(req *lib.SignRequest) (*ssh.Certificate, error) { pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(req.Key)) if err != nil { - return "", err + return nil, err } expires := time.Now().UTC().Add(s.validity) if req.ValidUntil.After(expires) { @@ -45,13 +45,10 @@ func (s *KeySigner) SignUserKey(req *lib.SignRequest) (string, error) { cert.ValidPrincipals = append(cert.ValidPrincipals, s.principals...) cert.Extensions = s.permissions if err := cert.SignCert(rand.Reader, s.ca); err != nil { - return "", err + return nil, err } - marshaled := ssh.MarshalAuthorizedKey(cert) - // Remove the trailing newline. - marshaled = marshaled[:len(marshaled)-1] log.Printf("Issued cert id: %s principals: %s fp: %s valid until: %s\n", cert.KeyId, cert.ValidPrincipals, fingerprint(pubkey), time.Unix(int64(cert.ValidBefore), 0).UTC()) - return string(marshaled), nil + return cert, nil } func makeperms(perms []string) map[string]string { diff --git a/server/signer/signer_test.go b/server/signer/signer_test.go index 08f9025..a80e64a 100644 --- a/server/signer/signer_test.go +++ b/server/signer/signer_test.go @@ -27,15 +27,10 @@ func TestCert(t *testing.T) { Principal: "gopher1", ValidUntil: time.Now().Add(1 * time.Hour), } - ret, err := signer.SignUserKey(r) + cert, err := signer.SignUserKey(r) if err != nil { t.Fatal(err) } - c, _, _, _, err := ssh.ParseAuthorizedKey([]byte(ret)) - cert, ok := c.(*ssh.Certificate) - if !ok { - t.Fatalf("Expected type *ssh.Certificate, got %v (%T)", cert, cert) - } if !bytes.Equal(cert.SignatureKey.Marshal(), signer.ca.PublicKey().Marshal()) { t.Fatal("Cert signer and server signer don't match") } diff --git a/server/store/mem.go b/server/store/mem.go new file mode 100644 index 0000000..8b78e27 --- /dev/null +++ b/server/store/mem.go @@ -0,0 +1,86 @@ +package store + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +type memoryStore struct { + sync.Mutex + certs map[string]*CertRecord +} + +func (ms *memoryStore) Get(id string) (*CertRecord, error) { + ms.Lock() + defer ms.Unlock() + r, ok := ms.certs[id] + if !ok { + return nil, fmt.Errorf("unknown cert %s", id) + } + return r, nil +} + +func (ms *memoryStore) SetCert(cert *ssh.Certificate) error { + return ms.SetRecord(parseCertificate(cert)) +} + +func (ms *memoryStore) SetRecord(record *CertRecord) error { + ms.Lock() + defer ms.Unlock() + ms.certs[record.KeyID] = record + return nil +} + +func (ms *memoryStore) List() ([]*CertRecord, error) { + var records []*CertRecord + ms.Lock() + defer ms.Unlock() + for _, value := range ms.certs { + records = append(records, value) + } + return records, nil +} + +func (ms *memoryStore) Revoke(id string) error { + r, err := ms.Get(id) + if err != nil { + return err + } + r.Revoked = true + ms.SetRecord(r) + return nil +} + +func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { + var revoked []*CertRecord + all, _ := ms.List() + for _, r := range all { + if r.Revoked && uint64(time.Now().UTC().Unix()) <= r.Expires { + revoked = append(revoked, r) + } + } + return revoked, nil +} + +func (ms *memoryStore) Close() error { + ms.Lock() + defer ms.Unlock() + ms.certs = nil + return nil +} + +func (ms *memoryStore) clear() { + for k := range ms.certs { + delete(ms.certs, k) + } +} + +// NewMemoryStore returns an in-memory CertStorer. +func NewMemoryStore() CertStorer { + return &memoryStore{ + certs: make(map[string]*CertRecord), + } +} diff --git a/server/store/mysql.go b/server/store/mysql.go new file mode 100644 index 0000000..b108fdc --- /dev/null +++ b/server/store/mysql.go @@ -0,0 +1,160 @@ +package store + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/go-sql-driver/mysql" +) + +type mysqlDB struct { + conn *sql.DB + + get *sql.Stmt + set *sql.Stmt + list *sql.Stmt + revoke *sql.Stmt + revoked *sql.Stmt +} + +func parseConfig(config string) string { + s := strings.Split(config, ":") + if len(s) == 4 { + s = append(s, "3306") + } + _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] + c := &mysql.Config{ + User: user, + Passwd: passwd, + Net: "tcp", + Addr: fmt.Sprintf("%s:%s", host, port), + DBName: "certs", + } + return c.FormatDSN() +} + +// NewMySQLStore returns a MySQL CertStorer. +func NewMySQLStore(config string) (CertStorer, error) { + conn, err := sql.Open("mysql", parseConfig(config)) + if err != nil { + return nil, fmt.Errorf("mysql: could not get a connection: %v", err) + } + if err := conn.Ping(); err != nil { + conn.Close() + return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err) + } + + db := &mysqlDB{ + conn: conn, + } + + if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil { + return nil, fmt.Errorf("mysql: prepare set: %v", err) + } + if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { + return nil, fmt.Errorf("mysql: prepare get: %v", err) + } + if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { + return nil, fmt.Errorf("mysql: prepare list: %v", err) + } + if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil { + return nil, fmt.Errorf("mysql: prepare revoke: %v", err) + } + if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil { + return nil, fmt.Errorf("mysql: prepare revoked: %v", err) + } + return db, nil +} + +// rowScanner is implemented by sql.Row and sql.Rows +type rowScanner interface { + Scan(dest ...interface{}) error +} + +func scanCert(s rowScanner) (*CertRecord, error) { + var ( + keyID sql.NullString + principals sql.NullString + createdAt sql.NullInt64 + expires sql.NullInt64 + revoked sql.NullBool + raw sql.NullString + ) + if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil { + return nil, err + } + var p []string + if err := json.Unmarshal([]byte(principals.String), &p); err != nil { + return nil, err + } + return &CertRecord{ + KeyID: keyID.String, + Principals: p, + CreatedAt: uint64(createdAt.Int64), + Expires: uint64(expires.Int64), + Revoked: revoked.Bool, + Raw: raw.String, + }, nil +} + +func (db *mysqlDB) Get(id string) (*CertRecord, error) { + return scanCert(db.get.QueryRow(id)) +} + +func (db *mysqlDB) SetCert(cert *ssh.Certificate) error { + return db.SetRecord(parseCertificate(cert)) +} + +func (db *mysqlDB) SetRecord(rec *CertRecord) error { + principals, err := json.Marshal(rec.Principals) + if err != nil { + return err + } + _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw) + return err +} + +func (db *mysqlDB) List() ([]*CertRecord, error) { + var recs []*CertRecord + rows, _ := db.list.Query() + defer rows.Close() + for rows.Next() { + cert, err := scanCert(rows) + if err != nil { + return nil, err + } + recs = append(recs, cert) + } + return recs, nil +} + +func (db *mysqlDB) Revoke(id string) error { + _, err := db.revoke.Exec(id) + if err != nil { + return err + } + return nil +} + +func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { + var recs []*CertRecord + rows, _ := db.revoked.Query(time.Now().UTC().Unix()) + defer rows.Close() + for rows.Next() { + cert, err := scanCert(rows) + if err != nil { + return nil, err + } + recs = append(recs, cert) + } + return recs, nil +} + +func (db *mysqlDB) Close() error { + return db.conn.Close() +} diff --git a/server/store/store.go b/server/store/store.go new file mode 100644 index 0000000..ad4922a --- /dev/null +++ b/server/store/store.go @@ -0,0 +1,39 @@ +package store + +import ( + "golang.org/x/crypto/ssh" + + "github.com/nsheridan/cashier/server/certutil" +) + +// CertStorer records issued certs in a persistent store for audit and +// revocation purposes. +type CertStorer interface { + Get(id string) (*CertRecord, error) + SetCert(cert *ssh.Certificate) error + SetRecord(record *CertRecord) error + List() ([]*CertRecord, error) + Revoke(id string) error + GetRevoked() ([]*CertRecord, error) + Close() error +} + +// A CertRecord is a representation of a ssh certificate used by a CertStorer. +type CertRecord struct { + KeyID string + Principals []string + CreatedAt uint64 + Expires uint64 + Revoked bool + Raw string +} + +func parseCertificate(cert *ssh.Certificate) *CertRecord { + return &CertRecord{ + KeyID: cert.KeyId, + Principals: cert.ValidPrincipals, + CreatedAt: cert.ValidAfter, + Expires: cert.ValidBefore, + Raw: certutil.GetPublicKey(cert), + } +} diff --git a/server/store/store_test.go b/server/store/store_test.go new file mode 100644 index 0000000..d3aa3c1 --- /dev/null +++ b/server/store/store_test.go @@ -0,0 +1,102 @@ +package store + +import ( + "crypto/rand" + "crypto/rsa" + "os" + "testing" + "time" + + "github.com/nsheridan/cashier/testdata" + "github.com/stretchr/testify/assert" + + "golang.org/x/crypto/ssh" +) + +func TestParseCertificate(t *testing.T) { + a := assert.New(t) + now := uint64(time.Now().Unix()) + r, _ := rsa.GenerateKey(rand.Reader, 1024) + pub, _ := ssh.NewPublicKey(r.Public()) + c := &ssh.Certificate{ + KeyId: "id", + ValidPrincipals: []string{"principal"}, + ValidBefore: now, + CertType: ssh.UserCert, + Key: pub, + } + s, _ := ssh.NewSignerFromKey(r) + c.SignCert(rand.Reader, s) + rec := parseCertificate(c) + + a.Equal(c.KeyId, rec.KeyID) + a.Equal(c.ValidPrincipals, rec.Principals) + a.Equal(c.ValidBefore, rec.Expires) + a.Equal(c.ValidAfter, rec.CreatedAt) +} + +func testStore(t *testing.T, db CertStorer) { + defer db.Close() + + ids := []string{"a", "b"} + for _, id := range ids { + r := &CertRecord{ + KeyID: id, + Expires: uint64(time.Now().UTC().Unix()) - 10, + } + if err := db.SetRecord(r); err != nil { + t.Error(err) + } + } + recs, err := db.List() + if err != nil { + t.Error(err) + } + if len(recs) != len(ids) { + t.Errorf("Want %d records, got %d", len(ids), len(recs)) + } + + c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert) + cert := c.(*ssh.Certificate) + cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix()) + if err := db.SetCert(cert); err != nil { + t.Error(err) + } + + if _, err := db.Get("key"); err != nil { + t.Error(err) + } + if err := db.Revoke("key"); err != nil { + t.Error(err) + } + + // A revoked key shouldn't get returned if it's already expired + db.Revoke("a") + + revoked, err := db.GetRevoked() + if err != nil { + t.Error(err) + } + for _, k := range revoked { + if k.KeyID != "key" { + t.Errorf("Unexpected key: %s", k.KeyID) + } + } +} + +func TestMemoryStore(t *testing.T) { + db := NewMemoryStore() + testStore(t, db) +} + +func TestMySQLStore(t *testing.T) { + config := os.Getenv("MYSQL_TEST_CONFIG") + if config == "" { + t.Skip("No MYSQL_TEST_CONFIG environment variable") + } + db, err := NewMySQLStore(config) + if err != nil { + t.Error(err) + } + testStore(t, db) +} -- cgit v1.2.3