diff options
| -rw-r--r-- | client/keys.go | 64 | ||||
| -rw-r--r-- | client/main.go | 121 | ||||
| -rw-r--r-- | exampleconfig.json | 24 | ||||
| -rw-r--r-- | lib/const.go | 18 | ||||
| -rw-r--r-- | server/auth/google/google.go | 100 | ||||
| -rw-r--r-- | server/auth/provider.go | 27 | ||||
| -rw-r--r-- | server/config/config.go | 51 | ||||
| -rw-r--r-- | server/main.go | 218 | ||||
| -rw-r--r-- | server/signer/signer.go | 85 | ||||
| -rw-r--r-- | templates/token.html | 46 | 
10 files changed, 754 insertions, 0 deletions
diff --git a/client/keys.go b/client/keys.go new file mode 100644 index 0000000..4acfbb9 --- /dev/null +++ b/client/keys.go @@ -0,0 +1,64 @@ +package main + +import ( +	"crypto/ecdsa" +	"crypto/elliptic" +	"crypto/rand" +	"crypto/rsa" +	"fmt" + +	"golang.org/x/crypto/ssh" +) + +const ( +	rsaKey   = "rsa" +	ecdsaKey = "ecdsa" +) + +type key interface{} + +func generateRSAKey(bits int) (*rsa.PrivateKey, ssh.PublicKey, error) { +	k, err := rsa.GenerateKey(rand.Reader, bits) +	if err != nil { +		return nil, nil, err +	} +	pub, err := ssh.NewPublicKey(&k.PublicKey) +	if err != nil { +		return nil, nil, err +	} +	return k, pub, nil +} + +func generateECDSAKey(bits int) (*ecdsa.PrivateKey, ssh.PublicKey, error) { +	var curve elliptic.Curve +	switch bits { +	case 256: +		curve = elliptic.P256() +	case 384: +		curve = elliptic.P384() +	case 521: +		curve = elliptic.P521() +	default: +		return nil, nil, fmt.Errorf("Unsupported key size. Valid sizes are '256', '384', '521'") +	} +	k, err := ecdsa.GenerateKey(curve, rand.Reader) +	if err != nil { +		return nil, nil, err +	} +	pub, err := ssh.NewPublicKey(&k.PublicKey) +	if err != nil { +		return nil, nil, err +	} +	return k, pub, nil +} + +func generateKey(keytype string, bits int) (key, ssh.PublicKey, error) { +	switch keytype { +	case rsaKey: +		return generateRSAKey(bits) +	case ecdsaKey: +		return generateECDSAKey(bits) +	default: +		return nil, nil, fmt.Errorf("Unsupported key type %s. Valid choices are [%s, %s]", keytype, rsaKey, ecdsaKey) +	} +} diff --git a/client/main.go b/client/main.go new file mode 100644 index 0000000..10a3646 --- /dev/null +++ b/client/main.go @@ -0,0 +1,121 @@ +package main + +import ( +	"bytes" +	"encoding/json" +	"flag" +	"fmt" +	"io/ioutil" +	"log" +	"net" +	"net/http" +	"os" +	"time" + +	"github.com/nsheridan/cashier/lib" +	"golang.org/x/crypto/ssh" +	"golang.org/x/crypto/ssh/agent" +) + +var ( +	url      = flag.String("url", "http://localhost:10000/sign", "Signing URL") +	keybits  = flag.Int("bits", 4096, "Key size") +	validity = flag.Duration("validity", time.Hour*24, "Key validity") +	keytype  = flag.String("key_type", "rsa", "Type of private key to generate - rsa or ecdsa") +) + +func installCert(a agent.Agent, cert *ssh.Certificate, key key) error { +	pubcert := agent.AddedKey{ +		PrivateKey:  key, +		Certificate: cert, +		Comment:     cert.KeyId, +	} +	if err := a.Add(pubcert); err != nil { +		return fmt.Errorf("error importing certificate: %s", err) +	} +	return nil +} + +func send(s []byte, token string) (*lib.SignResponse, error) { +	req, err := http.NewRequest("POST", *url, bytes.NewReader(s)) +	if err != nil { +		return nil, err +	} +	req.Header.Set("Content-Type", "application/json") +	req.Header.Add("Accept", "application/json") +	req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) +	client := &http.Client{} +	resp, err := client.Do(req) +	if resp.StatusCode != http.StatusOK { +		return nil, fmt.Errorf("Bad response from server: %s", resp.Status) +	} +	if err != nil { +		return nil, err +	} +	defer resp.Body.Close() +	body, err := ioutil.ReadAll(resp.Body) +	if err != nil { +		return nil, err +	} +	c := &lib.SignResponse{} +	if err := json.Unmarshal(body, c); err != nil { +		return nil, err +	} +	return c, nil +} + +func sign(pub ssh.PublicKey, token string) (*ssh.Certificate, error) { +	marshaled := ssh.MarshalAuthorizedKey(pub) +	marshaled = marshaled[:len(marshaled)-1] +	s, err := json.Marshal(&lib.SignRequest{ +		Key:        string(marshaled), +		ValidUntil: time.Now().Add(*validity), +	}) +	if err != nil { +		return nil, err +	} +	resp, err := send(s, token) +	if err != nil { +		return nil, err +	} +	if resp.Status != "ok" { +		return nil, fmt.Errorf("error: %s", resp.Response) +	} +	k, _, _, _, err := ssh.ParseAuthorizedKey([]byte(resp.Response)) +	if err != nil { +		return nil, err +	} +	cert, ok := k.(*ssh.Certificate) +	if !ok { +		return nil, fmt.Errorf("did not receive a certificate from server") +	} +	return cert, nil +} + +func main() { +	flag.Parse() + +	priv, pub, err := generateKey(*keytype, *keybits) +	if err != nil { +		log.Fatalln("Error generating key pair: ", err) +	} + +	fmt.Print("Enter token: ") +	var token string +	fmt.Scanln(&token) + +	cert, err := sign(pub, token) +	if err != nil { +		log.Fatalln(err) +	} +	sock, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) +	if err != nil { +		log.Fatalln("Error connecting to agent: %s", err) +	} +	defer sock.Close() +	a := agent.NewClient(sock) +	if err := installCert(a, cert, priv); err != nil { +		log.Fatalln(err) +	} +	fmt.Println("Certificate added.") +} diff --git a/exampleconfig.json b/exampleconfig.json new file mode 100644 index 0000000..97d3af5 --- /dev/null +++ b/exampleconfig.json @@ -0,0 +1,24 @@ +{ +  "server": { +    "tls_key": "server.key", +    "tls_cert": "server.crt", +    "port": 443, +    "cookie_secret": "supersecret" +  }, +  "auth": { +    "provider": "google", +    "oauth_client_id": "nnnnnnnnnnnnnnnn.apps.googleusercontent.com", +    "oauth_client_secret": "yyyyyyyyyyyyyyyyyyyyyy", +    "oauth_callback_url": "https://sshca.example.com/auth/callback", +    "google_opts": { +      "domain": "example.com" +    }, +    "jwt_signing_key": "supersecret" +  }, +  "ssh": { +    "signing_key": "signing_key", +    "additional_principals": ["ec2-user"], +    "max_age": "720h", +    "permissions": ["permit-pty"] +  } +} diff --git a/lib/const.go b/lib/const.go new file mode 100644 index 0000000..fd771a0 --- /dev/null +++ b/lib/const.go @@ -0,0 +1,18 @@ +package lib + +import "time" + +// SignRequest represents a signing request sent to the server. +type SignRequest struct { +	Key        string    `json:"key"` +	Principal  string    `json:"principal"` +	ValidUntil time.Time `json:"valid_until"` +} + +// SignResponse is sent by the server. +// `Status' is "ok" or "error". +// `Response' contains a signed certificate or an error message. +type SignResponse struct { +	Status   string `json:"status"` +	Response string `json:"response"` +} diff --git a/server/auth/google/google.go b/server/auth/google/google.go new file mode 100644 index 0000000..9944d58 --- /dev/null +++ b/server/auth/google/google.go @@ -0,0 +1,100 @@ +package google + +import ( +	"fmt" +	"net/http" +	"strings" + +	"github.com/nsheridan/cashier/server/auth" +	"github.com/nsheridan/cashier/server/config" + +	"golang.org/x/oauth2" +	"golang.org/x/oauth2/google" +	googleapi "google.golang.org/api/oauth2/v2" +) + +const ( +	revokeURL = "https://accounts.google.com/o/oauth2/revoke?token=%s" +	name      = "google" +) + +type Config struct { +	config *oauth2.Config +	domain string +} + +func New(c config.Auth) auth.Provider { +	return &Config{ +		config: &oauth2.Config{ +			ClientID:     c.OauthClientID, +			ClientSecret: c.OauthClientSecret, +			RedirectURL:  c.OauthCallbackURL, +			Endpoint:     google.Endpoint, +			Scopes:       []string{googleapi.UserinfoEmailScope, googleapi.UserinfoProfileScope}, +		}, +		domain: c.GoogleOpts["domain"].(string), +	} +} + +func (c *Config) newClient(token *oauth2.Token) *http.Client { +	return c.config.Client(oauth2.NoContext, token) +} + +func (c *Config) Name() string { +	return name +} + +func (c *Config) Valid(token *oauth2.Token) bool { +	if !token.Valid() { +		return false +	} +	svc, err := googleapi.New(c.newClient(token)) +	if err != nil { +		return false +	} +	t := svc.Tokeninfo() +	t.AccessToken(token.AccessToken) +	ti, err := t.Do() +	if err != nil { +		return false +	} +	ui, err := svc.Userinfo.Get().Do() +	if err != nil { +		return false +	} +	switch { +	case ti.Audience != c.config.ClientID: +	case ui.Hd != c.domain: +		return false +	} +	return true +} + +func (c *Config) Revoke(token *oauth2.Token) error { +	h := c.newClient(token) +	_, err := h.Get(fmt.Sprintf(revokeURL, token.AccessToken)) +	return err +} + +func (c *Config) StartSession(state string) *auth.Session { +	return &auth.Session{ +		AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)), +		State:   state, +	} +} + +func (c *Config) Exchange(code string) (*oauth2.Token, error) { +	return c.config.Exchange(oauth2.NoContext, code) +} + +func (c *Config) Username(token *oauth2.Token) string { +	svc, err := googleapi.New(c.newClient(token)) +	if err != nil { +		return "" +	} +	ui, err := svc.Userinfo.Get().Do() +	if err != nil { +		return "" +	} +	return strings.Split(ui.Email, "@")[0] +} diff --git a/server/auth/provider.go b/server/auth/provider.go new file mode 100644 index 0000000..ae512bd --- /dev/null +++ b/server/auth/provider.go @@ -0,0 +1,27 @@ +package auth + +import "golang.org/x/oauth2" + +type Provider interface { +	Name() string +	StartSession(string) *Session +	Exchange(string) (*oauth2.Token, error) +	Username(*oauth2.Token) string +	Valid(*oauth2.Token) bool +	Revoke(*oauth2.Token) error +} + +type Session struct { +	AuthURL string +	Token   *oauth2.Token +	State   string +} + +func (s *Session) Authorize(provider Provider, code string) error { +	t, err := provider.Exchange(code) +	if err != nil { +		return err +	} +	s.Token = t +	return nil +} diff --git a/server/config/config.go b/server/config/config.go new file mode 100644 index 0000000..b65d171 --- /dev/null +++ b/server/config/config.go @@ -0,0 +1,51 @@ +package config + +import "github.com/spf13/viper" + +// Config holds the values from the json config file. +type Config struct { +	Server Server `mapstructure:"server"` +	Auth   Auth   `mapstructure:"auth"` +	SSH    SSH    `mapstructure:"ssh"` +} + +// Server holds the configuration specific to the web server and sessions. +type Server struct { +	UseTLS       bool   `mapstructure:"use_tls"` +	TLSKey       string `mapstructure:"tls_key"` +	TLSCert      string `mapstructure:"tls_cert"` +	Port         int    `mapstructure:"port"` +	CookieSecret string `mapstructure:"cookie_secret"` +} + +// Auth holds the configuration specific to the OAuth provider. +type Auth struct { +	OauthClientID     string                 `mapstructure:"oauth_client_id"` +	OauthClientSecret string                 `mapstructure:"oauth_client_secret"` +	OauthCallbackURL  string                 `mapstructure:"oauth_callback_url"` +	Provider          string                 `mapstructure:"provider"` +	GoogleOpts        map[string]interface{} `mapstructure:"google_opts"` +	JWTSigningKey     string                 `mapstructure:"jwt_signing_key"` +} + +// SSH holds the configuration specific to signing ssh keys. +type SSH struct { +	SigningKey  string   `mapstructure:"signing_key"` +	Principals  []string `mapstructure:"principals"` +	MaxAge      string   `mapstructure:"max_age"` +	Permissions []string `mapstructure:"permissions"` +} + +// ReadConfig parses a JSON configuration file into a Config struct. +func ReadConfig(filename string) (*Config, error) { +	config := &Config{} +	v := viper.New() +	v.SetConfigFile(filename) +	if err := v.ReadInConfig(); err != nil { +		return nil, err +	} +	if err := v.Unmarshal(config); err != nil { +		return nil, err +	} +	return config, nil +} diff --git a/server/main.go b/server/main.go new file mode 100644 index 0000000..0125ca8 --- /dev/null +++ b/server/main.go @@ -0,0 +1,218 @@ +package main + +import ( +	"crypto/rand" +	"encoding/hex" +	"encoding/json" +	"errors" +	"flag" +	"fmt" +	"html/template" +	"io" +	"io/ioutil" +	"log" +	"net/http" +	"time" + +	"golang.org/x/oauth2" + +	"github.com/dgrijalva/jwt-go" +	"github.com/gorilla/mux" +	"github.com/gorilla/sessions" +	"github.com/nsheridan/cashier/lib" +	"github.com/nsheridan/cashier/server/auth" +	"github.com/nsheridan/cashier/server/auth/google" +	"github.com/nsheridan/cashier/server/config" +	"github.com/nsheridan/cashier/server/signer" +) + +var ( +	cfg = flag.String("config_file", "config.json", "Path to configuration file.") +) + +type appContext struct { +	cookiestore   *sessions.CookieStore +	authprovider  auth.Provider +	authsession   *auth.Session +	views         *template.Template +	sshKeySigner  *signer.KeySigner +	jwtSigningKey []byte +} + +func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token { +	session, _ := a.cookiestore.Get(r, "tok") +	t, ok := session.Values["token"] +	if !ok { +		return nil +	} +	var tok oauth2.Token +	if err := json.Unmarshal(t.([]byte), &tok); err != nil { +		return nil +	} +	if !a.authprovider.Valid(&tok) { +		return nil +	} +	return &tok +} + +func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) { +	session, _ := a.cookiestore.Get(r, "tok") +	val, _ := json.Marshal(t) +	session.Values["token"] = val +	session.Save(r, w) +} + +func parseKey(r *http.Request) (*lib.SignRequest, error) { +	var s lib.SignRequest +	body, err := ioutil.ReadAll(r.Body) +	if err != nil { +		return nil, err +	} +	if err := json.Unmarshal(body, &s); err != nil { +		return nil, err +	} +	return &s, nil +} +func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { +	jwtoken, err := jwt.ParseFromRequest(r, func(t *jwt.Token) (interface{}, error) { +		return a.jwtSigningKey, nil +	}) +	if err != nil { +		return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) +	} +	if !jwtoken.Valid { +		log.Printf("Token %v not valid", jwtoken) +		return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) +	} +	expiry := int64(jwtoken.Claims["exp"].(float64)) +	token := &oauth2.Token{ +		AccessToken: jwtoken.Claims["token"].(string), +		Expiry:      time.Unix(expiry, 0), +	} +	ok := a.authprovider.Valid(token) +	if !ok { +		return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) +	} +	// finally sign the pubkey and issue the cert. +	req, err := parseKey(r) +	req.Principal = a.authprovider.Username(token) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	signed, err := a.sshKeySigner.Sign(req) +	a.authprovider.Revoke(token) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	json.NewEncoder(w).Encode(&lib.SignResponse{ +		Status:   "ok", +		Response: signed, +	}) +	return http.StatusOK, nil +} + +func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { +	a.authsession = a.authprovider.StartSession(hex.EncodeToString(random(32))) +	http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound) +	return http.StatusFound, nil +} + +func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { +	if r.FormValue("state") != a.authsession.State { +		return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized)) +	} +	code := r.FormValue("code") +	if err := a.authsession.Authorize(a.authprovider, code); err != nil { +		return http.StatusInternalServerError, err +	} +	a.setAuthCookie(w, r, a.authsession.Token) +	http.Redirect(w, r, "/", http.StatusFound) +	return http.StatusFound, nil +} + +func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) { +	tok := a.getAuthCookie(r) +	if !tok.Valid() { +		http.Redirect(w, r, "/auth/login", http.StatusSeeOther) +		return http.StatusSeeOther, nil +	} +	j := jwt.New(jwt.SigningMethodHS256) +	j.Claims["token"] = tok.AccessToken +	j.Claims["exp"] = tok.Expiry.Unix() +	t, err := j.SignedString(a.jwtSigningKey) +	if err != nil { +		return http.StatusInternalServerError, err +	} +	page := struct { +		Token string +	}{t} +	a.views.ExecuteTemplate(w, "token.html", page) +	return http.StatusOK, nil +} + +type appHandler struct { +	*appContext +	h func(*appContext, http.ResponseWriter, *http.Request) (int, error) +} + +func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +	status, err := ah.h(ah.appContext, w, r) +	if err != nil { +		log.Printf("HTTP %d: %q", status, err) +		switch status { +		case http.StatusNotFound: +			http.NotFound(w, r) +		case http.StatusInternalServerError: +			http.Error(w, http.StatusText(status), status) +		default: +			http.Error(w, http.StatusText(status), status) +		} +	} +} + +func random(length int) []byte { +	k := make([]byte, length) +	if _, err := io.ReadFull(rand.Reader, k); err != nil { +		return nil +	} +	return k +} + +func main() { +	flag.Parse() +	config, err := config.ReadConfig(*cfg) +	if err != nil { +		log.Fatal(err) +	} +	signer, err := signer.NewSigner(config.SSH) +	if err != nil { +		log.Fatal(err) +	} +	authprovider := google.New(config.Auth) +	ctx := &appContext{ +		cookiestore:   sessions.NewCookieStore([]byte(config.Server.CookieSecret)), +		authprovider:  authprovider, +		views:         template.Must(template.ParseGlob("templates/*")), +		sshKeySigner:  signer, +		jwtSigningKey: []byte(config.Auth.JWTSigningKey), +	} +	ctx.cookiestore.Options = &sessions.Options{ +		MaxAge:   900, +		Path:     "/", +		Secure:   config.Server.UseTLS, +		HttpOnly: true, +	} + +	m := mux.NewRouter() +	m.Handle("/", appHandler{ctx, rootHandler}) +	m.Handle("/auth/login", appHandler{ctx, loginHandler}) +	m.Handle("/auth/callback", appHandler{ctx, callbackHandler}) +	m.Handle("/sign", appHandler{ctx, signHandler}) + +	fmt.Println("Starting server...") +	l := fmt.Sprintf(":%d", config.Server.Port) +	if config.Server.UseTLS { +		log.Fatal(http.ListenAndServeTLS(l, config.Server.TLSCert, config.Server.TLSKey, m)) +	} +	log.Fatal(http.ListenAndServe(l, m)) +} diff --git a/server/signer/signer.go b/server/signer/signer.go new file mode 100644 index 0000000..4ae5058 --- /dev/null +++ b/server/signer/signer.go @@ -0,0 +1,85 @@ +package signer + +import ( +	"crypto/rand" +	"fmt" +	"io/ioutil" +	"time" + +	"github.com/nsheridan/cashier/lib" +	"github.com/nsheridan/cashier/server/config" +	"golang.org/x/crypto/ssh" +) + +type KeySigner struct { +	ca          ssh.Signer +	validity    time.Duration +	principals  []string +	permissions map[string]string +} + +func (s *KeySigner) Sign(req *lib.SignRequest) (string, error) { +	pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(req.Key)) +	if err != nil { +		return "", err +	} +	expires := time.Now().Add(s.validity) +	if req.ValidUntil.After(expires) { +		req.ValidUntil = expires +	} +	cert := &ssh.Certificate{ +		CertType:    ssh.UserCert, +		Key:         pubkey, +		KeyId:       req.Principal, +		ValidBefore: uint64(req.ValidUntil.Unix()), +		ValidAfter:  uint64(time.Now().Add(-5 * time.Minute).Unix()), +	} +	cert.ValidPrincipals = append(cert.ValidPrincipals, req.Principal) +	cert.ValidPrincipals = append(cert.ValidPrincipals, s.principals...) +	cert.Extensions = s.permissions +	if err := cert.SignCert(rand.Reader, s.ca); err != nil { +		return "", err +	} +	marshaled := ssh.MarshalAuthorizedKey(cert) +	// Remove the trailing newline. +	marshaled = marshaled[:len(marshaled)-1] +	return string(marshaled), nil +} + +func makeperms(perms []string) map[string]string { +	if len(perms) > 0 { +		m := make(map[string]string) +		for _, p := range perms { +			m[p] = "" +		} +		return m +	} +	return map[string]string{ +		"permit-X11-forwarding":   "", +		"permit-agent-forwarding": "", +		"permit-port-forwarding":  "", +		"permit-pty":              "", +		"permit-user-rc":          "", +	} +} + +func NewSigner(conf config.SSH) (*KeySigner, error) { +	data, err := ioutil.ReadFile(conf.SigningKey) +	if err != nil { +		return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err) +	} +	key, err := ssh.ParsePrivateKey(data) +	if err != nil { +		return nil, fmt.Errorf("unable to parse CA key: %v", err) +	} +	validity, err := time.ParseDuration(conf.MaxAge) +	if err != nil { +		return nil, fmt.Errorf("error parsing duration '%s': %v", conf.MaxAge, err) +	} +	return &KeySigner{ +		ca:          key, +		validity:    validity, +		principals:  conf.Principals, +		permissions: makeperms(conf.Permissions), +	}, nil +} diff --git a/templates/token.html b/templates/token.html new file mode 100644 index 0000000..8553c84 --- /dev/null +++ b/templates/token.html @@ -0,0 +1,46 @@ +<html> +  <head> +    <title>YOUR TOKEN!</title> +    <style> +      <!-- +      body { +        text-align: center; +        font-family: sans-serif; +        background-color: #edece4; +        margin-top: 120px; +      } +      .code { +        background-color: #26292B; +        border: none; +        color: #fff; +        font-family: monospace; +        font-size: 13; +        font-weight: bold; +        height: auto; +        margin: 12px 12px 12px 12px; +        padding: 12px 12px 12px 12px; +        resize: none; +        text-align: center; +        width: 960px; +      } +      ::selection { +        background: #32d0ff; +        color: #000; +      } +      ::-moz-selection { +        background: #32d0ff; +        color: #000; +      } +      --> +    </style> +  </head> +  <body> +    <h2> +      This is your token. There are many like it but this one is yours. +    </h2> +    <textarea class="code" readonly spellcheck="false" onclick="this.focus();this.select();">{{.Token}}</textarea> +    <h2> +      The token will expire in < 1 hour. +    </h2> +  </body> +</html>  | 
