aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/auth/google/google.go100
-rw-r--r--server/auth/provider.go27
-rw-r--r--server/config/config.go51
-rw-r--r--server/main.go218
-rw-r--r--server/signer/signer.go85
5 files changed, 481 insertions, 0 deletions
diff --git a/server/auth/google/google.go b/server/auth/google/google.go
new file mode 100644
index 0000000..9944d58
--- /dev/null
+++ b/server/auth/google/google.go
@@ -0,0 +1,100 @@
+package google
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/config"
+
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/google"
+ googleapi "google.golang.org/api/oauth2/v2"
+)
+
+const (
+ revokeURL = "https://accounts.google.com/o/oauth2/revoke?token=%s"
+ name = "google"
+)
+
+type Config struct {
+ config *oauth2.Config
+ domain string
+}
+
+func New(c config.Auth) auth.Provider {
+ return &Config{
+ config: &oauth2.Config{
+ ClientID: c.OauthClientID,
+ ClientSecret: c.OauthClientSecret,
+ RedirectURL: c.OauthCallbackURL,
+ Endpoint: google.Endpoint,
+ Scopes: []string{googleapi.UserinfoEmailScope, googleapi.UserinfoProfileScope},
+ },
+ domain: c.GoogleOpts["domain"].(string),
+ }
+}
+
+func (c *Config) newClient(token *oauth2.Token) *http.Client {
+ return c.config.Client(oauth2.NoContext, token)
+}
+
+func (c *Config) Name() string {
+ return name
+}
+
+func (c *Config) Valid(token *oauth2.Token) bool {
+ if !token.Valid() {
+ return false
+ }
+ svc, err := googleapi.New(c.newClient(token))
+ if err != nil {
+ return false
+ }
+ t := svc.Tokeninfo()
+ t.AccessToken(token.AccessToken)
+ ti, err := t.Do()
+ if err != nil {
+ return false
+ }
+ ui, err := svc.Userinfo.Get().Do()
+ if err != nil {
+ return false
+ }
+ switch {
+ case ti.Audience != c.config.ClientID:
+ case ui.Hd != c.domain:
+ return false
+ }
+ return true
+}
+
+func (c *Config) Revoke(token *oauth2.Token) error {
+ h := c.newClient(token)
+ _, err := h.Get(fmt.Sprintf(revokeURL, token.AccessToken))
+ return err
+}
+
+func (c *Config) StartSession(state string) *auth.Session {
+ return &auth.Session{
+ AuthURL: c.config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.domain)),
+ State: state,
+ }
+}
+
+func (c *Config) Exchange(code string) (*oauth2.Token, error) {
+ return c.config.Exchange(oauth2.NoContext, code)
+}
+
+func (c *Config) Username(token *oauth2.Token) string {
+ svc, err := googleapi.New(c.newClient(token))
+ if err != nil {
+ return ""
+ }
+ ui, err := svc.Userinfo.Get().Do()
+ if err != nil {
+ return ""
+ }
+ return strings.Split(ui.Email, "@")[0]
+}
diff --git a/server/auth/provider.go b/server/auth/provider.go
new file mode 100644
index 0000000..ae512bd
--- /dev/null
+++ b/server/auth/provider.go
@@ -0,0 +1,27 @@
+package auth
+
+import "golang.org/x/oauth2"
+
+type Provider interface {
+ Name() string
+ StartSession(string) *Session
+ Exchange(string) (*oauth2.Token, error)
+ Username(*oauth2.Token) string
+ Valid(*oauth2.Token) bool
+ Revoke(*oauth2.Token) error
+}
+
+type Session struct {
+ AuthURL string
+ Token *oauth2.Token
+ State string
+}
+
+func (s *Session) Authorize(provider Provider, code string) error {
+ t, err := provider.Exchange(code)
+ if err != nil {
+ return err
+ }
+ s.Token = t
+ return nil
+}
diff --git a/server/config/config.go b/server/config/config.go
new file mode 100644
index 0000000..b65d171
--- /dev/null
+++ b/server/config/config.go
@@ -0,0 +1,51 @@
+package config
+
+import "github.com/spf13/viper"
+
+// Config holds the values from the json config file.
+type Config struct {
+ Server Server `mapstructure:"server"`
+ Auth Auth `mapstructure:"auth"`
+ SSH SSH `mapstructure:"ssh"`
+}
+
+// Server holds the configuration specific to the web server and sessions.
+type Server struct {
+ UseTLS bool `mapstructure:"use_tls"`
+ TLSKey string `mapstructure:"tls_key"`
+ TLSCert string `mapstructure:"tls_cert"`
+ Port int `mapstructure:"port"`
+ CookieSecret string `mapstructure:"cookie_secret"`
+}
+
+// Auth holds the configuration specific to the OAuth provider.
+type Auth struct {
+ OauthClientID string `mapstructure:"oauth_client_id"`
+ OauthClientSecret string `mapstructure:"oauth_client_secret"`
+ OauthCallbackURL string `mapstructure:"oauth_callback_url"`
+ Provider string `mapstructure:"provider"`
+ GoogleOpts map[string]interface{} `mapstructure:"google_opts"`
+ JWTSigningKey string `mapstructure:"jwt_signing_key"`
+}
+
+// SSH holds the configuration specific to signing ssh keys.
+type SSH struct {
+ SigningKey string `mapstructure:"signing_key"`
+ Principals []string `mapstructure:"principals"`
+ MaxAge string `mapstructure:"max_age"`
+ Permissions []string `mapstructure:"permissions"`
+}
+
+// ReadConfig parses a JSON configuration file into a Config struct.
+func ReadConfig(filename string) (*Config, error) {
+ config := &Config{}
+ v := viper.New()
+ v.SetConfigFile(filename)
+ if err := v.ReadInConfig(); err != nil {
+ return nil, err
+ }
+ if err := v.Unmarshal(config); err != nil {
+ return nil, err
+ }
+ return config, nil
+}
diff --git a/server/main.go b/server/main.go
new file mode 100644
index 0000000..0125ca8
--- /dev/null
+++ b/server/main.go
@@ -0,0 +1,218 @@
+package main
+
+import (
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "flag"
+ "fmt"
+ "html/template"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "time"
+
+ "golang.org/x/oauth2"
+
+ "github.com/dgrijalva/jwt-go"
+ "github.com/gorilla/mux"
+ "github.com/gorilla/sessions"
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/auth"
+ "github.com/nsheridan/cashier/server/auth/google"
+ "github.com/nsheridan/cashier/server/config"
+ "github.com/nsheridan/cashier/server/signer"
+)
+
+var (
+ cfg = flag.String("config_file", "config.json", "Path to configuration file.")
+)
+
+type appContext struct {
+ cookiestore *sessions.CookieStore
+ authprovider auth.Provider
+ authsession *auth.Session
+ views *template.Template
+ sshKeySigner *signer.KeySigner
+ jwtSigningKey []byte
+}
+
+func (a *appContext) getAuthCookie(r *http.Request) *oauth2.Token {
+ session, _ := a.cookiestore.Get(r, "tok")
+ t, ok := session.Values["token"]
+ if !ok {
+ return nil
+ }
+ var tok oauth2.Token
+ if err := json.Unmarshal(t.([]byte), &tok); err != nil {
+ return nil
+ }
+ if !a.authprovider.Valid(&tok) {
+ return nil
+ }
+ return &tok
+}
+
+func (a *appContext) setAuthCookie(w http.ResponseWriter, r *http.Request, t *oauth2.Token) {
+ session, _ := a.cookiestore.Get(r, "tok")
+ val, _ := json.Marshal(t)
+ session.Values["token"] = val
+ session.Save(r, w)
+}
+
+func parseKey(r *http.Request) (*lib.SignRequest, error) {
+ var s lib.SignRequest
+ body, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ return nil, err
+ }
+ if err := json.Unmarshal(body, &s); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ jwtoken, err := jwt.ParseFromRequest(r, func(t *jwt.Token) (interface{}, error) {
+ return a.jwtSigningKey, nil
+ })
+ if err != nil {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ if !jwtoken.Valid {
+ log.Printf("Token %v not valid", jwtoken)
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ expiry := int64(jwtoken.Claims["exp"].(float64))
+ token := &oauth2.Token{
+ AccessToken: jwtoken.Claims["token"].(string),
+ Expiry: time.Unix(expiry, 0),
+ }
+ ok := a.authprovider.Valid(token)
+ if !ok {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ // finally sign the pubkey and issue the cert.
+ req, err := parseKey(r)
+ req.Principal = a.authprovider.Username(token)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ signed, err := a.sshKeySigner.Sign(req)
+ a.authprovider.Revoke(token)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ json.NewEncoder(w).Encode(&lib.SignResponse{
+ Status: "ok",
+ Response: signed,
+ })
+ return http.StatusOK, nil
+}
+
+func loginHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ a.authsession = a.authprovider.StartSession(hex.EncodeToString(random(32)))
+ http.Redirect(w, r, a.authsession.AuthURL, http.StatusFound)
+ return http.StatusFound, nil
+}
+
+func callbackHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ if r.FormValue("state") != a.authsession.State {
+ return http.StatusUnauthorized, errors.New(http.StatusText(http.StatusUnauthorized))
+ }
+ code := r.FormValue("code")
+ if err := a.authsession.Authorize(a.authprovider, code); err != nil {
+ return http.StatusInternalServerError, err
+ }
+ a.setAuthCookie(w, r, a.authsession.Token)
+ http.Redirect(w, r, "/", http.StatusFound)
+ return http.StatusFound, nil
+}
+
+func rootHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, error) {
+ tok := a.getAuthCookie(r)
+ if !tok.Valid() {
+ http.Redirect(w, r, "/auth/login", http.StatusSeeOther)
+ return http.StatusSeeOther, nil
+ }
+ j := jwt.New(jwt.SigningMethodHS256)
+ j.Claims["token"] = tok.AccessToken
+ j.Claims["exp"] = tok.Expiry.Unix()
+ t, err := j.SignedString(a.jwtSigningKey)
+ if err != nil {
+ return http.StatusInternalServerError, err
+ }
+ page := struct {
+ Token string
+ }{t}
+ a.views.ExecuteTemplate(w, "token.html", page)
+ return http.StatusOK, nil
+}
+
+type appHandler struct {
+ *appContext
+ h func(*appContext, http.ResponseWriter, *http.Request) (int, error)
+}
+
+func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ status, err := ah.h(ah.appContext, w, r)
+ if err != nil {
+ log.Printf("HTTP %d: %q", status, err)
+ switch status {
+ case http.StatusNotFound:
+ http.NotFound(w, r)
+ case http.StatusInternalServerError:
+ http.Error(w, http.StatusText(status), status)
+ default:
+ http.Error(w, http.StatusText(status), status)
+ }
+ }
+}
+
+func random(length int) []byte {
+ k := make([]byte, length)
+ if _, err := io.ReadFull(rand.Reader, k); err != nil {
+ return nil
+ }
+ return k
+}
+
+func main() {
+ flag.Parse()
+ config, err := config.ReadConfig(*cfg)
+ if err != nil {
+ log.Fatal(err)
+ }
+ signer, err := signer.NewSigner(config.SSH)
+ if err != nil {
+ log.Fatal(err)
+ }
+ authprovider := google.New(config.Auth)
+ ctx := &appContext{
+ cookiestore: sessions.NewCookieStore([]byte(config.Server.CookieSecret)),
+ authprovider: authprovider,
+ views: template.Must(template.ParseGlob("templates/*")),
+ sshKeySigner: signer,
+ jwtSigningKey: []byte(config.Auth.JWTSigningKey),
+ }
+ ctx.cookiestore.Options = &sessions.Options{
+ MaxAge: 900,
+ Path: "/",
+ Secure: config.Server.UseTLS,
+ HttpOnly: true,
+ }
+
+ m := mux.NewRouter()
+ m.Handle("/", appHandler{ctx, rootHandler})
+ m.Handle("/auth/login", appHandler{ctx, loginHandler})
+ m.Handle("/auth/callback", appHandler{ctx, callbackHandler})
+ m.Handle("/sign", appHandler{ctx, signHandler})
+
+ fmt.Println("Starting server...")
+ l := fmt.Sprintf(":%d", config.Server.Port)
+ if config.Server.UseTLS {
+ log.Fatal(http.ListenAndServeTLS(l, config.Server.TLSCert, config.Server.TLSKey, m))
+ }
+ log.Fatal(http.ListenAndServe(l, m))
+}
diff --git a/server/signer/signer.go b/server/signer/signer.go
new file mode 100644
index 0000000..4ae5058
--- /dev/null
+++ b/server/signer/signer.go
@@ -0,0 +1,85 @@
+package signer
+
+import (
+ "crypto/rand"
+ "fmt"
+ "io/ioutil"
+ "time"
+
+ "github.com/nsheridan/cashier/lib"
+ "github.com/nsheridan/cashier/server/config"
+ "golang.org/x/crypto/ssh"
+)
+
+type KeySigner struct {
+ ca ssh.Signer
+ validity time.Duration
+ principals []string
+ permissions map[string]string
+}
+
+func (s *KeySigner) Sign(req *lib.SignRequest) (string, error) {
+ pubkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(req.Key))
+ if err != nil {
+ return "", err
+ }
+ expires := time.Now().Add(s.validity)
+ if req.ValidUntil.After(expires) {
+ req.ValidUntil = expires
+ }
+ cert := &ssh.Certificate{
+ CertType: ssh.UserCert,
+ Key: pubkey,
+ KeyId: req.Principal,
+ ValidBefore: uint64(req.ValidUntil.Unix()),
+ ValidAfter: uint64(time.Now().Add(-5 * time.Minute).Unix()),
+ }
+ cert.ValidPrincipals = append(cert.ValidPrincipals, req.Principal)
+ cert.ValidPrincipals = append(cert.ValidPrincipals, s.principals...)
+ cert.Extensions = s.permissions
+ if err := cert.SignCert(rand.Reader, s.ca); err != nil {
+ return "", err
+ }
+ marshaled := ssh.MarshalAuthorizedKey(cert)
+ // Remove the trailing newline.
+ marshaled = marshaled[:len(marshaled)-1]
+ return string(marshaled), nil
+}
+
+func makeperms(perms []string) map[string]string {
+ if len(perms) > 0 {
+ m := make(map[string]string)
+ for _, p := range perms {
+ m[p] = ""
+ }
+ return m
+ }
+ return map[string]string{
+ "permit-X11-forwarding": "",
+ "permit-agent-forwarding": "",
+ "permit-port-forwarding": "",
+ "permit-pty": "",
+ "permit-user-rc": "",
+ }
+}
+
+func NewSigner(conf config.SSH) (*KeySigner, error) {
+ data, err := ioutil.ReadFile(conf.SigningKey)
+ if err != nil {
+ return nil, fmt.Errorf("unable to read CA key %s: %v", conf.SigningKey, err)
+ }
+ key, err := ssh.ParsePrivateKey(data)
+ if err != nil {
+ return nil, fmt.Errorf("unable to parse CA key: %v", err)
+ }
+ validity, err := time.ParseDuration(conf.MaxAge)
+ if err != nil {
+ return nil, fmt.Errorf("error parsing duration '%s': %v", conf.MaxAge, err)
+ }
+ return &KeySigner{
+ ca: key,
+ validity: validity,
+ principals: conf.Principals,
+ permissions: makeperms(conf.Permissions),
+ }, nil
+}