diff options
Diffstat (limited to 'server')
| -rw-r--r-- | server/config/config.go | 209 | ||||
| -rw-r--r-- | server/config/config_test.go | 119 | ||||
| -rw-r--r-- | server/config/testdata/config.go | 48 | ||||
| -rw-r--r-- | server/wkfs/s3fs/s3.go | 4 | ||||
| -rw-r--r-- | server/wkfs/vaultfs/vault.go | 4 | 
5 files changed, 223 insertions, 161 deletions
diff --git a/server/config/config.go b/server/config/config.go index fa580b0..2d821f9 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -10,31 +10,21 @@ import (  	"strings"  	"github.com/hashicorp/go-multierror" +	"github.com/mitchellh/mapstructure"  	"github.com/nsheridan/cashier/server/helpers/vault"  	"github.com/spf13/viper"  )  // Config holds the final server configuration.  type Config struct { -	Server *Server -	Auth   *Auth -	SSH    *SSH -	AWS    *AWS -	Vault  *Vault +	Server *Server `mapstructure:"server"` +	Auth   *Auth   `mapstructure:"auth"` +	SSH    *SSH    `mapstructure:"ssh"` +	AWS    *AWS    `mapstructure:"aws"` +	Vault  *Vault  `mapstructure:"vault"`  } -// unmarshalled holds the raw config. -// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there -// we perform some validation checks, other overrides and then produce a final Config struct. -type unmarshalled struct { -	Server []Server `mapstructure:"server"` -	Auth   []Auth   `mapstructure:"auth"` -	SSH    []SSH    `mapstructure:"ssh"` -	AWS    []AWS    `mapstructure:"aws"` -	Vault  []Vault  `mapstructure:"vault"` -} - -// Database config +// Database holds database configuration.  type Database map[string]string  // Server holds the configuration specific to the web server and sessions. @@ -49,7 +39,7 @@ type Server struct {  	CSRFSecret   string   `mapstructure:"csrf_secret"`  	HTTPLogFile  string   `mapstructure:"http_logfile"`  	Database     Database `mapstructure:"database"` -	Datastore    string   `mapstructure:"datastore"` // Deprecated. +	Datastore    string   `mapstructure:"datastore"` // Deprecated. TODO: remove.  }  // Auth holds the configuration specific to the OAuth provider. @@ -84,168 +74,153 @@ type Vault struct {  	Token   string `mapstructure:"token"`  } -func verifyConfig(u *unmarshalled) error { +func verifyConfig(c *Config) error {  	var err error -	if len(u.SSH) == 0 { +	if c.SSH == nil {  		err = multierror.Append(err, errors.New("missing ssh config section"))  	} -	if len(u.Auth) == 0 { +	if c.Auth == nil {  		err = multierror.Append(err, errors.New("missing auth config section"))  	} -	if len(u.Server) == 0 { +	if c.Server == nil {  		err = multierror.Append(err, errors.New("missing server config section"))  	} -	if len(u.AWS) == 0 { -		// AWS config is optional -		u.AWS = append(u.AWS, AWS{}) -	} -	if len(u.Vault) == 0 { -		// Vault config is optional -		u.Vault = append(u.Vault, Vault{}) -	} -	if u.Server[0].Datastore != "" { -		log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") -		log.Println("The new config (passwords have been redacted) should look something like:") -		fmt.Printf("server {\n  database {\n") -		for k, v := range u.Server[0].Database { -			if v == "" { -				continue -			} -			if k == "password" { -				fmt.Printf("    password = \"[ REDACTED ]\"\n") -				continue -			} -			fmt.Printf("    %s = \"%s\"\n", k, v) -		} -		fmt.Printf("  }\n}\n") -	}  	return err  } -func convertDatastoreConfig(u *unmarshalled) { +func convertDatastoreConfig(c *Config) {  	// Convert the deprecated 'datastore' config to the new 'database' config. -	if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" { -		c := u.Server[0].Datastore -		engine := strings.Split(c, ":")[0] +	if c.Server != nil && c.Server.Datastore != "" { +		conf := c.Server.Datastore +		engine := strings.Split(conf, ":")[0]  		switch engine {  		case "mysql", "mongo": -			s := strings.SplitN(c, ":", 4) +			s := strings.SplitN(conf, ":", 4)  			engine, user, passwd, addrs := s[0], s[1], s[2], s[3] -			u.Server[0].Database = map[string]string{ +			c.Server.Database = map[string]string{  				"type":     engine,  				"username": user,  				"password": passwd,  				"address":  addrs,  			}  		case "sqlite": -			s := strings.Split(c, ":") -			u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]} +			s := strings.Split(conf, ":") +			c.Server.Database = map[string]string{"type": s[0], "filename": s[1]}  		case "mem": -			u.Server[0].Database = map[string]string{"type": "mem"} +			c.Server.Database = map[string]string{"type": "mem"}  		} +		log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.") +		log.Println("The new config (passwords have been redacted) should look something like:") +		fmt.Printf("server {\n  database {\n") +		for k, v := range c.Server.Database { +			if v == "" { +				continue +			} +			if k == "password" { +				fmt.Printf("    password = \"[ REDACTED ]\"\n") +				continue +			} +			fmt.Printf("    %s = \"%s\"\n", k, v) +		} +		fmt.Printf("  }\n}\n")  	}  } -func setFromEnv(u *unmarshalled) { + +func setFromEnvironment(c *Config) {  	port, err := strconv.Atoi(os.Getenv("PORT"))  	if err == nil { -		u.Server[0].Port = port +		c.Server.Port = port  	}  	if os.Getenv("DATASTORE") != "" { -		u.Server[0].Datastore = os.Getenv("DATASTORE") +		c.Server.Datastore = os.Getenv("DATASTORE")  	}  	if os.Getenv("OAUTH_CLIENT_ID") != "" { -		u.Auth[0].OauthClientID = os.Getenv("OAUTH_CLIENT_ID") +		c.Auth.OauthClientID = os.Getenv("OAUTH_CLIENT_ID")  	}  	if os.Getenv("OAUTH_CLIENT_SECRET") != "" { -		u.Auth[0].OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET") +		c.Auth.OauthClientSecret = os.Getenv("OAUTH_CLIENT_SECRET")  	}  	if os.Getenv("CSRF_SECRET") != "" { -		u.Server[0].CSRFSecret = os.Getenv("CSRF_SECRET") +		c.Server.CSRFSecret = os.Getenv("CSRF_SECRET")  	}  	if os.Getenv("COOKIE_SECRET") != "" { -		u.Server[0].CookieSecret = os.Getenv("COOKIE_SECRET") +		c.Server.CookieSecret = os.Getenv("COOKIE_SECRET")  	}  } -func setFromVault(u *unmarshalled) error { -	if len(u.Vault) == 0 || u.Vault[0].Token == "" || u.Vault[0].Address == "" { +func setFromVault(c *Config) error { +	if c.Vault == nil || c.Vault.Token == "" || c.Vault.Address == "" {  		return nil  	} -	v, err := vault.NewClient(u.Vault[0].Address, u.Vault[0].Token) +	v, err := vault.NewClient(c.Vault.Address, c.Vault.Token)  	if err != nil {  		return err  	} -	get := func(value string) (string, error) { -		if len(value) > 0 && value[:7] == "/vault/" { -			return v.Read(value) -		} -		return value, nil -	}  	var errors error -	if len(u.Auth) > 0 { -		u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID) -		if err != nil { -			errors = multierror.Append(errors, err) -		} -		u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret) -		if err != nil { -			errors = multierror.Append(errors, err) -		} -	} -	if len(u.Server) > 0 { -		u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret) -		if err != nil { -			errors = multierror.Append(errors, err) -		} -		u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret) -		if err != nil { -			errors = multierror.Append(errors, err) -		} -		if len(u.Server[0].Database) > 0 { -			u.Server[0].Database["password"], err = get(u.Server[0].Database["password"]) +	get := func(value string) string { +		if strings.HasPrefix(value, "/vault/") { +			s, err := v.Read(value)  			if err != nil {  				errors = multierror.Append(errors, err)  			} +			return s  		} +		return value  	} -	if len(u.AWS) > 0 { -		u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey) -		if err != nil { -			errors = multierror.Append(errors, err) +	c.Auth.OauthClientID = get(c.Auth.OauthClientID) +	c.Auth.OauthClientSecret = get(c.Auth.OauthClientSecret) +	c.Server.CSRFSecret = get(c.Server.CSRFSecret) +	c.Server.CookieSecret = get(c.Server.CookieSecret) +	if len(c.Server.Database) != 0 { +		c.Server.Database["password"] = get(c.Server.Database["password"]) +	} +	if c.AWS != nil { +		c.AWS.AccessKey = get(c.AWS.AccessKey) +		c.AWS.SecretKey = get(c.AWS.SecretKey) +	} +	return errors +} + +// Unmarshal the config into a *Config +func decode() (*Config, error) { +	var errors error +	config := &Config{} +	configPieces := map[string]interface{}{ +		"auth":   &config.Auth, +		"aws":    &config.AWS, +		"server": &config.Server, +		"ssh":    &config.SSH, +		"vault":  &config.Vault, +	} +	for key, val := range configPieces { +		conf, ok := viper.Get(key).([]map[string]interface{}) +		if !ok { +			continue  		} -		u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey) -		if err != nil { +		if err := mapstructure.WeakDecode(conf[0], val); err != nil {  			errors = multierror.Append(errors, err)  		}  	} -	return errors +	return config, errors  } -// ReadConfig parses a JSON configuration file into a Config struct. +// ReadConfig parses a hcl configuration file into a Config struct.  func ReadConfig(r io.Reader) (*Config, error) { -	u := &unmarshalled{} -	v := viper.New() -	v.SetConfigType("hcl") -	if err := v.ReadConfig(r); err != nil { +	viper.SetConfigType("hcl") +	if err := viper.ReadConfig(r); err != nil {  		return nil, err  	} -	if err := v.Unmarshal(u); err != nil { +	config, err := decode() +	if err != nil {  		return nil, err  	} -	setFromEnv(u) -	if err := setFromVault(u); err != nil { +	if err := setFromVault(config); err != nil {  		return nil, err  	} -	convertDatastoreConfig(u) -	if err := verifyConfig(u); err != nil { +	setFromEnvironment(config) +	convertDatastoreConfig(config) +	if err := verifyConfig(config); err != nil {  		return nil, err  	} -	c := &Config{ -		Server: &u.Server[0], -		Auth:   &u.Auth[0], -		SSH:    &u.SSH[0], -		AWS:    &u.AWS[0], -		Vault:  &u.Vault[0], -	} -	return c, nil +	return config, nil  } diff --git a/server/config/config_test.go b/server/config/config_test.go index c6bdc3d..399e143 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -3,61 +3,92 @@ package config  import (  	"bytes"  	"testing" -	"time" -	"github.com/nsheridan/cashier/testdata" +	"github.com/nsheridan/cashier/server/config/testdata"  	"github.com/stretchr/testify/assert"  ) -func TestServerConfig(t *testing.T) { -	t.Parallel() -	a := assert.New(t) -	c, err := ReadConfig(bytes.NewBuffer(testdata.ServerConfig)) -	if err != nil { -		t.Error(err) +var ( +	parsedConfig = &Config{ +		Server: &Server{ +			UseTLS:       true, +			TLSKey:       "server.key", +			TLSCert:      "server.crt", +			Addr:         "127.0.0.1", +			Port:         443, +			User:         "nobody", +			CookieSecret: "supersecret", +			CSRFSecret:   "supersecret", +			HTTPLogFile:  "cashierd.log", +			Database:     Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"}, +			Datastore:    "mysql:user:passwd:localhost:3306", +		}, +		Auth: &Auth{ +			OauthClientID:     "client_id", +			OauthClientSecret: "secret", +			OauthCallbackURL:  "https://sshca.example.com/auth/callback", +			Provider:          "google", +			ProviderOpts:      map[string]string{"domain": "example.com"}, +			UsersWhitelist:    []string{"a_user"}, +		}, +		SSH: &SSH{ +			SigningKey:           "signing_key", +			AdditionalPrincipals: []string{"ec2-user", "ubuntu"}, +			MaxAge:               "720h", +			Permissions:          []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"}, +		}, +		AWS: &AWS{ +			Region:    "us-east-1", +			AccessKey: "abcdef", +			SecretKey: "omg123", +		}, +		Vault: &Vault{ +			Address: "https://vault:8200", +			Token:   "abc-def-456-789", +		},  	} -	server := c.Server -	a.IsType(server, &Server{}) -	a.True(server.UseTLS) -	a.Equal(server.TLSKey, "server.key") -	a.Equal(server.TLSCert, "server.crt") -	a.Equal(server.Port, 443) -	a.Equal(server.Addr, "127.0.0.1") -	a.Equal(server.CookieSecret, "supersecret") -} +) -func TestAuthConfig(t *testing.T) { -	t.Parallel() -	a := assert.New(t) -	c, err := ReadConfig(bytes.NewBuffer(testdata.AuthConfig)) +func TestConfigParser(t *testing.T) { +	c, err := ReadConfig(bytes.NewBuffer(testdata.Config))  	if err != nil {  		t.Error(err)  	} -	auth := c.Auth -	a.IsType(auth, &Auth{}) -	a.Equal(auth.Provider, "google") -	a.Equal(auth.ProviderOpts, map[string]string{"domain": "example.com"}) -	a.Equal(auth.OauthClientID, "client_id") -	a.Equal(auth.OauthClientSecret, "secret") -	a.Equal(auth.OauthCallbackURL, "https://sshca.example.com/auth/callback") +	assert.Equal(t, parsedConfig, c)  } -func TestSSHConfig(t *testing.T) { -	t.Parallel() -	a := assert.New(t) -	c, err := ReadConfig(bytes.NewBuffer(testdata.SSHConfig)) -	if err != nil { -		t.Error(err) +func TestConfigVerify(t *testing.T) { +	bad := bytes.NewBuffer([]byte("")) +	_, err := ReadConfig(bad) +	assert.Contains(t, err.Error(), "missing ssh config section", "missing server config section", "missing auth config section") +} + +func TestDatastoreConversion(t *testing.T) { +	tests := []struct { +		in  string +		out Database +	}{ +		{ +			"mysql:user:passwd:localhost:3306", Database{"type": "mysql", "username": "user", "password": "passwd", "address": "localhost:3306"}, +		}, +		{ +			"mongo:::host1,host2", Database{"type": "mongo", "username": "", "password": "", "address": "host1,host2"}, +		}, +		{ +			"mem", Database{"type": "mem"}, +		}, +		{ +			"sqlite:/data/certs.db", Database{"type": "sqlite", "filename": "/data/certs.db"}, +		},  	} -	ssh := c.SSH -	a.IsType(ssh, &SSH{}) -	a.Equal(ssh.SigningKey, "signing_key") -	a.Equal(ssh.AdditionalPrincipals, []string{"ec2-user", "ubuntu"}) -	a.Equal(ssh.Permissions, []string{"permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"}) -	a.Equal(ssh.MaxAge, "720h") -	d, err := time.ParseDuration(ssh.MaxAge) -	if err != nil { -		t.Error(err) + +	for _, tc := range tests { +		config := &Config{ +			Server: &Server{ +				Datastore: tc.in, +			}, +		} +		convertDatastoreConfig(config) +		assert.EqualValues(t, config.Server.Database, tc.out)  	} -	a.Equal(d.Hours(), float64(720))  } diff --git a/server/config/testdata/config.go b/server/config/testdata/config.go new file mode 100644 index 0000000..27cffcc --- /dev/null +++ b/server/config/testdata/config.go @@ -0,0 +1,48 @@ +package testdata + +var Config = []byte(` +	server { +		use_tls = true +		tls_key = "server.key" +		tls_cert = "server.crt" +		address = "127.0.0.1" +		port = 443 +		user = "nobody" +		cookie_secret = "supersecret" +		csrf_secret = "supersecret" +		http_logfile = "cashierd.log" +		datastore = "mysql:user:passwd:localhost:3306" +		database { +			type = "mysql" +			username = "user" +			password = "passwd" +			address = "localhost:3306" +		} +		datastore = "mysql:user:passwd:localhost:3306" +	} +	auth { +		provider = "google" +		oauth_client_id = "client_id" +		oauth_client_secret = "secret" +		oauth_callback_url = "https://sshca.example.com/auth/callback" +		provider_opts { +			domain = "example.com" +		} +		users_whitelist = ["a_user"] +	} +	ssh { +		signing_key = "signing_key" +		additional_principals = ["ec2-user", "ubuntu"] +		max_age = "720h" +		permissions = ["permit-pty", "permit-X11-forwarding", "permit-port-forwarding", "permit-user-rc"] +	} +	aws { +		region = "us-east-1" +		access_key = "abcdef" +		secret_key = "omg123" +	} +	vault { +		address = "https://vault:8200" +		token = "abc-def-456-789" +	} +`) diff --git a/server/wkfs/s3fs/s3.go b/server/wkfs/s3fs/s3.go index a71d874..331b55f 100644 --- a/server/wkfs/s3fs/s3.go +++ b/server/wkfs/s3fs/s3.go @@ -21,6 +21,10 @@ import (  // Register the /s3/ filesystem as a well-known filesystem.  func Register(config *config.AWS) { +	if config == nil { +		registerBrokenFS(errors.New("aws credentials not found")) +		return +	}  	ac := &aws.Config{}  	// If region is unset the SDK will attempt to read the region from the environment.  	if config.Region != "" { diff --git a/server/wkfs/vaultfs/vault.go b/server/wkfs/vaultfs/vault.go index 6f11057..f7c1360 100644 --- a/server/wkfs/vaultfs/vault.go +++ b/server/wkfs/vaultfs/vault.go @@ -14,6 +14,10 @@ import (  // Register the /vault/ filesystem as a well-known filesystem.  func Register(vc *config.Vault) { +	if vc == nil { +		registerBrokenFS(errors.New("no vault configuration found")) +		return +	}  	client, err := vault.NewClient(vc.Address, vc.Token)  	if err != nil {  		registerBrokenFS(err)  | 
