From d836a4496de7b24a9d3317e274800d35053a04f6 Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 5 Aug 2018 23:00:58 +0100 Subject: Manage db schema with rubenv/sql-migrate It's currently hard to make changes to the database schema. Use sql-migrate to make incremental changes. Stop hard-coding the database name (the default is still "certs" for backward-compatibility) The `automigrate()` function will automatically run pending migrations. Use a different migration directory per database driver. This carries a cost of duplication, but is easier than creating migrations which will cleanly execute in both SQLite and MySQL. Migrations are shipped using the packr utility. --- server/store/sqldb.go | 48 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 10 deletions(-) (limited to 'server/store/sqldb.go') diff --git a/server/store/sqldb.go b/server/store/sqldb.go index 0b91023..3526a2b 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -2,15 +2,20 @@ package store import ( "fmt" + "log" "net" + "path" "time" "golang.org/x/crypto/ssh" "github.com/go-sql-driver/mysql" + "github.com/gobuffalo/packr" + multierror "github.com/hashicorp/go-multierror" "github.com/jmoiron/sqlx" "github.com/nsheridan/cashier/server/config" "github.com/pkg/errors" + migrate "github.com/rubenv/sql-migrate" ) var _ CertStorer = (*SQLStore)(nil) @@ -42,20 +47,24 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { m.User = c["username"] m.Passwd = c["password"] m.Addr = address - m.DBName = "certs" + m.Net = "tcp" + m.DBName = c["dbname"] + if m.DBName == "" { + m.DBName = "certs" // Legacy database name + } m.ParseTime = true dsn = m.FormatDSN() case "sqlite": driver = "sqlite3" dsn = c["filename"] } - conn, err := sqlx.Open(driver, dsn) + + conn, err := sqlx.Connect(driver, dsn) if err != nil { return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err) } - if err := conn.Ping(); err != nil { - conn.Close() - return nil, fmt.Errorf("SQLStore: could not establish a good connection: %v", err) + if err := autoMigrate(driver, conn); err != nil { + return nil, fmt.Errorf("SQLStore: could not update schema: %v", err) } db := &SQLStore{ @@ -71,7 +80,7 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err) } - if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { + if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE expires_at >= ?"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) } if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { @@ -80,6 +89,25 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { return db, nil } +func autoMigrate(driver string, conn *sqlx.DB) error { + log.Print("Executing any pending schema migrations") + var err error + migrate.SetTable("schema_migrations") + srcs := &migrate.PackrMigrationSource{ + Box: packr.NewBox(path.Join("migrations", driver)), + } + n, err := migrate.Exec(conn.DB, driver, srcs, migrate.Up) + if err != nil { + err = multierror.Append(err) + return err + } + log.Printf("Executed %d migrations", n) + if err != nil { + log.Fatalf("Errors were found running migrations: %v", err) + } + return nil +} + // rowScanner is implemented by sql.Row and sql.Rows type rowScanner interface { Scan(dest ...interface{}) error @@ -88,7 +116,7 @@ type rowScanner interface { // Get a single *CertRecord func (db *SQLStore) Get(id string) (*CertRecord, error) { if err := db.conn.Ping(); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to connect to database") } r := &CertRecord{} return r, db.get.Get(r, id) @@ -102,7 +130,7 @@ func (db *SQLStore) SetCert(cert *ssh.Certificate) error { // SetRecord records a *CertRecord func (db *SQLStore) SetRecord(rec *CertRecord) error { if err := db.conn.Ping(); err != nil { - return err + return errors.Wrap(err, "unable to connect to database") } _, err := db.set.Exec(rec.KeyID, rec.Principals, rec.CreatedAt, rec.Expires, rec.Raw) return err @@ -112,7 +140,7 @@ func (db *SQLStore) SetRecord(rec *CertRecord) error { // By default only active certs are returned. func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to connect to database") } recs := []*CertRecord{} if includeExpired { @@ -140,7 +168,7 @@ func (db *SQLStore) Revoke(ids []string) error { // GetRevoked returns all revoked certs func (db *SQLStore) GetRevoked() ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { - return nil, err + return nil, errors.Wrap(err, "unable to connect to database") } var recs []*CertRecord if err := db.revoked.Select(&recs, time.Now().UTC()); err != nil { -- cgit v1.2.3