diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2016-08-07 17:55:05 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2016-08-07 17:55:05 +0100 |
commit | a6e42d899cde380f513710d07787ba11dfbe229a (patch) | |
tree | 39279f5aadd89a680d010ffe60c595cad3457c12 /server | |
parent | 823dd5ced21db2879899dbeda1bc1104c4d37e91 (diff) |
Ping the db before attempting to query it
Diffstat (limited to 'server')
-rw-r--r-- | server/store/mongo.go | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/server/store/mongo.go b/server/store/mongo.go index 9773da7..c056171 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -16,7 +16,8 @@ var ( ) type mongoDB struct { - conn *mgo.Collection + collection *mgo.Collection + session *mgo.Session } func parseMongoConfig(config string) *mgo.DialInfo { @@ -40,13 +41,17 @@ func NewMongoStore(config string) (CertStorer, error) { } c := session.DB(certsDB).C(issuedTable) return &mongoDB{ - conn: c, + collection: c, + session: session, }, nil } func (m *mongoDB) Get(id string) (*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } c := &CertRecord{} - err := m.conn.Find(bson.M{"keyid": id}).One(c) + err := m.collection.Find(bson.M{"keyid": id}).One(c) return c, err } @@ -56,26 +61,38 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error { } func (m *mongoDB) SetRecord(record *CertRecord) error { - return m.conn.Insert(record) + if err := m.session.Ping(); err != nil { + return err + } + return m.collection.Insert(record) } func (m *mongoDB) List() ([]*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } var result []*CertRecord - m.conn.Find(nil).All(&result) + m.collection.Find(nil).All(&result) return result, nil } func (m *mongoDB) Revoke(id string) error { - return m.conn.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) + if err := m.session.Ping(); err != nil { + return err + } + return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) } func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } var result []*CertRecord - err := m.conn.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) + err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) return result, err } func (m *mongoDB) Close() error { - m.conn.Database.Session.Close() + m.session.Close() return nil } |