aboutsummaryrefslogtreecommitdiff
path: root/server/store
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-06-19 23:44:25 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-07-03 18:01:24 +0100
commitdee5a19d36554a8f9a365efd65d13b134889bf63 (patch)
tree41103a2d3665d604fe22dcd16d110ed56c466f6d /server/store
parent6e7dfa0df6b102219817e26095f2ba636cd9288c (diff)
first pass at a certificate store
Diffstat (limited to 'server/store')
-rw-r--r--server/store/mem.go86
-rw-r--r--server/store/mysql.go160
-rw-r--r--server/store/store.go39
-rw-r--r--server/store/store_test.go102
4 files changed, 387 insertions, 0 deletions
diff --git a/server/store/mem.go b/server/store/mem.go
new file mode 100644
index 0000000..8b78e27
--- /dev/null
+++ b/server/store/mem.go
@@ -0,0 +1,86 @@
+package store
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+)
+
+type memoryStore struct {
+ sync.Mutex
+ certs map[string]*CertRecord
+}
+
+func (ms *memoryStore) Get(id string) (*CertRecord, error) {
+ ms.Lock()
+ defer ms.Unlock()
+ r, ok := ms.certs[id]
+ if !ok {
+ return nil, fmt.Errorf("unknown cert %s", id)
+ }
+ return r, nil
+}
+
+func (ms *memoryStore) SetCert(cert *ssh.Certificate) error {
+ return ms.SetRecord(parseCertificate(cert))
+}
+
+func (ms *memoryStore) SetRecord(record *CertRecord) error {
+ ms.Lock()
+ defer ms.Unlock()
+ ms.certs[record.KeyID] = record
+ return nil
+}
+
+func (ms *memoryStore) List() ([]*CertRecord, error) {
+ var records []*CertRecord
+ ms.Lock()
+ defer ms.Unlock()
+ for _, value := range ms.certs {
+ records = append(records, value)
+ }
+ return records, nil
+}
+
+func (ms *memoryStore) Revoke(id string) error {
+ r, err := ms.Get(id)
+ if err != nil {
+ return err
+ }
+ r.Revoked = true
+ ms.SetRecord(r)
+ return nil
+}
+
+func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
+ var revoked []*CertRecord
+ all, _ := ms.List()
+ for _, r := range all {
+ if r.Revoked && uint64(time.Now().UTC().Unix()) <= r.Expires {
+ revoked = append(revoked, r)
+ }
+ }
+ return revoked, nil
+}
+
+func (ms *memoryStore) Close() error {
+ ms.Lock()
+ defer ms.Unlock()
+ ms.certs = nil
+ return nil
+}
+
+func (ms *memoryStore) clear() {
+ for k := range ms.certs {
+ delete(ms.certs, k)
+ }
+}
+
+// NewMemoryStore returns an in-memory CertStorer.
+func NewMemoryStore() CertStorer {
+ return &memoryStore{
+ certs: make(map[string]*CertRecord),
+ }
+}
diff --git a/server/store/mysql.go b/server/store/mysql.go
new file mode 100644
index 0000000..b108fdc
--- /dev/null
+++ b/server/store/mysql.go
@@ -0,0 +1,160 @@
+package store
+
+import (
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "github.com/go-sql-driver/mysql"
+)
+
+type mysqlDB struct {
+ conn *sql.DB
+
+ get *sql.Stmt
+ set *sql.Stmt
+ list *sql.Stmt
+ revoke *sql.Stmt
+ revoked *sql.Stmt
+}
+
+func parseConfig(config string) string {
+ s := strings.Split(config, ":")
+ if len(s) == 4 {
+ s = append(s, "3306")
+ }
+ _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4]
+ c := &mysql.Config{
+ User: user,
+ Passwd: passwd,
+ Net: "tcp",
+ Addr: fmt.Sprintf("%s:%s", host, port),
+ DBName: "certs",
+ }
+ return c.FormatDSN()
+}
+
+// NewMySQLStore returns a MySQL CertStorer.
+func NewMySQLStore(config string) (CertStorer, error) {
+ conn, err := sql.Open("mysql", parseConfig(config))
+ if err != nil {
+ return nil, fmt.Errorf("mysql: could not get a connection: %v", err)
+ }
+ if err := conn.Ping(); err != nil {
+ conn.Close()
+ return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err)
+ }
+
+ db := &mysqlDB{
+ conn: conn,
+ }
+
+ if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil {
+ return nil, fmt.Errorf("mysql: prepare set: %v", err)
+ }
+ if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
+ return nil, fmt.Errorf("mysql: prepare get: %v", err)
+ }
+ if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+ return nil, fmt.Errorf("mysql: prepare list: %v", err)
+ }
+ if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil {
+ return nil, fmt.Errorf("mysql: prepare revoke: %v", err)
+ }
+ if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil {
+ return nil, fmt.Errorf("mysql: prepare revoked: %v", err)
+ }
+ return db, nil
+}
+
+// rowScanner is implemented by sql.Row and sql.Rows
+type rowScanner interface {
+ Scan(dest ...interface{}) error
+}
+
+func scanCert(s rowScanner) (*CertRecord, error) {
+ var (
+ keyID sql.NullString
+ principals sql.NullString
+ createdAt sql.NullInt64
+ expires sql.NullInt64
+ revoked sql.NullBool
+ raw sql.NullString
+ )
+ if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil {
+ return nil, err
+ }
+ var p []string
+ if err := json.Unmarshal([]byte(principals.String), &p); err != nil {
+ return nil, err
+ }
+ return &CertRecord{
+ KeyID: keyID.String,
+ Principals: p,
+ CreatedAt: uint64(createdAt.Int64),
+ Expires: uint64(expires.Int64),
+ Revoked: revoked.Bool,
+ Raw: raw.String,
+ }, nil
+}
+
+func (db *mysqlDB) Get(id string) (*CertRecord, error) {
+ return scanCert(db.get.QueryRow(id))
+}
+
+func (db *mysqlDB) SetCert(cert *ssh.Certificate) error {
+ return db.SetRecord(parseCertificate(cert))
+}
+
+func (db *mysqlDB) SetRecord(rec *CertRecord) error {
+ principals, err := json.Marshal(rec.Principals)
+ if err != nil {
+ return err
+ }
+ _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw)
+ return err
+}
+
+func (db *mysqlDB) List() ([]*CertRecord, error) {
+ var recs []*CertRecord
+ rows, _ := db.list.Query()
+ defer rows.Close()
+ for rows.Next() {
+ cert, err := scanCert(rows)
+ if err != nil {
+ return nil, err
+ }
+ recs = append(recs, cert)
+ }
+ return recs, nil
+}
+
+func (db *mysqlDB) Revoke(id string) error {
+ _, err := db.revoke.Exec(id)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) {
+ var recs []*CertRecord
+ rows, _ := db.revoked.Query(time.Now().UTC().Unix())
+ defer rows.Close()
+ for rows.Next() {
+ cert, err := scanCert(rows)
+ if err != nil {
+ return nil, err
+ }
+ recs = append(recs, cert)
+ }
+ return recs, nil
+}
+
+func (db *mysqlDB) Close() error {
+ return db.conn.Close()
+}
diff --git a/server/store/store.go b/server/store/store.go
new file mode 100644
index 0000000..ad4922a
--- /dev/null
+++ b/server/store/store.go
@@ -0,0 +1,39 @@
+package store
+
+import (
+ "golang.org/x/crypto/ssh"
+
+ "github.com/nsheridan/cashier/server/certutil"
+)
+
+// CertStorer records issued certs in a persistent store for audit and
+// revocation purposes.
+type CertStorer interface {
+ Get(id string) (*CertRecord, error)
+ SetCert(cert *ssh.Certificate) error
+ SetRecord(record *CertRecord) error
+ List() ([]*CertRecord, error)
+ Revoke(id string) error
+ GetRevoked() ([]*CertRecord, error)
+ Close() error
+}
+
+// A CertRecord is a representation of a ssh certificate used by a CertStorer.
+type CertRecord struct {
+ KeyID string
+ Principals []string
+ CreatedAt uint64
+ Expires uint64
+ Revoked bool
+ Raw string
+}
+
+func parseCertificate(cert *ssh.Certificate) *CertRecord {
+ return &CertRecord{
+ KeyID: cert.KeyId,
+ Principals: cert.ValidPrincipals,
+ CreatedAt: cert.ValidAfter,
+ Expires: cert.ValidBefore,
+ Raw: certutil.GetPublicKey(cert),
+ }
+}
diff --git a/server/store/store_test.go b/server/store/store_test.go
new file mode 100644
index 0000000..d3aa3c1
--- /dev/null
+++ b/server/store/store_test.go
@@ -0,0 +1,102 @@
+package store
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/nsheridan/cashier/testdata"
+ "github.com/stretchr/testify/assert"
+
+ "golang.org/x/crypto/ssh"
+)
+
+func TestParseCertificate(t *testing.T) {
+ a := assert.New(t)
+ now := uint64(time.Now().Unix())
+ r, _ := rsa.GenerateKey(rand.Reader, 1024)
+ pub, _ := ssh.NewPublicKey(r.Public())
+ c := &ssh.Certificate{
+ KeyId: "id",
+ ValidPrincipals: []string{"principal"},
+ ValidBefore: now,
+ CertType: ssh.UserCert,
+ Key: pub,
+ }
+ s, _ := ssh.NewSignerFromKey(r)
+ c.SignCert(rand.Reader, s)
+ rec := parseCertificate(c)
+
+ a.Equal(c.KeyId, rec.KeyID)
+ a.Equal(c.ValidPrincipals, rec.Principals)
+ a.Equal(c.ValidBefore, rec.Expires)
+ a.Equal(c.ValidAfter, rec.CreatedAt)
+}
+
+func testStore(t *testing.T, db CertStorer) {
+ defer db.Close()
+
+ ids := []string{"a", "b"}
+ for _, id := range ids {
+ r := &CertRecord{
+ KeyID: id,
+ Expires: uint64(time.Now().UTC().Unix()) - 10,
+ }
+ if err := db.SetRecord(r); err != nil {
+ t.Error(err)
+ }
+ }
+ recs, err := db.List()
+ if err != nil {
+ t.Error(err)
+ }
+ if len(recs) != len(ids) {
+ t.Errorf("Want %d records, got %d", len(ids), len(recs))
+ }
+
+ c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert)
+ cert := c.(*ssh.Certificate)
+ cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix())
+ if err := db.SetCert(cert); err != nil {
+ t.Error(err)
+ }
+
+ if _, err := db.Get("key"); err != nil {
+ t.Error(err)
+ }
+ if err := db.Revoke("key"); err != nil {
+ t.Error(err)
+ }
+
+ // A revoked key shouldn't get returned if it's already expired
+ db.Revoke("a")
+
+ revoked, err := db.GetRevoked()
+ if err != nil {
+ t.Error(err)
+ }
+ for _, k := range revoked {
+ if k.KeyID != "key" {
+ t.Errorf("Unexpected key: %s", k.KeyID)
+ }
+ }
+}
+
+func TestMemoryStore(t *testing.T) {
+ db := NewMemoryStore()
+ testStore(t, db)
+}
+
+func TestMySQLStore(t *testing.T) {
+ config := os.Getenv("MYSQL_TEST_CONFIG")
+ if config == "" {
+ t.Skip("No MYSQL_TEST_CONFIG environment variable")
+ }
+ db, err := NewMySQLStore(config)
+ if err != nil {
+ t.Error(err)
+ }
+ testStore(t, db)
+}