aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/gorilla/csrf/helpers.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/csrf/helpers.go')
-rw-r--r--vendor/github.com/gorilla/csrf/helpers.go205
1 files changed, 205 insertions, 0 deletions
diff --git a/vendor/github.com/gorilla/csrf/helpers.go b/vendor/github.com/gorilla/csrf/helpers.go
new file mode 100644
index 0000000..7adb5ff
--- /dev/null
+++ b/vendor/github.com/gorilla/csrf/helpers.go
@@ -0,0 +1,205 @@
+package csrf
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "fmt"
+ "html/template"
+ "net/http"
+ "net/url"
+
+ "github.com/gorilla/context"
+)
+
+// Token returns a masked CSRF token ready for passing into HTML template or
+// a JSON response body. An empty token will be returned if the middleware
+// has not been applied (which will fail subsequent validation).
+func Token(r *http.Request) string {
+ if val, err := contextGet(r, tokenKey); err == nil {
+ if maskedToken, ok := val.(string); ok {
+ return maskedToken
+ }
+ }
+
+ return ""
+}
+
+// FailureReason makes CSRF validation errors available in the request context.
+// This is useful when you want to log the cause of the error or report it to
+// client.
+func FailureReason(r *http.Request) error {
+ if val, err := contextGet(r, errorKey); err == nil {
+ if err, ok := val.(error); ok {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// UnsafeSkipCheck will skip the CSRF check for any requests. This must be
+// called before the CSRF middleware.
+//
+// Note: You should not set this without otherwise securing the request from
+// CSRF attacks. The primary use-case for this function is to turn off CSRF
+// checks for non-browser clients using authorization tokens against your API.
+func UnsafeSkipCheck(r *http.Request) *http.Request {
+ return contextSave(r, skipCheckKey, true)
+}
+
+// TemplateField is a template helper for html/template that provides an <input> field
+// populated with a CSRF token.
+//
+// Example:
+//
+// // The following tag in our form.tmpl template:
+// {{ .csrfField }}
+//
+// // ... becomes:
+// <input type="hidden" name="gorilla.csrf.Token" value="<token>">
+//
+func TemplateField(r *http.Request) template.HTML {
+ if name, err := contextGet(r, formKey); err == nil {
+ fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
+ name, Token(r))
+
+ return template.HTML(fragment)
+ }
+
+ return template.HTML("")
+}
+
+// mask returns a unique-per-request token to mitigate the BREACH attack
+// as per http://breachattack.com/#mitigations
+//
+// The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
+// token and returning them together as a 64-byte slice. This effectively
+// randomises the token on a per-request basis without breaking multiple browser
+// tabs/windows.
+func mask(realToken []byte, r *http.Request) string {
+ otp, err := generateRandomBytes(tokenLength)
+ if err != nil {
+ return ""
+ }
+
+ // XOR the OTP with the real token to generate a masked token. Append the
+ // OTP to the front of the masked token to allow unmasking in the subsequent
+ // request.
+ return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
+}
+
+// unmask splits the issued token (one-time-pad + masked token) and returns the
+// unmasked request token for comparison.
+func unmask(issued []byte) []byte {
+ // Issued tokens are always masked and combined with the pad.
+ if len(issued) != tokenLength*2 {
+ return nil
+ }
+
+ // We now know the length of the byte slice.
+ otp := issued[tokenLength:]
+ masked := issued[:tokenLength]
+
+ // Unmask the token by XOR'ing it against the OTP used to mask it.
+ return xorToken(otp, masked)
+}
+
+// requestToken returns the issued token (pad + masked token) from the HTTP POST
+// body or HTTP header. It will return nil if the token fails to decode.
+func (cs *csrf) requestToken(r *http.Request) []byte {
+ // 1. Check the HTTP header first.
+ issued := r.Header.Get(cs.opts.RequestHeader)
+
+ // 2. Fall back to the POST (form) value.
+ if issued == "" {
+ issued = r.PostFormValue(cs.opts.FieldName)
+ }
+
+ // 3. Finally, fall back to the multipart form (if set).
+ if issued == "" && r.MultipartForm != nil {
+ vals := r.MultipartForm.Value[cs.opts.FieldName]
+
+ if len(vals) > 0 {
+ issued = vals[0]
+ }
+ }
+
+ // Decode the "issued" (pad + masked) token sent in the request. Return a
+ // nil byte slice on a decoding error (this will fail upstream).
+ decoded, err := base64.StdEncoding.DecodeString(issued)
+ if err != nil {
+ return nil
+ }
+
+ return decoded
+}
+
+// generateRandomBytes returns securely generated random bytes.
+// It will return an error if the system's secure random number generator
+// fails to function correctly.
+func generateRandomBytes(n int) ([]byte, error) {
+ b := make([]byte, n)
+ _, err := rand.Read(b)
+ // err == nil only if len(b) == n
+ if err != nil {
+ return nil, err
+ }
+
+ return b, nil
+
+}
+
+// sameOrigin returns true if URLs a and b share the same origin. The same
+// origin is defined as host (which includes the port) and scheme.
+func sameOrigin(a, b *url.URL) bool {
+ return (a.Scheme == b.Scheme && a.Host == b.Host)
+}
+
+// compare securely (constant-time) compares the unmasked token from the request
+// against the real token from the session.
+func compareTokens(a, b []byte) bool {
+ // This is required as subtle.ConstantTimeCompare does not check for equal
+ // lengths in Go versions prior to 1.3.
+ if len(a) != len(b) {
+ return false
+ }
+
+ return subtle.ConstantTimeCompare(a, b) == 1
+}
+
+// xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
+// will return a masked token if the base token is XOR'ed with a one-time-pad.
+// An unmasked token will be returned if a masked token is XOR'ed with the
+// one-time-pad used to mask it.
+func xorToken(a, b []byte) []byte {
+ n := len(a)
+ if len(b) < n {
+ n = len(b)
+ }
+
+ res := make([]byte, n)
+
+ for i := 0; i < n; i++ {
+ res[i] = a[i] ^ b[i]
+ }
+
+ return res
+}
+
+// contains is a helper function to check if a string exists in a slice - e.g.
+// whether a HTTP method exists in a list of safe methods.
+func contains(vals []string, s string) bool {
+ for _, v := range vals {
+ if v == s {
+ return true
+ }
+ }
+
+ return false
+}
+
+// envError stores a CSRF error in the request context.
+func envError(r *http.Request, err error) {
+ context.Set(r, errorKey, err)
+}