From 2e7c8c2f521c9e50bb3aea4df16771c22fe70e58 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 10 Sep 2016 20:14:20 +0100 Subject: Allow filtering results --- server/store/mem.go | 14 +++++++------- server/store/mongo.go | 9 +++++++-- server/store/sqldb.go | 27 ++++++++++++++++++--------- server/store/store.go | 14 +++++++------- 4 files changed, 39 insertions(+), 25 deletions(-) (limited to 'server') diff --git a/server/store/mem.go b/server/store/mem.go index 92167a9..e63d00a 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -34,14 +34,16 @@ func (ms *memoryStore) SetRecord(record *CertRecord) error { return nil } -func (ms *memoryStore) List() ([]*CertRecord, error) { +func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) { var records []*CertRecord ms.Lock() defer ms.Unlock() + for _, value := range ms.certs { - if value.Expires.After(time.Now().UTC()) { - records = append(records, value) + if !includeExpired && value.Expires.After(time.Now().UTC()) { + continue } + records = append(records, value) } return records, nil } @@ -58,11 +60,9 @@ func (ms *memoryStore) Revoke(id string) error { func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { var revoked []*CertRecord - all, _ := ms.List() + all, _ := ms.List(false) for _, r := range all { - if r.Revoked && time.Now().UTC().Unix() <= r.Expires.UTC().Unix() { - revoked = append(revoked, r) - } + revoked = append(revoked, r) } return revoked, nil } diff --git a/server/store/mongo.go b/server/store/mongo.go index 79df69d..8a3ccda 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -67,12 +67,17 @@ func (m *mongoDB) SetRecord(record *CertRecord) error { return m.collection.Insert(record) } -func (m *mongoDB) List() ([]*CertRecord, error) { +func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { if err := m.session.Ping(); err != nil { return nil, err } var result []*CertRecord - err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + var err error + if includeExpired { + err = m.collection.Find(nil).All(&result) + } else { + err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + } return result, err } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index 54a52c6..81784b0 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -16,11 +16,12 @@ import ( type sqldb struct { conn *sql.DB - get *sql.Stmt - set *sql.Stmt - list *sql.Stmt - revoke *sql.Stmt - revoked *sql.Stmt + get *sql.Stmt + set *sql.Stmt + listAll *sql.Stmt + listCurrent *sql.Stmt + revoke *sql.Stmt + revoked *sql.Stmt } func parse(config string) []string { @@ -66,8 +67,11 @@ func NewSQLStore(config string) (CertStorer, error) { if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("sqldb: prepare get: %v", err) } - if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { - return nil, fmt.Errorf("sqldb: prepare list: %v", err) + if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { + return nil, fmt.Errorf("sqldb: prepare listAll: %v", err) + } + if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { + return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err) } if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("sqldb: prepare revoke: %v", err) @@ -132,12 +136,17 @@ func (db *sqldb) SetRecord(rec *CertRecord) error { return err } -func (db *sqldb) List() ([]*CertRecord, error) { +func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } var recs []*CertRecord - rows, _ := db.revoked.Query(time.Now().UTC()) + var rows *sql.Rows + if includeExpired { + rows, _ = db.listAll.Query() + } else { + rows, _ = db.listCurrent.Query(time.Now().UTC()) + } defer rows.Close() for rows.Next() { cert, err := scanCert(rows) diff --git a/server/store/store.go b/server/store/store.go index f6ac66e..a846bda 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -14,7 +14,7 @@ type CertStorer interface { Get(id string) (*CertRecord, error) SetCert(cert *ssh.Certificate) error SetRecord(record *CertRecord) error - List() ([]*CertRecord, error) + List(includeExpired bool) ([]*CertRecord, error) Revoke(id string) error GetRevoked() ([]*CertRecord, error) Close() error @@ -22,12 +22,12 @@ type CertStorer interface { // A CertRecord is a representation of a ssh certificate used by a CertStorer. type CertRecord struct { - KeyID string - Principals []string - CreatedAt time.Time - Expires time.Time - Revoked bool - Raw string + KeyID string `json:"key_id"` + Principals []string `json:"principals"` + CreatedAt time.Time `json:"created_at"` + Expires time.Time `json:"expires"` + Revoked bool `json:"revoked"` + Raw string `json:"-"` } func parseTime(t uint64) time.Time { -- cgit v1.2.3