From 51cc4c07b2a2b6345b1496baac865f5faf955e7d Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Fri, 20 Jan 2017 00:52:56 +0000 Subject: Switch from database/sql to sqlx --- server/store/sqldb.go | 98 ++++++++++---------------------------- server/store/store.go | 15 +++--- server/store/store_test.go | 5 +- server/store/types/string_slice.go | 37 ++++++++++++++ 4 files changed, 73 insertions(+), 82 deletions(-) create mode 100644 server/store/types/string_slice.go (limited to 'server') diff --git a/server/store/sqldb.go b/server/store/sqldb.go index a51678e..2efca0e 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -1,8 +1,6 @@ package store import ( - "database/sql" - "encoding/json" "fmt" "net" "time" @@ -10,6 +8,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" "github.com/nsheridan/cashier/server/config" ) @@ -17,14 +16,14 @@ var _ CertStorer = (*SQLStore)(nil) // SQLStore is an sql-based CertStorer type SQLStore struct { - conn *sql.DB - - get *sql.Stmt - set *sql.Stmt - listAll *sql.Stmt - listCurrent *sql.Stmt - revoke *sql.Stmt - revoked *sql.Stmt + conn *sqlx.DB + + get *sqlx.Stmt + set *sqlx.Stmt + listAll *sqlx.Stmt + listCurrent *sqlx.Stmt + revoke *sqlx.Stmt + revoked *sqlx.Stmt } // NewSQLStore returns a *sql.DB CertStorer. @@ -52,7 +51,7 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { driver = "sqlite3" dsn = c["filename"] } - conn, err := sql.Open(driver, dsn) + conn, err := sqlx.Open(driver, dsn) if err != nil { return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err) } @@ -65,22 +64,22 @@ func NewSQLStore(c config.Database) (*SQLStore, error) { conn: conn, } - if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { + if db.set, err = conn.Preparex("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil { return nil, fmt.Errorf("SQLStore: prepare set: %v", err) } - if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { + if db.get, err = conn.Preparex("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("SQLStore: prepare get: %v", err) } - if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil { + if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err) } - if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { + if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err) } - if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { + if db.revoke, err = conn.Preparex("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil { return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err) } - if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { + if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil { return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err) } return db, nil @@ -91,38 +90,13 @@ type rowScanner interface { Scan(dest ...interface{}) error } -func scanCert(s rowScanner) (*CertRecord, error) { - var ( - keyID sql.NullString - principals sql.NullString - createdAt time.Time - expires time.Time - revoked sql.NullBool - raw sql.NullString - ) - if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil { - return nil, err - } - var p []string - if err := json.Unmarshal([]byte(principals.String), &p); err != nil { - return nil, err - } - return &CertRecord{ - KeyID: keyID.String, - Principals: p, - CreatedAt: createdAt, - Expires: expires, - Revoked: revoked.Bool, - Raw: raw.String, - }, nil -} - // Get a single *CertRecord func (db *SQLStore) Get(id string) (*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } - return scanCert(db.get.QueryRow(id)) + r := &CertRecord{} + return r, db.get.Get(r, id) } // SetCert parses a *ssh.Certificate and records it @@ -132,14 +106,10 @@ func (db *SQLStore) SetCert(cert *ssh.Certificate) error { // SetRecord records a *CertRecord func (db *SQLStore) SetRecord(rec *CertRecord) error { - principals, err := json.Marshal(rec.Principals) - if err != nil { - return err - } if err := db.conn.Ping(); err != nil { return err } - _, err = db.set.Exec(rec.KeyID, principals, rec.CreatedAt, rec.Expires, rec.Raw) + _, err := db.set.Exec(rec.KeyID, rec.Principals, rec.CreatedAt, rec.Expires, rec.Raw) return err } @@ -149,20 +119,9 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) { if err := db.conn.Ping(); err != nil { return nil, err } - var recs []*CertRecord - var rows *sql.Rows - if includeExpired { - rows, _ = db.listAll.Query() - } else { - rows, _ = db.listCurrent.Query(time.Now().UTC()) - } - defer rows.Close() - for rows.Next() { - cert, err := scanCert(rows) - if err != nil { - return nil, err - } - recs = append(recs, cert) + recs := []*CertRecord{} + if err := db.listAll.Select(&recs); err != nil { + return nil, err } return recs, nil } @@ -172,8 +131,7 @@ func (db *SQLStore) Revoke(id string) error { if err := db.conn.Ping(); err != nil { return err } - _, err := db.revoke.Exec(id) - if err != nil { + if _, err := db.revoke.Exec(id); err != nil { return err } return nil @@ -185,14 +143,8 @@ func (db *SQLStore) GetRevoked() ([]*CertRecord, error) { return nil, err } var recs []*CertRecord - rows, _ := db.revoked.Query(time.Now().UTC()) - defer rows.Close() - for rows.Next() { - cert, err := scanCert(rows) - if err != nil { - return nil, err - } - recs = append(recs, cert) + if err := db.revoked.Select(&recs, time.Now().UTC()); err != nil { + return nil, err } return recs, nil } diff --git a/server/store/store.go b/server/store/store.go index 8af77e3..249489a 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -7,6 +7,7 @@ import ( "github.com/nsheridan/cashier/lib" "github.com/nsheridan/cashier/server/config" + "github.com/nsheridan/cashier/server/store/types" ) // New returns a new configured database. @@ -36,12 +37,12 @@ type CertStorer interface { // A CertRecord is a representation of a ssh certificate used by a CertStorer. type CertRecord struct { - KeyID string `json:"key_id"` - Principals []string `json:"principals"` - CreatedAt time.Time `json:"created_at"` - Expires time.Time `json:"expires"` - Revoked bool `json:"revoked"` - Raw string `json:"-"` + KeyID string `json:"key_id" db:"key_id"` + Principals types.StringSlice `json:"principals" db:"principals"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + Expires time.Time `json:"expires" db:"expires_at"` + Revoked bool `json:"revoked" db:"revoked"` + Raw string `json:"-" db:"raw_key"` } func parseTime(t uint64) time.Time { @@ -51,7 +52,7 @@ func parseTime(t uint64) time.Time { func parseCertificate(cert *ssh.Certificate) *CertRecord { return &CertRecord{ KeyID: cert.KeyId, - Principals: cert.ValidPrincipals, + Principals: types.StringSlice(cert.ValidPrincipals), CreatedAt: parseTime(cert.ValidAfter), Expires: parseTime(cert.ValidBefore), Raw: lib.GetPublicKey(cert), diff --git a/server/store/store_test.go b/server/store/store_test.go index afe6c03..4196c37 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/nsheridan/cashier/server/store/types" "github.com/nsheridan/cashier/testdata" "github.com/stretchr/testify/assert" @@ -25,7 +26,7 @@ func TestParseCertificate(t *testing.T) { pub, _ := ssh.NewPublicKey(r.Public()) c := &ssh.Certificate{ KeyId: "id", - ValidPrincipals: []string{"principal"}, + ValidPrincipals: types.StringSlice{"principal"}, ValidBefore: now, CertType: ssh.UserCert, Key: pub, @@ -35,7 +36,7 @@ func TestParseCertificate(t *testing.T) { rec := parseCertificate(c) a.Equal(c.KeyId, rec.KeyID) - a.Equal(c.ValidPrincipals, rec.Principals) + a.Equal(c.ValidPrincipals, []string(rec.Principals)) a.Equal(c.ValidBefore, uint64(rec.Expires.Unix())) a.Equal(c.ValidAfter, uint64(rec.CreatedAt.Unix())) } diff --git a/server/store/types/string_slice.go b/server/store/types/string_slice.go new file mode 100644 index 0000000..81b38c3 --- /dev/null +++ b/server/store/types/string_slice.go @@ -0,0 +1,37 @@ +package types + +import ( + "database/sql/driver" + "encoding/json" +) + +// StringSlice is a []string which will be stored in a database as a JSON array. +type StringSlice []string + +var _ driver.Valuer = (*StringSlice)(nil) + +// Value implements the driver.Valuer interface, marshalling the raw value to +// a JSON array. +func (s StringSlice) Value() (driver.Value, error) { + v, err := json.Marshal(s) + if err != nil { + return nil, err + } + return string(v), err +} + +// Scan implements the sql.Scanner interface, unmarshalling the value coming +// off the wire and storing the result in the StringSlice. +func (s *StringSlice) Scan(value interface{}) error { + if value == nil { + s = &StringSlice{} + return nil + } + var err error + if v, err := driver.String.ConvertValue(value); err == nil { + if v, ok := v.([]byte); ok { + err = json.Unmarshal(v, s) + } + } + return err +} -- cgit v1.2.3