diff options
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/cashier/client_test.go | 18 | ||||
| -rw-r--r-- | cmd/cashier/config.go | 36 | ||||
| -rw-r--r-- | cmd/cashier/main.go | 48 | 
3 files changed, 80 insertions, 22 deletions
diff --git a/cmd/cashier/client_test.go b/cmd/cashier/client_test.go index 492f4fc..f0176c6 100644 --- a/cmd/cashier/client_test.go +++ b/cmd/cashier/client_test.go @@ -58,8 +58,7 @@ func TestSignGood(t *testing.T) {  		fmt.Fprintln(w, string(j))  	}))  	defer ts.Close() -	*ca = ts.URL -	_, err := send([]byte(`{}`), "token") +	_, err := send([]byte(`{}`), "token", ts.URL, true)  	if err != nil {  		t.Fatal(err)  	} @@ -67,7 +66,11 @@ func TestSignGood(t *testing.T) {  	if err != nil {  		t.Fatal(err)  	} -	cert, err := sign(k, "token") +	c := &config{ +		CA:       ts.URL, +		Validity: "24h", +	} +	cert, err := sign(k, "token", c)  	if cert == nil && err != nil {  		t.Fatal(err)  	} @@ -83,8 +86,7 @@ func TestSignBad(t *testing.T) {  		fmt.Fprintln(w, string(j))  	}))  	defer ts.Close() -	*ca = ts.URL -	_, err := send([]byte(`{}`), "token") +	_, err := send([]byte(`{}`), "token", ts.URL, true)  	if err != nil {  		t.Fatal(err)  	} @@ -92,7 +94,11 @@ func TestSignBad(t *testing.T) {  	if err != nil {  		t.Fatal(err)  	} -	cert, err := sign(k, "token") +	c := &config{ +		CA:       ts.URL, +		Validity: "24h", +	} +	cert, err := sign(k, "token", c)  	if cert != nil && err == nil {  		t.Fatal(err)  	} diff --git a/cmd/cashier/config.go b/cmd/cashier/config.go new file mode 100644 index 0000000..eed98e1 --- /dev/null +++ b/cmd/cashier/config.go @@ -0,0 +1,36 @@ +package main + +import ( +	"github.com/spf13/pflag" +	"github.com/spf13/viper" +) + +type config struct { +	CA                     string `mapstructure:"ca"` +	Keytype                string `mapstructure:"key_type"` +	Keysize                int    `mapstructure:"key_size"` +	Validity               string `mapstructure:"validity"` +	ValidateTLSCertificate bool   `mapstructure:"validate_tls_certificate"` +} + +func setDefaults() { +	viper.BindPFlag("ca", pflag.Lookup("ca")) +	viper.BindPFlag("key_type", pflag.Lookup("key_type")) +	viper.BindPFlag("key_size", pflag.Lookup("key_size")) +	viper.BindPFlag("validity", pflag.Lookup("validity")) +	viper.SetDefault("validateTLSCertificate", true) +} + +func readConfig(path string) (*config, error) { +	setDefaults() +	viper.SetConfigFile(path) +	viper.SetConfigType("hcl") +	if err := viper.ReadInConfig(); err != nil { +		return nil, err +	} +	c := &config{} +	if err := viper.Unmarshal(c); err != nil { +		return nil, err +	} +	return c, nil +} diff --git a/cmd/cashier/main.go b/cmd/cashier/main.go index 8bcc3e7..768ebcd 100644 --- a/cmd/cashier/main.go +++ b/cmd/cashier/main.go @@ -2,27 +2,32 @@ package main  import (  	"bytes" +	"crypto/tls"  	"encoding/json" -	"flag"  	"fmt"  	"io/ioutil"  	"log"  	"net"  	"net/http"  	"os" +	"os/user" +	"path"  	"time"  	"github.com/nsheridan/cashier/lib"  	"github.com/pkg/browser" +	"github.com/spf13/pflag"  	"golang.org/x/crypto/ssh"  	"golang.org/x/crypto/ssh/agent"  )  var ( -	ca       = flag.String("ca", "http://localhost:10000", "CA server") -	keybits  = flag.Int("bits", 2048, "Key size. Ignored for ed25519 keys") -	validity = flag.Duration("validity", time.Hour*24, "Key validity") -	keytype  = flag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519") +	u, _     = user.Current() +	cfg      = pflag.String("config", path.Join(u.HomeDir, ".cashier.conf"), "Path to config file") +	ca       = pflag.String("ca", "http://localhost:10000", "CA server") +	keysize  = pflag.Int("key_size", 2048, "Key size. Ignored for ed25519 keys") +	validity = pflag.Duration("validity", time.Hour*24, "Key validity") +	keytype  = pflag.String("key_type", "rsa", "Type of private key to generate - rsa, ecdsa or ed25519")  )  func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { @@ -37,15 +42,18 @@ func installCert(a agent.Agent, cert *ssh.Certificate, key key) error {  	return nil  } -func send(s []byte, token string) (*lib.SignResponse, error) { -	req, err := http.NewRequest("POST", *ca+"/sign", bytes.NewReader(s)) +func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) { +	transport := &http.Transport{ +		TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate}, +	} +	client := &http.Client{Transport: transport} +	req, err := http.NewRequest("POST", ca+"/sign", bytes.NewReader(s))  	if err != nil {  		return nil, err  	}  	req.Header.Set("Content-Type", "application/json")  	req.Header.Add("Accept", "application/json")  	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) -	client := &http.Client{}  	resp, err := client.Do(req)  	if err != nil {  		return nil, err @@ -65,17 +73,21 @@ func send(s []byte, token string) (*lib.SignResponse, error) {  	return c, nil  } -func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) { +func sign(pub ssh.PublicKey, token string, conf *config) (*ssh.Certificate, error) { +	validity, err := time.ParseDuration(conf.Validity) +	if err != nil { +		return nil, err +	}  	marshaled := ssh.MarshalAuthorizedKey(pub)  	marshaled = marshaled[:len(marshaled)-1]  	s, err := json.Marshal(&lib.SignRequest{  		Key:        string(marshaled), -		ValidUntil: time.Now().Add(*validity), +		ValidUntil: time.Now().Add(validity),  	})  	if err != nil {  		return nil, err  	} -	resp, err := send(s, token) +	resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate)  	if err != nil {  		return nil, err  	} @@ -94,14 +106,18 @@ func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) {  }  func main() { -	flag.Parse() +	pflag.Parse() -	fmt.Printf("Your browser has been opened to visit %s\n", *ca) -	if err := browser.OpenURL(*ca); err != nil { +	c, err := readConfig(*cfg) +	if err != nil { +		log.Fatalf("Error parsing config file: %v\n", err) +	} +	fmt.Printf("Your browser has been opened to visit %s\n", c.CA) +	if err := browser.OpenURL(c.CA); err != nil {  		fmt.Println("Error launching web browser. Go to the link in your web browser")  	}  	fmt.Println("Generating new key pair") -	priv, pub, err := generateKey(*keytype, *keybits) +	priv, pub, err := generateKey(c.Keytype, c.Keysize)  	if err != nil {  		log.Fatalln("Error generating key pair: ", err)  	} @@ -110,7 +126,7 @@ func main() {  	var token string  	fmt.Scanln(&token) -	cert, err := sign(pub, token) +	cert, err := sign(pub, token, c)  	if err != nil {  		log.Fatalln(err)  	}  | 
