diff options
| author | Niall Sheridan <nsheridan@gmail.com> | 2016-06-19 23:44:25 +0100 | 
|---|---|---|
| committer | Niall Sheridan <nsheridan@gmail.com> | 2016-07-03 18:01:24 +0100 | 
| commit | dee5a19d36554a8f9a365efd65d13b134889bf63 (patch) | |
| tree | 41103a2d3665d604fe22dcd16d110ed56c466f6d /server | |
| parent | 6e7dfa0df6b102219817e26095f2ba636cd9288c (diff) | |
first pass at a certificate store
Diffstat (limited to 'server')
| -rw-r--r-- | server/certutil/util.go | 10 | ||||
| -rw-r--r-- | server/certutil/util_test.go | 15 | ||||
| -rw-r--r-- | server/config/config.go | 1 | ||||
| -rw-r--r-- | server/signer/signer.go | 11 | ||||
| -rw-r--r-- | server/signer/signer_test.go | 7 | ||||
| -rw-r--r-- | server/store/mem.go | 86 | ||||
| -rw-r--r-- | server/store/mysql.go | 160 | ||||
| -rw-r--r-- | server/store/store.go | 39 | ||||
| -rw-r--r-- | server/store/store_test.go | 102 | 
9 files changed, 418 insertions, 13 deletions
diff --git a/server/certutil/util.go b/server/certutil/util.go new file mode 100644 index 0000000..eb1900b --- /dev/null +++ b/server/certutil/util.go @@ -0,0 +1,10 @@ +package certutil + +import "golang.org/x/crypto/ssh" + +// GetPublicKey marshals a ssh certificate to a string. +func GetPublicKey(cert *ssh.Certificate) string { +	marshaled := ssh.MarshalAuthorizedKey(cert) +	// Strip trailing newline +	return string(marshaled[:len(marshaled)-1]) +} diff --git a/server/certutil/util_test.go b/server/certutil/util_test.go new file mode 100644 index 0000000..abb8f10 --- /dev/null +++ b/server/certutil/util_test.go @@ -0,0 +1,15 @@ +package certutil + +import ( +	"testing" + +	"github.com/nsheridan/cashier/testdata" +	"golang.org/x/crypto/ssh" +) + +func TestGetPublicKey(t *testing.T) { +	c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert) +	if GetPublicKey(c.(*ssh.Certificate)) != string(testdata.Cert) { +		t.Fail() +	} +} diff --git a/server/config/config.go b/server/config/config.go index 0ef417f..674ceee 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -32,6 +32,7 @@ type Server struct {  	Port         int    `mapstructure:"port"`  	CookieSecret string `mapstructure:"cookie_secret"`  	HTTPLogFile  string `mapstructure:"http_logfile"` +	Datastore    string `mapstructure:"datastore"`  }  // Auth holds the configuration specific to the OAuth provider. diff --git a/server/signer/signer.go b/server/signer/signer.go index 1be6d75..a3f056a 100644 --- a/server/signer/signer.go +++ b/server/signer/signer.go @@ -25,10 +25,10 @@ type KeySigner struct {  }  // SignUserKey returns a signed ssh certificate. -func (s *KeySigner) SignUserKey(req *lib.SignRequest) (string, error) { +func (s *KeySigner) SignUserKey(req *lib.SignRequest) (*ssh.Certificate, error) {  	pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(req.Key))  	if err != nil { -		return "", err +		return nil, err  	}  	expires := time.Now().UTC().Add(s.validity)  	if req.ValidUntil.After(expires) { @@ -45,13 +45,10 @@ func (s *KeySigner) SignUserKey(req *lib.SignRequest) (string, error) {  	cert.ValidPrincipals = append(cert.ValidPrincipals, s.principals...)  	cert.Extensions = s.permissions  	if err := cert.SignCert(rand.Reader, s.ca); err != nil { -		return "", err +		return nil, err  	} -	marshaled := ssh.MarshalAuthorizedKey(cert) -	// Remove the trailing newline. -	marshaled = marshaled[:len(marshaled)-1]  	log.Printf("Issued cert id: %s principals: %s fp: %s valid until: %s\n", cert.KeyId, cert.ValidPrincipals, fingerprint(pubkey), time.Unix(int64(cert.ValidBefore), 0).UTC()) -	return string(marshaled), nil +	return cert, nil  }  func makeperms(perms []string) map[string]string { diff --git a/server/signer/signer_test.go b/server/signer/signer_test.go index 08f9025..a80e64a 100644 --- a/server/signer/signer_test.go +++ b/server/signer/signer_test.go @@ -27,15 +27,10 @@ func TestCert(t *testing.T) {  		Principal:  "gopher1",  		ValidUntil: time.Now().Add(1 * time.Hour),  	} -	ret, err := signer.SignUserKey(r) +	cert, err := signer.SignUserKey(r)  	if err != nil {  		t.Fatal(err)  	} -	c, _, _, _, err := ssh.ParseAuthorizedKey([]byte(ret)) -	cert, ok := c.(*ssh.Certificate) -	if !ok { -		t.Fatalf("Expected type *ssh.Certificate, got %v (%T)", cert, cert) -	}  	if !bytes.Equal(cert.SignatureKey.Marshal(), signer.ca.PublicKey().Marshal()) {  		t.Fatal("Cert signer and server signer don't match")  	} diff --git a/server/store/mem.go b/server/store/mem.go new file mode 100644 index 0000000..8b78e27 --- /dev/null +++ b/server/store/mem.go @@ -0,0 +1,86 @@ +package store + +import ( +	"fmt" +	"sync" +	"time" + +	"golang.org/x/crypto/ssh" +) + +type memoryStore struct { +	sync.Mutex +	certs map[string]*CertRecord +} + +func (ms *memoryStore) Get(id string) (*CertRecord, error) { +	ms.Lock() +	defer ms.Unlock() +	r, ok := ms.certs[id] +	if !ok { +		return nil, fmt.Errorf("unknown cert %s", id) +	} +	return r, nil +} + +func (ms *memoryStore) SetCert(cert *ssh.Certificate) error { +	return ms.SetRecord(parseCertificate(cert)) +} + +func (ms *memoryStore) SetRecord(record *CertRecord) error { +	ms.Lock() +	defer ms.Unlock() +	ms.certs[record.KeyID] = record +	return nil +} + +func (ms *memoryStore) List() ([]*CertRecord, error) { +	var records []*CertRecord +	ms.Lock() +	defer ms.Unlock() +	for _, value := range ms.certs { +		records = append(records, value) +	} +	return records, nil +} + +func (ms *memoryStore) Revoke(id string) error { +	r, err := ms.Get(id) +	if err != nil { +		return err +	} +	r.Revoked = true +	ms.SetRecord(r) +	return nil +} + +func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { +	var revoked []*CertRecord +	all, _ := ms.List() +	for _, r := range all { +		if r.Revoked && uint64(time.Now().UTC().Unix()) <= r.Expires { +			revoked = append(revoked, r) +		} +	} +	return revoked, nil +} + +func (ms *memoryStore) Close() error { +	ms.Lock() +	defer ms.Unlock() +	ms.certs = nil +	return nil +} + +func (ms *memoryStore) clear() { +	for k := range ms.certs { +		delete(ms.certs, k) +	} +} + +// NewMemoryStore returns an in-memory CertStorer. +func NewMemoryStore() CertStorer { +	return &memoryStore{ +		certs: make(map[string]*CertRecord), +	} +} diff --git a/server/store/mysql.go b/server/store/mysql.go new file mode 100644 index 0000000..b108fdc --- /dev/null +++ b/server/store/mysql.go @@ -0,0 +1,160 @@ +package store + +import ( +	"database/sql" +	"encoding/json" +	"fmt" +	"strings" +	"time" + +	"golang.org/x/crypto/ssh" + +	"github.com/go-sql-driver/mysql" +) + +type mysqlDB struct { +	conn *sql.DB + +	get     *sql.Stmt +	set     *sql.Stmt +	list    *sql.Stmt +	revoke  *sql.Stmt +	revoked *sql.Stmt +} + +func parseConfig(config string) string { +	s := strings.Split(config, ":") +	if len(s) == 4 { +		s = append(s, "3306") +	} +	_, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] +	c := &mysql.Config{ +		User:   user, +		Passwd: passwd, +		Net:    "tcp", +		Addr:   fmt.Sprintf("%s:%s", host, port), +		DBName: "certs", +	} +	return c.FormatDSN() +} + +// NewMySQLStore returns a MySQL CertStorer. +func NewMySQLStore(config string) (CertStorer, error) { +	conn, err := sql.Open("mysql", parseConfig(config)) +	if err != nil { +		return nil, fmt.Errorf("mysql: could not get a connection: %v", err) +	} +	if err := conn.Ping(); err != nil { +		conn.Close() +		return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err) +	} + +	db := &mysqlDB{ +		conn: conn, +	} + +	if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil { +		return nil, fmt.Errorf("mysql: prepare set: %v", err) +	} +	if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { +		return nil, fmt.Errorf("mysql: prepare get: %v", err) +	} +	if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { +		return nil, fmt.Errorf("mysql: prepare list: %v", err) +	} +	if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil { +		return nil, fmt.Errorf("mysql: prepare revoke: %v", err) +	} +	if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil { +		return nil, fmt.Errorf("mysql: prepare revoked: %v", err) +	} +	return db, nil +} + +// rowScanner is implemented by sql.Row and sql.Rows +type rowScanner interface { +	Scan(dest ...interface{}) error +} + +func scanCert(s rowScanner) (*CertRecord, error) { +	var ( +		keyID      sql.NullString +		principals sql.NullString +		createdAt  sql.NullInt64 +		expires    sql.NullInt64 +		revoked    sql.NullBool +		raw        sql.NullString +	) +	if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil { +		return nil, err +	} +	var p []string +	if err := json.Unmarshal([]byte(principals.String), &p); err != nil { +		return nil, err +	} +	return &CertRecord{ +		KeyID:      keyID.String, +		Principals: p, +		CreatedAt:  uint64(createdAt.Int64), +		Expires:    uint64(expires.Int64), +		Revoked:    revoked.Bool, +		Raw:        raw.String, +	}, nil +} + +func (db *mysqlDB) Get(id string) (*CertRecord, error) { +	return scanCert(db.get.QueryRow(id)) +} + +func (db *mysqlDB) SetCert(cert *ssh.Certificate) error { +	return db.SetRecord(parseCertificate(cert)) +} + +func (db *mysqlDB) SetRecord(rec *CertRecord) error { +	principals, err := json.Marshal(rec.Principals) +	if err != nil { +		return err +	} +	_, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw) +	return err +} + +func (db *mysqlDB) List() ([]*CertRecord, error) { +	var recs []*CertRecord +	rows, _ := db.list.Query() +	defer rows.Close() +	for rows.Next() { +		cert, err := scanCert(rows) +		if err != nil { +			return nil, err +		} +		recs = append(recs, cert) +	} +	return recs, nil +} + +func (db *mysqlDB) Revoke(id string) error { +	_, err := db.revoke.Exec(id) +	if err != nil { +		return err +	} +	return nil +} + +func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { +	var recs []*CertRecord +	rows, _ := db.revoked.Query(time.Now().UTC().Unix()) +	defer rows.Close() +	for rows.Next() { +		cert, err := scanCert(rows) +		if err != nil { +			return nil, err +		} +		recs = append(recs, cert) +	} +	return recs, nil +} + +func (db *mysqlDB) Close() error { +	return db.conn.Close() +} diff --git a/server/store/store.go b/server/store/store.go new file mode 100644 index 0000000..ad4922a --- /dev/null +++ b/server/store/store.go @@ -0,0 +1,39 @@ +package store + +import ( +	"golang.org/x/crypto/ssh" + +	"github.com/nsheridan/cashier/server/certutil" +) + +// CertStorer records issued certs in a persistent store for audit and +// revocation purposes. +type CertStorer interface { +	Get(id string) (*CertRecord, error) +	SetCert(cert *ssh.Certificate) error +	SetRecord(record *CertRecord) error +	List() ([]*CertRecord, error) +	Revoke(id string) error +	GetRevoked() ([]*CertRecord, error) +	Close() error +} + +// A CertRecord is a representation of a ssh certificate used by a CertStorer. +type CertRecord struct { +	KeyID      string +	Principals []string +	CreatedAt  uint64 +	Expires    uint64 +	Revoked    bool +	Raw        string +} + +func parseCertificate(cert *ssh.Certificate) *CertRecord { +	return &CertRecord{ +		KeyID:      cert.KeyId, +		Principals: cert.ValidPrincipals, +		CreatedAt:  cert.ValidAfter, +		Expires:    cert.ValidBefore, +		Raw:        certutil.GetPublicKey(cert), +	} +} diff --git a/server/store/store_test.go b/server/store/store_test.go new file mode 100644 index 0000000..d3aa3c1 --- /dev/null +++ b/server/store/store_test.go @@ -0,0 +1,102 @@ +package store + +import ( +	"crypto/rand" +	"crypto/rsa" +	"os" +	"testing" +	"time" + +	"github.com/nsheridan/cashier/testdata" +	"github.com/stretchr/testify/assert" + +	"golang.org/x/crypto/ssh" +) + +func TestParseCertificate(t *testing.T) { +	a := assert.New(t) +	now := uint64(time.Now().Unix()) +	r, _ := rsa.GenerateKey(rand.Reader, 1024) +	pub, _ := ssh.NewPublicKey(r.Public()) +	c := &ssh.Certificate{ +		KeyId:           "id", +		ValidPrincipals: []string{"principal"}, +		ValidBefore:     now, +		CertType:        ssh.UserCert, +		Key:             pub, +	} +	s, _ := ssh.NewSignerFromKey(r) +	c.SignCert(rand.Reader, s) +	rec := parseCertificate(c) + +	a.Equal(c.KeyId, rec.KeyID) +	a.Equal(c.ValidPrincipals, rec.Principals) +	a.Equal(c.ValidBefore, rec.Expires) +	a.Equal(c.ValidAfter, rec.CreatedAt) +} + +func testStore(t *testing.T, db CertStorer) { +	defer db.Close() + +	ids := []string{"a", "b"} +	for _, id := range ids { +		r := &CertRecord{ +			KeyID:   id, +			Expires: uint64(time.Now().UTC().Unix()) - 10, +		} +		if err := db.SetRecord(r); err != nil { +			t.Error(err) +		} +	} +	recs, err := db.List() +	if err != nil { +		t.Error(err) +	} +	if len(recs) != len(ids) { +		t.Errorf("Want %d records, got %d", len(ids), len(recs)) +	} + +	c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert) +	cert := c.(*ssh.Certificate) +	cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix()) +	if err := db.SetCert(cert); err != nil { +		t.Error(err) +	} + +	if _, err := db.Get("key"); err != nil { +		t.Error(err) +	} +	if err := db.Revoke("key"); err != nil { +		t.Error(err) +	} + +	// A revoked key shouldn't get returned if it's already expired +	db.Revoke("a") + +	revoked, err := db.GetRevoked() +	if err != nil { +		t.Error(err) +	} +	for _, k := range revoked { +		if k.KeyID != "key" { +			t.Errorf("Unexpected key: %s", k.KeyID) +		} +	} +} + +func TestMemoryStore(t *testing.T) { +	db := NewMemoryStore() +	testStore(t, db) +} + +func TestMySQLStore(t *testing.T) { +	config := os.Getenv("MYSQL_TEST_CONFIG") +	if config == "" { +		t.Skip("No MYSQL_TEST_CONFIG environment variable") +	} +	db, err := NewMySQLStore(config) +	if err != nil { +		t.Error(err) +	} +	testStore(t, db) +}  | 
