package middleware import ( "fmt" "net/http" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip" ) // RateLimitMiddleware provides rate limiting for HTTP requests type RateLimitMiddleware struct { rateLimiter ratelimit.RateLimiter ipExtractor *clientip.Extractor logger *zap.Logger } // NewRateLimitMiddleware creates a new rate limiting middleware // CWE-348: Uses clientip.Extractor to securely extract IP addresses with trusted proxy validation func NewRateLimitMiddleware(rateLimiter ratelimit.RateLimiter, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddleware { return &RateLimitMiddleware{ rateLimiter: rateLimiter, ipExtractor: ipExtractor, logger: logger.Named("rate-limit-middleware"), } } // Handler wraps an HTTP handler with rate limiting (IP-based) // Used for: Registration endpoints func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // CWE-348: Extract client IP securely with trusted proxy validation clientIP := m.ipExtractor.Extract(r) // Check rate limit allowed, err := m.rateLimiter.Allow(r.Context(), clientIP) if err != nil { // Log error but fail open (allow request) m.logger.Error("rate limiter error", zap.String("ip", clientIP), zap.Error(err)) } if !allowed { m.logger.Warn("rate limit exceeded", zap.String("ip", clientIP), zap.String("path", r.URL.Path), zap.String("method", r.Method)) // Add Retry-After header (suggested wait time in seconds) w.Header().Set("Retry-After", "3600") // 1 hour // Return 429 Too Many Requests httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.") return } // Get remaining requests and add to response headers remaining, err := m.rateLimiter.GetRemaining(r.Context(), clientIP) if err != nil { m.logger.Error("failed to get remaining requests", zap.String("ip", clientIP), zap.Error(err)) } else { // Add rate limit headers for transparency w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) } // Continue to next handler next.ServeHTTP(w, r) }) } // HandlerWithUserKey wraps an HTTP handler with rate limiting (User-based) // Used for: Generic CRUD endpoints (tenant/user/site management, admin, /me, /hello) // Extracts user ID from JWT context for per-user rate limiting func (m *RateLimitMiddleware) HandlerWithUserKey(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract user ID from JWT context var key string if userID, ok := r.Context().Value(constants.SessionUserID).(uint64); ok { key = fmt.Sprintf("user:%d", userID) } else { // Fallback to IP if user ID not available key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r)) m.logger.Warn("user ID not found in context, falling back to IP-based rate limiting", zap.String("path", r.URL.Path)) } // Check rate limit allowed, err := m.rateLimiter.Allow(r.Context(), key) if err != nil { m.logger.Error("rate limiter error", zap.String("key", key), zap.Error(err)) } if !allowed { m.logger.Warn("rate limit exceeded", zap.String("key", key), zap.String("path", r.URL.Path), zap.String("method", r.Method)) w.Header().Set("Retry-After", "3600") // 1 hour httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.") return } // Get remaining requests and add to response headers remaining, err := m.rateLimiter.GetRemaining(r.Context(), key) if err != nil { m.logger.Error("failed to get remaining requests", zap.String("key", key), zap.Error(err)) } else { w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) } next.ServeHTTP(w, r) }) } // HandlerWithSiteKey wraps an HTTP handler with rate limiting (Site-based) // Used for: WordPress Plugin API endpoints // Extracts site ID from API key context for per-site rate limiting func (m *RateLimitMiddleware) HandlerWithSiteKey(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract site ID from API key context var key string if siteID, ok := r.Context().Value(constants.SiteID).(string); ok && siteID != "" { key = fmt.Sprintf("site:%s", siteID) } else { // Fallback to IP if site ID not available key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r)) m.logger.Warn("site ID not found in context, falling back to IP-based rate limiting", zap.String("path", r.URL.Path)) } // Check rate limit allowed, err := m.rateLimiter.Allow(r.Context(), key) if err != nil { m.logger.Error("rate limiter error", zap.String("key", key), zap.Error(err)) } if !allowed { m.logger.Warn("rate limit exceeded", zap.String("key", key), zap.String("path", r.URL.Path), zap.String("method", r.Method)) w.Header().Set("Retry-After", "3600") // 1 hour httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.") return } // Get remaining requests and add to response headers remaining, err := m.rateLimiter.GetRemaining(r.Context(), key) if err != nil { m.logger.Error("failed to get remaining requests", zap.String("key", key), zap.Error(err)) } else { w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) } next.ServeHTTP(w, r) }) }