aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/config/config.go115
-rw-r--r--server/store/config_test.go70
-rw-r--r--server/store/mongo.go26
-rw-r--r--server/store/sqldb.go52
-rw-r--r--server/store/store.go14
-rw-r--r--server/store/store_test.go24
6 files changed, 152 insertions, 149 deletions
diff --git a/server/config/config.go b/server/config/config.go
index 9678f6d..fa580b0 100644
--- a/server/config/config.go
+++ b/server/config/config.go
@@ -2,25 +2,30 @@ package config
import (
"errors"
+ "fmt"
"io"
+ "log"
"os"
"strconv"
+ "strings"
"github.com/hashicorp/go-multierror"
"github.com/nsheridan/cashier/server/helpers/vault"
"github.com/spf13/viper"
)
-// Config holds the server configuration.
+// Config holds the final server configuration.
type Config struct {
- Server *Server `mapstructure:"server"`
- Auth *Auth `mapstructure:"auth"`
- SSH *SSH `mapstructure:"ssh"`
- AWS *AWS `mapstructure:"aws"`
- Vault *Vault `mapstructure:"vault"`
+ Server *Server
+ Auth *Auth
+ SSH *SSH
+ AWS *AWS
+ Vault *Vault
}
// unmarshalled holds the raw config.
+// The original hcl config is a series of slices. The config is unmarshalled from hcl into this structure and from there
+// we perform some validation checks, other overrides and then produce a final Config struct.
type unmarshalled struct {
Server []Server `mapstructure:"server"`
Auth []Auth `mapstructure:"auth"`
@@ -29,18 +34,22 @@ type unmarshalled struct {
Vault []Vault `mapstructure:"vault"`
}
+// Database config
+type Database map[string]string
+
// Server holds the configuration specific to the web server and sessions.
type Server struct {
- UseTLS bool `mapstructure:"use_tls"`
- TLSKey string `mapstructure:"tls_key"`
- TLSCert string `mapstructure:"tls_cert"`
- Addr string `mapstructure:"address"`
- Port int `mapstructure:"port"`
- User string `mapstructure:"user"`
- CookieSecret string `mapstructure:"cookie_secret"`
- CSRFSecret string `mapstructure:"csrf_secret"`
- HTTPLogFile string `mapstructure:"http_logfile"`
- Datastore string `mapstructure:"datastore"`
+ UseTLS bool `mapstructure:"use_tls"`
+ TLSKey string `mapstructure:"tls_key"`
+ TLSCert string `mapstructure:"tls_cert"`
+ Addr string `mapstructure:"address"`
+ Port int `mapstructure:"port"`
+ User string `mapstructure:"user"`
+ CookieSecret string `mapstructure:"cookie_secret"`
+ CSRFSecret string `mapstructure:"csrf_secret"`
+ HTTPLogFile string `mapstructure:"http_logfile"`
+ Database Database `mapstructure:"database"`
+ Datastore string `mapstructure:"datastore"` // Deprecated.
}
// Auth holds the configuration specific to the OAuth provider.
@@ -78,13 +87,13 @@ type Vault struct {
func verifyConfig(u *unmarshalled) error {
var err error
if len(u.SSH) == 0 {
- err = multierror.Append(errors.New("missing ssh config section"))
+ err = multierror.Append(err, errors.New("missing ssh config section"))
}
if len(u.Auth) == 0 {
- err = multierror.Append(errors.New("missing auth config section"))
+ err = multierror.Append(err, errors.New("missing auth config section"))
}
if len(u.Server) == 0 {
- err = multierror.Append(errors.New("missing server config section"))
+ err = multierror.Append(err, errors.New("missing server config section"))
}
if len(u.AWS) == 0 {
// AWS config is optional
@@ -94,9 +103,48 @@ func verifyConfig(u *unmarshalled) error {
// Vault config is optional
u.Vault = append(u.Vault, Vault{})
}
+ if u.Server[0].Datastore != "" {
+ log.Println("The `datastore` option has been deprecated in favour of the `database` option. You should update your config.")
+ log.Println("The new config (passwords have been redacted) should look something like:")
+ fmt.Printf("server {\n database {\n")
+ for k, v := range u.Server[0].Database {
+ if v == "" {
+ continue
+ }
+ if k == "password" {
+ fmt.Printf(" password = \"[ REDACTED ]\"\n")
+ continue
+ }
+ fmt.Printf(" %s = \"%s\"\n", k, v)
+ }
+ fmt.Printf(" }\n}\n")
+ }
return err
}
+func convertDatastoreConfig(u *unmarshalled) {
+ // Convert the deprecated 'datastore' config to the new 'database' config.
+ if len(u.Server[0].Database) == 0 && u.Server[0].Datastore != "" {
+ c := u.Server[0].Datastore
+ engine := strings.Split(c, ":")[0]
+ switch engine {
+ case "mysql", "mongo":
+ s := strings.SplitN(c, ":", 4)
+ engine, user, passwd, addrs := s[0], s[1], s[2], s[3]
+ u.Server[0].Database = map[string]string{
+ "type": engine,
+ "username": user,
+ "password": passwd,
+ "address": addrs,
+ }
+ case "sqlite":
+ s := strings.Split(c, ":")
+ u.Server[0].Database = map[string]string{"type": s[0], "filename": s[1]}
+ case "mem":
+ u.Server[0].Database = map[string]string{"type": "mem"}
+ }
+ }
+}
func setFromEnv(u *unmarshalled) {
port, err := strconv.Atoi(os.Getenv("PORT"))
if err == nil {
@@ -128,42 +176,49 @@ func setFromVault(u *unmarshalled) error {
return err
}
get := func(value string) (string, error) {
- if value[:7] == "/vault/" {
+ if len(value) > 0 && value[:7] == "/vault/" {
return v.Read(value)
}
return value, nil
}
+ var errors error
if len(u.Auth) > 0 {
u.Auth[0].OauthClientID, err = get(u.Auth[0].OauthClientID)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
}
u.Auth[0].OauthClientSecret, err = get(u.Auth[0].OauthClientSecret)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
}
}
if len(u.Server) > 0 {
u.Server[0].CSRFSecret, err = get(u.Server[0].CSRFSecret)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
}
u.Server[0].CookieSecret, err = get(u.Server[0].CookieSecret)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
+ }
+ if len(u.Server[0].Database) > 0 {
+ u.Server[0].Database["password"], err = get(u.Server[0].Database["password"])
+ if err != nil {
+ errors = multierror.Append(errors, err)
+ }
}
}
if len(u.AWS) > 0 {
u.AWS[0].AccessKey, err = get(u.AWS[0].AccessKey)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
}
u.AWS[0].SecretKey, err = get(u.AWS[0].SecretKey)
if err != nil {
- err = multierror.Append(err)
+ errors = multierror.Append(errors, err)
}
}
- return err
+ return errors
}
// ReadConfig parses a JSON configuration file into a Config struct.
@@ -181,14 +236,16 @@ func ReadConfig(r io.Reader) (*Config, error) {
if err := setFromVault(u); err != nil {
return nil, err
}
+ convertDatastoreConfig(u)
if err := verifyConfig(u); err != nil {
return nil, err
}
- return &Config{
+ c := &Config{
Server: &u.Server[0],
Auth: &u.Auth[0],
SSH: &u.SSH[0],
AWS: &u.AWS[0],
Vault: &u.Vault[0],
- }, nil
+ }
+ return c, nil
}
diff --git a/server/store/config_test.go b/server/store/config_test.go
deleted file mode 100644
index 9a77027..0000000
--- a/server/store/config_test.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package store
-
-import (
- "reflect"
- "testing"
- "time"
-
- mgo "gopkg.in/mgo.v2"
-)
-
-func TestMySQLConfig(t *testing.T) {
- t.Parallel()
- var tests = []struct {
- in string
- out []string
- }{
- {"mysql:user:passwd:localhost", []string{"mysql", "user:passwd@tcp(localhost:3306)/certs?parseTime=true"}},
- {"mysql:user:passwd:localhost:13306", []string{"mysql", "user:passwd@tcp(localhost:13306)/certs?parseTime=true"}},
- {"mysql:root::localhost", []string{"mysql", "root@tcp(localhost:3306)/certs?parseTime=true"}},
- }
- for _, tt := range tests {
- result := parse(tt.in)
- if !reflect.DeepEqual(result, tt.out) {
- t.Errorf("want %s, got %s", tt.out, result)
- }
- }
-}
-
-func TestMongoConfig(t *testing.T) {
- t.Parallel()
- var tests = []struct {
- in string
- out *mgo.DialInfo
- }{
- {"mongo:user:passwd:host", &mgo.DialInfo{
- Username: "user",
- Password: "passwd",
- Addrs: []string{"host"},
- Database: "certs",
- Timeout: 5 * time.Second,
- }},
- {"mongo:user:passwd:host1,host2", &mgo.DialInfo{
- Username: "user",
- Password: "passwd",
- Addrs: []string{"host1", "host2"},
- Database: "certs",
- Timeout: 5 * time.Second,
- }},
- {"mongo:user:passwd:host1:27017,host2:27017", &mgo.DialInfo{
- Username: "user",
- Password: "passwd",
- Addrs: []string{"host1:27017", "host2:27017"},
- Database: "certs",
- Timeout: 5 * time.Second,
- }},
- {"mongo:user:passwd:host1,host2:27017", &mgo.DialInfo{
- Username: "user",
- Password: "passwd",
- Addrs: []string{"host1", "host2:27017"},
- Database: "certs",
- Timeout: 5 * time.Second,
- }},
- }
- for _, tt := range tests {
- result := parseMongoConfig(tt.in)
- if !reflect.DeepEqual(result, tt.out) {
- t.Errorf("want:\n%+v\ngot:\n%+v", tt.out, result)
- }
- }
-}
diff --git a/server/store/mongo.go b/server/store/mongo.go
index 1b13d7a..fc4131f 100644
--- a/server/store/mongo.go
+++ b/server/store/mongo.go
@@ -4,6 +4,8 @@ import (
"strings"
"time"
+ "github.com/nsheridan/cashier/server/config"
+
"golang.org/x/crypto/ssh"
mgo "gopkg.in/mgo.v2"
@@ -15,26 +17,20 @@ var (
issuedTable = "issued_certs"
)
-func parseMongoConfig(config string) *mgo.DialInfo {
- s := strings.SplitN(config, ":", 4)
- _, user, passwd, hosts := s[0], s[1], s[2], s[3]
- d := &mgo.DialInfo{
- Addrs: strings.Split(hosts, ","),
- Username: user,
- Password: passwd,
- Database: certsDB,
- Timeout: time.Second * 5,
- }
- return d
-}
-
func collection(session *mgo.Session) *mgo.Collection {
return session.DB(certsDB).C(issuedTable)
}
// NewMongoStore returns a MongoDB CertStorer.
-func NewMongoStore(config string) (CertStorer, error) {
- session, err := mgo.DialWithInfo(parseMongoConfig(config))
+func NewMongoStore(c config.Database) (CertStorer, error) {
+ m := &mgo.DialInfo{
+ Addrs: strings.Split(c["address"], ","),
+ Username: c["username"],
+ Password: c["password"],
+ Database: certsDB,
+ Timeout: time.Second * 5,
+ }
+ session, err := mgo.DialWithInfo(m)
if err != nil {
return nil, err
}
diff --git a/server/store/sqldb.go b/server/store/sqldb.go
index f65f601..6c1be0e 100644
--- a/server/store/sqldb.go
+++ b/server/store/sqldb.go
@@ -4,13 +4,14 @@ import (
"database/sql"
"encoding/json"
"fmt"
- "strings"
+ "net"
"time"
"golang.org/x/crypto/ssh"
"github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3" // required by sql driver
+ "github.com/nsheridan/cashier/server/config"
)
type sqldb struct {
@@ -24,31 +25,32 @@ type sqldb struct {
revoked *sql.Stmt
}
-func parse(config string) []string {
- s := strings.Split(config, ":")
- if s[0] == "sqlite" {
- s[0] = "sqlite3"
- return s
- }
- if len(s) == 4 {
- s = append(s, "3306")
- }
- _, user, passwd, host, port := s[0], s[1], s[2], s[3], s[4]
- c := &mysql.Config{
- User: user,
- Passwd: passwd,
- Net: "tcp",
- Addr: fmt.Sprintf("%s:%s", host, port),
- DBName: "certs",
- ParseTime: true,
- }
- return []string{"mysql", c.FormatDSN()}
-}
-
// NewSQLStore returns a *sql.DB CertStorer.
-func NewSQLStore(config string) (CertStorer, error) {
- parsed := parse(config)
- conn, err := sql.Open(parsed[0], parsed[1])
+func NewSQLStore(c config.Database) (CertStorer, error) {
+ var driver string
+ var dsn string
+ switch c["type"] {
+ case "mysql":
+ driver = "mysql"
+ address := c["address"]
+ _, _, err := net.SplitHostPort(address)
+ if err != nil {
+ address = address + ":3306"
+ }
+ m := &mysql.Config{
+ User: c["username"],
+ Passwd: c["password"],
+ Net: "tcp",
+ Addr: address,
+ DBName: "certs",
+ ParseTime: true,
+ }
+ dsn = m.FormatDSN()
+ case "sqlite":
+ driver = "sqlite3"
+ dsn = c["filename"]
+ }
+ conn, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("sqldb: could not get a connection: %v", err)
}
diff --git a/server/store/store.go b/server/store/store.go
index c039d3c..a447e72 100644
--- a/server/store/store.go
+++ b/server/store/store.go
@@ -5,9 +5,23 @@ import (
"golang.org/x/crypto/ssh"
+ "github.com/nsheridan/cashier/server/config"
"github.com/nsheridan/cashier/server/util"
)
+// New returns a new configured database.
+func New(c config.Database) (CertStorer, error) {
+ switch c["type"] {
+ case "mongo":
+ return NewMongoStore(c)
+ case "mysql", "sqlite":
+ return NewSQLStore(c)
+ case "mem":
+ return NewMemoryStore(), nil
+ }
+ return NewMemoryStore(), nil
+}
+
// CertStorer records issued certs in a persistent store for audit and
// revocation purposes.
type CertStorer interface {
diff --git a/server/store/store_test.go b/server/store/store_test.go
index 594da37..dbe2d95 100644
--- a/server/store/store_test.go
+++ b/server/store/store_test.go
@@ -3,7 +3,6 @@ package store
import (
"crypto/rand"
"crypto/rsa"
- "fmt"
"io/ioutil"
"os"
"os/exec"
@@ -16,6 +15,10 @@ import (
"golang.org/x/crypto/ssh"
)
+var (
+ dbConfig = map[string]string{"username": "user", "password": "passwd", "address": "localhost"}
+)
+
func TestParseCertificate(t *testing.T) {
t.Parallel()
a := assert.New(t)
@@ -87,11 +90,11 @@ func TestMemoryStore(t *testing.T) {
func TestMySQLStore(t *testing.T) {
t.Parallel()
- config := os.Getenv("MYSQL_TEST_CONFIG")
- if config == "" {
- t.Skip("No MYSQL_TEST_CONFIG environment variable")
+ if os.Getenv("MYSQL_TEST") == "" {
+ t.Skip("No MYSQL_TEST environment variable")
}
- db, err := NewSQLStore(config)
+ dbConfig["type"] = "mysql"
+ db, err := NewSQLStore(dbConfig)
if err != nil {
t.Error(err)
}
@@ -100,11 +103,11 @@ func TestMySQLStore(t *testing.T) {
func TestMongoStore(t *testing.T) {
t.Parallel()
- config := os.Getenv("MONGO_TEST_CONFIG")
- if config == "" {
- t.Skip("No MONGO_TEST_CONFIG environment variable")
+ if os.Getenv("MONGO_TEST") == "" {
+ t.Skip("No MONGO_TEST environment variable")
}
- db, err := NewMongoStore(config)
+ dbConfig["type"] = "mongo"
+ db, err := NewMongoStore(dbConfig)
if err != nil {
t.Error(err)
}
@@ -123,7 +126,8 @@ func TestSQLiteStore(t *testing.T) {
if err := exec.Command("go", args...).Run(); err != nil {
t.Error(err)
}
- db, err := NewSQLStore(fmt.Sprintf("sqlite:%s", f.Name()))
+ config := map[string]string{"type": "sqlite", "filename": f.Name()}
+ db, err := NewSQLStore(config)
if err != nil {
t.Error(err)
}