From 8ee3c6473f3e2373303b9cb16ab5f059f9e6369e Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 15 Apr 2017 18:28:23 +0100 Subject: Revoke multiple certs in a single call --- server/store/mem.go | 11 +++++------ server/store/sqldb.go | 16 ++++++---------- server/store/store.go | 2 +- server/store/store_test.go | 2 +- server/web.go | 6 ++---- 5 files changed, 15 insertions(+), 22 deletions(-) (limited to 'server') diff --git a/server/store/mem.go b/server/store/mem.go index e289b16..68c5a13 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -57,13 +57,12 @@ func (ms *MemoryStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (ms *MemoryStore) Revoke(id string) error { - r, err := ms.Get(id) - if err != nil { - return err +func (ms *MemoryStore) Revoke(ids []string) error { + ms.Lock() + defer ms.Unlock() + for _, id := range ids { + ms.certs[id].Revoked = true } - r.Revoked = true - ms.SetRecord(r) return nil } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index bdb8893..c5a0f4e 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -10,6 +10,7 @@ import ( "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "github.com/nsheridan/cashier/server/config" + "github.com/pkg/errors" ) var _ CertStorer = (*SQLStore)(nil) @@ -22,7 +23,6 @@ type SQLStore struct { set *sqlx.Stmt listAll *sqlx.Stmt listCurrent *sqlx.Stmt - revoke *sqlx.Stmt revoked *sqlx.Stmt } @@ -76,9 +76,6 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) } - if db.revoke, err = conn.Preparex("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err) - } if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err) } @@ -133,14 +130,13 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { } // Revoke an issued cert by id. -func (db *SQLStore) Revoke(id string) error { +func (db *SQLStore) Revoke(ids []string) error { if err := db.conn.Ping(); err != nil { - return err + return errors.Wrap(err, "unable to connect to database") } - if _, err := db.revoke.Exec(id); err != nil { - return err - } - return nil + q, args, err := sqlx.In("UPDATE issued_certs SET revoked = 1 WHERE key_id IN (?)", ids) + _, err = db.conn.Query(q, args...) + return err } // GetRevoked returns all revoked certs diff --git a/server/store/store.go b/server/store/store.go index b200e81..4edb446 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -29,7 +29,7 @@ type CertStorer interface { SetCert(cert *ssh.Certificate) error SetRecord(record *CertRecord) error List(includeExpired bool) ([]*CertRecord, error) - Revoke(id string) error + Revoke(id []string) error GetRevoked() ([]*CertRecord, error) Close() error } diff --git a/server/store/store_test.go b/server/store/store_test.go index 9a8a4be..d18d02b 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -87,7 +87,7 @@ func testStore(t *testing.T, db CertStorer) { if ret.KeyID != cert.KeyId { t.Error("key mismatch") } - if err := db.Revoke("key"); err != nil { + if err := db.Revoke([]string{"key"}); err != nil { t.Error(err) } diff --git a/server/web.go b/server/web.go index edaa394..08162d5 100644 --- a/server/web.go +++ b/server/web.go @@ -240,10 +240,8 @@ func revokeCertHandler(a *appContext, w http.ResponseWriter, r *http.Request) (i return a.login(w, r) } r.ParseForm() - for _, id := range r.Form["cert_id"] { - if err := certstore.Revoke(id); err != nil { - return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke") - } + if err := certstore.Revoke(r.Form["cert_id"]); err != nil { + return http.StatusInternalServerError, errors.Wrap(err, "unable to revoke certs") } http.Redirect(w, r, "/admin/certs", http.StatusSeeOther) return http.StatusSeeOther, nil -- cgit v1.2.3