From 0ed23b71115ad2213bf7ea545f9f765052008872 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sat, 24 Sep 2016 23:25:06 +0100 Subject: Use a new session for each request --- server/store/mongo.go | 51 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/server/store/mongo.go b/server/store/mongo.go index 8a3ccda..1b13d7a 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -15,11 +15,6 @@ var ( issuedTable = "issued_certs" ) -type mongoDB struct { - collection *mgo.Collection - session *mgo.Session -} - func parseMongoConfig(config string) *mgo.DialInfo { s := strings.SplitN(config, ":", 4) _, user, passwd, hosts := s[0], s[1], s[2], s[3] @@ -33,25 +28,33 @@ func parseMongoConfig(config string) *mgo.DialInfo { return d } +func collection(session *mgo.Session) *mgo.Collection { + return session.DB(certsDB).C(issuedTable) +} + // NewMongoStore returns a MongoDB CertStorer. func NewMongoStore(config string) (CertStorer, error) { session, err := mgo.DialWithInfo(parseMongoConfig(config)) if err != nil { return nil, err } - c := session.DB(certsDB).C(issuedTable) return &mongoDB{ - collection: c, - session: session, + session: session, }, nil } +type mongoDB struct { + session *mgo.Session +} + func (m *mongoDB) Get(id string) (*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } c := &CertRecord{} - err := m.collection.Find(bson.M{"keyid": id}).One(c) + err := collection(s).Find(bson.M{"keyid": id}).One(c) return c, err } @@ -61,39 +64,49 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error { } func (m *mongoDB) SetRecord(record *CertRecord) error { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return err } - return m.collection.Insert(record) + return collection(s).Insert(record) } func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } var result []*CertRecord var err error + c := collection(s) if includeExpired { - err = m.collection.Find(nil).All(&result) + err = c.Find(nil).All(&result) } else { - err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + err = c.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) } return result, err } func (m *mongoDB) Revoke(id string) error { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return err } - return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) + c := collection(s) + return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) } func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } var result []*CertRecord - err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) + err := collection(s).Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) return result, err } -- cgit v1.2.3