From a427038700c0f1c080090a8158c1a793923aa03c Mon Sep 17 00:00:00 2001 From: Niall Sheridan Date: Sun, 16 Oct 2016 16:47:09 +0100 Subject: Unmarshal the config using mapstructure directly. Avoid unmarshalling into an intermediate struct. Better tests. --- server/config/config.go | 209 +++++++++++++++++++++--------------------------- 1 file changed, 92 insertions(+), 117 deletions(-) (limited to 'server/config/config.go') diff --git a/server/config/config.go b/server/config/config.go index fa580b0..2d821f9 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -10,31 +10,21 @@ import ( "strings" "github.com/hashicorp/go-multierror" + "github.com/mitchellh/mapstructure" "github.com/nsheridan/cashier/server/helpers/vault" "github.com/spf13/viper" ) // Config holds the final server configuration. type Config struct { - Server *Server - Auth *Auth - SSH *SSH - AWS *AWS - Vault *Vault + Server *Server `mapstructure:"server"` + Auth *Auth `mapstructure:"auth"` + SSH *SSH `mapstructure:"ssh"` + AWS *AWS `mapstructure:"aws"` + Vault *Vault `mapstructure:"vault"` } -// unmarshalled holds the raw config. -// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there -// we perform some validation checks, other overrides and then produce a final Config struct. -type unmarshalled struct { - Server []Server `mapstructure:"server"` - Auth []Auth `mapstructure:"auth"` - SSH []SSH `mapstructure:"ssh"` - AWS []AWS `mapstructure:"aws"` - Vault []Vault `mapstructure:"vault"` -} - -// Database config +// Database holds database configuration. type Database map[string]string // Server holds the configuration specific to the web server and sessions. @@ -49,7 +39,7 @@ type Server struct { CSRFSecret string `mapstructure:"csrf_secret"` HTTPLogFile string `mapstructure:"http_logfile"` Database Database `mapstructure:"database"` - Datastore string `mapstructure:"datastore"` // Deprecated. + Datastore string `mapstructure:"datastore"` // Deprecated. TODO: remove. } // Auth holds the configuration specific to the OAuth provider. @@ -84,168 +74,153 @@ type Vault struct { Token string `mapstructure:"token"` } -func verifyConfig(u *unmarshalled) error { +func verifyConfig(c *Config) error { var err error - if len(u.SSH) == 0 { + if c.SSH == nil { err = multierror.Append(err, errors.New("missing ssh config section")) } - if len(u.Auth) == 0 { + if c.Auth == nil { err = multierror.Append(err, errors.New("missing auth config section")) } - if len(u.Server) == 0 { + if c.Server == nil { err = multierror.Append(err, errors.New("missing server config section")) } - if len(u.AWS) == 0 { - // AWS config is optional - u.AWS = append(u.AWS, AWS{}) - } - if len(u.Vault) == 0 { - // Vault config is optional - u.Vault = append(u.Vault, Vault{}) - } - if u.Server[0].Datastore != "" { - log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") - log.Println("The new config (passwords have been redacted) should look something like:") - fmt.Printf("server {\n database {\n") - for k, v := range u.Server[0].Database { - if v == "" { - continue - } - if k == "password" { - fmt.Printf(" password = \"[ REDACTED ]\"\n") - continue - } - fmt.Printf(" %s = \"%s\"\n", k, v) - } - fmt.Printf(" }\n}\n") - } return err } -func convertDatastoreConfig(u *unmarshalled) { +func convertDatastoreConfig(c *Config) { // Convert the deprecated 'datastore' config to the new 'database' config. - if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" { - c := u.Server[0].Datastore - engine := strings.Split(c, ":")[0] + if c.Server != nil && c.Server.Datastore != "" { + conf := c.Server.Datastore + engine := strings.Split(conf, ":")[0] switch engine { case "mysql", "mongo": - s := strings.SplitN(c, ":", 4) + s := strings.SplitN(conf, ":", 4) engine, user, passwd, addrs := s[0], s[1], s[2], s[3] - u.Server[0].Database = map[string]string{ + c.Server.Database = map[string]string{ "type": engine, "username": user, "password": passwd, "address": addrs, } case "sqlite": - s := strings.Split(c, ":") - u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]} + s := strings.Split(conf, ":") + c.Server.Database = map[string]string{"type": s[0], "filename": s[1]} case "mem": - u.Server[0].Database = map[string]string{"type": "mem"} + c.Server.Database = map[string]string{"type": "mem"} } + log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") + log.Println("The new config (passwords have been redacted) should look something like:") + fmt.Printf("server {\n database {\n") + for k, v := range c.Server.Database { + if v == "" { + continue + } + if k == "password" { + fmt.Printf(" password = \"[ REDACTED ]\"\n") + continue + } + fmt.Printf(" %s = \"%s\"\n", k, v) + } + fmt.Printf(" }\n}\n") } } -func setFromEnv(u *unmarshalled) { + +func setFromEnvironment(c *Config) { port, err := strconv.Atoi(os.Getenv("PORT")) if err == nil { - u.Server[0].Port = port + c.Server.Port = port } if os.Getenv("DATASTORE") != "" { - u.Server[0].Datastore = os.Getenv("DATASTORE") + c.Server.Datastore = os.Getenv("DATASTORE") } if os.Getenv("OAUTH_CLIENT_ID") != "" { - u.Auth[0].OauthClientID = os.Getenv("OAUTH_CLIENT_ID") + c.Auth.OauthClientID = os.Getenv("OAUTH_CLIENT_ID") } if os.Getenv("OAUTH_CLIENT_SECRET") != "" { - u.Auth[0].OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET") + c.Auth.OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET") } if os.Getenv("CSRF_SECRET") != "" { - u.Server[0].CSRFSecret = os.Getenv("CSRF_SECRET") + c.Server.CSRFSecret = os.Getenv("CSRF_SECRET") } if os.Getenv("COOKIE_SECRET") != "" { - u.Server[0].CookieSecret = os.Getenv("COOKIE_SECRET") + c.Server.CookieSecret = os.Getenv("COOKIE_SECRET") } } -func setFromVault(u *unmarshalled) error { - if len(u.Vault) == 0 || u.Vault[0].Token == "" || u.Vault[0].Address == "" { +func setFromVault(c *Config) error { + if c.Vault == nil || c.Vault.Token == "" || c.Vault.Address == "" { return nil } - v, err := vault.NewClient(u.Vault[0].Address, u.Vault[0].Token) + v, err := vault.NewClient(c.Vault.Address, c.Vault.Token) if err != nil { return err } - get := func(value string) (string, error) { - if len(value) > 0 && value[:7] == "/vault/" { - return v.Read(value) - } - return value, nil - } var errors error - if len(u.Auth) > 0 { - u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID) - if err != nil { - errors = multierror.Append(errors, err) - } - u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - } - if len(u.Server) > 0 { - u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret) - if err != nil { - errors = multierror.Append(errors, err) - } - if len(u.Server[0].Database) > 0 { - u.Server[0].Database["password"], err = get(u.Server[0].Database["password"]) + get := func(value string) string { + if strings.HasPrefix(value, "/vault/") { + s, err := v.Read(value) if err != nil { errors = multierror.Append(errors, err) } + return s } + return value } - if len(u.AWS) > 0 { - u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey) - if err != nil { - errors = multierror.Append(errors, err) + c.Auth.OauthClientID = get(c.Auth.OauthClientID) + c.Auth.OauthClientSecret = get(c.Auth.OauthClientSecret) + c.Server.CSRFSecret = get(c.Server.CSRFSecret) + c.Server.CookieSecret = get(c.Server.CookieSecret) + if len(c.Server.Database) != 0 { + c.Server.Database["password"] = get(c.Server.Database["password"]) + } + if c.AWS != nil { + c.AWS.AccessKey = get(c.AWS.AccessKey) + c.AWS.SecretKey = get(c.AWS.SecretKey) + } + return errors +} + +// Unmarshal the config into a *Config +func decode() (*Config, error) { + var errors error + config := &Config{} + configPieces := map[string]interface{}{ + "auth": &config.Auth, + "aws": &config.AWS, + "server": &config.Server, + "ssh": &config.SSH, + "vault": &config.Vault, + } + for key, val := range configPieces { + conf, ok := viper.Get(key).([]map[string]interface{}) + if !ok { + continue } - u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey) - if err != nil { + if err := mapstructure.WeakDecode(conf[0], val); err != nil { errors = multierror.Append(errors, err) } } - return errors + return config, errors } -// ReadConfig parses a JSON configuration file into a Config struct. +// ReadConfig parses a hcl configuration file into a Config struct. func ReadConfig(r io.Reader) (*Config, error) { - u := &unmarshalled{} - v := viper.New() - v.SetConfigType("hcl") - if err := v.ReadConfig(r); err != nil { + viper.SetConfigType("hcl") + if err := viper.ReadConfig(r); err != nil { return nil, err } - if err := v.Unmarshal(u); err != nil { + config, err := decode() + if err != nil { return nil, err } - setFromEnv(u) - if err := setFromVault(u); err != nil { + if err := setFromVault(config); err != nil { return nil, err } - convertDatastoreConfig(u) - if err := verifyConfig(u); err != nil { + setFromEnvironment(config) + convertDatastoreConfig(config) + if err := verifyConfig(config); err != nil { return nil, err } - c := &Config{ - Server: &u.Server[0], - Auth: &u.Auth[0], - SSH: &u.SSH[0], - AWS: &u.AWS[0], - Vault: &u.Vault[0], - } - return c, nil + return config, nil } -- cgit v1.2.3