211 lines
5.7 KiB
Go
211 lines
5.7 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"crypto/subtle"
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/labstack/echo"
|
||
|
"github.com/labstack/gommon/random"
|
||
|
)
|
||
|
|
||
|
type (
|
||
|
// CSRFConfig defines the config for CSRF middleware.
|
||
|
CSRFConfig struct {
|
||
|
// Skipper defines a function to skip middleware.
|
||
|
Skipper Skipper
|
||
|
|
||
|
// TokenLength is the length of the generated token.
|
||
|
TokenLength uint8 `yaml:"token_length"`
|
||
|
// Optional. Default value 32.
|
||
|
|
||
|
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||
|
// to extract token from the request.
|
||
|
// Optional. Default value "header:X-CSRF-Token".
|
||
|
// Possible values:
|
||
|
// - "header:<name>"
|
||
|
// - "form:<name>"
|
||
|
// - "query:<name>"
|
||
|
TokenLookup string `yaml:"token_lookup"`
|
||
|
|
||
|
// Context key to store generated CSRF token into context.
|
||
|
// Optional. Default value "csrf".
|
||
|
ContextKey string `yaml:"context_key"`
|
||
|
|
||
|
// Name of the CSRF cookie. This cookie will store CSRF token.
|
||
|
// Optional. Default value "csrf".
|
||
|
CookieName string `yaml:"cookie_name"`
|
||
|
|
||
|
// Domain of the CSRF cookie.
|
||
|
// Optional. Default value none.
|
||
|
CookieDomain string `yaml:"cookie_domain"`
|
||
|
|
||
|
// Path of the CSRF cookie.
|
||
|
// Optional. Default value none.
|
||
|
CookiePath string `yaml:"cookie_path"`
|
||
|
|
||
|
// Max age (in seconds) of the CSRF cookie.
|
||
|
// Optional. Default value 86400 (24hr).
|
||
|
CookieMaxAge int `yaml:"cookie_max_age"`
|
||
|
|
||
|
// Indicates if CSRF cookie is secure.
|
||
|
// Optional. Default value false.
|
||
|
CookieSecure bool `yaml:"cookie_secure"`
|
||
|
|
||
|
// Indicates if CSRF cookie is HTTP only.
|
||
|
// Optional. Default value false.
|
||
|
CookieHTTPOnly bool `yaml:"cookie_http_only"`
|
||
|
}
|
||
|
|
||
|
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
||
|
// either a token or an error.
|
||
|
csrfTokenExtractor func(echo.Context) (string, error)
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
||
|
DefaultCSRFConfig = CSRFConfig{
|
||
|
Skipper: DefaultSkipper,
|
||
|
TokenLength: 32,
|
||
|
TokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||
|
ContextKey: "csrf",
|
||
|
CookieName: "_csrf",
|
||
|
CookieMaxAge: 86400,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
|
||
|
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
|
||
|
func CSRF() echo.MiddlewareFunc {
|
||
|
c := DefaultCSRFConfig
|
||
|
return CSRFWithConfig(c)
|
||
|
}
|
||
|
|
||
|
// CSRFWithConfig returns a CSRF middleware with config.
|
||
|
// See `CSRF()`.
|
||
|
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||
|
// Defaults
|
||
|
if config.Skipper == nil {
|
||
|
config.Skipper = DefaultCSRFConfig.Skipper
|
||
|
}
|
||
|
if config.TokenLength == 0 {
|
||
|
config.TokenLength = DefaultCSRFConfig.TokenLength
|
||
|
}
|
||
|
if config.TokenLookup == "" {
|
||
|
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
||
|
}
|
||
|
if config.ContextKey == "" {
|
||
|
config.ContextKey = DefaultCSRFConfig.ContextKey
|
||
|
}
|
||
|
if config.CookieName == "" {
|
||
|
config.CookieName = DefaultCSRFConfig.CookieName
|
||
|
}
|
||
|
if config.CookieMaxAge == 0 {
|
||
|
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
|
||
|
}
|
||
|
|
||
|
// Initialize
|
||
|
parts := strings.Split(config.TokenLookup, ":")
|
||
|
extractor := csrfTokenFromHeader(parts[1])
|
||
|
switch parts[0] {
|
||
|
case "form":
|
||
|
extractor = csrfTokenFromForm(parts[1])
|
||
|
case "query":
|
||
|
extractor = csrfTokenFromQuery(parts[1])
|
||
|
}
|
||
|
|
||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||
|
return func(c echo.Context) error {
|
||
|
if config.Skipper(c) {
|
||
|
return next(c)
|
||
|
}
|
||
|
|
||
|
req := c.Request()
|
||
|
k, err := c.Cookie(config.CookieName)
|
||
|
token := ""
|
||
|
|
||
|
// Generate token
|
||
|
if err != nil {
|
||
|
token = random.String(config.TokenLength)
|
||
|
} else {
|
||
|
// Reuse token
|
||
|
token = k.Value
|
||
|
}
|
||
|
|
||
|
switch req.Method {
|
||
|
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
|
||
|
default:
|
||
|
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
||
|
clientToken, err := extractor(c)
|
||
|
if err != nil {
|
||
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||
|
}
|
||
|
if !validateCSRFToken(token, clientToken) {
|
||
|
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Set CSRF cookie
|
||
|
cookie := new(http.Cookie)
|
||
|
cookie.Name = config.CookieName
|
||
|
cookie.Value = token
|
||
|
if config.CookiePath != "" {
|
||
|
cookie.Path = config.CookiePath
|
||
|
}
|
||
|
if config.CookieDomain != "" {
|
||
|
cookie.Domain = config.CookieDomain
|
||
|
}
|
||
|
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
|
||
|
cookie.Secure = config.CookieSecure
|
||
|
cookie.HttpOnly = config.CookieHTTPOnly
|
||
|
c.SetCookie(cookie)
|
||
|
|
||
|
// Store token in the context
|
||
|
c.Set(config.ContextKey, token)
|
||
|
|
||
|
// Protect clients from caching the response
|
||
|
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
|
||
|
|
||
|
return next(c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||
|
// provided request header.
|
||
|
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
||
|
return func(c echo.Context) (string, error) {
|
||
|
return c.Request().Header.Get(header), nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||
|
// provided form parameter.
|
||
|
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
||
|
return func(c echo.Context) (string, error) {
|
||
|
token := c.FormValue(param)
|
||
|
if token == "" {
|
||
|
return "", errors.New("missing csrf token in the form parameter")
|
||
|
}
|
||
|
return token, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
||
|
// provided query parameter.
|
||
|
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
||
|
return func(c echo.Context) (string, error) {
|
||
|
token := c.QueryParam(param)
|
||
|
if token == "" {
|
||
|
return "", errors.New("missing csrf token in the query string")
|
||
|
}
|
||
|
return token, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func validateCSRFToken(token, clientToken string) bool {
|
||
|
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
||
|
}
|