175 lines
5.2 KiB
Go
175 lines
5.2 KiB
Go
// Package middleware provides HTTP middleware for the MapleFile backend.
|
|
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/httperror"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/ratelimit"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
|
)
|
|
|
|
// RateLimitMiddleware provides rate limiting functionality for HTTP endpoints
|
|
type RateLimitMiddleware struct {
|
|
logger *zap.Logger
|
|
loginRateLimiter ratelimit.LoginRateLimiter
|
|
}
|
|
|
|
// NewRateLimitMiddleware creates a new rate limit middleware
|
|
func NewRateLimitMiddleware(logger *zap.Logger, loginRateLimiter ratelimit.LoginRateLimiter) *RateLimitMiddleware {
|
|
return &RateLimitMiddleware{
|
|
logger: logger.Named("RateLimitMiddleware"),
|
|
loginRateLimiter: loginRateLimiter,
|
|
}
|
|
}
|
|
|
|
// LoginRateLimit applies login-specific rate limiting to auth endpoints
|
|
// CWE-307: Protects against brute force attacks on authentication endpoints
|
|
func (m *RateLimitMiddleware) LoginRateLimit(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Extract client IP
|
|
clientIP := m.extractClientIP(r)
|
|
|
|
// Extract email from request body (need to buffer and restore)
|
|
email := m.extractEmailFromRequest(r)
|
|
|
|
// Check rate limit
|
|
allowed, isLocked, remainingAttempts, err := m.loginRateLimiter.CheckAndRecordAttempt(ctx, email, clientIP)
|
|
if err != nil {
|
|
// Log error but allow request (fail open for availability)
|
|
m.logger.Warn("Rate limiter error, allowing request",
|
|
zap.Error(err),
|
|
zap.String("ip", validation.MaskIP(clientIP)))
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
// Check if account is locked
|
|
if isLocked {
|
|
m.logger.Warn("Login attempt on locked account",
|
|
zap.String("ip", validation.MaskIP(clientIP)),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
problem := httperror.NewTooManyRequestsError(
|
|
"Account temporarily locked due to too many failed attempts. Please try again later.")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// Check if IP rate limit exceeded
|
|
if !allowed {
|
|
m.logger.Warn("Rate limit exceeded",
|
|
zap.String("ip", validation.MaskIP(clientIP)),
|
|
zap.String("path", r.URL.Path),
|
|
zap.Int("remaining_attempts", remainingAttempts))
|
|
|
|
problem := httperror.NewTooManyRequestsError(
|
|
"Too many requests. Please slow down and try again later.")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// Add remaining attempts to response header for client awareness
|
|
if remainingAttempts > 0 && remainingAttempts <= 3 {
|
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remainingAttempts))
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// AuthRateLimit applies general rate limiting to auth endpoints
|
|
// For endpoints like registration, email verification, etc.
|
|
func (m *RateLimitMiddleware) AuthRateLimit(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract client IP for rate limiting key
|
|
clientIP := m.extractClientIP(r)
|
|
|
|
// Use the login rate limiter for IP-based checking only
|
|
// This provides basic protection against automated attacks
|
|
ctx := r.Context()
|
|
allowed, _, _, err := m.loginRateLimiter.CheckAndRecordAttempt(ctx, "", clientIP)
|
|
if err != nil {
|
|
// Fail open
|
|
m.logger.Warn("Rate limiter error, allowing request", zap.Error(err))
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
if !allowed {
|
|
m.logger.Warn("Auth rate limit exceeded",
|
|
zap.String("ip", validation.MaskIP(clientIP)),
|
|
zap.String("path", r.URL.Path))
|
|
|
|
problem := httperror.NewTooManyRequestsError(
|
|
"Too many requests from this IP. Please try again later.")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// extractClientIP extracts the real client IP from the request
|
|
func (m *RateLimitMiddleware) extractClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first (for reverse proxies)
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the chain
|
|
ips := strings.Split(xff, ",")
|
|
if len(ips) > 0 {
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
// Remove port if present
|
|
ip := r.RemoteAddr
|
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
ip = ip[:idx]
|
|
}
|
|
|
|
return ip
|
|
}
|
|
|
|
// extractEmailFromRequest extracts email from JSON request body
|
|
// It buffers the body so it can be read again by the handler
|
|
func (m *RateLimitMiddleware) extractEmailFromRequest(r *http.Request) string {
|
|
// Read body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
// Restore body for handler
|
|
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
|
|
// Parse JSON to extract email
|
|
var req struct {
|
|
Email string `json:"email"`
|
|
}
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
return ""
|
|
}
|
|
|
|
return strings.ToLower(strings.TrimSpace(req.Email))
|
|
}
|