aboutsummaryrefslogtreecommitdiff
path: root/server/store
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-09-10 20:14:20 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-09-11 20:41:32 +0100
commit2e7c8c2f521c9e50bb3aea4df16771c22fe70e58 (patch)
tree44daf7fea192d0e2368b2bb93545098c0adf610a /server/store
parentadc3c7f16051d51a58d96e32082aaeb051e3da20 (diff)
Allow filtering results
Diffstat (limited to 'server/store')
-rw-r--r--server/store/mem.go14
-rw-r--r--server/store/mongo.go9
-rw-r--r--server/store/sqldb.go27
-rw-r--r--server/store/store.go14
4 files changed, 39 insertions, 25 deletions
diff --git a/server/store/mem.go b/server/store/mem.go
index 92167a9..e63d00a 100644
--- a/server/store/mem.go
+++ b/server/store/mem.go
@@ -34,14 +34,16 @@ func (ms *memoryStore) SetRecord(record *CertRecord) error {
return nil
}
-func (ms *memoryStore) List() ([]*CertRecord, error) {
+func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) {
var records []*CertRecord
ms.Lock()
defer ms.Unlock()
+
for _, value := range ms.certs {
- if value.Expires.After(time.Now().UTC()) {
- records = append(records, value)
+ if !includeExpired && value.Expires.After(time.Now().UTC()) {
+ continue
}
+ records = append(records, value)
}
return records, nil
}
@@ -58,11 +60,9 @@ func (ms *memoryStore) Revoke(id string) error {
func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) {
var revoked []*CertRecord
- all, _ := ms.List()
+ all, _ := ms.List(false)
for _, r := range all {
- if r.Revoked && time.Now().UTC().Unix() <= r.Expires.UTC().Unix() {
- revoked = append(revoked, r)
- }
+ revoked = append(revoked, r)
}
return revoked, nil
}
diff --git a/server/store/mongo.go b/server/store/mongo.go
index 79df69d..8a3ccda 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -67,12 +67,17 @@ func (m *mongoDB) SetRecord(record *CertRecord) error {
return m.collection.Insert(record)
}
-func (m *mongoDB) List() ([]*CertRecord, error) {
+func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
if err := m.session.Ping(); err != nil {
return nil, err
}
var result []*CertRecord
- err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+ var err error
+ if includeExpired {
+ err = m.collection.Find(nil).All(&result)
+ } else {
+ err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+ }
return result, err
}
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index 54a52c6..81784b0 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -16,11 +16,12 @@ import (
type sqldb struct {
conn *sql.DB
- get *sql.Stmt
- set *sql.Stmt
- list *sql.Stmt
- revoke *sql.Stmt
- revoked *sql.Stmt
+ get *sql.Stmt
+ set *sql.Stmt
+ listAll *sql.Stmt
+ listCurrent *sql.Stmt
+ revoke *sql.Stmt
+ revoked *sql.Stmt
}
func parse(config string) []string {
@@ -66,8 +67,11 @@ func NewSQLStore(config string) (CertStorer, error) {
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare get: %v", err)
}
- if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
- return nil, fmt.Errorf("sqldb: prepare list: %v", err)
+ if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare listAll: %v", err)
+ }
+ if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err)
}
if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
@@ -132,12 +136,17 @@ func (db *sqldb) SetRecord(rec *CertRecord) error {
return err
}
-func (db *sqldb) List() ([]*CertRecord, error) {
+func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
var recs []*CertRecord
- rows, _ := db.revoked.Query(time.Now().UTC())
+ var rows *sql.Rows
+ if includeExpired {
+ rows, _ = db.listAll.Query()
+ } else {
+ rows, _ = db.listCurrent.Query(time.Now().UTC())
+ }
defer rows.Close()
for rows.Next() {
cert, err := scanCert(rows)
diff --git a/server/store/store.go b/server/store/store.go
index f6ac66e..a846bda 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -14,7 +14,7 @@ type CertStorer interface {
Get(id string) (*CertRecord, error)
SetCert(cert *ssh.Certificate) error
SetRecord(record *CertRecord) error
- List() ([]*CertRecord, error)
+ List(includeExpired bool) ([]*CertRecord, error)
Revoke(id string) error
GetRevoked() ([]*CertRecord, error)
Close() error
@@ -22,12 +22,12 @@ type CertStorer interface {
// A CertRecord is a representation of a ssh certificate used by a CertStorer.
type CertRecord struct {
- KeyID string
- Principals []string
- CreatedAt time.Time
- Expires time.Time
- Revoked bool
- Raw string
+ KeyID string `json:"key_id"`
+ Principals []string `json:"principals"`
+ CreatedAt time.Time `json:"created_at"`
+ Expires time.Time `json:"expires"`
+ Revoked bool `json:"revoked"`
+ Raw string `json:"-"`
}
func parseTime(t uint64) time.Time {