diff options
Diffstat (limited to 'server/store')
-rw-r--r-- | server/store/mem.go | 11 | ||||
-rw-r--r-- | server/store/sqldb.go | 16 | ||||
-rw-r--r-- | server/store/store.go | 2 | ||||
-rw-r--r-- | server/store/store_test.go | 2 |
4 files changed, 13 insertions, 18 deletions
diff --git a/server/store/mem.go b/server/store/mem.go index e289b16..68c5a13 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -57,13 +57,12 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (ms *MemoryStore) Revoke(id string) error { - r, err := ms.Get(id) - if err != nil { - return err +func (ms *MemoryStore) Revoke(ids []string) error { + ms.Lock() + defer ms.Unlock() + for _, id := range ids { + ms.certs[id].Revoked = true } - r.Revoked = true - ms.SetRecord(r) return nil } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index bdb8893..c5a0f4e 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -10,6 +10,7 @@ import ( "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "github.com/nsheridan/cashier/server/config" + "github.com/pkg/errors" ) var _ CertStorer = (*SQLStore)(nil) @@ -22,7 +23,6 @@ type SQLStore struct { set *sqlx.Stmt listAll *sqlx.Stmt listCurrent *sqlx.Stmt - revoke *sqlx.Stmt revoked *sqlx.Stmt } @@ -76,9 +76,6 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) } - if db.revoke, err = conn.Preparex("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err) - } if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err) } @@ -133,14 +130,13 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (db *SQLStore) Revoke(id string) error { +func (db *SQLStore) Revoke(ids []string) error { if err := db.conn.Ping(); err != nil { - return err + return errors.Wrap(err, "unable to connect to database") } - if _, err := db.revoke.Exec(id); err != nil { - return err - } - return nil + q, args, err := sqlx.In("UPDATE issued_certs SET revoked = 1 WHERE key_id IN (?)", ids) + _, err = db.conn.Query(q, args...) + return err } // GetRevoked returns all revoked certs diff --git a/server/store/store.go b/server/store/store.go index b200e81..4edb446 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -29,7 +29,7 @@ type CertStorer interface { SetCert(cert *ssh.Certificate) error SetRecord(record *CertRecord) error List(includeExpired bool) ([]*CertRecord, error) - Revoke(id string) error + Revoke(id []string) error GetRevoked() ([]*CertRecord, error) Close() error } diff --git a/server/store/store_test.go b/server/store/store_test.go index 9a8a4be..d18d02b 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -87,7 +87,7 @@ func testStore(t *testing.T, db CertStorer) { if ret.KeyID != cert.KeyId { t.Error("key mismatch") } - if err := db.Revoke("key"); err != nil { + if err := db.Revoke([]string{"key"}); err != nil { t.Error(err) } |