diff options
author | Niall Sheridan <nsheridan@gmail.com> | 2016-09-24 23:25:06 +0100 |
---|---|---|
committer | Niall Sheridan <nsheridan@gmail.com> | 2016-09-24 23:25:06 +0100 |
commit | 0ed23b71115ad2213bf7ea545f9f765052008872 (patch) | |
tree | dfc836b66ed803fafc4c274a7c196551ce6ba387 /server/store | |
parent | 46514b6826b6097b5618c95ac240ee8c24d2c6e8 (diff) |
Use a new session for each request
Diffstat (limited to 'server/store')
-rw-r--r-- | server/store/mongo.go | 51 |
1 files changed, 32 insertions, 19 deletions
diff --git a/server/store/mongo.go b/server/store/mongo.go index 8a3ccda..1b13d7a 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -15,11 +15,6 @@ var ( issuedTable = "issued_certs" ) -type mongoDB struct { - collection *mgo.Collection - session *mgo.Session -} - func parseMongoConfig(config string) *mgo.DialInfo { s := strings.SplitN(config, ":", 4) _, user, passwd, hosts := s[0], s[1], s[2], s[3] @@ -33,25 +28,33 @@ func parseMongoConfig(config string) *mgo.DialInfo { return d } +func collection(session *mgo.Session) *mgo.Collection { + return session.DB(certsDB).C(issuedTable) +} + // NewMongoStore returns a MongoDB CertStorer. func NewMongoStore(config string) (CertStorer, error) { session, err := mgo.DialWithInfo(parseMongoConfig(config)) if err != nil { return nil, err } - c := session.DB(certsDB).C(issuedTable) return &mongoDB{ - collection: c, - session: session, + session: session, }, nil } +type mongoDB struct { + session *mgo.Session +} + func (m *mongoDB) Get(id string) (*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } c := &CertRecord{} - err := m.collection.Find(bson.M{"keyid": id}).One(c) + err := collection(s).Find(bson.M{"keyid": id}).One(c) return c, err } @@ -61,39 +64,49 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error { } func (m *mongoDB) SetRecord(record *CertRecord) error { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return err } - return m.collection.Insert(record) + return collection(s).Insert(record) } func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } var result []*CertRecord var err error + c := collection(s) if includeExpired { - err = m.collection.Find(nil).All(&result) + err = c.Find(nil).All(&result) } else { - err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) + err = c.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result) } return result, err } func (m *mongoDB) Revoke(id string) error { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return err } - return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) + c := collection(s) + return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}}) } func (m *mongoDB) GetRevoked() ([]*CertRecord, error) { - if err := m.session.Ping(); err != nil { + s := m.session.Copy() + defer s.Close() + if err := s.Ping(); err != nil { return nil, err } var result []*CertRecord - err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) + err := collection(s).Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result) return result, err } |