From 4f2385db4b3d4171fff841594f8c591703e84b0f Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Mon, 6 Aug 2018 00:21:11 +0100 Subject: Unexport store implementations Return an error if the store isn't known, instead of defaulting to a mem store --- server/handlers_test.go | 2 +- server/store/mem.go | 30 ++++++++++++++---------------- server/store/sqldb.go | 40 ++++++++++++++++++++-------------------- server/store/store.go | 7 ++++--- server/store/store_test.go | 6 +++--- 5 files changed, 42 insertions(+), 43 deletions(-) (limited to 'server') diff --git a/server/handlers_test.go b/server/handlers_test.go index 1670f2f..7f31452 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -38,7 +38,7 @@ func init() { MaxAge: "1h", }) authprovider = testprovider.New() - certstore = store.NewMemoryStore() + certstore, _ = store.New(map[string]string{"type": "mem"}) ctx = &appContext{ cookiestore: sessions.NewCookieStore([]byte("secret")), authsession: &auth.Session{AuthURL: "https://www.example.com/auth"}, diff --git a/server/store/mem.go b/server/store/mem.go index c4fe14c..9d5038d 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -2,23 +2,22 @@ package store import ( "fmt" - "log" "sync" "time" "golang.org/x/crypto/ssh" ) -var _ CertStorer = (*MemoryStore)(nil) +var _ CertStorer = (*memoryStore)(nil) -// MemoryStore is an in-memory CertStorer -type MemoryStore struct { +// memoryStore is an in-memory CertStorer +type memoryStore struct { sync.Mutex certs map[string]*CertRecord } // Get a single *CertRecord -func (ms *MemoryStore) Get(id string) (*CertRecord, error) { +func (ms *memoryStore) Get(id string) (*CertRecord, error) { ms.Lock() defer ms.Unlock() r, ok := ms.certs[id] @@ -29,12 +28,12 @@ func (ms *MemoryStore) Get(id string) (*CertRecord, error) { } // SetCert parses a *ssh.Certificate and records it -func (ms *MemoryStore) SetCert(cert *ssh.Certificate) error { +func (ms *memoryStore) SetCert(cert *ssh.Certificate) error { return ms.SetRecord(parseCertificate(cert)) } // SetRecord records a *CertRecord -func (ms *MemoryStore) SetRecord(record *CertRecord) error { +func (ms *memoryStore) SetRecord(record *CertRecord) error { ms.Lock() defer ms.Unlock() ms.certs[record.KeyID] = record @@ -43,7 +42,7 @@ func (ms *MemoryStore) SetRecord(record *CertRecord) error { // List returns all recorded certs. // By default only active certs are returned. -func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) { +func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) { var records []*CertRecord ms.Lock() defer ms.Unlock() @@ -58,7 +57,7 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (ms *MemoryStore) Revoke(ids []string) error { +func (ms *memoryStore) Revoke(ids []string) error { ms.Lock() defer ms.Unlock() for _, id := range ids { @@ -68,7 +67,7 @@ func (ms *MemoryStore) Revoke(ids []string) error { } // GetRevoked returns all revoked certs -func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) { +func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { var revoked []*CertRecord all, _ := ms.List(false) for _, r := range all { @@ -80,23 +79,22 @@ func (ms *MemoryStore) GetRevoked() ([]*CertRecord, error) { } // Close the store. This will clear the contents. -func (ms *MemoryStore) Close() error { +func (ms *memoryStore) Close() error { ms.Lock() defer ms.Unlock() ms.certs = nil return nil } -func (ms *MemoryStore) clear() { +func (ms *memoryStore) clear() { for k := range ms.certs { delete(ms.certs, k) } } -// NewMemoryStore returns an in-memory CertStorer. -func NewMemoryStore() *MemoryStore { - log.Println("WARNING: Using memory store to record issued certs.") - return &MemoryStore{ +// newMemoryStore returns an in-memory CertStorer. +func newMemoryStore() *memoryStore { + return &memoryStore{ certs: make(map[string]*CertRecord), } } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index 3526a2b..b5948b7 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -18,10 +18,10 @@ import ( migrate "github.com/rubenv/sql-migrate" ) -var _ CertStorer = (*SQLStore)(nil) +var _ CertStorer = (*sqlStore)(nil) -// SQLStore is an sql-based CertStorer -type SQLStore struct { +// sqlStore is an sql-based CertStorer +type sqlStore struct { conn *sqlx.DB get *sqlx.Stmt @@ -31,8 +31,8 @@ type SQLStore struct { revoked *sqlx.Stmt } -// NewSQLStore returns a *sql.DB CertStorer. -func NewSQLStore(c config.Database) (*SQLStore, error) { +// newSQLStore returns a *sql.DB CertStorer. +func newSQLStore(c config.Database) (*sqlStore, error) { var driver string var dsn string switch c["type"] { @@ -61,30 +61,30 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { conn, err := sqlx.Connect(driver, dsn) if err != nil { - return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err) + return nil, fmt.Errorf("sqlStore: could not get a connection: %v", err) } if err := autoMigrate(driver, conn); err != nil { - return nil, fmt.Errorf("SQLStore: could not update schema: %v", err) + return nil, fmt.Errorf("sqlStore: could not update schema: %v", err) } - db := &SQLStore{ + db := &sqlStore{ conn: conn, } if db.set, err = conn.Preparex("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare set: %v", err) + return nil, fmt.Errorf("sqlStore: prepare set: %v", err) } if db.get, err = conn.Preparex("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare get: %v", err) + return nil, fmt.Errorf("sqlStore: prepare get: %v", err) } if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err) + return nil, fmt.Errorf("sqlStore: prepare listAll: %v", err) } if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE expires_at >= ?"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) + return nil, fmt.Errorf("sqlStore: prepare listCurrent: %v", err) } if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err) + return nil, fmt.Errorf("sqlStore: prepare revoked: %v", err) } return db, nil } @@ -114,7 +114,7 @@ type rowScanner interface { } // Get a single *CertRecord -func (db *SQLStore) Get(id string) (*CertRecord, error) { +func (db *sqlStore) Get(id string) (*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, errors.Wrap(err, "unable to connect to database") } @@ -123,12 +123,12 @@ func (db *SQLStore) Get(id string) (*CertRecord, error) { } // SetCert parses a *ssh.Certificate and records it -func (db *SQLStore) SetCert(cert *ssh.Certificate) error { +func (db *sqlStore) SetCert(cert *ssh.Certificate) error { return db.SetRecord(parseCertificate(cert)) } // SetRecord records a *CertRecord -func (db *SQLStore) SetRecord(rec *CertRecord) error { +func (db *sqlStore) SetRecord(rec *CertRecord) error { if err := db.conn.Ping(); err != nil { return errors.Wrap(err, "unable to connect to database") } @@ -138,7 +138,7 @@ func (db *SQLStore) SetRecord(rec *CertRecord) error { // List returns all recorded certs. // By default only active certs are returned. -func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { +func (db *sqlStore) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, errors.Wrap(err, "unable to connect to database") } @@ -156,7 +156,7 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (db *SQLStore) Revoke(ids []string) error { +func (db *sqlStore) Revoke(ids []string) error { if err := db.conn.Ping(); err != nil { return errors.Wrap(err, "unable to connect to database") } @@ -166,7 +166,7 @@ func (db *SQLStore) Revoke(ids []string) error { } // GetRevoked returns all revoked certs -func (db *SQLStore) GetRevoked() ([]*CertRecord, error) { +func (db *sqlStore) GetRevoked() ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, errors.Wrap(err, "unable to connect to database") } @@ -178,6 +178,6 @@ func (db *SQLStore) GetRevoked() ([]*CertRecord, error) { } // Close the connection to the database -func (db *SQLStore) Close() error { +func (db *sqlStore) Close() error { return db.conn.Close() } diff --git a/server/store/store.go b/server/store/store.go index 4edb446..4863ff0 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -2,6 +2,7 @@ package store import ( "encoding/json" + "fmt" "time" "golang.org/x/crypto/ssh" @@ -15,11 +16,11 @@ import ( func New(c config.Database) (CertStorer, error) { switch c["type"] { case "mysql", "sqlite": - return NewSQLStore(c) + return newSQLStore(c) case "mem": - return NewMemoryStore(), nil + return newMemoryStore(), nil } - return NewMemoryStore(), nil + return nil, fmt.Errorf("unable to create store with driver %s", c["type"]) } // CertStorer records issued certs in a persistent store for audit and diff --git a/server/store/store_test.go b/server/store/store_test.go index 5704ce0..d9ae325 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -101,7 +101,7 @@ func testStore(t *testing.T, db CertStorer) { } func TestMemoryStore(t *testing.T) { - db := NewMemoryStore() + db := newMemoryStore() testStore(t, db) } @@ -120,7 +120,7 @@ func TestMySQLStore(t *testing.T) { } else { sqlConfig["username"] = u.Username } - db, err := NewSQLStore(sqlConfig) + db, err := newSQLStore(sqlConfig) if err != nil { t.Error(err) } @@ -134,7 +134,7 @@ func TestSQLiteStore(t *testing.T) { } defer os.Remove(f.Name()) config := map[string]string{"type": "sqlite", "filename": f.Name()} - db, err := NewSQLStore(config) + db, err := newSQLStore(config) if err != nil { t.Error(err) } -- cgit v1.2.3