From a6e42d899cde380f513710d07787ba11dfbe229a Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 7 Aug 2016 17:55:05 +0100 Subject: Ping the db before attempting to query it --- server/store/mongo.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) (limited to 'server') diff --git a/server/store/mongo.go b/server/store/mongo.go index 9773da7..c056171 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -16,7 +16,8 @@ var ( ) type mongoDB struct { - conn *mgo.Collection + collection *mgo.Collection + session *mgo.Session } func parseMongoConfig(config string) *mgo.DialInfo { @@ -40,13 +41,17 @@ func NewMongoStore(config string) (CertStorer, error) { } c := session.DB(certsDB).C(issuedTable) return &mongoDB{ - conn: c, + collection: c, + session: session, }, nil } func (m *mongoDB) Get(id string) (*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } c := &CertRecord{} - err := m.conn.Find(bson.M{"keyid": id}).One(c) + err := m.collection.Find(bson.M{"keyid": id}).One(c) return c, err } @@ -56,26 +61,38 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error { } func (m *mongoDB) SetRecord(record *CertRecord) error { - return m.conn.Insert(record) + if err := m.session.Ping(); err != nil { + return err + } + return m.collection.Insert(record) } func (m *mongoDB) List() ([]*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } var result []*CertRecord - m.conn.Find(nil).All(&result) + m.collection.Find(nil).All(&result) return result, nil } func (m *mongoDB) Revoke(id string) error { - return m.conn.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) + if err := m.session.Ping(); err != nil { + return err + } + return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) } func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { + if err := m.session.Ping(); err != nil { + return nil, err + } var result []*CertRecord - err := m.conn.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) + err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) return result, err } func (m *mongoDB) Close() error { - m.conn.Database.Session.Close() + m.session.Close() return nil } -- cgit v1.2.3