diff options
Diffstat (limited to 'server/config')
-rw-r--r-- | server/config/config.go | 209 | ||||
-rw-r--r-- | server/config/config_test.go | 119 | ||||
-rw-r--r-- | server/config/testdata/config.go | 48 |
3 files changed, 215 insertions, 161 deletions
diff --git a/server/config/config.go b/server/config/config.go index fa580b0..2d821f9 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -10,31 +10,21 @@ import ( "strings" "github.com/hashicorp/go-multierror" + "github.com/mitchellh/mapstructure" "github.com/nsheridan/cashier/server/helpers/vault" "github.com/spf13/viper" ) // Config holds the final server configuration. type Config struct { - Server *Server - Auth *Auth - SSH *SSH - AWS *AWS - Vault *Vault + Server *Server `mapstructure:"server"` + Auth *Auth `mapstructure:"auth"` + SSH *SSH `mapstructure:"ssh"` + AWS *AWS `mapstructure:"aws"` + Vault *Vault `mapstructure:"vault"` } -// unmarshalled holds the raw config. -// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there -// we perform some validation checks, other overrides and then produce a final Config struct. -type unmarshalled struct { - Server []Server `mapstructure:"server"` - Auth []Auth `mapstructure:"auth"` - SSH []SSH `mapstructure:"ssh"` - AWS []AWS `mapstructure:"aws"` - Vault []Vault `mapstructure:"vault"` -} - -// Database config +// Database holds database configuration. type Database map[string]string // Server holds the configuration specific to the web server and sessions. @@ -49,7 +39,7 @@ type Server struct { CSRFSecret string `mapstructure:"csrf_secret"` HTTPLogFile string `mapstructure:"http_logfile"` Database Database `mapstructure:"database"` - Datastore string `mapstructure:"datastore"` // Deprecated. + Datastore string `mapstructure:"datastore"` // Deprecated. TODO: remove. } // Auth holds the configuration specific to the OAuth provider. @@ -84,168 +74,153 @@ type Vault struct { Token string `mapstructure:"token"` } -func verifyConfig(u *unmarshalled) error { +func verifyConfig(c *Config) error { var err error - if len(u.SSH) == 0 { + if c.SSH == nil { err = multierror.Append(err, errors.New("missing ssh config section")) } - if len(u.Auth) == 0 { + if c.Auth == nil { err = multierror.Append(err, errors.New("missing auth config section")) } - if len(u.Server) == 0 { + if c.Server == nil { err = multierror.Append(err, errors.New("missing server config section")) } - if len(u.AWS) == 0 { - // AWS config is optional - u.AWS = append(u.AWS, AWS{}) - } - if len(u.Vault) == 0 { - // Vault config is optional - u.Vault = append(u.Vault, Vault{}) - } - if u.Server[0].Datastore != "" { - log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") - log.Println("The new config (passwords have been redacted) should look something like:") - fmt.Printf("server {\n database {\n") - for k, v := range u.Server[0].Database { - if v == "" { - continue - } - if k == "password" { - fmt.Printf(" password = \"[ REDACTED ]\"\n") - continue - } - fmt.Printf(" %s = \"%s\"\n", k, v) - } - fmt.Printf(" }\n}\n") - } return err } -func convertDatastoreConfig(u *unmarshalled) { +func convertDatastoreConfig(c *Config) { // Convert the deprecated 'datastore' config to the new 'database' config. - if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" { - c := u.Server[0].Datastore - engine := strings.Split(c, ":")[0] + if c.Server != nil && c.Server.Datastore != "" { + conf := c.Server.Datastore + engine := strings.Split(conf, ":")[0] switch engine { case "mysql", "mongo": - s := strings.SplitN(c, ":", 4) + s := strings.SplitN(conf, ":", 4) engine, user, passwd, addrs := s[0], s[1], s[2], s[3] - u.Server[0].Database = map[string]string{ + c.Server.Database = map[string]string{ "type": engine, "username": user, "password": passwd, "address": addrs, } case "sqlite": - s := strings.Split(c, ":") - u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]} + s := strings.Split(conf, ":") + c.Server.Database = map[string]string{"type": s[0], "filename": s[1]} case "mem": - u.Server[0].Database = map[string]string{"type": "mem"} + c.Server.Database = map[string]string{"type": "mem"} } + log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") + log.Println("The new config (passwords have been redacted) should look something like:") + fmt.Printf("server {\n database {\n") + for k, v := range c.Server.Database { + if v == "" { + continue + } + if k == "password" { + fmt.Printf(" password = \"[ REDACTED ]\"\n") + continue + } + fmt.Printf(" %s = \"%s\"\n", k, v) + } + fmt.Printf(" }\n}\n") } } -func setFromEnv(u *unmarshalled) { + +func setFromEnvironment(c *Config) { port, err := strconv.Atoi(os.Getenv("PORT")) if err == nil { - u.Server[0].Port = port + c.Server.Port = port } if os.Getenv("DATASTORE") != "" { - u.Server[0].Datastore = os.Getenv("DATASTORE") + c.Server.Datastore = os.Getenv("DATASTORE") } if os.Getenv("OAUTH_CLIENT_ID") != "" { - u.Auth[0].OauthClientID = os.Getenv("OAUTH_CLIENT_ID") + c.Auth.OauthClientID = os.Getenv("OAUTH_CLIENT_ID") } if os.Getenv("OAUTH_CLIENT_SECRET") != "" { - u.Auth[0].OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET") + c.Auth.OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET") } if os.Getenv("CSRF_SECRET") != "" { - u.Server[0].CSRFSecret = os.Getenv("CSRF_SECRET") + c.Server.CSRFSecret = os.Getenv("CSRF_SECRET") } if os.Getenv("COOKIE_SECRET") != "" { - u.Server[0].CookieSecret = os.Getenv("COOKIE_SECRET") + c.Server.CookieSecret = os.Getenv("COOKIE_SECRET") } } -func setFromVault(u *unmarshalled) error { - if len(u.Vault) == 0 || u.Vault[0].Token == "" || u.Vault[0].Address == "" { +func setFromVault(c *Config) error { + if c.Vault == nil || c.Vault.Token == "" || c.Vault.Address == "" { return nil } - v, err := vault.NewClient(u.Vault[0].Address, u.Vault[0].Token) + v, err := vault.NewClient(c.Vault.Address, c.Vault.Token) if err != nil { return err } - get := func(value string) (string, error) { - if len(value) > 0 && value[:7] == "/vault/" { - return v.Read(value) - } - return value, nil - } var errors error - if len(u.Auth) > 0 { - u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID) - if err != nil { - errors = multierror.Append(errors, err) - } - u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - } - if len(u.Server) > 0 { - u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - if len(u.Server[0].Database) > 0 { - u.Server[0].Database["password"], err = get(u.Server[0].Database["password"]) + get := func(value string) string { + if strings.HasPrefix(value, "/vault/") { + s, err := v.Read(value) if err != nil { errors = multierror.Append(errors, err) } + return s } + return value } - if len(u.AWS) > 0 { - u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey) - if err != nil { - errors = multierror.Append(errors, err) + c.Auth.OauthClientID = get(c.Auth.OauthClientID) + c.Auth.OauthClientSecret = get(c.Auth.OauthClientSecret) + c.Server.CSRFSecret = get(c.Server.CSRFSecret) + c.Server.CookieSecret = get(c.Server.CookieSecret) + if len(c.Server.Database) != 0 { + c.Server.Database["password"] = get(c.Server.Database["password"]) + } + if c.AWS != nil { + c.AWS.AccessKey = get(c.AWS.AccessKey) + c.AWS.SecretKey = get(c.AWS.SecretKey) + } + return errors +} + +// Unmarshal the config into a *Config +func decode() (*Config, error) { + var errors error + config := &Config{} + configPieces := map[string]interface{}{ + "auth": &config.Auth, + "aws": &config.AWS, + "server": &config.Server, + "ssh": &config.SSH, + "vault": &config.Vault, + } + for key, val := range configPieces { + conf, ok := viper.Get(key).([]map[string]interface{}) + if !ok { + continue } - u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey) - if err != nil { + if err := mapstructure.WeakDecode(conf[0], val); err != nil { errors = multierror.Append(errors, err) } } - return errors + return config, errors } -// ReadConfig parses a JSON configuration file into a Config struct. +// ReadConfig parses a hcl configuration file into a Config struct. func ReadConfig(r io.Reader) (*Config, error) { - u := &unmarshalled{} - v := viper.New() - v.SetConfigType("hcl") - if err := v.ReadConfig(r); err != nil { + viper.SetConfigType("hcl") + if err := viper.ReadConfig(r); err != nil { return nil, err } - if err := v.Unmarshal(u); err != nil { + config, err := decode() + if err != nil { return nil, err } - setFromEnv(u) - if err := setFromVault(u); err != nil { + if err := setFromVault(config); err != nil { return nil, err } - convertDatastoreConfig(u) - if err := verifyConfig(u); err != nil { + setFromEnvironment(config) + convertDatastoreConfig(config) + if err := verifyConfig(config); err != nil { return nil, err } - c := &Config{ - Server: &u.Server[0], - Auth: &u.Auth[0], - SSH: &u.SSH[0], - AWS: &u.AWS[0], - Vault: &u.Vault[0], - } - return c, nil + return config, nil } diff --git a/server/config/config_test.go b/server/config/config_test.go index c6bdc3d..399e143 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -3,61 +3,92 @@ package config import ( "bytes" "testing" - "time" - "github.com/nsheridan/cashier/testdata" + "github.com/nsheridan/cashier/server/config/testdata" "github.com/stretchr/testify/assert" ) -func TestServerConfig(t *testing.T) { - t.Parallel() - a := assert.New(t) - c, err := ReadConfig(bytes.NewBuffer(testdata.ServerConfig)) - if err != nil { - t.Error(err) +var ( + parsedConfig = &Config{ + Server: &Server{ + UseTLS: true, + TLSKey: "server.key", + TLSCert: "server.crt", + Addr: "127.0.0.1", + Port: 443, + User: "nobody", + CookieSecret: "supersecret", + CSRFSecret: "supersecret", + HTTPLogFile: "cashierd.log", + Database: Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"}, + Datastore: "mysql:user:passwd:localhost:3306", + }, + Auth: &Auth{ + OauthClientID: "client_id", + OauthClientSecret: "secret", + OauthCallbackURL: "https://sshca.example.com/auth/callback", + Provider: "google", + ProviderOpts: map[string]string{"domain": "example.com"}, + UsersWhitelist: []string{"a_user"}, + }, + SSH: &SSH{ + SigningKey: "signing_key", + AdditionalPrincipals: []string{"ec2-user", "ubuntu"}, + MaxAge: "720h", + Permissions: []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"}, + }, + AWS: &AWS{ + Region: "us-east-1", + AccessKey: "abcdef", + SecretKey: "omg123", + }, + Vault: &Vault{ + Address: "https://vault:8200", + Token: "abc-def-456-789", + }, } - server := c.Server - a.IsType(server, &Server{}) - a.True(server.UseTLS) - a.Equal(server.TLSKey, "server.key") - a.Equal(server.TLSCert, "server.crt") - a.Equal(server.Port, 443) - a.Equal(server.Addr, "127.0.0.1") - a.Equal(server.CookieSecret, "supersecret") -} +) -func TestAuthConfig(t *testing.T) { - t.Parallel() - a := assert.New(t) - c, err := ReadConfig(bytes.NewBuffer(testdata.AuthConfig)) +func TestConfigParser(t *testing.T) { + c, err := ReadConfig(bytes.NewBuffer(testdata.Config)) if err != nil { t.Error(err) } - auth := c.Auth - a.IsType(auth, &Auth{}) - a.Equal(auth.Provider, "google") - a.Equal(auth.ProviderOpts, map[string]string{"domain": "example.com"}) - a.Equal(auth.OauthClientID, "client_id") - a.Equal(auth.OauthClientSecret, "secret") - a.Equal(auth.OauthCallbackURL, "https://sshca.example.com/auth/callback") + assert.Equal(t, parsedConfig, c) } -func TestSSHConfig(t *testing.T) { - t.Parallel() - a := assert.New(t) - c, err := ReadConfig(bytes.NewBuffer(testdata.SSHConfig)) - if err != nil { - t.Error(err) +func TestConfigVerify(t *testing.T) { + bad := bytes.NewBuffer([]byte("")) + _, err := ReadConfig(bad) + assert.Contains(t, err.Error(), "missing ssh config section", "missing server config section", "missing auth config section") +} + +func TestDatastoreConversion(t *testing.T) { + tests := []struct { + in string + out Database + }{ + { + "mysql:user:passwd:localhost:3306", Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"}, + }, + { + "mongo:::host1,host2", Database{"type": "mongo", "username": "", "password": "", "address": "host1,host2"}, + }, + { + "mem", Database{"type": "mem"}, + }, + { + "sqlite:/data/certs.db", Database{"type": "sqlite", "filename": "/data/certs.db"}, + }, } - ssh := c.SSH - a.IsType(ssh, &SSH{}) - a.Equal(ssh.SigningKey, "signing_key") - a.Equal(ssh.AdditionalPrincipals, []string{"ec2-user", "ubuntu"}) - a.Equal(ssh.Permissions, []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"}) - a.Equal(ssh.MaxAge, "720h") - d, err := time.ParseDuration(ssh.MaxAge) - if err != nil { - t.Error(err) + + for _, tc := range tests { + config := &Config{ + Server: &Server{ + Datastore: tc.in, + }, + } + convertDatastoreConfig(config) + assert.EqualValues(t, config.Server.Database, tc.out) } - a.Equal(d.Hours(), float64(720)) } diff --git a/server/config/testdata/config.go b/server/config/testdata/config.go new file mode 100644 index 0000000..27cffcc --- /dev/null +++ b/server/config/testdata/config.go @@ -0,0 +1,48 @@ +package testdata + +var Config = []byte(` + server { + use_tls = true + tls_key = "server.key" + tls_cert = "server.crt" + address = "127.0.0.1" + port = 443 + user = "nobody" + cookie_secret = "supersecret" + csrf_secret = "supersecret" + http_logfile = "cashierd.log" + datastore = "mysql:user:passwd:localhost:3306" + database { + type = "mysql" + username = "user" + password = "passwd" + address = "localhost:3306" + } + datastore = "mysql:user:passwd:localhost:3306" + } + auth { + provider = "google" + oauth_client_id = "client_id" + oauth_client_secret = "secret" + oauth_callback_url = "https://sshca.example.com/auth/callback" + provider_opts { + domain = "example.com" + } + users_whitelist = ["a_user"] + } + ssh { + signing_key = "signing_key" + additional_principals = ["ec2-user", "ubuntu"] + max_age = "720h" + permissions = ["permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"] + } + aws { + region = "us-east-1" + access_key = "abcdef" + secret_key = "omg123" + } + vault { + address = "https://vault:8200" + token = "abc-def-456-789" + } +`) |