aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--server/store/mongo.go51
1 files changed, 32 insertions, 19 deletions
diff --git a/server/store/mongo.go b/server/store/mongo.go
index 8a3ccda..1b13d7a 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -15,11 +15,6 @@ var (
issuedTable = "issued_certs"
)
-type mongoDB struct {
- collection *mgo.Collection
- session *mgo.Session
-}
-
func parseMongoConfig(config string) *mgo.DialInfo {
s := strings.SplitN(config, ":", 4)
_, user, passwd, hosts := s[0], s[1], s[2], s[3]
@@ -33,25 +28,33 @@ func parseMongoConfig(config string) *mgo.DialInfo {
return d
}
+func collection(session *mgo.Session) *mgo.Collection {
+ return session.DB(certsDB).C(issuedTable)
+}
+
// NewMongoStore returns a MongoDB CertStorer.
func NewMongoStore(config string) (CertStorer, error) {
session, err := mgo.DialWithInfo(parseMongoConfig(config))
if err != nil {
return nil, err
}
- c := session.DB(certsDB).C(issuedTable)
return &mongoDB{
- collection: c,
- session: session,
+ session: session,
}, nil
}
+type mongoDB struct {
+ session *mgo.Session
+}
+
func (m *mongoDB) Get(id string) (*CertRecord, error) {
- if err := m.session.Ping(); err != nil {
+ s := m.session.Copy()
+ defer s.Close()
+ if err := s.Ping(); err != nil {
return nil, err
}
c := &CertRecord{}
- err := m.collection.Find(bson.M{"keyid": id}).One(c)
+ err := collection(s).Find(bson.M{"keyid": id}).One(c)
return c, err
}
@@ -61,39 +64,49 @@ func (m *mongoDB) SetCert(cert *ssh.Certificate) error {
}
func (m *mongoDB) SetRecord(record *CertRecord) error {
- if err := m.session.Ping(); err != nil {
+ s := m.session.Copy()
+ defer s.Close()
+ if err := s.Ping(); err != nil {
return err
}
- return m.collection.Insert(record)
+ return collection(s).Insert(record)
}
func (m *mongoDB) List(includeExpired bool) ([]*CertRecord, error) {
- if err := m.session.Ping(); err != nil {
+ s := m.session.Copy()
+ defer s.Close()
+ if err := s.Ping(); err != nil {
return nil, err
}
var result []*CertRecord
var err error
+ c := collection(s)
if includeExpired {
- err = m.collection.Find(nil).All(&result)
+ err = c.Find(nil).All(&result)
} else {
- err = m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
+ err = c.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}}).All(&result)
}
return result, err
}
func (m *mongoDB) Revoke(id string) error {
- if err := m.session.Ping(); err != nil {
+ s := m.session.Copy()
+ defer s.Close()
+ if err := s.Ping(); err != nil {
return err
}
- return m.collection.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
+ c := collection(s)
+ return c.Update(bson.M{"keyid": id}, bson.M{"$set": bson.M{"revoked": true}})
}
func (m *mongoDB) GetRevoked() ([]*CertRecord, error) {
- if err := m.session.Ping(); err != nil {
+ s := m.session.Copy()
+ defer s.Close()
+ if err := s.Ping(); err != nil {
return nil, err
}
var result []*CertRecord
- err := m.collection.Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
+ err := collection(s).Find(bson.M{"expires": bson.M{"$gte": time.Now().UTC()}, "revoked": true}).All(&result)
return result, err
}