83 lines
2.2 KiB
Go
83 lines
2.2 KiB
Go
package csrf
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
type Middleware struct {
|
|
// Name of the cookie with the first csrf token.
|
|
CookieName string
|
|
// `Secure` cookie attribute.
|
|
Secure bool
|
|
// `SameSite` cookie attribute.
|
|
SameSite http.SameSite
|
|
// Name of the form field with the second token.
|
|
FormFieldName string
|
|
// Name of the header with the second token.
|
|
HeaderName string
|
|
// Length of generated CSRF tokens in bytes (symbols).
|
|
TokenLength uint
|
|
}
|
|
|
|
// Create a new instance of CSRF middleware.
|
|
// You can tweak settings after initialization.
|
|
func New() Middleware {
|
|
return Middleware{
|
|
CookieName: "csrfmiddlewaretoken",
|
|
Secure: false,
|
|
SameSite: http.SameSiteLaxMode,
|
|
FormFieldName: "csrfmiddlewaretoken",
|
|
HeaderName: "X-CSRFToken",
|
|
TokenLength: 32,
|
|
}
|
|
}
|
|
|
|
// Middleware to prevent CSRF attacks.
|
|
func (m *Middleware) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenFromCookies := m.getTokenFromCookies(r)
|
|
if tokenFromCookies == "" {
|
|
// Generate a new token if one is not provided.
|
|
var err error
|
|
tokenFromCookies, err = generateCSRFToken(m.TokenLength)
|
|
if err != nil {
|
|
http.Error(w, "Cryptographic functions are unavailable", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
m.setTokenCookie(w, tokenFromCookies)
|
|
}
|
|
|
|
// Embed CSRF token in context for request handlers.
|
|
ctx := setCSRFToken(r.Context(), tokenFromCookies)
|
|
r = r.WithContext(ctx)
|
|
|
|
if !isMethodSafe(r) {
|
|
// Enforce CSRF protection.
|
|
secondToken := m.extractSecondToken(r)
|
|
if tokenFromCookies != secondToken {
|
|
http.Error(w, "CSRF tokens do not match. Make sure you have cookies enabled.", http.StatusBadRequest)
|
|
return
|
|
}
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) extractSecondToken(r *http.Request) string {
|
|
token := r.FormValue(m.FormFieldName)
|
|
token = strings.TrimSpace(token)
|
|
if token != "" {
|
|
return token
|
|
}
|
|
|
|
token = r.Header.Get(m.HeaderName)
|
|
token = strings.TrimSpace(token)
|
|
return token
|
|
}
|
|
|
|
func isMethodSafe(r *http.Request) bool {
|
|
return r.Method == http.MethodGet || r.Method == http.MethodOptions || r.Method == http.MethodTrace || r.Method == http.MethodHead
|
|
}
|