diff options
Diffstat (limited to 'server/store')
-rw-r--r-- | server/store/config_test.go | 12 | ||||
-rw-r--r-- | server/store/sqldb.go (renamed from server/store/mysql.go) | 77 | ||||
-rw-r--r-- | server/store/store_test.go | 23 |
3 files changed, 77 insertions, 35 deletions
diff --git a/server/store/config_test.go b/server/store/config_test.go index 8e283f5..f262b57 100644 --- a/server/store/config_test.go +++ b/server/store/config_test.go @@ -11,15 +11,15 @@ import ( func TestMySQLConfig(t *testing.T) { var tests = []struct { in string - out string + out []string }{ - {"mysql:user:passwd:localhost", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}, - {"mysql:user:passwd:localhost:13306", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}, - {"mysql:root::localhost", "root@tcp(localhost:3306)/certs?parseTime=true"}, + {"mysql:user:passwd:localhost", []string{"mysql", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}}, + {"mysql:user:passwd:localhost:13306", []string{"mysql", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}}, + {"mysql:root::localhost", []string{"mysql", "root@tcp(localhost:3306)/certs?parseTime=true"}}, } for _, tt := range tests { - result := parseMySQLConfig(tt.in) - if result != tt.out { + result := parse(tt.in) + if !reflect.DeepEqual(result, tt.out) { t.Errorf("want %s, got %s", tt.out, result) } } diff --git a/server/store/mysql.go b/server/store/sqldb.go index 7a0b111..2ea5ea5 100644 --- a/server/store/mysql.go +++ b/server/store/sqldb.go @@ -10,9 +10,10 @@ import ( "golang.org/x/crypto/ssh" "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" // required by sql driver ) -type mysqlDB struct { +type sqldb struct { conn *sql.DB get *sql.Stmt @@ -22,8 +23,12 @@ type mysqlDB struct { revoked *sql.Stmt } -func parseMySQLConfig(config string) string { +func parse(config string) []string { s := strings.Split(config, ":") + if s[0] == "sqlite" { + s[0] = "sqlite3" + return s + } if len(s) == 4 { s = append(s, "3306") } @@ -36,38 +41,39 @@ func parseMySQLConfig(config string) string { DBName: "certs", ParseTime: true, } - return c.FormatDSN() + return []string{"mysql", c.FormatDSN()} } -// NewMySQLStore returns a MySQL CertStorer. -func NewMySQLStore(config string) (CertStorer, error) { - conn, err := sql.Open("mysql", parseMySQLConfig(config)) +// NewSQLStore returns a *sql.DB CertStorer. +func NewSQLStore(config string) (CertStorer, error) { + parsed := parse(config) + conn, err := sql.Open(parsed[0], parsed[1]) if err != nil { - return nil, fmt.Errorf("mysql: could not get a connection: %v", err) + return nil, fmt.Errorf("sqldb: could not get a connection: %v", err) } if err := conn.Ping(); err != nil { conn.Close() - return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err) + return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err) } - db := &mysqlDB{ + db := &sqldb{ conn: conn, } - if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil { - return nil, fmt.Errorf("mysql: prepare set: %v", err) + if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { + return nil, fmt.Errorf("sqldb: prepare set: %v", err) } if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("mysql: prepare get: %v", err) + return nil, fmt.Errorf("sqldb: prepare get: %v", err) } if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { - return nil, fmt.Errorf("mysql: prepare list: %v", err) + return nil, fmt.Errorf("sqldb: prepare list: %v", err) } - if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil { - return nil, fmt.Errorf("mysql: prepare revoke: %v", err) + if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { + return nil, fmt.Errorf("sqldb: prepare revoke: %v", err) } - if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil { - return nil, fmt.Errorf("mysql: prepare revoked: %v", err) + if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { + return nil, fmt.Errorf("sqldb: prepare revoked: %v", err) } return db, nil } @@ -81,8 +87,8 @@ func scanCert(s rowScanner) (*CertRecord, error) { var ( keyID sql.NullString principals sql.NullString - createdAt mysql.NullTime - expires mysql.NullTime + createdAt time.Time + expires time.Time revoked sql.NullBool raw sql.NullString ) @@ -96,31 +102,40 @@ func scanCert(s rowScanner) (*CertRecord, error) { return &CertRecord{ KeyID: keyID.String, Principals: p, - CreatedAt: createdAt.Time, - Expires: expires.Time, + CreatedAt: createdAt, + Expires: expires, Revoked: revoked.Bool, Raw: raw.String, }, nil } -func (db *mysqlDB) Get(id string) (*CertRecord, error) { +func (db *sqldb) Get(id string) (*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } return scanCert(db.get.QueryRow(id)) } -func (db *mysqlDB) SetCert(cert *ssh.Certificate) error { +func (db *sqldb) SetCert(cert *ssh.Certificate) error { return db.SetRecord(parseCertificate(cert)) } -func (db *mysqlDB) SetRecord(rec *CertRecord) error { +func (db *sqldb) SetRecord(rec *CertRecord) error { principals, err := json.Marshal(rec.Principals) if err != nil { return err } + if err := db.conn.Ping(); err != nil { + return err + } _, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw) return err } -func (db *mysqlDB) List() ([]*CertRecord, error) { +func (db *sqldb) List() ([]*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } var recs []*CertRecord rows, _ := db.list.Query() defer rows.Close() @@ -134,7 +149,10 @@ func (db *mysqlDB) List() ([]*CertRecord, error) { return recs, nil } -func (db *mysqlDB) Revoke(id string) error { +func (db *sqldb) Revoke(id string) error { + if err := db.conn.Ping(); err != nil { + return err + } _, err := db.revoke.Exec(id) if err != nil { return err @@ -142,7 +160,10 @@ func (db *mysqlDB) Revoke(id string) error { return nil } -func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { +func (db *sqldb) GetRevoked() ([]*CertRecord, error) { + if err := db.conn.Ping(); err != nil { + return nil, err + } var recs []*CertRecord rows, _ := db.revoked.Query(time.Now().UTC()) defer rows.Close() @@ -156,6 +177,6 @@ func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { return recs, nil } -func (db *mysqlDB) Close() error { +func (db *sqldb) Close() error { return db.conn.Close() } diff --git a/server/store/store_test.go b/server/store/store_test.go index bf16fa6..629230b 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -3,7 +3,10 @@ package store import ( "crypto/rand" "crypto/rsa" + "fmt" + "io/ioutil" "os" + "os/exec" "testing" "time" @@ -94,7 +97,7 @@ func TestMySQLStore(t *testing.T) { if config == "" { t.Skip("No MYSQL_TEST_CONFIG environment variable") } - db, err := NewMySQLStore(config) + db, err := NewSQLStore(config) if err != nil { t.Error(err) } @@ -112,3 +115,21 @@ func TestMongoStore(t *testing.T) { } testStore(t, db) } + +func TestSQLiteStore(t *testing.T) { + f, err := ioutil.TempFile("", "sqlite_test_db") + if err != nil { + t.Error(err) + } + defer os.Remove(f.Name()) + // This is so jank. + args := []string{"run", "../../cmd/dbinit/dbinit.go", "-db_type", "sqlite", "-db_path", f.Name()} + if err := exec.Command("go", args...).Run(); err != nil { + t.Error(err) + } + db, err := NewSQLStore(fmt.Sprintf("sqlite:%s", f.Name())) + if err != nil { + t.Error(err) + } + testStore(t, db) +} |