diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/config/config.go | 1 | ||||
-rw-r--r-- | server/store/mem.go | 2 | ||||
-rw-r--r-- | server/store/mysql.go | 21 | ||||
-rw-r--r-- | server/store/store.go | 14 | ||||
-rw-r--r-- | server/store/store_test.go | 6 |
5 files changed, 26 insertions, 18 deletions
diff --git a/server/config/config.go b/server/config/config.go index 674ceee..107ebcc 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -31,6 +31,7 @@ type Server struct { TLSCert string `mapstructure:"tls_cert"` Port int `mapstructure:"port"` CookieSecret string `mapstructure:"cookie_secret"` + CSRFSecret string `mapstructure:"csrf_secret"` HTTPLogFile string `mapstructure:"http_logfile"` Datastore string `mapstructure:"datastore"` } diff --git a/server/store/mem.go b/server/store/mem.go index 8b78e27..cd37071 100644 --- a/server/store/mem.go +++ b/server/store/mem.go @@ -58,7 +58,7 @@ func (ms *memoryStore) GetRevoked() ([]*CertRecord, error) { var revoked []*CertRecord all, _ := ms.List() for _, r := range all { - if r.Revoked && uint64(time.Now().UTC().Unix()) <= r.Expires { + if r.Revoked && time.Now().UTC().Unix() <= r.Expires.UTC().Unix() { revoked = append(revoked, r) } } diff --git a/server/store/mysql.go b/server/store/mysql.go index b108fdc..a62af6b 100644 --- a/server/store/mysql.go +++ b/server/store/mysql.go @@ -29,11 +29,12 @@ func parseConfig(config string) string { } _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] c := &mysql.Config{ - User: user, - Passwd: passwd, - Net: "tcp", - Addr: fmt.Sprintf("%s:%s", host, port), - DBName: "certs", + User: user, + Passwd: passwd, + Net: "tcp", + Addr: fmt.Sprintf("%s:%s", host, port), + DBName: "certs", + ParseTime: true, } return c.FormatDSN() } @@ -80,8 +81,8 @@ func scanCert(s rowScanner) (*CertRecord, error) { var ( keyID sql.NullString principals sql.NullString - createdAt sql.NullInt64 - expires sql.NullInt64 + createdAt mysql.NullTime + expires mysql.NullTime revoked sql.NullBool raw sql.NullString ) @@ -95,8 +96,8 @@ func scanCert(s rowScanner) (*CertRecord, error) { return &CertRecord{ KeyID: keyID.String, Principals: p, - CreatedAt: uint64(createdAt.Int64), - Expires: uint64(expires.Int64), + CreatedAt: createdAt.Time, + Expires: expires.Time, Revoked: revoked.Bool, Raw: raw.String, }, nil @@ -143,7 +144,7 @@ func (db *mysqlDB) Revoke(id string) error { func (db *mysqlDB) GetRevoked() ([]*CertRecord, error) { var recs []*CertRecord - rows, _ := db.revoked.Query(time.Now().UTC().Unix()) + rows, _ := db.revoked.Query(time.Now().UTC()) defer rows.Close() for rows.Next() { cert, err := scanCert(rows) diff --git a/server/store/store.go b/server/store/store.go index ad4922a..f6ac66e 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -1,6 +1,8 @@ package store import ( + "time" + "golang.org/x/crypto/ssh" "github.com/nsheridan/cashier/server/certutil" @@ -22,18 +24,22 @@ type CertStorer interface { type CertRecord struct { KeyID string Principals []string - CreatedAt uint64 - Expires uint64 + CreatedAt time.Time + Expires time.Time Revoked bool Raw string } +func parseTime(t uint64) time.Time { + return time.Unix(int64(t), 0) +} + func parseCertificate(cert *ssh.Certificate) *CertRecord { return &CertRecord{ KeyID: cert.KeyId, Principals: cert.ValidPrincipals, - CreatedAt: cert.ValidAfter, - Expires: cert.ValidBefore, + CreatedAt: parseTime(cert.ValidAfter), + Expires: parseTime(cert.ValidBefore), Raw: certutil.GetPublicKey(cert), } } diff --git a/server/store/store_test.go b/server/store/store_test.go index d3aa3c1..ee80241 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -31,8 +31,8 @@ func TestParseCertificate(t *testing.T) { a.Equal(c.KeyId, rec.KeyID) a.Equal(c.ValidPrincipals, rec.Principals) - a.Equal(c.ValidBefore, rec.Expires) - a.Equal(c.ValidAfter, rec.CreatedAt) + a.Equal(c.ValidBefore, uint64(rec.Expires.Unix())) + a.Equal(c.ValidAfter, uint64(rec.CreatedAt.Unix())) } func testStore(t *testing.T, db CertStorer) { @@ -42,7 +42,7 @@ func testStore(t *testing.T, db CertStorer) { for _, id := range ids { r := &CertRecord{ KeyID: id, - Expires: uint64(time.Now().UTC().Unix()) - 10, + Expires: time.Now().UTC().Add(time.Second * -10), } if err := db.SetRecord(r); err != nil { t.Error(err) |