aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorNiall Sheridan <nsheridan@gmail.com>2017-01-20 00:52:56 +0000
committerNiall Sheridan <nsheridan@gmail.com>2017-01-22 22:25:35 +0000
commit51cc4c07b2a2b6345b1496baac865f5faf955e7d (patch)
treeedd51d045954eb802c470be4481a1d130d5f988c /server
parentfb4a1232be3b2d00483a7399e7131c211d8cd551 (diff)
Switch from database/sql to sqlx
Diffstat (limited to 'server')
-rw-r--r--server/store/sqldb.go98
-rw-r--r--server/store/store.go15
-rw-r--r--server/store/store_test.go5
-rw-r--r--server/store/types/string_slice.go37
4 files changed, 73 insertions, 82 deletions
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index a51678e..2efca0e 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -1,8 +1,6 @@
package store
import (
- "database/sql"
- "encoding/json"
"fmt"
"net"
"time"
@@ -10,6 +8,7 @@ import (
"golang.org/x/crypto/ssh"
"github.com/go-sql-driver/mysql"
+ "github.com/jmoiron/sqlx"
"github.com/nsheridan/cashier/server/config"
)
@@ -17,14 +16,14 @@ var _ CertStorer = (*SQLStore)(nil)
// SQLStore is an sql-based CertStorer
type SQLStore struct {
- conn *sql.DB
-
- get *sql.Stmt
- set *sql.Stmt
- listAll *sql.Stmt
- listCurrent *sql.Stmt
- revoke *sql.Stmt
- revoked *sql.Stmt
+ conn *sqlx.DB
+
+ get *sqlx.Stmt
+ set *sqlx.Stmt
+ listAll *sqlx.Stmt
+ listCurrent *sqlx.Stmt
+ revoke *sqlx.Stmt
+ revoked *sqlx.Stmt
}
// NewSQLStore returns a *sql.DB CertStorer.
@@ -52,7 +51,7 @@ func NewSQLStore(c config.Database) (*SQLStore, error) {
driver = "sqlite3"
dsn = c["filename"]
}
- conn, err := sql.Open(driver, dsn)
+ conn, err := sqlx.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("SQLStore: could not get a connection: %v", err)
}
@@ -65,22 +64,22 @@ func NewSQLStore(c config.Database) (*SQLStore, error) {
conn: conn,
}
- if db.set, err = conn.Prepare("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
+ if db.set, err = conn.Preparex("INSERT INTO issued_certs (key_id, principals, created_at, expires_at, raw_key) VALUES (?, ?, ?, ?, ?)"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare set: %v", err)
}
- if db.get, err = conn.Prepare("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
+ if db.get, err = conn.Preparex("SELECT * FROM issued_certs WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare get: %v", err)
}
- if db.listAll, err = conn.Prepare("SELECT * FROM issued_certs"); err != nil {
+ if db.listAll, err = conn.Preparex("SELECT * FROM issued_certs"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare listAll: %v", err)
}
- if db.listCurrent, err = conn.Prepare("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
+ if db.listCurrent, err = conn.Preparex("SELECT * FROM issued_certs WHERE ? <= expires_at"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare listCurrent: %v", err)
}
- if db.revoke, err = conn.Prepare("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
+ if db.revoke, err = conn.Preparex("UPDATE issued_certs SET revoked = 1 WHERE key_id = ?"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare revoke: %v", err)
}
- if db.revoked, err = conn.Prepare("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
+ if db.revoked, err = conn.Preparex("SELECT * FROM issued_certs WHERE revoked = 1 AND ? <= expires_at"); err != nil {
return nil, fmt.Errorf("SQLStore: prepare revoked: %v", err)
}
return db, nil
@@ -91,38 +90,13 @@ type rowScanner interface {
Scan(dest ...interface{}) error
}
-func scanCert(s rowScanner) (*CertRecord, error) {
- var (
- keyID sql.NullString
- principals sql.NullString
- createdAt time.Time
- expires time.Time
- revoked sql.NullBool
- raw sql.NullString
- )
- if err := s.Scan(&keyID, &principals, &createdAt, &expires, &revoked, &raw); err != nil {
- return nil, err
- }
- var p []string
- if err := json.Unmarshal([]byte(principals.String), &p); err != nil {
- return nil, err
- }
- return &CertRecord{
- KeyID: keyID.String,
- Principals: p,
- CreatedAt: createdAt,
- Expires: expires,
- Revoked: revoked.Bool,
- Raw: raw.String,
- }, nil
-}
-
// Get a single *CertRecord
func (db *SQLStore) Get(id string) (*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
- return scanCert(db.get.QueryRow(id))
+ r := &CertRecord{}
+ return r, db.get.Get(r, id)
}
// SetCert parses a *ssh.Certificate and records it
@@ -132,14 +106,10 @@ func (db *SQLStore) SetCert(cert *ssh.Certificate) error {
// SetRecord records a *CertRecord
func (db *SQLStore) SetRecord(rec *CertRecord) error {
- principals, err := json.Marshal(rec.Principals)
- if err != nil {
- return err
- }
if err := db.conn.Ping(); err != nil {
return err
}
- _, err = db.set.Exec(rec.KeyID, principals, rec.CreatedAt, rec.Expires, rec.Raw)
+ _, err := db.set.Exec(rec.KeyID, rec.Principals, rec.CreatedAt, rec.Expires, rec.Raw)
return err
}
@@ -149,20 +119,9 @@ func (db *SQLStore) List(includeExpired bool) ([]*CertRecord, error) {
if err := db.conn.Ping(); err != nil {
return nil, err
}
- var recs []*CertRecord
- var rows *sql.Rows
- if includeExpired {
- rows, _ = db.listAll.Query()
- } else {
- rows, _ = db.listCurrent.Query(time.Now().UTC())
- }
- defer rows.Close()
- for rows.Next() {
- cert, err := scanCert(rows)
- if err != nil {
- return nil, err
- }
- recs = append(recs, cert)
+ recs := []*CertRecord{}
+ if err := db.listAll.Select(&recs); err != nil {
+ return nil, err
}
return recs, nil
}
@@ -172,8 +131,7 @@ func (db *SQLStore) Revoke(id string) error {
if err := db.conn.Ping(); err != nil {
return err
}
- _, err := db.revoke.Exec(id)
- if err != nil {
+ if _, err := db.revoke.Exec(id); err != nil {
return err
}
return nil
@@ -185,14 +143,8 @@ func (db *SQLStore) GetRevoked() ([]*CertRecord, error) {
return nil, err
}
var recs []*CertRecord
- rows, _ := db.revoked.Query(time.Now().UTC())
- defer rows.Close()
- for rows.Next() {
- cert, err := scanCert(rows)
- if err != nil {
- return nil, err
- }
- recs = append(recs, cert)
+ if err := db.revoked.Select(&recs, time.Now().UTC()); err != nil {
+ return nil, err
}
return recs, nil
}
diff --git a/server/store/store.go b/server/store/store.go
index 8af77e3..249489a 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -7,6 +7,7 @@ import (
"github.com/nsheridan/cashier/lib"
"github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/store/types"
)
// New returns a new configured database.
@@ -36,12 +37,12 @@ type CertStorer interface {
// A CertRecord is a representation of a ssh certificate used by a CertStorer.
type CertRecord struct {
- KeyID string `json:"key_id"`
- Principals []string `json:"principals"`
- CreatedAt time.Time `json:"created_at"`
- Expires time.Time `json:"expires"`
- Revoked bool `json:"revoked"`
- Raw string `json:"-"`
+ KeyID string `json:"key_id" db:"key_id"`
+ Principals types.StringSlice `json:"principals" db:"principals"`
+ CreatedAt time.Time `json:"created_at" db:"created_at"`
+ Expires time.Time `json:"expires" db:"expires_at"`
+ Revoked bool `json:"revoked" db:"revoked"`
+ Raw string `json:"-" db:"raw_key"`
}
func parseTime(t uint64) time.Time {
@@ -51,7 +52,7 @@ func parseTime(t uint64) time.Time {
func parseCertificate(cert *ssh.Certificate) *CertRecord {
return &CertRecord{
KeyID: cert.KeyId,
- Principals: cert.ValidPrincipals,
+ Principals: types.StringSlice(cert.ValidPrincipals),
CreatedAt: parseTime(cert.ValidAfter),
Expires: parseTime(cert.ValidBefore),
Raw: lib.GetPublicKey(cert),
diff --git a/server/store/store_test.go b/server/store/store_test.go
index afe6c03..4196c37 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -11,6 +11,7 @@ import (
"testing"
"time"
+ "github.com/nsheridan/cashier/server/store/types"
"github.com/nsheridan/cashier/testdata"
"github.com/stretchr/testify/assert"
@@ -25,7 +26,7 @@ func TestParseCertificate(t *testing.T) {
pub, _ := ssh.NewPublicKey(r.Public())
c := &ssh.Certificate{
KeyId: "id",
- ValidPrincipals: []string{"principal"},
+ ValidPrincipals: types.StringSlice{"principal"},
ValidBefore: now,
CertType: ssh.UserCert,
Key: pub,
@@ -35,7 +36,7 @@ func TestParseCertificate(t *testing.T) {
rec := parseCertificate(c)
a.Equal(c.KeyId, rec.KeyID)
- a.Equal(c.ValidPrincipals, rec.Principals)
+ a.Equal(c.ValidPrincipals, []string(rec.Principals))
a.Equal(c.ValidBefore, uint64(rec.Expires.Unix()))
a.Equal(c.ValidAfter, uint64(rec.CreatedAt.Unix()))
}
diff --git a/server/store/types/string_slice.go b/server/store/types/string_slice.go
new file mode 100644
index 0000000..81b38c3
--- /dev/null
+++ b/server/store/types/string_slice.go
@@ -0,0 +1,37 @@
+package types
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+)
+
+// StringSlice is a []string which will be stored in a database as a JSON array.
+type StringSlice []string
+
+var _ driver.Valuer = (*StringSlice)(nil)
+
+// Value implements the driver.Valuer interface, marshalling the raw value to
+// a JSON array.
+func (s StringSlice) Value() (driver.Value, error) {
+ v, err := json.Marshal(s)
+ if err != nil {
+ return nil, err
+ }
+ return string(v), err
+}
+
+// Scan implements the sql.Scanner interface, unmarshalling the value coming
+// off the wire and storing the result in the StringSlice.
+func (s *StringSlice) Scan(value interface{}) error {
+ if value == nil {
+ s = &StringSlice{}
+ return nil
+ }
+ var err error
+ if v, err := driver.String.ConvertValue(value); err == nil {
+ if v, ok := v.([]byte); ok {
+ err = json.Unmarshal(v, s)
+ }
+ }
+ return err
+}