aboutsummaryrefslogtreecommitdiff
path: root/server/store/sqldb.go
diff options
context:
space:
mode:
Diffstat (limited to 'server/store/sqldb.go')
-rw-r--r--server/store/sqldb.go182
1 files changed, 182 insertions, 0 deletions
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
new file mode 100644
index 0000000..2ea5ea5
--- /dev/null
+++ b/server/store/sqldb.go
@@ -0,0 +1,182 @@
+package store
+
+import (
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+
+ "github.com/go-sql-driver/mysql"
+ _ "github.com/mattn/go-sqlite3" // required by sql driver
+)
+
+type sqldb struct {
+ conn *sql.DB
+
+ get *sql.Stmt
+ set *sql.Stmt
+ list *sql.Stmt
+ revoke *sql.Stmt
+ revoked *sql.Stmt
+}
+
+func parse(config string) []string {
+ s := strings.Split(config, ":")
+ if s[0] == "sqlite" {
+ s[0] = "sqlite3"
+ return s
+ }
+ if len(s) == 4 {
+ s = append(s, "3306")
+ }
+ _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4]
+ c := &mysql.Config{
+ User: user,
+ Passwd: passwd,
+ Net: "tcp",
+ Addr: fmt.Sprintf("%s:%s", host, port),
+ DBName: "certs",
+ ParseTime: true,
+ }
+ return []string{"mysql", c.FormatDSN()}
+}
+
+// NewSQLStore returns a *sql.DB CertStorer.
+func NewSQLStore(config string) (CertStorer, error) {
+ parsed := parse(config)
+ conn, err := sql.Open(parsed[0], parsed[1])
+ if err != nil {
+ return nil, fmt.Errorf("sqldb: could not get a connection: %v", err)
+ }
+ if err := conn.Ping(); err != nil {
+ conn.Close()
+ return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err)
+ }
+
+ db := &sqldb{
+ conn: conn,
+ }
+
+ if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare set: %v", err)
+ }
+ if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare get: %v", err)
+ }
+ if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare list: %v", err)
+ }
+ if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
+ }
+ if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare revoked: %v", err)
+ }
+ return db, nil
+}
+
+// rowScanner is implemented by sql.Row and sql.Rows
+type rowScanner interface {
+ Scan(dest ...interface{}) error
+}
+
+func scanCert(s rowScanner) (*CertRecord, error) {
+ var (
+ keyID sql.NullString
+ principals sql.NullString
+ createdAt time.Time
+ expires time.Time
+ revoked sql.NullBool
+ raw sql.NullString
+ )
+ if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil {
+ return nil, err
+ }
+ var p []string
+ if err := json.Unmarshal([]byte(principals.String), &p); err != nil {
+ return nil, err
+ }
+ return &CertRecord{
+ KeyID: keyID.String,
+ Principals: p,
+ CreatedAt: createdAt,
+ Expires: expires,
+ Revoked: revoked.Bool,
+ Raw: raw.String,
+ }, nil
+}
+
+func (db *sqldb) Get(id string) (*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
+ return scanCert(db.get.QueryRow(id))
+}
+
+func (db *sqldb) SetCert(cert *ssh.Certificate) error {
+ return db.SetRecord(parseCertificate(cert))
+}
+
+func (db *sqldb) SetRecord(rec *CertRecord) error {
+ principals, err := json.Marshal(rec.Principals)
+ if err != nil {
+ return err
+ }
+ if err := db.conn.Ping(); err != nil {
+ return err
+ }
+ _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw)
+ return err
+}
+
+func (db *sqldb) List() ([]*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
+ var recs []*CertRecord
+ rows, _ := db.list.Query()
+ defer rows.Close()
+ for rows.Next() {
+ cert, err := scanCert(rows)
+ if err != nil {
+ return nil, err
+ }
+ recs = append(recs, cert)
+ }
+ return recs, nil
+}
+
+func (db *sqldb) Revoke(id string) error {
+ if err := db.conn.Ping(); err != nil {
+ return err
+ }
+ _, err := db.revoke.Exec(id)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
+ var recs []*CertRecord
+ rows, _ := db.revoked.Query(time.Now().UTC())
+ defer rows.Close()
+ for rows.Next() {
+ cert, err := scanCert(rows)
+ if err != nil {
+ return nil, err
+ }
+ recs = append(recs, cert)
+ }
+ return recs, nil
+}
+
+func (db *sqldb) Close() error {
+ return db.conn.Close()
+}