diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2016-09-10 20:14:20 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2016-09-11 20:41:32 +0100 |
commit | 2e7c8c2f521c9e50bb3aea4df16771c22fe70e58 (patch) | |
tree | 44daf7fea192d0e2368b2bb93545098c0adf610a | |
parent | adc3c7f16051d51a58d96e32082aaeb051e3da20 (diff) |
Allow filtering results
-rw-r--r-- | client/keys.go | 3 | ||||
-rw-r--r-- | server/store/mem.go | 14 | ||||
-rw-r--r-- | server/store/mongo.go | 9 | ||||
-rw-r--r-- | server/store/sqldb.go | 27 | ||||
-rw-r--r-- | server/store/store.go | 14 |
5 files changed, 41 insertions, 26 deletions
diff --git a/client/keys.go b/client/keys.go index 4b3b69e..0ec0f1d 100644 --- a/client/keys.go +++ b/client/keys.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/rsa" "fmt" + "strings" "golang.org/x/crypto/ed25519" "golang.org/x/crypto/ssh" @@ -78,7 +79,7 @@ func GenerateKey(keytype string, bits int) (Key, ssh.PublicKey, error) { for k := range keytypes { valid = append(valid, k) } - return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, valid) + return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are %s", keytype, strings.Join(valid, "|")) } return f(bits) } diff --git a/server/store/mem.go b/server/store/mem.go index 92167a9..e63d00a 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -34,14 +34,16 @@ func (ms *memoryStore) SetRecord(record *CertRecord) error { return nil } -func (ms *memoryStore) List() ([]*CertRecord, error) { +func (ms *memoryStore) List(includeExpired bool) ([]*CertRecord, error) { var records []*CertRecord ms.Lock() defer ms.Unlock() + for _, value := range ms.certs { - if value.Expires.After(time.Now().UTC()) { - records = append(records, value) + if !includeExpired && value.Expires.After(time.Now().UTC()) { + continue } + records = append(records, value) } return records, nil } @@ -58,11 +60,9 @@ func (ms *memoryStore) Revoke(id string) error { func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { var revoked []*CertRecord - all, _ := ms.List() + all, _ := ms.List(false) for _, r := range all { - if r.Revoked && time.Now().UTC().Unix() <= r.Expires.UTC().Unix() { - revoked = append(revoked, r) - } + revoked = append(revoked, r) } return revoked, nil } diff --git a/server/store/mongo.go b/server/store/mongo.go index 79df69d..8a3ccda 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -67,12 +67,17 @@ func (m *mongoDB) SetRecord(record *CertRecord) error { return m.collection.Insert(record) } -func (m *mongoDB) List() ([]*CertRecord, error) { +func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { if err := m.session.Ping(); err != nil { return nil, err } var result []*CertRecord - err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + var err error + if includeExpired { + err = m.collection.Find(nil).All(&result) + } else { + err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + } return result, err } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index 54a52c6..81784b0 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -16,11 +16,12 @@ import ( type sqldb struct { conn *sql.DB - get *sql.Stmt - set *sql.Stmt - list *sql.Stmt - revoke *sql.Stmt - revoked *sql.Stmt + get *sql.Stmt + set *sql.Stmt + listAll *sql.Stmt + listCurrent *sql.Stmt + revoke *sql.Stmt + revoked *sql.Stmt } func parse(config string) []string { @@ -66,8 +67,11 @@ func NewSQLStore(config string) (CertStorer, error) { if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("sqldb: prepare get: %v", err) } - if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { - return nil, fmt.Errorf("sqldb: prepare list: %v", err) + if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { + return nil, fmt.Errorf("sqldb: prepare listAll: %v", err) + } + if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { + return nil, fmt.Errorf("sqldb: prepare listCurrent: %v", err) } if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("sqldb: prepare revoke: %v", err) @@ -132,12 +136,17 @@ func (db *sqldb) SetRecord(rec *CertRecord) error { return err } -func (db *sqldb) List() ([]*CertRecord, error) { +func (db *sqldb) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } var recs []*CertRecord - rows, _ := db.revoked.Query(time.Now().UTC()) + var rows *sql.Rows + if includeExpired { + rows, _ = db.listAll.Query() + } else { + rows, _ = db.listCurrent.Query(time.Now().UTC()) + } defer rows.Close() for rows.Next() { cert, err := scanCert(rows) diff --git a/server/store/store.go b/server/store/store.go index f6ac66e..a846bda 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -14,7 +14,7 @@ type CertStorer interface { Get(id string) (*CertRecord, error) SetCert(cert *ssh.Certificate) error SetRecord(record *CertRecord) error - List() ([]*CertRecord, error) + List(includeExpired bool) ([]*CertRecord, error) Revoke(id string) error GetRevoked() ([]*CertRecord, error) Close() error @@ -22,12 +22,12 @@ type CertStorer interface { // A CertRecord is a representation of a ssh certificate used by a CertStorer. type CertRecord struct { - KeyID string - Principals []string - CreatedAt time.Time - Expires time.Time - Revoked bool - Raw string + KeyID string `json:"key_id"` + Principals []string `json:"principals"` + CreatedAt time.Time `json:"created_at"` + Expires time.Time `json:"expires"` + Revoked bool `json:"revoked"` + Raw string `json:"-"` } func parseTime(t uint64) time.Time { |