aboutsummaryrefslogtreecommitdiff
path: root/server/store
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2016-08-08 23:39:46 +0100
committerNiall Sheridan <nsheridan@gmail.com>2016-08-09 22:07:58 +0100
commit030ff273473f0a5620ba276a370e5119f57179df (patch)
tree035f1d5f629a228f5d4a170583e92726dbefb604 /server/store
parent66a7d51577c83da7cc3cf385a188799fe885cd3a (diff)
SQLite DB support
Diffstat (limited to 'server/store')
-rw-r--r--server/store/config_test.go12
-rw-r--r--server/store/sqldb.go (renamed from server/store/mysql.go)77
-rw-r--r--server/store/store_test.go23
3 files changed, 77 insertions, 35 deletions
diff --git a/server/store/config_test.go b/server/store/config_test.go
index 8e283f5..f262b57 100644
--- a/server/store/config_test.go
+++ b/server/store/config_test.go
@@ -11,15 +11,15 @@ import (
func TestMySQLConfig(t *testing.T) {
var tests = []struct {
in string
- out string
+ out []string
}{
- {"mysql:user:passwd:localhost", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"},
- {"mysql:user:passwd:localhost:13306", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"},
- {"mysql:root::localhost", "root@tcp(localhost:3306)/certs?parseTime=true"},
+ {"mysql:user:passwd:localhost", []string{"mysql", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}},
+ {"mysql:user:passwd:localhost:13306", []string{"mysql", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}},
+ {"mysql:root::localhost", []string{"mysql", "root@tcp(localhost:3306)/certs?parseTime=true"}},
}
for _, tt := range tests {
- result := parseMySQLConfig(tt.in)
- if result != tt.out {
+ result := parse(tt.in)
+ if !reflect.DeepEqual(result, tt.out) {
t.Errorf("want %s, got %s", tt.out, result)
}
}
diff --git a/server/store/mysql.go b/server/store/sqldb.go
index 7a0b111..2ea5ea5 100644
--- a/server/store/mysql.go
+++ b/server/store/sqldb.go
@@ -10,9 +10,10 @@ import (
"golang.org/x/crypto/ssh"
"github.com/go-sql-driver/mysql"
+ _ "github.com/mattn/go-sqlite3" // required by sql driver
)
-type mysqlDB struct {
+type sqldb struct {
conn *sql.DB
get *sql.Stmt
@@ -22,8 +23,12 @@ type mysqlDB struct {
revoked *sql.Stmt
}
-func parseMySQLConfig(config string) string {
+func parse(config string) []string {
s := strings.Split(config, ":")
+ if s[0] == "sqlite" {
+ s[0] = "sqlite3"
+ return s
+ }
if len(s) == 4 {
s = append(s, "3306")
}
@@ -36,38 +41,39 @@ func parseMySQLConfig(config string) string {
DBName: "certs",
ParseTime: true,
}
- return c.FormatDSN()
+ return []string{"mysql", c.FormatDSN()}
}
-// NewMySQLStore returns a MySQL CertStorer.
-func NewMySQLStore(config string) (CertStorer, error) {
- conn, err := sql.Open("mysql", parseMySQLConfig(config))
+// NewSQLStore returns a *sql.DB CertStorer.
+func NewSQLStore(config string) (CertStorer, error) {
+ parsed := parse(config)
+ conn, err := sql.Open(parsed[0], parsed[1])
if err != nil {
- return nil, fmt.Errorf("mysql: could not get a connection: %v", err)
+ return nil, fmt.Errorf("sqldb: could not get a connection: %v", err)
}
if err := conn.Ping(); err != nil {
conn.Close()
- return nil, fmt.Errorf("mysql: could not establish a good connection: %v", err)
+ return nil, fmt.Errorf("sqldb: could not establish a good connection: %v", err)
}
- db := &mysqlDB{
+ db := &sqldb{
conn: conn,
}
- if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?) ON DUPLICATE KEY UPDATE key_id = VALUES(key_id), principals = VALUES(principals), created_at = VALUES(created_at), expires_at = VALUES(expires_at), raw_key = VALUES(raw_key)"); err != nil {
- return nil, fmt.Errorf("mysql: prepare set: %v", err)
+ if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare set: %v", err)
}
if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
- return nil, fmt.Errorf("mysql: prepare get: %v", err)
+ return nil, fmt.Errorf("sqldb: prepare get: %v", err)
}
if db.list, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
- return nil, fmt.Errorf("mysql: prepare list: %v", err)
+ return nil, fmt.Errorf("sqldb: prepare list: %v", err)
}
- if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = TRUE WHERE key_id = ?"); err != nil {
- return nil, fmt.Errorf("mysql: prepare revoke: %v", err)
+ if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare revoke: %v", err)
}
- if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = TRUE AND ? <= expires_at"); err != nil {
- return nil, fmt.Errorf("mysql: prepare revoked: %v", err)
+ if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
+ return nil, fmt.Errorf("sqldb: prepare revoked: %v", err)
}
return db, nil
}
@@ -81,8 +87,8 @@ func scanCert(s rowScanner) (*CertRecord, error) {
var (
keyID sql.NullString
principals sql.NullString
- createdAt mysql.NullTime
- expires mysql.NullTime
+ createdAt time.Time
+ expires time.Time
revoked sql.NullBool
raw sql.NullString
)
@@ -96,31 +102,40 @@ func scanCert(s rowScanner) (*CertRecord, error) {
return &CertRecord{
KeyID: keyID.String,
Principals: p,
- CreatedAt: createdAt.Time,
- Expires: expires.Time,
+ CreatedAt: createdAt,
+ Expires: expires,
Revoked: revoked.Bool,
Raw: raw.String,
}, nil
}
-func (db *mysqlDB) Get(id string) (*CertRecord, error) {
+func (db *sqldb) Get(id string) (*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
return scanCert(db.get.QueryRow(id))
}
-func (db *mysqlDB) SetCert(cert *ssh.Certificate) error {
+func (db *sqldb) SetCert(cert *ssh.Certificate) error {
return db.SetRecord(parseCertificate(cert))
}
-func (db *mysqlDB) SetRecord(rec *CertRecord) error {
+func (db *sqldb) SetRecord(rec *CertRecord) error {
principals, err := json.Marshal(rec.Principals)
if err != nil {
return err
}
+ if err := db.conn.Ping(); err != nil {
+ return err
+ }
_, err = db.set.Exec(rec.KeyID, string(principals), rec.CreatedAt, rec.Expires, rec.Raw)
return err
}
-func (db *mysqlDB) List() ([]*CertRecord, error) {
+func (db *sqldb) List() ([]*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
var recs []*CertRecord
rows, _ := db.list.Query()
defer rows.Close()
@@ -134,7 +149,10 @@ func (db *mysqlDB) List() ([]*CertRecord, error) {
return recs, nil
}
-func (db *mysqlDB) Revoke(id string) error {
+func (db *sqldb) Revoke(id string) error {
+ if err := db.conn.Ping(); err != nil {
+ return err
+ }
_, err := db.revoke.Exec(id)
if err != nil {
return err
@@ -142,7 +160,10 @@ func (db *mysqlDB) Revoke(id string) error {
return nil
}
-func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) {
+func (db *sqldb) GetRevoked() ([]*CertRecord, error) {
+ if err := db.conn.Ping(); err != nil {
+ return nil, err
+ }
var recs []*CertRecord
rows, _ := db.revoked.Query(time.Now().UTC())
defer rows.Close()
@@ -156,6 +177,6 @@ func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) {
return recs, nil
}
-func (db *mysqlDB) Close() error {
+func (db *sqldb) Close() error {
return db.conn.Close()
}
diff --git a/server/store/store_test.go b/server/store/store_test.go
index bf16fa6..629230b 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -3,7 +3,10 @@ package store
import (
"crypto/rand"
"crypto/rsa"
+ "fmt"
+ "io/ioutil"
"os"
+ "os/exec"
"testing"
"time"
@@ -94,7 +97,7 @@ func TestMySQLStore(t *testing.T) {
if config == "" {
t.Skip("No MYSQL_TEST_CONFIG environment variable")
}
- db, err := NewMySQLStore(config)
+ db, err := NewSQLStore(config)
if err != nil {
t.Error(err)
}
@@ -112,3 +115,21 @@ func TestMongoStore(t *testing.T) {
}
testStore(t, db)
}
+
+func TestSQLiteStore(t *testing.T) {
+ f, err := ioutil.TempFile("", "sqlite_test_db")
+ if err != nil {
+ t.Error(err)
+ }
+ defer os.Remove(f.Name())
+ // This is so jank.
+ args := []string{"run", "../../cmd/dbinit/dbinit.go", "-db_type", "sqlite", "-db_path", f.Name()}
+ if err := exec.Command("go", args...).Run(); err != nil {
+ t.Error(err)
+ }
+ db, err := NewSQLStore(fmt.Sprintf("sqlite:%s", f.Name()))
+ if err != nil {
+ t.Error(err)
+ }
+ testStore(t, db)
+}