aboutsummaryrefslogtreecommitdiff
path: root/server/store/sqldb.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/store/sqldb.go')
-rw-r--r--server/store/sqldb.go47
1 files changed, 29 insertions, 18 deletions
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index d7ef878..a51678e 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -13,7 +13,10 @@ import (
"github.com/nsheridan/cashier/server/config"
)
-type sqldb struct {
+var _ CertStorer = (*SQLStore)(nil)
+
+// SQLStore is an sql-based CertStorer
+type SQLStore struct {
conn *sql.DB
get *sql.Stmt
@@ -25,7 +28,7 @@ type sqldb struct {
}
// NewSQLStore returns a *sql.DB CertStorer.
-func NewSQLStore(c config.Database) (CertStorer, error) {
+func NewSQLStore(c config.Database) (*SQLStore, error) {
var driver string
var dsn string
switch c["type"] {
@@ -51,34 +54,34 @@ func NewSQLStore(c config.Database) (CertStorer, error) {
}
conn, err := sql.Open(driver, dsn)
if err != nil {
- return nil, fmt.Errorf("sqldb: could not get a connection: %v", err)
+ return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err)
}
if err := conn.Ping(); err != nil {
conn.Close()
- return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err)
+ return nil, fmt.Errorf("SQLStore: could not establish a good connection: %v", err)
}
- db := &sqldb{
+ db := &SQLStore{
conn: conn,
}
if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare set: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare set: %v", err)
}
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare get: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare get: %v", err)
}
if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare listAll: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err)
}
if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
}
if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err)
}
if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare revoked: %v", err)
+ return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
}
return db, nil
}
@@ -114,18 +117,21 @@ func scanCert(s rowScanner) (*CertRecord, error) {
}, nil
}
-func (db *sqldb) Get(id string) (*CertRecord, error) {
+// Get a single *CertRecord
+func (db *SQLStore) Get(id string) (*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
return scanCert(db.get.QueryRow(id))
}
-func (db *sqldb) SetCert(cert *ssh.Certificate) error {
+// SetCert parses a *ssh.Certificate and records it
+func (db *SQLStore) SetCert(cert *ssh.Certificate) error {
return db.SetRecord(parseCertificate(cert))
}
-func (db *sqldb) SetRecord(rec *CertRecord) error {
+// SetRecord records a *CertRecord
+func (db *SQLStore) SetRecord(rec *CertRecord) error {
principals, err := json.Marshal(rec.Principals)
if err != nil {
return err
@@ -137,7 +143,9 @@ func (db *sqldb) SetRecord(rec *CertRecord) error {
return err
}
-func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
+// List returns all recorded certs.
+// By default only active certs are returned.
+func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
@@ -159,7 +167,8 @@ func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
return recs, nil
}
-func (db *sqldb) Revoke(id string) error {
+// Revoke an issued cert by id.
+func (db *SQLStore) Revoke(id string) error {
if err := db.conn.Ping(); err != nil {
return err
}
@@ -170,7 +179,8 @@ func (db *sqldb) Revoke(id string) error {
return nil
}
-func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
+// GetRevoked returns all revoked certs
+func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
@@ -187,6 +197,7 @@ func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
return recs, nil
}
-func (db *sqldb) Close() error {
+// Close the connection to the database
+func (db *SQLStore) Close() error {
return db.conn.Close()
}