diff options
Diffstat (limited to 'cmd')
-rw-r--r-- | cmd/cashier/client_test.go | 18 | ||||
-rw-r--r-- | cmd/cashier/config.go | 36 | ||||
-rw-r--r-- | cmd/cashier/main.go | 48 |
3 files changed, 80 insertions, 22 deletions
diff --git a/cmd/cashier/client_test.go b/cmd/cashier/client_test.go index 492f4fc..f0176c6 100644 --- a/cmd/cashier/client_test.go +++ b/cmd/cashier/client_test.go @@ -58,8 +58,7 @@ func TestSignGood(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - *ca = ts.URL - _, err := send([]byte(`{}`), "token") + _, err := send([]byte(`{}`), "token", ts.URL, true) if err != nil { t.Fatal(err) } @@ -67,7 +66,11 @@ func TestSignGood(t *testing.T) { if err != nil { t.Fatal(err) } - cert, err := sign(k, "token") + c := &config{ + CA: ts.URL, + Validity: "24h", + } + cert, err := sign(k, "token", c) if cert == nil && err != nil { t.Fatal(err) } @@ -83,8 +86,7 @@ func TestSignBad(t *testing.T) { fmt.Fprintln(w, string(j)) })) defer ts.Close() - *ca = ts.URL - _, err := send([]byte(`{}`), "token") + _, err := send([]byte(`{}`), "token", ts.URL, true) if err != nil { t.Fatal(err) } @@ -92,7 +94,11 @@ func TestSignBad(t *testing.T) { if err != nil { t.Fatal(err) } - cert, err := sign(k, "token") + c := &config{ + CA: ts.URL, + Validity: "24h", + } + cert, err := sign(k, "token", c) if cert != nil && err == nil { t.Fatal(err) } diff --git a/cmd/cashier/config.go b/cmd/cashier/config.go new file mode 100644 index 0000000..eed98e1 --- /dev/null +++ b/cmd/cashier/config.go @@ -0,0 +1,36 @@ +package main + +import ( + "github.com/spf13/pflag" + "github.com/spf13/viper" +) + +type config struct { + CA string `mapstructure:"ca"` + Keytype string `mapstructure:"key_type"` + Keysize int `mapstructure:"key_size"` + Validity string `mapstructure:"validity"` + ValidateTLSCertificate bool `mapstructure:"validate_tls_certificate"` +} + +func setDefaults() { + viper.BindPFlag("ca", pflag.Lookup("ca")) + viper.BindPFlag("key_type", pflag.Lookup("key_type")) + viper.BindPFlag("key_size", pflag.Lookup("key_size")) + viper.BindPFlag("validity", pflag.Lookup("validity")) + viper.SetDefault("validateTLSCertificate", true) +} + +func readConfig(path string) (*config, error) { + setDefaults() + viper.SetConfigFile(path) + viper.SetConfigType("hcl") + if err := viper.ReadInConfig(); err != nil { + return nil, err + } + c := &config{} + if err := viper.Unmarshal(c); err != nil { + return nil, err + } + return c, nil +} diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index 8bcc3e7..768ebcd 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -2,27 +2,32 @@ package main import ( "bytes" + "crypto/tls" "encoding/json" - "flag" "fmt" "io/ioutil" "log" "net" "net/http" "os" + "os/user" + "path" "time" "github.com/nsheridan/cashier/lib" "github.com/pkg/browser" + "github.com/spf13/pflag" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) var ( - ca = flag.String("ca", "http://localhost:10000", "CA server") - keybits = flag.Int("bits", 2048, "Key size. Ignored for ed25519 keys") - validity = flag.Duration("validity", time.Hour*24, "Key validity") - keytype = flag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519") + u, _ = user.Current() + cfg = pflag.String("config", path.Join(u.HomeDir, ".cashier.conf"), "Path to config file") + ca = pflag.String("ca", "http://localhost:10000", "CA server") + keysize = pflag.Int("key_size", 2048, "Key size. Ignored for ed25519 keys") + validity = pflag.Duration("validity", time.Hour*24, "Key validity") + keytype = pflag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519") ) func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { @@ -37,15 +42,18 @@ func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { return nil } -func send(s []byte, token string) (*lib.SignResponse, error) { - req, err := http.NewRequest("POST", *ca+"/sign", bytes.NewReader(s)) +func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, + } + client := &http.Client{Transport: transport} + req, err := http.NewRequest("POST", ca+"/sign", bytes.NewReader(s)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Add("Accept", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err @@ -65,17 +73,21 @@ func send(s []byte, token string) (*lib.SignResponse, error) { return c, nil } -func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) { +func sign(pub ssh.PublicKey, token string, conf *config) (*ssh.Certificate, error) { + validity, err := time.ParseDuration(conf.Validity) + if err != nil { + return nil, err + } marshaled := ssh.MarshalAuthorizedKey(pub) marshaled = marshaled[:len(marshaled)-1] s, err := json.Marshal(&lib.SignRequest{ Key: string(marshaled), - ValidUntil: time.Now().Add(*validity), + ValidUntil: time.Now().Add(validity), }) if err != nil { return nil, err } - resp, err := send(s, token) + resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate) if err != nil { return nil, err } @@ -94,14 +106,18 @@ func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) { } func main() { - flag.Parse() + pflag.Parse() - fmt.Printf("Your browser has been opened to visit %s\n", *ca) - if err := browser.OpenURL(*ca); err != nil { + c, err := readConfig(*cfg) + if err != nil { + log.Fatalf("Error parsing config file: %v\n", err) + } + fmt.Printf("Your browser has been opened to visit %s\n", c.CA) + if err := browser.OpenURL(c.CA); err != nil { fmt.Println("Error launching web browser. Go to the link in your web browser") } fmt.Println("Generating new key pair") - priv, pub, err := generateKey(*keytype, *keybits) + priv, pub, err := generateKey(c.Keytype, c.Keysize) if err != nil { log.Fatalln("Error generating key pair: ", err) } @@ -110,7 +126,7 @@ func main() { var token string fmt.Scanln(&token) - cert, err := sign(pub, token) + cert, err := sign(pub, token, c) if err != nil { log.Fatalln(err) } |