aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2018-08-06 00:21:11 +0100
committerNiall Sheridan <nsheridan@gmail.com>2018-08-07 15:43:39 +0100
commit4f2385db4b3d4171fff841594f8c591703e84b0f (patch)
tree0a28668961e727881add7753fc7e8daa0ef0d998 /server
parent162efe8828ce1c2828206e5050a2c0c175265d70 (diff)
Unexport store implementations
Return an error if the store isn't known, instead of defaulting to a mem store
Diffstat (limited to 'server')
-rw-r--r--server/handlers_test.go2
-rw-r--r--server/store/mem.go30
-rw-r--r--server/store/sqldb.go40
-rw-r--r--server/store/store.go7
-rw-r--r--server/store/store_test.go6
5 files changed, 42 insertions, 43 deletions
diff --git a/server/handlers_test.go b/server/handlers_test.go
index 1670f2f..7f31452 100644
--- a/server/handlers_test.go
+++ b/server/handlers_test.go
@@ -38,7 +38,7 @@ func init() {
MaxAge: "1h",
})
authprovider = testprovider.New()
- certstore = store.NewMemoryStore()
+ certstore, _ = store.New(map[string]string{"type": "mem"})
ctx = &appContext{
cookiestore: sessions.NewCookieStore([]byte("secret")),
authsession: &auth.Session{AuthURL: "https://www.example.com/auth"},
diff --git a/server/store/mem.go b/server/store/mem.go
index c4fe14c..9d5038d 100644
--- a/server/store/mem.go
+++ b/server/store/mem.go
@@ -2,23 +2,22 @@ package store
import (
"fmt"
- "log"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
-var _ CertStorer = (*MemoryStore)(nil)
+var _ CertStorer = (*memoryStore)(nil)
-// MemoryStore is an in-memory CertStorer
-type MemoryStore struct {
+// memoryStore is an in-memory CertStorer
+type memoryStore struct {
sync.Mutex
certs map[string]*CertRecord
}
// Get a single *CertRecord
-func (ms *MemoryStore) Get(id string) (*CertRecord, error) {
+func (ms *memoryStore) Get(id string) (*CertRecord, error) {
ms.Lock()
defer ms.Unlock()
r, ok := ms.certs[id]
@@ -29,12 +28,12 @@ func (ms *MemoryStore) Get(id string) (*CertRecord, error) {
}
// SetCert parses a *ssh.Certificate and records it
-func (ms *MemoryStore) SetCert(cert *ssh.Certificate) error {
+func (ms *memoryStore) SetCert(cert *ssh.Certificate) error {
return ms.SetRecord(parseCertificate(cert))
}
// SetRecord records a *CertRecord
-func (ms *MemoryStore) SetRecord(record *CertRecord) error {
+func (ms *memoryStore) SetRecord(record *CertRecord) error {
ms.Lock()
defer ms.Unlock()
ms.certs[record.KeyID] = record
@@ -43,7 +42,7 @@ func (ms *MemoryStore) SetRecord(record *CertRecord) error {
// List returns all recorded certs.
// By default only active certs are returned.
-func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
+func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
var records []*CertRecord
ms.Lock()
defer ms.Unlock()
@@ -58,7 +57,7 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) {
}
// Revoke an issued cert by id.
-func (ms *MemoryStore) Revoke(ids []string) error {
+func (ms *memoryStore) Revoke(ids []string) error {
ms.Lock()
defer ms.Unlock()
for _, id := range ids {
@@ -68,7 +67,7 @@ func (ms *MemoryStore) Revoke(ids []string) error {
}
// GetRevoked returns all revoked certs
-func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) {
+func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
var revoked []*CertRecord
all, _ := ms.List(false)
for _, r := range all {
@@ -80,23 +79,22 @@ func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) {
}
// Close the store. This will clear the contents.
-func (ms *MemoryStore) Close() error {
+func (ms *memoryStore) Close() error {
ms.Lock()
defer ms.Unlock()
ms.certs = nil
return nil
}
-func (ms *MemoryStore) clear() {
+func (ms *memoryStore) clear() {
for k := range ms.certs {
delete(ms.certs, k)
}
}
-// NewMemoryStore returns an in-memory CertStorer.
-func NewMemoryStore() *MemoryStore {
- log.Println("WARNING: Using memory store to record issued certs.")
- return &MemoryStore{
+// newMemoryStore returns an in-memory CertStorer.
+func newMemoryStore() *memoryStore {
+ return &memoryStore{
certs: make(map[string]*CertRecord),
}
}
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index 3526a2b..b5948b7 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -18,10 +18,10 @@ import (
migrate "github.com/rubenv/sql-migrate"
)
-var _ CertStorer = (*SQLStore)(nil)
+var _ CertStorer = (*sqlStore)(nil)
-// SQLStore is an sql-based CertStorer
-type SQLStore struct {
+// sqlStore is an sql-based CertStorer
+type sqlStore struct {
conn *sqlx.DB
get *sqlx.Stmt
@@ -31,8 +31,8 @@ type SQLStore struct {
revoked *sqlx.Stmt
}
-// NewSQLStore returns a *sql.DB CertStorer.
-func NewSQLStore(c config.Database) (*SQLStore, error) {
+// newSQLStore returns a *sql.DB CertStorer.
+func newSQLStore(c config.Database) (*sqlStore, error) {
var driver string
var dsn string
switch c["type"] {
@@ -61,30 +61,30 @@ func NewSQLStore(c config.Database) (*SQLStore, error) {
conn, err := sqlx.Connect(driver, dsn)
if err != nil {
- return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err)
+ return nil, fmt.Errorf("sqlStore: could not get a connection: %v", err)
}
if err := autoMigrate(driver, conn); err != nil {
- return nil, fmt.Errorf("SQLStore: could not update schema: %v", err)
+ return nil, fmt.Errorf("sqlStore: could not update schema: %v", err)
}
- db := &SQLStore{
+ db := &sqlStore{
conn: conn,
}
if db.set, err = conn.Preparex("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
- return nil, fmt.Errorf("SQLStore: prepare set: %v", err)
+ return nil, fmt.Errorf("sqlStore: prepare set: %v", err)
}
if db.get, err = conn.Preparex("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
- return nil, fmt.Errorf("SQLStore: prepare get: %v", err)
+ return nil, fmt.Errorf("sqlStore: prepare get: %v", err)
}
if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil {
- return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err)
+ return nil, fmt.Errorf("sqlStore: prepare listAll: %v", err)
}
if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE expires_at >= ?"); err != nil {
- return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
+ return nil, fmt.Errorf("sqlStore: prepare listCurrent: %v", err)
}
if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
- return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
+ return nil, fmt.Errorf("sqlStore: prepare revoked: %v", err)
}
return db, nil
}
@@ -114,7 +114,7 @@ type rowScanner interface {
}
// Get a single *CertRecord
-func (db *SQLStore) Get(id string) (*CertRecord, error) {
+func (db *sqlStore) Get(id string) (*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
@@ -123,12 +123,12 @@ func (db *SQLStore) Get(id string) (*CertRecord, error) {
}
// SetCert parses a *ssh.Certificate and records it
-func (db *SQLStore) SetCert(cert *ssh.Certificate) error {
+func (db *sqlStore) SetCert(cert *ssh.Certificate) error {
return db.SetRecord(parseCertificate(cert))
}
// SetRecord records a *CertRecord
-func (db *SQLStore) SetRecord(rec *CertRecord) error {
+func (db *sqlStore) SetRecord(rec *CertRecord) error {
if err := db.conn.Ping(); err != nil {
return errors.Wrap(err, "unable to connect to database")
}
@@ -138,7 +138,7 @@ func (db *SQLStore) SetRecord(rec *CertRecord) error {
// List returns all recorded certs.
// By default only active certs are returned.
-func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
+func (db *sqlStore) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
@@ -156,7 +156,7 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
}
// Revoke an issued cert by id.
-func (db *SQLStore) Revoke(ids []string) error {
+func (db *sqlStore) Revoke(ids []string) error {
if err := db.conn.Ping(); err != nil {
return errors.Wrap(err, "unable to connect to database")
}
@@ -166,7 +166,7 @@ func (db *SQLStore) Revoke(ids []string) error {
}
// GetRevoked returns all revoked certs
-func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
+func (db *sqlStore) GetRevoked() ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, errors.Wrap(err, "unable to connect to database")
}
@@ -178,6 +178,6 @@ func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
}
// Close the connection to the database
-func (db *SQLStore) Close() error {
+func (db *sqlStore) Close() error {
return db.conn.Close()
}
diff --git a/server/store/store.go b/server/store/store.go
index 4edb446..4863ff0 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -2,6 +2,7 @@ package store
import (
"encoding/json"
+ "fmt"
"time"
"golang.org/x/crypto/ssh"
@@ -15,11 +16,11 @@ import (
func New(c config.Database) (CertStorer, error) {
switch c["type"] {
case "mysql", "sqlite":
- return NewSQLStore(c)
+ return newSQLStore(c)
case "mem":
- return NewMemoryStore(), nil
+ return newMemoryStore(), nil
}
- return NewMemoryStore(), nil
+ return nil, fmt.Errorf("unable to create store with driver %s", c["type"])
}
// CertStorer records issued certs in a persistent store for audit and
diff --git a/server/store/store_test.go b/server/store/store_test.go
index 5704ce0..d9ae325 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -101,7 +101,7 @@ func testStore(t *testing.T, db CertStorer) {
}
func TestMemoryStore(t *testing.T) {
- db := NewMemoryStore()
+ db := newMemoryStore()
testStore(t, db)
}
@@ -120,7 +120,7 @@ func TestMySQLStore(t *testing.T) {
} else {
sqlConfig["username"] = u.Username
}
- db, err := NewSQLStore(sqlConfig)
+ db, err := newSQLStore(sqlConfig)
if err != nil {
t.Error(err)
}
@@ -134,7 +134,7 @@ func TestSQLiteStore(t *testing.T) {
}
defer os.Remove(f.Name())
config := map[string]string{"type": "sqlite", "filename": f.Name()}
- db, err := NewSQLStore(config)
+ db, err := newSQLStore(config)
if err != nil {
t.Error(err)
}