aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-08-28 17:33:14 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-08-28 17:33:14 +0100
commita5602dd8cdec8cb8ce85cbc5fab29a91f533d2af (patch)
tree61a1bd2c941565039b3fe12aeee67a4506706617 /server
parent7dbcbcc73210d8efe15a72b51b3245860051a89a (diff)
List only certs which haven't expired
Diffstat (limited to 'server')
-rw-r--r--server/store/mem.go4
-rw-r--r--server/store/mongo.go4
-rw-r--r--server/store/sqldb.go4
-rw-r--r--server/store/store_test.go23
4 files changed, 14 insertions, 21 deletions
diff --git a/server/store/mem.go b/server/store/mem.go
index cd37071..92167a9 100644
--- a/server/store/mem.go
+++ b/server/store/mem.go
@@ -39,7 +39,9 @@ func (ms *memoryStore) List() ([]*CertRecord, error) {
ms.Lock()
defer ms.Unlock()
for _, value := range ms.certs {
- records = append(records, value)
+ if value.Expires.After(time.Now().UTC()) {
+ records = append(records, value)
+ }
}
return records, nil
}
diff --git a/server/store/mongo.go b/server/store/mongo.go
index c056171..79df69d 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -72,8 +72,8 @@ func (m *mongoDB) List() ([]*CertRecord, error) {
return nil, err
}
var result []*CertRecord
- m.collection.Find(nil).All(&result)
- return result, nil
+ err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+ return result, err
}
func (m *mongoDB) Revoke(id string) error {
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index 2ea5ea5..54a52c6 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -66,7 +66,7 @@ func NewSQLStore(config string) (CertStorer, error) {
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("sqldb: prepare get: %v", err)
}
- if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+ if db.list, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
return nil, fmt.Errorf("sqldb: prepare list: %v", err)
}
if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
@@ -137,7 +137,7 @@ func (db *sqldb) List() ([]*CertRecord, error) {
return nil, err
}
var recs []*CertRecord
- rows, _ := db.list.Query()
+ rows, _ := db.revoked.Query(time.Now().UTC())
defer rows.Close()
for rows.Next() {
cert, err := scanCert(rows)
diff --git a/server/store/store_test.go b/server/store/store_test.go
index 18fa0d1..3552d1c 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -42,27 +42,21 @@ func TestParseCertificate(t *testing.T) {
func testStore(t *testing.T, db CertStorer) {
defer db.Close()
- ids := []string{"a", "b"}
- for _, id := range ids {
- r := &CertRecord{
- KeyID: id,
- Expires: time.Now().UTC().Add(time.Second * -10),
- }
- if err := db.SetRecord(r); err != nil {
- t.Error(err)
- }
+ r := &CertRecord{
+ KeyID: "a",
+ Expires: time.Now().UTC().Add(1 * time.Minute),
}
- recs, err := db.List()
- if err != nil {
+ if err := db.SetRecord(r); err != nil {
t.Error(err)
}
- if len(recs) != len(ids) {
- t.Errorf("Want %d records, got %d", len(ids), len(recs))
+ if _, err := db.List(); err != nil {
+ t.Error(err)
}
c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert)
cert := c.(*ssh.Certificate)
cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix())
+ cert.ValidAfter = uint64(time.Now().Add(-5 * time.Minute).UTC().Unix())
if err := db.SetCert(cert); err != nil {
t.Error(err)
}
@@ -74,9 +68,6 @@ func testStore(t *testing.T, db CertStorer) {
t.Error(err)
}
- // A revoked key shouldn't get returned if it's already expired
- db.Revoke("a")
-
revoked, err := db.GetRevoked()
if err != nil {
t.Error(err)