aboutsummaryrefslogtreecommitdiff
path: root/server/config
diff options
context:
space:
mode:
Diffstat (limited to 'server/config')
-rw-r--r--server/config/config.go209
-rw-r--r--server/config/config_test.go119
-rw-r--r--server/config/testdata/config.go48
3 files changed, 215 insertions, 161 deletions
diff --git a/server/config/config.go b/server/config/config.go
index fa580b0..2d821f9 100644
--- a/server/config/config.go
+++ b/server/config/config.go
@@ -10,31 +10,21 @@ import (
"strings"
"github.com/hashicorp/go-multierror"
+ "github.com/mitchellh/mapstructure"
"github.com/nsheridan/cashier/server/helpers/vault"
"github.com/spf13/viper"
)
// Config holds the final server configuration.
type Config struct {
- Server *Server
- Auth *Auth
- SSH *SSH
- AWS *AWS
- Vault *Vault
+ Server *Server `mapstructure:"server"`
+ Auth *Auth `mapstructure:"auth"`
+ SSH *SSH `mapstructure:"ssh"`
+ AWS *AWS `mapstructure:"aws"`
+ Vault *Vault `mapstructure:"vault"`
}
-// unmarshalled holds the raw config.
-// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there
-// we perform some validation checks, other overrides and then produce a final Config struct.
-type unmarshalled struct {
- Server []Server `mapstructure:"server"`
- Auth []Auth `mapstructure:"auth"`
- SSH []SSH `mapstructure:"ssh"`
- AWS []AWS `mapstructure:"aws"`
- Vault []Vault `mapstructure:"vault"`
-}
-
-// Database config
+// Database holds database configuration.
type Database map[string]string
// Server holds the configuration specific to the web server and sessions.
@@ -49,7 +39,7 @@ type Server struct {
CSRFSecret string `mapstructure:"csrf_secret"`
HTTPLogFile string `mapstructure:"http_logfile"`
Database Database `mapstructure:"database"`
- Datastore string `mapstructure:"datastore"` // Deprecated.
+ Datastore string `mapstructure:"datastore"` // Deprecated. TODO: remove.
}
// Auth holds the configuration specific to the OAuth provider.
@@ -84,168 +74,153 @@ type Vault struct {
Token string `mapstructure:"token"`
}
-func verifyConfig(u *unmarshalled) error {
+func verifyConfig(c *Config) error {
var err error
- if len(u.SSH) == 0 {
+ if c.SSH == nil {
err = multierror.Append(err, errors.New("missing ssh config section"))
}
- if len(u.Auth) == 0 {
+ if c.Auth == nil {
err = multierror.Append(err, errors.New("missing auth config section"))
}
- if len(u.Server) == 0 {
+ if c.Server == nil {
err = multierror.Append(err, errors.New("missing server config section"))
}
- if len(u.AWS) == 0 {
- // AWS config is optional
- u.AWS = append(u.AWS, AWS{})
- }
- if len(u.Vault) == 0 {
- // Vault config is optional
- u.Vault = append(u.Vault, Vault{})
- }
- if u.Server[0].Datastore != "" {
- log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.")
- log.Println("The new config (passwords have been redacted) should look something like:")
- fmt.Printf("server {\n database {\n")
- for k, v := range u.Server[0].Database {
- if v == "" {
- continue
- }
- if k == "password" {
- fmt.Printf(" password = \"[ REDACTED ]\"\n")
- continue
- }
- fmt.Printf(" %s = \"%s\"\n", k, v)
- }
- fmt.Printf(" }\n}\n")
- }
return err
}
-func convertDatastoreConfig(u *unmarshalled) {
+func convertDatastoreConfig(c *Config) {
// Convert the deprecated 'datastore' config to the new 'database' config.
- if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" {
- c := u.Server[0].Datastore
- engine := strings.Split(c, ":")[0]
+ if c.Server != nil && c.Server.Datastore != "" {
+ conf := c.Server.Datastore
+ engine := strings.Split(conf, ":")[0]
switch engine {
case "mysql", "mongo":
- s := strings.SplitN(c, ":", 4)
+ s := strings.SplitN(conf, ":", 4)
engine, user, passwd, addrs := s[0], s[1], s[2], s[3]
- u.Server[0].Database = map[string]string{
+ c.Server.Database = map[string]string{
"type": engine,
"username": user,
"password": passwd,
"address": addrs,
}
case "sqlite":
- s := strings.Split(c, ":")
- u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]}
+ s := strings.Split(conf, ":")
+ c.Server.Database = map[string]string{"type": s[0], "filename": s[1]}
case "mem":
- u.Server[0].Database = map[string]string{"type": "mem"}
+ c.Server.Database = map[string]string{"type": "mem"}
}
+ log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.")
+ log.Println("The new config (passwords have been redacted) should look something like:")
+ fmt.Printf("server {\n database {\n")
+ for k, v := range c.Server.Database {
+ if v == "" {
+ continue
+ }
+ if k == "password" {
+ fmt.Printf(" password = \"[ REDACTED ]\"\n")
+ continue
+ }
+ fmt.Printf(" %s = \"%s\"\n", k, v)
+ }
+ fmt.Printf(" }\n}\n")
}
}
-func setFromEnv(u *unmarshalled) {
+
+func setFromEnvironment(c *Config) {
port, err := strconv.Atoi(os.Getenv("PORT"))
if err == nil {
- u.Server[0].Port = port
+ c.Server.Port = port
}
if os.Getenv("DATASTORE") != "" {
- u.Server[0].Datastore = os.Getenv("DATASTORE")
+ c.Server.Datastore = os.Getenv("DATASTORE")
}
if os.Getenv("OAUTH_CLIENT_ID") != "" {
- u.Auth[0].OauthClientID = os.Getenv("OAUTH_CLIENT_ID")
+ c.Auth.OauthClientID = os.Getenv("OAUTH_CLIENT_ID")
}
if os.Getenv("OAUTH_CLIENT_SECRET") != "" {
- u.Auth[0].OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET")
+ c.Auth.OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET")
}
if os.Getenv("CSRF_SECRET") != "" {
- u.Server[0].CSRFSecret = os.Getenv("CSRF_SECRET")
+ c.Server.CSRFSecret = os.Getenv("CSRF_SECRET")
}
if os.Getenv("COOKIE_SECRET") != "" {
- u.Server[0].CookieSecret = os.Getenv("COOKIE_SECRET")
+ c.Server.CookieSecret = os.Getenv("COOKIE_SECRET")
}
}
-func setFromVault(u *unmarshalled) error {
- if len(u.Vault) == 0 || u.Vault[0].Token == "" || u.Vault[0].Address == "" {
+func setFromVault(c *Config) error {
+ if c.Vault == nil || c.Vault.Token == "" || c.Vault.Address == "" {
return nil
}
- v, err := vault.NewClient(u.Vault[0].Address, u.Vault[0].Token)
+ v, err := vault.NewClient(c.Vault.Address, c.Vault.Token)
if err != nil {
return err
}
- get := func(value string) (string, error) {
- if len(value) > 0 && value[:7] == "/vault/" {
- return v.Read(value)
- }
- return value, nil
- }
var errors error
- if len(u.Auth) > 0 {
- u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID)
- if err != nil {
- errors = multierror.Append(errors, err)
- }
- u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret)
- if err != nil {
- errors = multierror.Append(errors, err)
- }
- }
- if len(u.Server) > 0 {
- u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret)
- if err != nil {
- errors = multierror.Append(errors, err)
- }
- u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret)
- if err != nil {
- errors = multierror.Append(errors, err)
- }
- if len(u.Server[0].Database) > 0 {
- u.Server[0].Database["password"], err = get(u.Server[0].Database["password"])
+ get := func(value string) string {
+ if strings.HasPrefix(value, "/vault/") {
+ s, err := v.Read(value)
if err != nil {
errors = multierror.Append(errors, err)
}
+ return s
}
+ return value
}
- if len(u.AWS) > 0 {
- u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey)
- if err != nil {
- errors = multierror.Append(errors, err)
+ c.Auth.OauthClientID = get(c.Auth.OauthClientID)
+ c.Auth.OauthClientSecret = get(c.Auth.OauthClientSecret)
+ c.Server.CSRFSecret = get(c.Server.CSRFSecret)
+ c.Server.CookieSecret = get(c.Server.CookieSecret)
+ if len(c.Server.Database) != 0 {
+ c.Server.Database["password"] = get(c.Server.Database["password"])
+ }
+ if c.AWS != nil {
+ c.AWS.AccessKey = get(c.AWS.AccessKey)
+ c.AWS.SecretKey = get(c.AWS.SecretKey)
+ }
+ return errors
+}
+
+// Unmarshal the config into a *Config
+func decode() (*Config, error) {
+ var errors error
+ config := &Config{}
+ configPieces := map[string]interface{}{
+ "auth": &config.Auth,
+ "aws": &config.AWS,
+ "server": &config.Server,
+ "ssh": &config.SSH,
+ "vault": &config.Vault,
+ }
+ for key, val := range configPieces {
+ conf, ok := viper.Get(key).([]map[string]interface{})
+ if !ok {
+ continue
}
- u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey)
- if err != nil {
+ if err := mapstructure.WeakDecode(conf[0], val); err != nil {
errors = multierror.Append(errors, err)
}
}
- return errors
+ return config, errors
}
-// ReadConfig parses a JSON configuration file into a Config struct.
+// ReadConfig parses a hcl configuration file into a Config struct.
func ReadConfig(r io.Reader) (*Config, error) {
- u := &unmarshalled{}
- v := viper.New()
- v.SetConfigType("hcl")
- if err := v.ReadConfig(r); err != nil {
+ viper.SetConfigType("hcl")
+ if err := viper.ReadConfig(r); err != nil {
return nil, err
}
- if err := v.Unmarshal(u); err != nil {
+ config, err := decode()
+ if err != nil {
return nil, err
}
- setFromEnv(u)
- if err := setFromVault(u); err != nil {
+ if err := setFromVault(config); err != nil {
return nil, err
}
- convertDatastoreConfig(u)
- if err := verifyConfig(u); err != nil {
+ setFromEnvironment(config)
+ convertDatastoreConfig(config)
+ if err := verifyConfig(config); err != nil {
return nil, err
}
- c := &Config{
- Server: &u.Server[0],
- Auth: &u.Auth[0],
- SSH: &u.SSH[0],
- AWS: &u.AWS[0],
- Vault: &u.Vault[0],
- }
- return c, nil
+ return config, nil
}
diff --git a/server/config/config_test.go b/server/config/config_test.go
index c6bdc3d..399e143 100644
--- a/server/config/config_test.go
+++ b/server/config/config_test.go
@@ -3,61 +3,92 @@ package config
import (
"bytes"
"testing"
- "time"
- "github.com/nsheridan/cashier/testdata"
+ "github.com/nsheridan/cashier/server/config/testdata"
"github.com/stretchr/testify/assert"
)
-func TestServerConfig(t *testing.T) {
- t.Parallel()
- a := assert.New(t)
- c, err := ReadConfig(bytes.NewBuffer(testdata.ServerConfig))
- if err != nil {
- t.Error(err)
+var (
+ parsedConfig = &Config{
+ Server: &Server{
+ UseTLS: true,
+ TLSKey: "server.key",
+ TLSCert: "server.crt",
+ Addr: "127.0.0.1",
+ Port: 443,
+ User: "nobody",
+ CookieSecret: "supersecret",
+ CSRFSecret: "supersecret",
+ HTTPLogFile: "cashierd.log",
+ Database: Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"},
+ Datastore: "mysql:user:passwd:localhost:3306",
+ },
+ Auth: &Auth{
+ OauthClientID: "client_id",
+ OauthClientSecret: "secret",
+ OauthCallbackURL: "https://sshca.example.com/auth/callback",
+ Provider: "google",
+ ProviderOpts: map[string]string{"domain": "example.com"},
+ UsersWhitelist: []string{"a_user"},
+ },
+ SSH: &SSH{
+ SigningKey: "signing_key",
+ AdditionalPrincipals: []string{"ec2-user", "ubuntu"},
+ MaxAge: "720h",
+ Permissions: []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"},
+ },
+ AWS: &AWS{
+ Region: "us-east-1",
+ AccessKey: "abcdef",
+ SecretKey: "omg123",
+ },
+ Vault: &Vault{
+ Address: "https://vault:8200",
+ Token: "abc-def-456-789",
+ },
}
- server := c.Server
- a.IsType(server, &Server{})
- a.True(server.UseTLS)
- a.Equal(server.TLSKey, "server.key")
- a.Equal(server.TLSCert, "server.crt")
- a.Equal(server.Port, 443)
- a.Equal(server.Addr, "127.0.0.1")
- a.Equal(server.CookieSecret, "supersecret")
-}
+)
-func TestAuthConfig(t *testing.T) {
- t.Parallel()
- a := assert.New(t)
- c, err := ReadConfig(bytes.NewBuffer(testdata.AuthConfig))
+func TestConfigParser(t *testing.T) {
+ c, err := ReadConfig(bytes.NewBuffer(testdata.Config))
if err != nil {
t.Error(err)
}
- auth := c.Auth
- a.IsType(auth, &Auth{})
- a.Equal(auth.Provider, "google")
- a.Equal(auth.ProviderOpts, map[string]string{"domain": "example.com"})
- a.Equal(auth.OauthClientID, "client_id")
- a.Equal(auth.OauthClientSecret, "secret")
- a.Equal(auth.OauthCallbackURL, "https://sshca.example.com/auth/callback")
+ assert.Equal(t, parsedConfig, c)
}
-func TestSSHConfig(t *testing.T) {
- t.Parallel()
- a := assert.New(t)
- c, err := ReadConfig(bytes.NewBuffer(testdata.SSHConfig))
- if err != nil {
- t.Error(err)
+func TestConfigVerify(t *testing.T) {
+ bad := bytes.NewBuffer([]byte(""))
+ _, err := ReadConfig(bad)
+ assert.Contains(t, err.Error(), "missing ssh config section", "missing server config section", "missing auth config section")
+}
+
+func TestDatastoreConversion(t *testing.T) {
+ tests := []struct {
+ in string
+ out Database
+ }{
+ {
+ "mysql:user:passwd:localhost:3306", Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"},
+ },
+ {
+ "mongo:::host1,host2", Database{"type": "mongo", "username": "", "password": "", "address": "host1,host2"},
+ },
+ {
+ "mem", Database{"type": "mem"},
+ },
+ {
+ "sqlite:/data/certs.db", Database{"type": "sqlite", "filename": "/data/certs.db"},
+ },
}
- ssh := c.SSH
- a.IsType(ssh, &SSH{})
- a.Equal(ssh.SigningKey, "signing_key")
- a.Equal(ssh.AdditionalPrincipals, []string{"ec2-user", "ubuntu"})
- a.Equal(ssh.Permissions, []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"})
- a.Equal(ssh.MaxAge, "720h")
- d, err := time.ParseDuration(ssh.MaxAge)
- if err != nil {
- t.Error(err)
+
+ for _, tc := range tests {
+ config := &Config{
+ Server: &Server{
+ Datastore: tc.in,
+ },
+ }
+ convertDatastoreConfig(config)
+ assert.EqualValues(t, config.Server.Database, tc.out)
}
- a.Equal(d.Hours(), float64(720))
}
diff --git a/server/config/testdata/config.go b/server/config/testdata/config.go
new file mode 100644
index 0000000..27cffcc
--- /dev/null
+++ b/server/config/testdata/config.go
@@ -0,0 +1,48 @@
+package testdata
+
+var Config = []byte(`
+ server {
+ use_tls = true
+ tls_key = "server.key"
+ tls_cert = "server.crt"
+ address = "127.0.0.1"
+ port = 443
+ user = "nobody"
+ cookie_secret = "supersecret"
+ csrf_secret = "supersecret"
+ http_logfile = "cashierd.log"
+ datastore = "mysql:user:passwd:localhost:3306"
+ database {
+ type = "mysql"
+ username = "user"
+ password = "passwd"
+ address = "localhost:3306"
+ }
+ datastore = "mysql:user:passwd:localhost:3306"
+ }
+ auth {
+ provider = "google"
+ oauth_client_id = "client_id"
+ oauth_client_secret = "secret"
+ oauth_callback_url = "https://sshca.example.com/auth/callback"
+ provider_opts {
+ domain = "example.com"
+ }
+ users_whitelist = ["a_user"]
+ }
+ ssh {
+ signing_key = "signing_key"
+ additional_principals = ["ec2-user", "ubuntu"]
+ max_age = "720h"
+ permissions = ["permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"]
+ }
+ aws {
+ region = "us-east-1"
+ access_key = "abcdef"
+ secret_key = "omg123"
+ }
+ vault {
+ address = "https://vault:8200"
+ token = "abc-def-456-789"
+ }
+`)