diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/config/config.go | 115 | ||||
-rw-r--r-- | server/store/config_test.go | 70 | ||||
-rw-r--r-- | server/store/mongo.go | 26 | ||||
-rw-r--r-- | server/store/sqldb.go | 52 | ||||
-rw-r--r-- | server/store/store.go | 14 | ||||
-rw-r--r-- | server/store/store_test.go | 24 |
6 files changed, 152 insertions, 149 deletions
diff --git a/server/config/config.go b/server/config/config.go index 9678f6d..fa580b0 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -2,25 +2,30 @@ package config import ( "errors" + "fmt" "io" + "log" "os" "strconv" + "strings" "github.com/hashicorp/go-multierror" "github.com/nsheridan/cashier/server/helpers/vault" "github.com/spf13/viper" ) -// Config holds the server configuration. +// Config holds the final server configuration. type Config struct { - Server *Server `mapstructure:"server"` - Auth *Auth `mapstructure:"auth"` - SSH *SSH `mapstructure:"ssh"` - AWS *AWS `mapstructure:"aws"` - Vault *Vault `mapstructure:"vault"` + Server *Server + Auth *Auth + SSH *SSH + AWS *AWS + Vault *Vault } // unmarshalled holds the raw config. +// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there +// we perform some validation checks, other overrides and then produce a final Config struct. type unmarshalled struct { Server []Server `mapstructure:"server"` Auth []Auth `mapstructure:"auth"` @@ -29,18 +34,22 @@ type unmarshalled struct { Vault []Vault `mapstructure:"vault"` } +// Database config +type Database map[string]string + // Server holds the configuration specific to the web server and sessions. type Server struct { - UseTLS bool `mapstructure:"use_tls"` - TLSKey string `mapstructure:"tls_key"` - TLSCert string `mapstructure:"tls_cert"` - Addr string `mapstructure:"address"` - Port int `mapstructure:"port"` - User string `mapstructure:"user"` - CookieSecret string `mapstructure:"cookie_secret"` - CSRFSecret string `mapstructure:"csrf_secret"` - HTTPLogFile string `mapstructure:"http_logfile"` - Datastore string `mapstructure:"datastore"` + UseTLS bool `mapstructure:"use_tls"` + TLSKey string `mapstructure:"tls_key"` + TLSCert string `mapstructure:"tls_cert"` + Addr string `mapstructure:"address"` + Port int `mapstructure:"port"` + User string `mapstructure:"user"` + CookieSecret string `mapstructure:"cookie_secret"` + CSRFSecret string `mapstructure:"csrf_secret"` + HTTPLogFile string `mapstructure:"http_logfile"` + Database Database `mapstructure:"database"` + Datastore string `mapstructure:"datastore"` // Deprecated. } // Auth holds the configuration specific to the OAuth provider. @@ -78,13 +87,13 @@ type Vault struct { func verifyConfig(u *unmarshalled) error { var err error if len(u.SSH) == 0 { - err = multierror.Append(errors.New("missing ssh config section")) + err = multierror.Append(err, errors.New("missing ssh config section")) } if len(u.Auth) == 0 { - err = multierror.Append(errors.New("missing auth config section")) + err = multierror.Append(err, errors.New("missing auth config section")) } if len(u.Server) == 0 { - err = multierror.Append(errors.New("missing server config section")) + err = multierror.Append(err, errors.New("missing server config section")) } if len(u.AWS) == 0 { // AWS config is optional @@ -94,9 +103,48 @@ func verifyConfig(u *unmarshalled) error { // Vault config is optional u.Vault = append(u.Vault, Vault{}) } + if u.Server[0].Datastore != "" { + log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") + log.Println("The new config (passwords have been redacted) should look something like:") + fmt.Printf("server {\n database {\n") + for k, v := range u.Server[0].Database { + if v == "" { + continue + } + if k == "password" { + fmt.Printf(" password = \"[ REDACTED ]\"\n") + continue + } + fmt.Printf(" %s = \"%s\"\n", k, v) + } + fmt.Printf(" }\n}\n") + } return err } +func convertDatastoreConfig(u *unmarshalled) { + // Convert the deprecated 'datastore' config to the new 'database' config. + if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" { + c := u.Server[0].Datastore + engine := strings.Split(c, ":")[0] + switch engine { + case "mysql", "mongo": + s := strings.SplitN(c, ":", 4) + engine, user, passwd, addrs := s[0], s[1], s[2], s[3] + u.Server[0].Database = map[string]string{ + "type": engine, + "username": user, + "password": passwd, + "address": addrs, + } + case "sqlite": + s := strings.Split(c, ":") + u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]} + case "mem": + u.Server[0].Database = map[string]string{"type": "mem"} + } + } +} func setFromEnv(u *unmarshalled) { port, err := strconv.Atoi(os.Getenv("PORT")) if err == nil { @@ -128,42 +176,49 @@ func setFromVault(u *unmarshalled) error { return err } get := func(value string) (string, error) { - if value[:7] == "/vault/" { + if len(value) > 0 && value[:7] == "/vault/" { return v.Read(value) } return value, nil } + var errors error if len(u.Auth) > 0 { u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) } u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) } } if len(u.Server) > 0 { u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) } u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) + } + if len(u.Server[0].Database) > 0 { + u.Server[0].Database["password"], err = get(u.Server[0].Database["password"]) + if err != nil { + errors = multierror.Append(errors, err) + } } } if len(u.AWS) > 0 { u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) } u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey) if err != nil { - err = multierror.Append(err) + errors = multierror.Append(errors, err) } } - return err + return errors } // ReadConfig parses a JSON configuration file into a Config struct. @@ -181,14 +236,16 @@ func ReadConfig(r io.Reader) (*Config, error) { if err := setFromVault(u); err != nil { return nil, err } + convertDatastoreConfig(u) if err := verifyConfig(u); err != nil { return nil, err } - return &Config{ + c := &Config{ Server: &u.Server[0], Auth: &u.Auth[0], SSH: &u.SSH[0], AWS: &u.AWS[0], Vault: &u.Vault[0], - }, nil + } + return c, nil } diff --git a/server/store/config_test.go b/server/store/config_test.go deleted file mode 100644 index 9a77027..0000000 --- a/server/store/config_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package store - -import ( - "reflect" - "testing" - "time" - - mgo "gopkg.in/mgo.v2" -) - -func TestMySQLConfig(t *testing.T) { - t.Parallel() - var tests = []struct { - in string - out []string - }{ - {"mysql:user:passwd:localhost", []string{"mysql", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}}, - {"mysql:user:passwd:localhost:13306", []string{"mysql", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}}, - {"mysql:root::localhost", []string{"mysql", "root@tcp(localhost:3306)/certs?parseTime=true"}}, - } - for _, tt := range tests { - result := parse(tt.in) - if !reflect.DeepEqual(result, tt.out) { - t.Errorf("want %s, got %s", tt.out, result) - } - } -} - -func TestMongoConfig(t *testing.T) { - t.Parallel() - var tests = []struct { - in string - out *mgo.DialInfo - }{ - {"mongo:user:passwd:host", &mgo.DialInfo{ - Username: "user", - Password: "passwd", - Addrs: []string{"host"}, - Database: "certs", - Timeout: 5 * time.Second, - }}, - {"mongo:user:passwd:host1,host2", &mgo.DialInfo{ - Username: "user", - Password: "passwd", - Addrs: []string{"host1", "host2"}, - Database: "certs", - Timeout: 5 * time.Second, - }}, - {"mongo:user:passwd:host1:27017,host2:27017", &mgo.DialInfo{ - Username: "user", - Password: "passwd", - Addrs: []string{"host1:27017", "host2:27017"}, - Database: "certs", - Timeout: 5 * time.Second, - }}, - {"mongo:user:passwd:host1,host2:27017", &mgo.DialInfo{ - Username: "user", - Password: "passwd", - Addrs: []string{"host1", "host2:27017"}, - Database: "certs", - Timeout: 5 * time.Second, - }}, - } - for _, tt := range tests { - result := parseMongoConfig(tt.in) - if !reflect.DeepEqual(result, tt.out) { - t.Errorf("want:\n%+v\ngot:\n%+v", tt.out, result) - } - } -} diff --git a/server/store/mongo.go b/server/store/mongo.go index 1b13d7a..fc4131f 100644 --- a/server/store/mongo.go +++ b/server/store/mongo.go @@ -4,6 +4,8 @@ import ( "strings" "time" + "github.com/nsheridan/cashier/server/config" + "golang.org/x/crypto/ssh" mgo "gopkg.in/mgo.v2" @@ -15,26 +17,20 @@ var ( issuedTable = "issued_certs" ) -func parseMongoConfig(config string) *mgo.DialInfo { - s := strings.SplitN(config, ":", 4) - _, user, passwd, hosts := s[0], s[1], s[2], s[3] - d := &mgo.DialInfo{ - Addrs: strings.Split(hosts, ","), - Username: user, - Password: passwd, - Database: certsDB, - Timeout: time.Second * 5, - } - return d -} - func collection(session *mgo.Session) *mgo.Collection { return session.DB(certsDB).C(issuedTable) } // NewMongoStore returns a MongoDB CertStorer. -func NewMongoStore(config string) (CertStorer, error) { - session, err := mgo.DialWithInfo(parseMongoConfig(config)) +func NewMongoStore(c config.Database) (CertStorer, error) { + m := &mgo.DialInfo{ + Addrs: strings.Split(c["address"], ","), + Username: c["username"], + Password: c["password"], + Database: certsDB, + Timeout: time.Second * 5, + } + session, err := mgo.DialWithInfo(m) if err != nil { return nil, err } diff --git a/server/store/sqldb.go b/server/store/sqldb.go index f65f601..6c1be0e 100644 --- a/server/store/sqldb.go +++ b/server/store/sqldb.go @@ -4,13 +4,14 @@ import ( "database/sql" "encoding/json" "fmt" - "strings" + "net" "time" "golang.org/x/crypto/ssh" "github.com/go-sql-driver/mysql" _ "github.com/mattn/go-sqlite3" // required by sql driver + "github.com/nsheridan/cashier/server/config" ) type sqldb struct { @@ -24,31 +25,32 @@ type sqldb struct { revoked *sql.Stmt } -func parse(config string) []string { - s := strings.Split(config, ":") - if s[0] == "sqlite" { - s[0] = "sqlite3" - return s - } - if len(s) == 4 { - s = append(s, "3306") - } - _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4] - c := &mysql.Config{ - User: user, - Passwd: passwd, - Net: "tcp", - Addr: fmt.Sprintf("%s:%s", host, port), - DBName: "certs", - ParseTime: true, - } - return []string{"mysql", c.FormatDSN()} -} - // NewSQLStore returns a *sql.DB CertStorer. -func NewSQLStore(config string) (CertStorer, error) { - parsed := parse(config) - conn, err := sql.Open(parsed[0], parsed[1]) +func NewSQLStore(c config.Database) (CertStorer, error) { + var driver string + var dsn string + switch c["type"] { + case "mysql": + driver = "mysql" + address := c["address"] + _, _, err := net.SplitHostPort(address) + if err != nil { + address = address + ":3306" + } + m := &mysql.Config{ + User: c["username"], + Passwd: c["password"], + Net: "tcp", + Addr: address, + DBName: "certs", + ParseTime: true, + } + dsn = m.FormatDSN() + case "sqlite": + driver = "sqlite3" + dsn = c["filename"] + } + conn, err := sql.Open(driver, dsn) if err != nil { return nil, fmt.Errorf("sqldb: could not get a connection: %v", err) } diff --git a/server/store/store.go b/server/store/store.go index c039d3c..a447e72 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -5,9 +5,23 @@ import ( "golang.org/x/crypto/ssh" + "github.com/nsheridan/cashier/server/config" "github.com/nsheridan/cashier/server/util" ) +// New returns a new configured database. +func New(c config.Database) (CertStorer, error) { + switch c["type"] { + case "mongo": + return NewMongoStore(c) + case "mysql", "sqlite": + return NewSQLStore(c) + case "mem": + return NewMemoryStore(), nil + } + return NewMemoryStore(), nil +} + // CertStorer records issued certs in a persistent store for audit and // revocation purposes. type CertStorer interface { diff --git a/server/store/store_test.go b/server/store/store_test.go index 594da37..dbe2d95 100644 --- a/server/store/store_test.go +++ b/server/store/store_test.go @@ -3,7 +3,6 @@ package store import ( "crypto/rand" "crypto/rsa" - "fmt" "io/ioutil" "os" "os/exec" @@ -16,6 +15,10 @@ import ( "golang.org/x/crypto/ssh" ) +var ( + dbConfig = map[string]string{"username": "user", "password": "passwd", "address": "localhost"} +) + func TestParseCertificate(t *testing.T) { t.Parallel() a := assert.New(t) @@ -87,11 +90,11 @@ func TestMemoryStore(t *testing.T) { func TestMySQLStore(t *testing.T) { t.Parallel() - config := os.Getenv("MYSQL_TEST_CONFIG") - if config == "" { - t.Skip("No MYSQL_TEST_CONFIG environment variable") + if os.Getenv("MYSQL_TEST") == "" { + t.Skip("No MYSQL_TEST environment variable") } - db, err := NewSQLStore(config) + dbConfig["type"] = "mysql" + db, err := NewSQLStore(dbConfig) if err != nil { t.Error(err) } @@ -100,11 +103,11 @@ func TestMySQLStore(t *testing.T) { func TestMongoStore(t *testing.T) { t.Parallel() - config := os.Getenv("MONGO_TEST_CONFIG") - if config == "" { - t.Skip("No MONGO_TEST_CONFIG environment variable") + if os.Getenv("MONGO_TEST") == "" { + t.Skip("No MONGO_TEST environment variable") } - db, err := NewMongoStore(config) + dbConfig["type"] = "mongo" + db, err := NewMongoStore(dbConfig) if err != nil { t.Error(err) } @@ -123,7 +126,8 @@ func TestSQLiteStore(t *testing.T) { if err := exec.Command("go", args...).Run(); err != nil { t.Error(err) } - db, err := NewSQLStore(fmt.Sprintf("sqlite:%s", f.Name())) + config := map[string]string{"type": "sqlite", "filename": f.Name()} + db, err := NewSQLStore(config) if err != nil { t.Error(err) } |