174 lines
5.6 KiB
Go
174 lines
5.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
|
|
)
|
|
|
|
// RateLimitMiddleware provides rate limiting for HTTP requests
|
|
type RateLimitMiddleware struct {
|
|
rateLimiter ratelimit.RateLimiter
|
|
ipExtractor *clientip.Extractor
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewRateLimitMiddleware creates a new rate limiting middleware
|
|
// CWE-348: Uses clientip.Extractor to securely extract IP addresses with trusted proxy validation
|
|
func NewRateLimitMiddleware(rateLimiter ratelimit.RateLimiter, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddleware {
|
|
return &RateLimitMiddleware{
|
|
rateLimiter: rateLimiter,
|
|
ipExtractor: ipExtractor,
|
|
logger: logger.Named("rate-limit-middleware"),
|
|
}
|
|
}
|
|
|
|
// Handler wraps an HTTP handler with rate limiting (IP-based)
|
|
// Used for: Registration endpoints
|
|
func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// CWE-348: Extract client IP securely with trusted proxy validation
|
|
clientIP := m.ipExtractor.Extract(r)
|
|
|
|
// Check rate limit
|
|
allowed, err := m.rateLimiter.Allow(r.Context(), clientIP)
|
|
if err != nil {
|
|
// Log error but fail open (allow request)
|
|
m.logger.Error("rate limiter error",
|
|
zap.String("ip", clientIP),
|
|
zap.Error(err))
|
|
}
|
|
|
|
if !allowed {
|
|
m.logger.Warn("rate limit exceeded",
|
|
zap.String("ip", clientIP),
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("method", r.Method))
|
|
|
|
// Add Retry-After header (suggested wait time in seconds)
|
|
w.Header().Set("Retry-After", "3600") // 1 hour
|
|
|
|
// Return 429 Too Many Requests
|
|
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
|
return
|
|
}
|
|
|
|
// Get remaining requests and add to response headers
|
|
remaining, err := m.rateLimiter.GetRemaining(r.Context(), clientIP)
|
|
if err != nil {
|
|
m.logger.Error("failed to get remaining requests",
|
|
zap.String("ip", clientIP),
|
|
zap.Error(err))
|
|
} else {
|
|
// Add rate limit headers for transparency
|
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
}
|
|
|
|
// Continue to next handler
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// HandlerWithUserKey wraps an HTTP handler with rate limiting (User-based)
|
|
// Used for: Generic CRUD endpoints (tenant/user/site management, admin, /me, /hello)
|
|
// Extracts user ID from JWT context for per-user rate limiting
|
|
func (m *RateLimitMiddleware) HandlerWithUserKey(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract user ID from JWT context
|
|
var key string
|
|
if userID, ok := r.Context().Value(constants.SessionUserID).(uint64); ok {
|
|
key = fmt.Sprintf("user:%d", userID)
|
|
} else {
|
|
// Fallback to IP if user ID not available
|
|
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
|
|
m.logger.Warn("user ID not found in context, falling back to IP-based rate limiting",
|
|
zap.String("path", r.URL.Path))
|
|
}
|
|
|
|
// Check rate limit
|
|
allowed, err := m.rateLimiter.Allow(r.Context(), key)
|
|
if err != nil {
|
|
m.logger.Error("rate limiter error",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
}
|
|
|
|
if !allowed {
|
|
m.logger.Warn("rate limit exceeded",
|
|
zap.String("key", key),
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("method", r.Method))
|
|
|
|
w.Header().Set("Retry-After", "3600") // 1 hour
|
|
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
|
return
|
|
}
|
|
|
|
// Get remaining requests and add to response headers
|
|
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
|
|
if err != nil {
|
|
m.logger.Error("failed to get remaining requests",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
} else {
|
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// HandlerWithSiteKey wraps an HTTP handler with rate limiting (Site-based)
|
|
// Used for: WordPress Plugin API endpoints
|
|
// Extracts site ID from API key context for per-site rate limiting
|
|
func (m *RateLimitMiddleware) HandlerWithSiteKey(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract site ID from API key context
|
|
var key string
|
|
if siteID, ok := r.Context().Value(constants.SiteID).(string); ok && siteID != "" {
|
|
key = fmt.Sprintf("site:%s", siteID)
|
|
} else {
|
|
// Fallback to IP if site ID not available
|
|
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
|
|
m.logger.Warn("site ID not found in context, falling back to IP-based rate limiting",
|
|
zap.String("path", r.URL.Path))
|
|
}
|
|
|
|
// Check rate limit
|
|
allowed, err := m.rateLimiter.Allow(r.Context(), key)
|
|
if err != nil {
|
|
m.logger.Error("rate limiter error",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
}
|
|
|
|
if !allowed {
|
|
m.logger.Warn("rate limit exceeded",
|
|
zap.String("key", key),
|
|
zap.String("path", r.URL.Path),
|
|
zap.String("method", r.Method))
|
|
|
|
w.Header().Set("Retry-After", "3600") // 1 hour
|
|
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
|
return
|
|
}
|
|
|
|
// Get remaining requests and add to response headers
|
|
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
|
|
if err != nil {
|
|
m.logger.Error("failed to get remaining requests",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
} else {
|
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|