csrf/middleware.go
2025-06-11 00:24:24 +05:00

83 lines
2.2 KiB
Go

package csrf
import (
"net/http"
"strings"
)
type Middleware struct {
// Name of the cookie with the first csrf token.
CookieName string
// `Secure` cookie attribute.
Secure bool
// `SameSite` cookie attribute.
SameSite http.SameSite
// Name of the form field with the second token.
FormFieldName string
// Name of the header with the second token.
HeaderName string
// Length of generated CSRF tokens in bytes (symbols).
TokenLength uint
}
// Create a new instance of CSRF middleware.
// You can tweak settings after initialization.
func New() Middleware {
return Middleware{
CookieName: "csrfmiddlewaretoken",
Secure: false,
SameSite: http.SameSiteLaxMode,
FormFieldName: "csrfmiddlewaretoken",
HeaderName: "X-CSRFToken",
TokenLength: 32,
}
}
// Middleware to prevent CSRF attacks.
func (m *Middleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenFromCookies := m.getTokenFromCookies(r)
if tokenFromCookies == "" {
// Generate a new token if one is not provided.
var err error
tokenFromCookies, err = generateCSRFToken(m.TokenLength)
if err != nil {
http.Error(w, "Cryptographic functions are unavailable", http.StatusInternalServerError)
return
}
m.setTokenCookie(w, tokenFromCookies)
}
// Embed CSRF token in context for request handlers.
ctx := setCSRFToken(r.Context(), tokenFromCookies)
r = r.WithContext(ctx)
if !isMethodSafe(r) {
// Enforce CSRF protection.
secondToken := m.extractSecondToken(r)
if tokenFromCookies != secondToken {
http.Error(w, "CSRF tokens do not match. Make sure you have cookies enabled.", http.StatusBadRequest)
return
}
}
next.ServeHTTP(w, r)
})
}
func (m *Middleware) extractSecondToken(r *http.Request) string {
token := r.FormValue(m.FormFieldName)
token = strings.TrimSpace(token)
if token != "" {
return token
}
token = r.Header.Get(m.HeaderName)
token = strings.TrimSpace(token)
return token
}
func isMethodSafe(r *http.Request) bool {
return r.Method == http.MethodGet || r.Method == http.MethodOptions || r.Method == http.MethodTrace || r.Method == http.MethodHead
}