diff options
Diffstat (limited to 'server/store/store_test.go')
-rw-r--r-- | server/store/store_test.go | 162 |
1 files changed, 0 insertions, 162 deletions
diff --git a/server/store/store_test.go b/server/store/store_test.go deleted file mode 100644 index 90a494e..0000000 --- a/server/store/store_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package store - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/json" - "io/ioutil" - "os" - "os/user" - "testing" - "time" - - "github.com/nsheridan/cashier/testdata" - "github.com/stretchr/testify/assert" - - "golang.org/x/crypto/ssh" -) - -func TestParseCertificate(t *testing.T) { - a := assert.New(t) - now := uint64(time.Now().Unix()) - r, _ := rsa.GenerateKey(rand.Reader, 1024) - pub, _ := ssh.NewPublicKey(r.Public()) - c := &ssh.Certificate{ - KeyId: "id", - ValidPrincipals: StringSlice{"principal"}, - ValidBefore: now, - CertType: ssh.UserCert, - Key: pub, - } - s, _ := ssh.NewSignerFromKey(r) - c.SignCert(rand.Reader, s) - rec := MakeRecord(c) - - a.Equal(c.KeyId, rec.KeyID) - a.Equal(c.ValidPrincipals, []string(rec.Principals)) - a.Equal(c.ValidBefore, uint64(rec.Expires.Unix())) - a.Equal(c.ValidAfter, uint64(rec.CreatedAt.Unix())) -} - -func testStore(t *testing.T, db CertStorer) { - defer db.Close() - - r := &CertRecord{ - KeyID: "a", - Principals: []string{"b"}, - CreatedAt: time.Now().UTC(), - Expires: time.Now().UTC().Add(-1 * time.Second), - Raw: "AAAAAA", - } - if err := db.SetRecord(r); err != nil { - t.Error(err) - } - - // includeExpired = false should return 0 results - recs, err := db.List(false) - if err != nil { - t.Error(err) - } - if len(recs) > 0 { - t.Errorf("Expected 0 results, got %d", len(recs)) - } - // includeExpired = false should return 1 result - recs, err = db.List(true) - if err != nil { - t.Error(err) - } - if recs[0].KeyID != r.KeyID { - t.Error("key mismatch") - } - - c, _, _, _, _ := ssh.ParseAuthorizedKey(testdata.Cert) - cert := c.(*ssh.Certificate) - cert.ValidBefore = uint64(time.Now().Add(1 * time.Hour).UTC().Unix()) - cert.ValidAfter = uint64(time.Now().Add(-5 * time.Minute).UTC().Unix()) - rec := MakeRecord(cert) - if err := db.SetRecord(rec); err != nil { - t.Error(err) - } - - ret, err := db.Get("key") - if err != nil { - t.Error(err) - } - if ret.KeyID != cert.KeyId { - t.Error("key mismatch") - } - if err := db.Revoke([]string{"key"}); err != nil { - t.Error(err) - } - - revoked, err := db.GetRevoked() - if err != nil { - t.Error(err) - } - if len(revoked) != 1 { - t.Errorf("Expected 1 revoked key, got %d", len(revoked)) - } - for _, k := range revoked { - if k.KeyID != "key" { - t.Errorf("Unexpected key: %s", k.KeyID) - } - } -} - -func TestMemoryStore(t *testing.T) { - db := newMemoryStore() - testStore(t, db) -} - -func TestMySQLStore(t *testing.T) { - if os.Getenv("MYSQL_TEST") == "" { - t.Skip("No MYSQL_TEST environment variable") - } - u, _ := user.Current() - sqlConfig := map[string]string{ - "type": "mysql", - "password": os.Getenv("MYSQL_TEST_PASS"), - "address": os.Getenv("MYSQL_TEST_HOST"), - } - if testUser, ok := os.LookupEnv("MYSQL_TEST_USER"); ok { - sqlConfig["username"] = testUser - } else { - sqlConfig["username"] = u.Username - } - db, err := newSQLStore(sqlConfig) - if err != nil { - t.Error(err) - } - testStore(t, db) -} - -func TestSQLiteStore(t *testing.T) { - f, err := ioutil.TempFile("", "sqlite_test_db") - if err != nil { - t.Error(err) - } - defer os.Remove(f.Name()) - config := map[string]string{"type": "sqlite", "filename": f.Name()} - db, err := newSQLStore(config) - if err != nil { - t.Error(err) - } - testStore(t, db) -} - -func TestMarshalCert(t *testing.T) { - a := assert.New(t) - c := &CertRecord{ - KeyID: "id", - Principals: []string{"user"}, - CreatedAt: time.Date(2017, time.April, 10, 13, 0, 0, 0, time.UTC), - Expires: time.Date(2017, time.April, 11, 10, 0, 0, 0, time.UTC), - Raw: "ABCDEF", - } - b, err := json.Marshal(c) - if err != nil { - t.Error(err) - } - want := `{"key_id":"id","principals":["user"],"revoked":false,"created_at":"2017-04-10 13:00:00 +0000","expires":"2017-04-11 10:00:00 +0000","message":""}` - a.JSONEq(want, string(b)) -} |