// Package middleware provides HTTP middleware for the MapleFile backend. package middleware import ( "bytes" "encoding/json" "fmt" "io" "net/http" "strings" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/httperror" "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/ratelimit" "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation" ) // RateLimitMiddleware provides rate limiting functionality for HTTP endpoints type RateLimitMiddleware struct { logger *zap.Logger loginRateLimiter ratelimit.LoginRateLimiter } // NewRateLimitMiddleware creates a new rate limit middleware func NewRateLimitMiddleware(logger *zap.Logger, loginRateLimiter ratelimit.LoginRateLimiter) *RateLimitMiddleware { return &RateLimitMiddleware{ logger: logger.Named("RateLimitMiddleware"), loginRateLimiter: loginRateLimiter, } } // LoginRateLimit applies login-specific rate limiting to auth endpoints // CWE-307: Protects against brute force attacks on authentication endpoints func (m *RateLimitMiddleware) LoginRateLimit(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Extract client IP clientIP := m.extractClientIP(r) // Extract email from request body (need to buffer and restore) email := m.extractEmailFromRequest(r) // Check rate limit allowed, isLocked, remainingAttempts, err := m.loginRateLimiter.CheckAndRecordAttempt(ctx, email, clientIP) if err != nil { // Log error but allow request (fail open for availability) m.logger.Warn("Rate limiter error, allowing request", zap.Error(err), zap.String("ip", validation.MaskIP(clientIP))) next(w, r) return } // Check if account is locked if isLocked { m.logger.Warn("Login attempt on locked account", zap.String("ip", validation.MaskIP(clientIP)), zap.String("path", r.URL.Path)) problem := httperror.NewTooManyRequestsError( "Account temporarily locked due to too many failed attempts. Please try again later.") problem.WithInstance(r.URL.Path). WithTraceID(httperror.ExtractRequestID(r)) httperror.RespondWithProblem(w, problem) return } // Check if IP rate limit exceeded if !allowed { m.logger.Warn("Rate limit exceeded", zap.String("ip", validation.MaskIP(clientIP)), zap.String("path", r.URL.Path), zap.Int("remaining_attempts", remainingAttempts)) problem := httperror.NewTooManyRequestsError( "Too many requests. Please slow down and try again later.") problem.WithInstance(r.URL.Path). WithTraceID(httperror.ExtractRequestID(r)) httperror.RespondWithProblem(w, problem) return } // Add remaining attempts to response header for client awareness if remainingAttempts > 0 && remainingAttempts <= 3 { w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remainingAttempts)) } next(w, r) } } // AuthRateLimit applies general rate limiting to auth endpoints // For endpoints like registration, email verification, etc. func (m *RateLimitMiddleware) AuthRateLimit(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Extract client IP for rate limiting key clientIP := m.extractClientIP(r) // Use the login rate limiter for IP-based checking only // This provides basic protection against automated attacks ctx := r.Context() allowed, _, _, err := m.loginRateLimiter.CheckAndRecordAttempt(ctx, "", clientIP) if err != nil { // Fail open m.logger.Warn("Rate limiter error, allowing request", zap.Error(err)) next(w, r) return } if !allowed { m.logger.Warn("Auth rate limit exceeded", zap.String("ip", validation.MaskIP(clientIP)), zap.String("path", r.URL.Path)) problem := httperror.NewTooManyRequestsError( "Too many requests from this IP. Please try again later.") problem.WithInstance(r.URL.Path). WithTraceID(httperror.ExtractRequestID(r)) httperror.RespondWithProblem(w, problem) return } next(w, r) } } // extractClientIP extracts the real client IP from the request func (m *RateLimitMiddleware) extractClientIP(r *http.Request) string { // Check X-Forwarded-For header first (for reverse proxies) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain ips := strings.Split(xff, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr // Remove port if present ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip } // extractEmailFromRequest extracts email from JSON request body // It buffers the body so it can be read again by the handler func (m *RateLimitMiddleware) extractEmailFromRequest(r *http.Request) string { // Read body body, err := io.ReadAll(r.Body) if err != nil { return "" } // Restore body for handler r.Body = io.NopCloser(bytes.NewBuffer(body)) // Parse JSON to extract email var req struct { Email string `json:"email"` } if err := json.Unmarshal(body, &req); err != nil { return "" } return strings.ToLower(strings.TrimSpace(req.Email)) }