From 030ff273473f0a5620ba276a370e5119f57179df Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Mon, 8 Aug 2016 23:39:46 +0100 Subject: SQLite DB support --- server/store/config_test.go | 12 +-- server/store/mysql.go | 161 --------------------------------------- server/store/sqldb.go | 182 ++++++++++++++++++++++++++++++++++++++++++++ server/store/store_test.go | 23 +++++- 4 files changed, 210 insertions(+), 168 deletions(-) delete mode 100644 server/store/mysql.go create mode 100644 server/store/sqldb.go (limited to 'server/store') diff --git a/server/store/config_test.go b/server/store/config_test.go index 8e283f5..f262b57 100644 --- a/server/store/config_test.go +++ b/server/store/config_test.go @@ -11,15 +11,15 @@ import ( func TestMySQLConfig(t *testing.T) { var tests = []struct { in string - out string + out []string }{ - {"mysql:user:passwd:localhost", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}, - {"mysql:user:passwd:localhost:13306", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}, - {"mysql:root::localhost", "root@tcp(localhost:3306)/certs?parseTime=true"}, + {"mysql:user:passwd:localhost", []string{"mysql", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}}, + {"mysql:user:passwd:localhost:13306", []string{"mysql", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}}, + {"mysql:root::localhost", []string{"mysql", "root@tcp(localhost:3306)/certs?parseTime=true"}}, } for _, tt := range tests { - result := parseMySQLConfig(tt.in) - if result != tt.out { + result := parse(tt.in) + if !reflect.DeepEqual(result, tt.out) { t.Errorf("want %s, got %s", tt.out, result) } } diff --git a/server/store/mysql.go b/server/store/mysql.go deleted file mode 100644 index 7a0b111..0000000 --- a/server/store/mysql.go +++ /dev/null @@ -1,161 +0,0 @@ -package store - -import ( - "database/sql" - "encoding/json" - "fmt" - "strings" - "time" - - "golang.org/x/crypto/ssh" - - "github.com/go-sql-driver/mysql" -) - -type mysqlDB struct { - conn *sql.DB - - get *sql.Stmt - set *sql.Stmt - list *sql.Stmt - revoke *sql.Stmt - revoked *sql.Stmt -} - -func parseMySQLConfig(config string) string { - s := strings.Split(config, ":") - if len(s) == 4 { - s = append(s, "3306") - } - _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] - c := &mysql.Config{ - User: user, - Passwd: passwd, - Net: "tcp", - Addr: fmt.Sprintf("%s:%s", host, port), - DBName: "certs", - ParseTime: true, - } - return c.FormatDSN() -} - -// NewMySQLStore returns a MySQL CertStorer. -func NewMySQLStore(config string) (CertStorer, error) { - conn, err := sql.Open("mysql", parseMySQLConfig(config)) - if err != nil { - return nil, fmt.Errorf("mysql: could not get a connection: %v", err) - } - if err := conn.Ping(); err != nil { - conn.Close() - return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err) - } - - db := &mysqlDB{ - conn: conn, - } - - if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil { - return nil, fmt.Errorf("mysql: prepare set: %v", err) - } - if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("mysql: prepare get: %v", err) - } - if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { - return nil, fmt.Errorf("mysql: prepare list: %v", err) - } - if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("mysql: prepare revoke: %v", err) - } - if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil { - return nil, fmt.Errorf("mysql: prepare revoked: %v", err) - } - return db, nil -} - -// rowScanner is implemented by sql.Row and sql.Rows -type rowScanner interface { - Scan(dest ...interface{}) error -} - -func scanCert(s rowScanner) (*CertRecord, error) { - var ( - keyID sql.NullString - principals sql.NullString - createdAt mysql.NullTime - expires mysql.NullTime - revoked sql.NullBool - raw sql.NullString - ) - if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil { - return nil, err - } - var p []string - if err := json.Unmarshal([]byte(principals.String), &p); err != nil { - return nil, err - } - return &CertRecord{ - KeyID: keyID.String, - Principals: p, - CreatedAt: createdAt.Time, - Expires: expires.Time, - Revoked: revoked.Bool, - Raw: raw.String, - }, nil -} - -func (db *mysqlDB) Get(id string) (*CertRecord, error) { - return scanCert(db.get.QueryRow(id)) -} - -func (db *mysqlDB) SetCert(cert *ssh.Certificate) error { - return db.SetRecord(parseCertificate(cert)) -} - -func (db *mysqlDB) SetRecord(rec *CertRecord) error { - principals, err := json.Marshal(rec.Principals) - if err != nil { - return err - } - _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw) - return err -} - -func (db *mysqlDB) List() ([]*CertRecord, error) { - var recs []*CertRecord - rows, _ := db.list.Query() - defer rows.Close() - for rows.Next() { - cert, err := scanCert(rows) - if err != nil { - return nil, err - } - recs = append(recs, cert) - } - return recs, nil -} - -func (db *mysqlDB) Revoke(id string) error { - _, err := db.revoke.Exec(id) - if err != nil { - return err - } - return nil -} - -func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { - var recs []*CertRecord - rows, _ := db.revoked.Query(time.Now().UTC()) - defer rows.Close() - for rows.Next() { - cert, err := scanCert(rows) - if err != nil { - return nil, err - } - recs = append(recs, cert) - } - return recs, nil -} - -func (db *mysqlDB) Close() error { - return db.conn.Close() -} diff --git a/server/store/sqldb.go b/server/store/sqldb.go new file mode 100644 index 0000000..2ea5ea5 --- /dev/null +++ b/server/store/sqldb.go @@ -0,0 +1,182 @@ +package store + +import ( + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" // required by sql driver +) + +type sqldb struct { + conn *sql.DB + + get *sql.Stmt + set *sql.Stmt + list *sql.Stmt + revoke *sql.Stmt + revoked *sql.Stmt +} + +func parse(config string) []string { + s := strings.Split(config, ":") + if s[0] == "sqlite" { + s[0] = "sqlite3" + return s + } + if len(s) == 4 { + s = append(s, "3306") + } + _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] + c := &mysql.Config{ + User: user, + Passwd: passwd, + Net: "tcp", + Addr: fmt.Sprintf("%s:%s", host, port), + DBName: "certs", + ParseTime: true, + } + return []string{"mysql", c.FormatDSN()} +} + +// NewSQLStore returns a *sql.DB CertStorer. +func NewSQLStore(config string) (CertStorer, error) { + parsed := parse(config) + conn, err := sql.Open(parsed[0], parsed[1]) + if err != nil { + return nil, fmt.Errorf("sqldb: could not get a connection: %v", err) + } + if err := conn.Ping(); err != nil { + conn.Close() + return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err) + } + + db := &sqldb{ + conn: conn, + } + + if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { + return nil, fmt.Errorf("sqldb: prepare set: %v", err) + } + if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { + return nil, fmt.Errorf("sqldb: prepare get: %v", err) + } + if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { + return nil, fmt.Errorf("sqldb: prepare list: %v", err) + } + if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { + return nil, fmt.Errorf("sqldb: prepare revoke: %v", err) + } + if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { + return nil, fmt.Errorf("sqldb: prepare revoked: %v", err) + } + return db, nil +} + +// rowScanner is implemented by sql.Row and sql.Rows +type rowScanner interface { + Scan(dest ...interface{}) error +} + +func scanCert(s rowScanner) (*CertRecord, error) { + var ( + keyID sql.NullString + principals sql.NullString + createdAt time.Time + expires time.Time + revoked sql.NullBool + raw sql.NullString + ) + if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil { + return nil, err + } + var p []string + if err := json.Unmarshal([]byte(principals.String), &p); err != nil { + return nil, err + } + return &CertRecord{ + KeyID: keyID.String, + Principals: p, + CreatedAt: createdAt, + Expires: expires, + Revoked: revoked.Bool, + Raw: raw.String, + }, nil +} + +func (db *sqldb) Get(id string) (*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } + return scanCert(db.get.QueryRow(id)) +} + +func (db *sqldb) SetCert(cert *ssh.Certificate) error { + return db.SetRecord(parseCertificate(cert)) +} + +func (db *sqldb) SetRecord(rec *CertRecord) error { + principals, err := json.Marshal(rec.Principals) + if err != nil { + return err + } + if err := db.conn.Ping(); err != nil { + return err + } + _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw) + return err +} + +func (db *sqldb) List() ([]*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } + var recs []*CertRecord + rows, _ := db.list.Query() + defer rows.Close() + for rows.Next() { + cert, err := scanCert(rows) + if err != nil { + return nil, err + } + recs = append(recs, cert) + } + return recs, nil +} + +func (db *sqldb) Revoke(id string) error { + if err := db.conn.Ping(); err != nil { + return err + } + _, err := db.revoke.Exec(id) + if err != nil { + return err + } + return nil +} + +func (db *sqldb) GetRevoked() ([]*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } + var recs []*CertRecord + rows, _ := db.revoked.Query(time.Now().UTC()) + defer rows.Close() + for rows.Next() { + cert, err := scanCert(rows) + if err != nil { + return nil, err + } + recs = append(recs, cert) + } + return recs, nil +} + +func (db *sqldb) Close() error { + return db.conn.Close() +} diff --git a/server/store/store_test.go b/server/store/store_test.go index bf16fa6..629230b 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -3,7 +3,10 @@ package store import ( "crypto/rand" "crypto/rsa" + "fmt" + "io/ioutil" "os" + "os/exec" "testing" "time" @@ -94,7 +97,7 @@ func TestMySQLStore(t *testing.T) { if config == "" { t.Skip("No MYSQL_TEST_CONFIG environment variable") } - db, err := NewMySQLStore(config) + db, err := NewSQLStore(config) if err != nil { t.Error(err) } @@ -112,3 +115,21 @@ func TestMongoStore(t *testing.T) { } testStore(t, db) } + +func TestSQLiteStore(t *testing.T) { + f, err := ioutil.TempFile("", "sqlite_test_db") + if err != nil { + t.Error(err) + } + defer os.Remove(f.Name()) + // This is so jank. + args := []string{"run", "../../cmd/dbinit/dbinit.go", "-db_type", "sqlite", "-db_path", f.Name()} + if err := exec.Command("go", args...).Run(); err != nil { + t.Error(err) + } + db, err := NewSQLStore(fmt.Sprintf("sqlite:%s", f.Name())) + if err != nil { + t.Error(err) + } + testStore(t, db) +} -- cgit v1.2.3