package csrf import ( "net/http" "strings" ) type Middleware struct { // Name of the cookie with the first csrf token. CookieName string // `Secure` cookie attribute. Secure bool // `SameSite` cookie attribute. SameSite http.SameSite // Name of the form field with the second token. FormFieldName string // Name of the header with the second token. HeaderName string // Length of generated CSRF tokens in bytes (symbols). TokenLength uint } // Create a new instance of CSRF middleware. // You can tweak settings after initialization. func New() Middleware { return Middleware{ CookieName: "csrfmiddlewaretoken", Secure: false, SameSite: http.SameSiteLaxMode, FormFieldName: "csrfmiddlewaretoken", HeaderName: "X-CSRFToken", TokenLength: 32, } } // Middleware to prevent CSRF attacks. func (m *Middleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenFromCookies := m.getTokenFromCookies(r) if tokenFromCookies == "" { // Generate a new token if one is not provided. var err error tokenFromCookies, err = generateCSRFToken(m.TokenLength) if err != nil { http.Error(w, "Cryptographic functions are unavailable", http.StatusInternalServerError) return } m.setTokenCookie(w, tokenFromCookies) } // Embed CSRF token in context for request handlers. ctx := setCSRFToken(r.Context(), tokenFromCookies) r = r.WithContext(ctx) if !isMethodSafe(r) { // Enforce CSRF protection. secondToken := m.extractSecondToken(r) if tokenFromCookies != secondToken { http.Error(w, "CSRF tokens do not match. Make sure you have cookies enabled.", http.StatusBadRequest) return } } next.ServeHTTP(w, r) }) } func (m *Middleware) extractSecondToken(r *http.Request) string { token := r.FormValue(m.FormFieldName) token = strings.TrimSpace(token) if token != "" { return token } token = r.Header.Get(m.HeaderName) token = strings.TrimSpace(token) return token } func isMethodSafe(r *http.Request) bool { return r.Method == http.MethodGet || r.Method == http.MethodOptions || r.Method == http.MethodTrace || r.Method == http.MethodHead }