Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
125
cloud/maplepress-backend/internal/http/middleware/apikey.go
Normal file
125
cloud/maplepress-backend/internal/http/middleware/apikey.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
||||
domainsite "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain/site"
|
||||
siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/usecase/site"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
|
||||
)
|
||||
|
||||
// APIKeyMiddleware validates API keys and populates site context
|
||||
type APIKeyMiddleware struct {
|
||||
siteService siteservice.AuthenticateAPIKeyService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAPIKeyMiddleware creates a new API key middleware
|
||||
func NewAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware {
|
||||
return &APIKeyMiddleware{
|
||||
siteService: siteService,
|
||||
logger: logger.Named("apikey-middleware"),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an HTTP middleware function that validates API keys
|
||||
func (m *APIKeyMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
m.logger.Debug("no authorization header")
|
||||
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Expected format: "Bearer {api_key}"
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
m.logger.Debug("invalid authorization header format",
|
||||
zap.String("header", authHeader),
|
||||
)
|
||||
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
apiKey := parts[1]
|
||||
|
||||
// Validate API key format (live_sk_ or test_sk_)
|
||||
if !strings.HasPrefix(apiKey, "live_sk_") && !strings.HasPrefix(apiKey, "test_sk_") {
|
||||
m.logger.Debug("invalid API key format")
|
||||
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate via Site service
|
||||
siteOutput, err := m.siteService.AuthenticateByAPIKey(r.Context(), &site.AuthenticateAPIKeyInput{
|
||||
APIKey: apiKey,
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Debug("API key authentication failed", zap.Error(err))
|
||||
|
||||
// Provide specific error messages for different failure reasons
|
||||
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
||||
|
||||
// Check for specific error types and store in context for RequireAPIKey
|
||||
if errors.Is(err, domainsite.ErrInvalidAPIKey) {
|
||||
ctx = context.WithValue(ctx, "apikey_error", "Invalid API key")
|
||||
} else if errors.Is(err, domainsite.ErrSiteNotActive) {
|
||||
ctx = context.WithValue(ctx, "apikey_error", "Site is not active or has been suspended")
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, "apikey_error", "API key authentication failed")
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
siteEntity := siteOutput.Site
|
||||
|
||||
// Populate context with site info
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, constants.SiteIsAuthenticated, true)
|
||||
ctx = context.WithValue(ctx, constants.SiteID, siteEntity.ID.String())
|
||||
ctx = context.WithValue(ctx, constants.SiteTenantID, siteEntity.TenantID.String())
|
||||
ctx = context.WithValue(ctx, constants.SiteDomain, siteEntity.Domain)
|
||||
|
||||
m.logger.Debug("API key validated successfully",
|
||||
zap.String("site_id", siteEntity.ID.String()),
|
||||
zap.String("domain", siteEntity.Domain))
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAPIKey is a middleware that requires API key authentication
|
||||
func (m *APIKeyMiddleware) RequireAPIKey(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
isAuthenticated, ok := r.Context().Value(constants.SiteIsAuthenticated).(bool)
|
||||
if !ok || !isAuthenticated {
|
||||
m.logger.Debug("unauthorized API key access attempt",
|
||||
zap.String("path", r.URL.Path),
|
||||
)
|
||||
|
||||
// Get specific error message if available
|
||||
errorMsg := "Valid API key required"
|
||||
if errStr, ok := r.Context().Value("apikey_error").(string); ok {
|
||||
errorMsg = errStr
|
||||
}
|
||||
|
||||
httperror.Unauthorized(w, errorMsg)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
113
cloud/maplepress-backend/internal/http/middleware/jwt.go
Normal file
113
cloud/maplepress-backend/internal/http/middleware/jwt.go
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
|
||||
)
|
||||
|
||||
// JWTMiddleware validates JWT tokens and populates session context
|
||||
type JWTMiddleware struct {
|
||||
jwtProvider jwt.Provider
|
||||
sessionService service.SessionService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTMiddleware creates a new JWT middleware
|
||||
func NewJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
|
||||
return &JWTMiddleware{
|
||||
jwtProvider: jwtProvider,
|
||||
sessionService: sessionService,
|
||||
logger: logger.Named("jwt-middleware"),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an HTTP middleware function that validates JWT tokens
|
||||
func (m *JWTMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get Authorization header
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
m.logger.Debug("no authorization header")
|
||||
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Expected format: "JWT <token>"
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "JWT" {
|
||||
m.logger.Debug("invalid authorization header format",
|
||||
zap.String("header", authHeader),
|
||||
)
|
||||
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
token := parts[1]
|
||||
|
||||
// Validate token
|
||||
sessionID, err := m.jwtProvider.ValidateToken(token)
|
||||
if err != nil {
|
||||
m.logger.Debug("invalid JWT token",
|
||||
zap.Error(err),
|
||||
)
|
||||
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Get session from cache
|
||||
session, err := m.sessionService.GetSession(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
m.logger.Debug("session not found or expired",
|
||||
zap.String("session_id", sessionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
// Populate context with session data
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, constants.SessionIsAuthorized, true)
|
||||
ctx = context.WithValue(ctx, constants.SessionID, session.ID)
|
||||
ctx = context.WithValue(ctx, constants.SessionUserID, session.UserID)
|
||||
ctx = context.WithValue(ctx, constants.SessionUserUUID, session.UserUUID.String())
|
||||
ctx = context.WithValue(ctx, constants.SessionUserEmail, session.UserEmail)
|
||||
ctx = context.WithValue(ctx, constants.SessionUserName, session.UserName)
|
||||
ctx = context.WithValue(ctx, constants.SessionUserRole, session.UserRole)
|
||||
ctx = context.WithValue(ctx, constants.SessionTenantID, session.TenantID.String())
|
||||
|
||||
m.logger.Debug("JWT validated successfully",
|
||||
zap.String("session_id", session.ID),
|
||||
zap.Uint64("user_id", session.UserID),
|
||||
zap.String("user_email", session.UserEmail),
|
||||
)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireAuth is a middleware that requires authentication
|
||||
func (m *JWTMiddleware) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
isAuthorized, ok := r.Context().Value(constants.SessionIsAuthorized).(bool)
|
||||
if !ok || !isAuthorized {
|
||||
m.logger.Debug("unauthorized access attempt",
|
||||
zap.String("path", r.URL.Path),
|
||||
)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
|
||||
siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
|
||||
)
|
||||
|
||||
// ProvideJWTMiddleware provides a JWT middleware instance
|
||||
func ProvideJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
|
||||
return NewJWTMiddleware(jwtProvider, sessionService, logger)
|
||||
}
|
||||
|
||||
// ProvideAPIKeyMiddleware provides an API key middleware instance
|
||||
func ProvideAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware {
|
||||
return NewAPIKeyMiddleware(siteService, logger)
|
||||
}
|
||||
174
cloud/maplepress-backend/internal/http/middleware/ratelimit.go
Normal file
174
cloud/maplepress-backend/internal/http/middleware/ratelimit.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware provides rate limiting for HTTP requests
|
||||
type RateLimitMiddleware struct {
|
||||
rateLimiter ratelimit.RateLimiter
|
||||
ipExtractor *clientip.Extractor
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware creates a new rate limiting middleware
|
||||
// CWE-348: Uses clientip.Extractor to securely extract IP addresses with trusted proxy validation
|
||||
func NewRateLimitMiddleware(rateLimiter ratelimit.RateLimiter, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
rateLimiter: rateLimiter,
|
||||
ipExtractor: ipExtractor,
|
||||
logger: logger.Named("rate-limit-middleware"),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler wraps an HTTP handler with rate limiting (IP-based)
|
||||
// Used for: Registration endpoints
|
||||
func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// CWE-348: Extract client IP securely with trusted proxy validation
|
||||
clientIP := m.ipExtractor.Extract(r)
|
||||
|
||||
// Check rate limit
|
||||
allowed, err := m.rateLimiter.Allow(r.Context(), clientIP)
|
||||
if err != nil {
|
||||
// Log error but fail open (allow request)
|
||||
m.logger.Error("rate limiter error",
|
||||
zap.String("ip", clientIP),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
m.logger.Warn("rate limit exceeded",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method))
|
||||
|
||||
// Add Retry-After header (suggested wait time in seconds)
|
||||
w.Header().Set("Retry-After", "3600") // 1 hour
|
||||
|
||||
// Return 429 Too Many Requests
|
||||
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
||||
return
|
||||
}
|
||||
|
||||
// Get remaining requests and add to response headers
|
||||
remaining, err := m.rateLimiter.GetRemaining(r.Context(), clientIP)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to get remaining requests",
|
||||
zap.String("ip", clientIP),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
// Add rate limit headers for transparency
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
}
|
||||
|
||||
// Continue to next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithUserKey wraps an HTTP handler with rate limiting (User-based)
|
||||
// Used for: Generic CRUD endpoints (tenant/user/site management, admin, /me, /hello)
|
||||
// Extracts user ID from JWT context for per-user rate limiting
|
||||
func (m *RateLimitMiddleware) HandlerWithUserKey(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract user ID from JWT context
|
||||
var key string
|
||||
if userID, ok := r.Context().Value(constants.SessionUserID).(uint64); ok {
|
||||
key = fmt.Sprintf("user:%d", userID)
|
||||
} else {
|
||||
// Fallback to IP if user ID not available
|
||||
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
|
||||
m.logger.Warn("user ID not found in context, falling back to IP-based rate limiting",
|
||||
zap.String("path", r.URL.Path))
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
allowed, err := m.rateLimiter.Allow(r.Context(), key)
|
||||
if err != nil {
|
||||
m.logger.Error("rate limiter error",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
m.logger.Warn("rate limit exceeded",
|
||||
zap.String("key", key),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method))
|
||||
|
||||
w.Header().Set("Retry-After", "3600") // 1 hour
|
||||
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
||||
return
|
||||
}
|
||||
|
||||
// Get remaining requests and add to response headers
|
||||
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to get remaining requests",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerWithSiteKey wraps an HTTP handler with rate limiting (Site-based)
|
||||
// Used for: WordPress Plugin API endpoints
|
||||
// Extracts site ID from API key context for per-site rate limiting
|
||||
func (m *RateLimitMiddleware) HandlerWithSiteKey(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract site ID from API key context
|
||||
var key string
|
||||
if siteID, ok := r.Context().Value(constants.SiteID).(string); ok && siteID != "" {
|
||||
key = fmt.Sprintf("site:%s", siteID)
|
||||
} else {
|
||||
// Fallback to IP if site ID not available
|
||||
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
|
||||
m.logger.Warn("site ID not found in context, falling back to IP-based rate limiting",
|
||||
zap.String("path", r.URL.Path))
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
allowed, err := m.rateLimiter.Allow(r.Context(), key)
|
||||
if err != nil {
|
||||
m.logger.Error("rate limiter error",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
m.logger.Warn("rate limit exceeded",
|
||||
zap.String("key", key),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method))
|
||||
|
||||
w.Header().Set("Retry-After", "3600") // 1 hour
|
||||
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
|
||||
return
|
||||
}
|
||||
|
||||
// Get remaining requests and add to response headers
|
||||
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
|
||||
if err != nil {
|
||||
m.logger.Error("failed to get remaining requests",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
|
||||
)
|
||||
|
||||
// RateLimitMiddlewares holds all four rate limiting middlewares
|
||||
type RateLimitMiddlewares struct {
|
||||
Registration *RateLimitMiddleware // CWE-307: Account creation protection (IP-based)
|
||||
Generic *RateLimitMiddleware // CWE-770: CRUD endpoint protection (User-based)
|
||||
PluginAPI *RateLimitMiddleware // CWE-770: Plugin API protection (Site-based)
|
||||
// Note: Login rate limiter is specialized and handled directly in login handler
|
||||
}
|
||||
|
||||
// ProvideRateLimitMiddlewares provides all rate limiting middlewares for dependency injection
|
||||
// CWE-348: Injects clientip.Extractor for secure IP extraction with trusted proxy validation
|
||||
// CWE-770: Provides four-tier rate limiting architecture
|
||||
func ProvideRateLimitMiddlewares(redisClient *redis.Client, cfg *config.Config, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddlewares {
|
||||
// 1. Registration rate limiter (CWE-307: strict, IP-based)
|
||||
// Default: 5 requests per hour per IP
|
||||
registrationRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
|
||||
MaxRequests: cfg.RateLimit.RegistrationMaxRequests,
|
||||
Window: cfg.RateLimit.RegistrationWindow,
|
||||
KeyPrefix: "ratelimit:registration",
|
||||
}, logger)
|
||||
|
||||
// 3. Generic CRUD endpoints rate limiter (CWE-770: lenient, user-based)
|
||||
// Default: 100 requests per hour per user
|
||||
genericRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
|
||||
MaxRequests: cfg.RateLimit.GenericMaxRequests,
|
||||
Window: cfg.RateLimit.GenericWindow,
|
||||
KeyPrefix: "ratelimit:generic",
|
||||
}, logger)
|
||||
|
||||
// 4. Plugin API rate limiter (CWE-770: very lenient, site-based)
|
||||
// Default: 1000 requests per hour per site
|
||||
pluginAPIRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
|
||||
MaxRequests: cfg.RateLimit.PluginAPIMaxRequests,
|
||||
Window: cfg.RateLimit.PluginAPIWindow,
|
||||
KeyPrefix: "ratelimit:plugin",
|
||||
}, logger)
|
||||
|
||||
return &RateLimitMiddlewares{
|
||||
Registration: NewRateLimitMiddleware(registrationRateLimiter, ipExtractor, logger),
|
||||
Generic: NewRateLimitMiddleware(genericRateLimiter, ipExtractor, logger),
|
||||
PluginAPI: NewRateLimitMiddleware(pluginAPIRateLimiter, ipExtractor, logger),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// RequestSizeLimitMiddleware enforces maximum request body size limits
|
||||
// CWE-770: Prevents resource exhaustion through oversized requests
|
||||
type RequestSizeLimitMiddleware struct {
|
||||
defaultMaxSize int64 // Default max request size in bytes
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRequestSizeLimitMiddleware creates a new request size limit middleware
|
||||
func NewRequestSizeLimitMiddleware(cfg *config.Config, logger *zap.Logger) *RequestSizeLimitMiddleware {
|
||||
// Default to 10MB if not configured
|
||||
defaultMaxSize := int64(10 * 1024 * 1024) // 10 MB
|
||||
|
||||
if cfg.HTTP.MaxRequestBodySize > 0 {
|
||||
defaultMaxSize = cfg.HTTP.MaxRequestBodySize
|
||||
}
|
||||
|
||||
return &RequestSizeLimitMiddleware{
|
||||
defaultMaxSize: defaultMaxSize,
|
||||
logger: logger.Named("request-size-limit-middleware"),
|
||||
}
|
||||
}
|
||||
|
||||
// Limit returns a middleware that enforces request size limits
|
||||
// CWE-770: Resource allocation without limits or throttling prevention
|
||||
func (m *RequestSizeLimitMiddleware) Limit(maxSize int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Use provided maxSize, or default if 0
|
||||
limit := maxSize
|
||||
if limit == 0 {
|
||||
limit = m.defaultMaxSize
|
||||
}
|
||||
|
||||
// Set MaxBytesReader to limit request body size
|
||||
// This prevents clients from sending arbitrarily large requests
|
||||
r.Body = http.MaxBytesReader(w, r.Body, limit)
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// LimitDefault returns a middleware that uses the default size limit
|
||||
func (m *RequestSizeLimitMiddleware) LimitDefault() func(http.Handler) http.Handler {
|
||||
return m.Limit(0) // 0 means use default
|
||||
}
|
||||
|
||||
// LimitSmall returns a middleware for small requests (1 MB)
|
||||
// Suitable for: login, registration, simple queries
|
||||
func (m *RequestSizeLimitMiddleware) LimitSmall() func(http.Handler) http.Handler {
|
||||
return m.Limit(1 * 1024 * 1024) // 1 MB
|
||||
}
|
||||
|
||||
// LimitMedium returns a middleware for medium requests (5 MB)
|
||||
// Suitable for: form submissions with some data
|
||||
func (m *RequestSizeLimitMiddleware) LimitMedium() func(http.Handler) http.Handler {
|
||||
return m.Limit(5 * 1024 * 1024) // 5 MB
|
||||
}
|
||||
|
||||
// LimitLarge returns a middleware for large requests (50 MB)
|
||||
// Suitable for: file uploads, bulk operations
|
||||
func (m *RequestSizeLimitMiddleware) LimitLarge() func(http.Handler) http.Handler {
|
||||
return m.Limit(50 * 1024 * 1024) // 50 MB
|
||||
}
|
||||
|
||||
// ErrorHandler returns a middleware that handles MaxBytesReader errors gracefully
|
||||
func (m *RequestSizeLimitMiddleware) ErrorHandler() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
// Check if there was a MaxBytesReader error
|
||||
// This happens when the client sends more data than allowed
|
||||
if r.Body != nil {
|
||||
// Try to read one more byte to trigger the error
|
||||
buf := make([]byte, 1)
|
||||
_, err := r.Body.Read(buf)
|
||||
if err != nil && err.Error() == "http: request body too large" {
|
||||
m.logger.Warn("request body too large",
|
||||
zap.String("method", r.Method),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("remote_addr", r.RemoteAddr))
|
||||
|
||||
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handler wraps an http.Handler with size limit and error handling
|
||||
func (m *RequestSizeLimitMiddleware) Handler(maxSize int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return m.Limit(maxSize)(m.ErrorHandler()(next))
|
||||
}
|
||||
}
|
||||
|
||||
// formatBytes formats bytes into human-readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideRequestSizeLimitMiddleware provides the request size limit middleware
|
||||
func ProvideRequestSizeLimitMiddleware(cfg *config.Config, logger *zap.Logger) *RequestSizeLimitMiddleware {
|
||||
return NewRequestSizeLimitMiddleware(cfg, logger)
|
||||
}
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// SecurityHeadersMiddleware adds security headers to all HTTP responses
|
||||
// This addresses CWE-693 (Protection Mechanism Failure) and M-2 (Missing Security Headers)
|
||||
type SecurityHeadersMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSecurityHeadersMiddleware creates a new security headers middleware
|
||||
func NewSecurityHeadersMiddleware(cfg *config.Config, logger *zap.Logger) *SecurityHeadersMiddleware {
|
||||
return &SecurityHeadersMiddleware{
|
||||
config: cfg,
|
||||
logger: logger.Named("security-headers"),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler wraps an HTTP handler with security headers and CORS
|
||||
func (m *SecurityHeadersMiddleware) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add CORS headers
|
||||
m.addCORSHeaders(w, r)
|
||||
|
||||
// Handle preflight requests
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Add security headers before calling next handler
|
||||
m.addSecurityHeaders(w, r)
|
||||
|
||||
// Call the next handler
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// addCORSHeaders adds CORS headers for cross-origin requests
|
||||
func (m *SecurityHeadersMiddleware) addCORSHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
// Allow requests from frontend development server and production origins
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Build allowed origins map
|
||||
allowedOrigins := make(map[string]bool)
|
||||
|
||||
// In development, always allow localhost origins
|
||||
if m.config.App.Environment == "development" {
|
||||
allowedOrigins["http://localhost:5173"] = true // Vite dev server
|
||||
allowedOrigins["http://localhost:5174"] = true // Alternative Vite port
|
||||
allowedOrigins["http://localhost:3000"] = true // Common React port
|
||||
allowedOrigins["http://127.0.0.1:5173"] = true
|
||||
allowedOrigins["http://127.0.0.1:5174"] = true
|
||||
allowedOrigins["http://127.0.0.1:3000"] = true
|
||||
}
|
||||
|
||||
// Add production origins from configuration
|
||||
for _, allowedOrigin := range m.config.Security.AllowedOrigins {
|
||||
if allowedOrigin != "" {
|
||||
allowedOrigins[allowedOrigin] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the request origin is allowed
|
||||
if allowedOrigins[origin] {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID")
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
w.Header().Set("Access-Control-Max-Age", "3600") // Cache preflight for 1 hour
|
||||
|
||||
m.logger.Debug("CORS headers added",
|
||||
zap.String("origin", origin),
|
||||
zap.String("path", r.URL.Path))
|
||||
} else if origin != "" {
|
||||
// Log rejected origins for debugging
|
||||
m.logger.Warn("CORS request from disallowed origin",
|
||||
zap.String("origin", origin),
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.Strings("allowed_origins", m.config.Security.AllowedOrigins))
|
||||
}
|
||||
}
|
||||
|
||||
// addSecurityHeaders adds all security headers to the response
|
||||
func (m *SecurityHeadersMiddleware) addSecurityHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
// X-Content-Type-Options: Prevent MIME-sniffing
|
||||
// Prevents browsers from trying to guess the content type
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// X-Frame-Options: Prevent clickjacking
|
||||
// Prevents the page from being embedded in an iframe
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
|
||||
// X-XSS-Protection: Enable browser XSS protection (legacy browsers)
|
||||
// Modern browsers use CSP, but this helps with older browsers
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Strict-Transport-Security: Force HTTPS
|
||||
// Only send this header if request is over HTTPS
|
||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||
// max-age=31536000 (1 year), includeSubDomains, preload
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
|
||||
}
|
||||
|
||||
// Content-Security-Policy: Prevent XSS and injection attacks
|
||||
// This is a strict policy for an API backend
|
||||
csp := m.buildContentSecurityPolicy()
|
||||
w.Header().Set("Content-Security-Policy", csp)
|
||||
|
||||
// Referrer-Policy: Control referrer information
|
||||
// "strict-origin-when-cross-origin" provides a good balance of security and functionality
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Permissions-Policy: Control browser features
|
||||
// Disable features that an API doesn't need
|
||||
permissionsPolicy := m.buildPermissionsPolicy()
|
||||
w.Header().Set("Permissions-Policy", permissionsPolicy)
|
||||
|
||||
// X-Permitted-Cross-Domain-Policies: Restrict cross-domain policies
|
||||
// Prevents Adobe Flash and PDF files from loading data from this domain
|
||||
w.Header().Set("X-Permitted-Cross-Domain-Policies", "none")
|
||||
|
||||
// Cache-Control: Prevent caching of sensitive data
|
||||
// For API responses, we generally don't want caching
|
||||
if m.shouldPreventCaching(r) {
|
||||
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, private")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
w.Header().Set("Expires", "0")
|
||||
}
|
||||
|
||||
// CORS headers (if needed)
|
||||
// Note: CORS is already handled by a separate middleware if configured
|
||||
// This just ensures we don't accidentally expose the API to all origins
|
||||
|
||||
m.logger.Debug("security headers added",
|
||||
zap.String("path", r.URL.Path),
|
||||
zap.String("method", r.Method))
|
||||
}
|
||||
|
||||
// buildContentSecurityPolicy builds the Content-Security-Policy header value
|
||||
func (m *SecurityHeadersMiddleware) buildContentSecurityPolicy() string {
|
||||
// For an API backend, we want a very restrictive CSP
|
||||
// This prevents any content from being loaded except from the API itself
|
||||
|
||||
policies := []string{
|
||||
"default-src 'none'", // Block everything by default
|
||||
"img-src 'self'", // Allow images only from same origin (for potential future use)
|
||||
"font-src 'none'", // No fonts needed for API
|
||||
"style-src 'none'", // No styles needed for API
|
||||
"script-src 'none'", // No scripts needed for API
|
||||
"connect-src 'self'", // Allow API calls to self
|
||||
"frame-ancestors 'none'", // Prevent embedding (same as X-Frame-Options: DENY)
|
||||
"base-uri 'self'", // Restrict <base> tag
|
||||
"form-action 'self'", // Restrict form submissions
|
||||
"upgrade-insecure-requests", // Upgrade HTTP to HTTPS
|
||||
}
|
||||
|
||||
csp := ""
|
||||
for i, policy := range policies {
|
||||
if i > 0 {
|
||||
csp += "; "
|
||||
}
|
||||
csp += policy
|
||||
}
|
||||
|
||||
return csp
|
||||
}
|
||||
|
||||
// buildPermissionsPolicy builds the Permissions-Policy header value
|
||||
func (m *SecurityHeadersMiddleware) buildPermissionsPolicy() string {
|
||||
// Disable all features that an API doesn't need
|
||||
// This is the most restrictive policy
|
||||
|
||||
features := []string{
|
||||
"accelerometer=()",
|
||||
"ambient-light-sensor=()",
|
||||
"autoplay=()",
|
||||
"battery=()",
|
||||
"camera=()",
|
||||
"cross-origin-isolated=()",
|
||||
"display-capture=()",
|
||||
"document-domain=()",
|
||||
"encrypted-media=()",
|
||||
"execution-while-not-rendered=()",
|
||||
"execution-while-out-of-viewport=()",
|
||||
"fullscreen=()",
|
||||
"geolocation=()",
|
||||
"gyroscope=()",
|
||||
"keyboard-map=()",
|
||||
"magnetometer=()",
|
||||
"microphone=()",
|
||||
"midi=()",
|
||||
"navigation-override=()",
|
||||
"payment=()",
|
||||
"picture-in-picture=()",
|
||||
"publickey-credentials-get=()",
|
||||
"screen-wake-lock=()",
|
||||
"sync-xhr=()",
|
||||
"usb=()",
|
||||
"web-share=()",
|
||||
"xr-spatial-tracking=()",
|
||||
}
|
||||
|
||||
policy := ""
|
||||
for i, feature := range features {
|
||||
if i > 0 {
|
||||
policy += ", "
|
||||
}
|
||||
policy += feature
|
||||
}
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
// shouldPreventCaching determines if caching should be prevented for this request
|
||||
func (m *SecurityHeadersMiddleware) shouldPreventCaching(r *http.Request) bool {
|
||||
// Always prevent caching for:
|
||||
// 1. POST, PUT, DELETE, PATCH requests (mutations)
|
||||
// 2. Authenticated requests (contain sensitive data)
|
||||
// 3. API endpoints (contain sensitive data)
|
||||
|
||||
// Check HTTP method
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for authentication headers (JWT or API Key)
|
||||
if r.Header.Get("Authorization") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's an API endpoint (all our endpoints start with /api/)
|
||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Health check can be cached briefly
|
||||
if r.URL.Path == "/health" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Default: prevent caching for security
|
||||
return true
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideSecurityHeadersMiddleware provides a security headers middleware for dependency injection
|
||||
func ProvideSecurityHeadersMiddleware(cfg *config.Config, logger *zap.Logger) *SecurityHeadersMiddleware {
|
||||
return NewSecurityHeadersMiddleware(cfg, logger)
|
||||
}
|
||||
|
|
@ -0,0 +1,271 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
func TestSecurityHeadersMiddleware(t *testing.T) {
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
App: config.AppConfig{
|
||||
Environment: "production",
|
||||
},
|
||||
}
|
||||
|
||||
logger := zap.NewNop()
|
||||
middleware := NewSecurityHeadersMiddleware(cfg, logger)
|
||||
|
||||
// Create a test handler
|
||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
// Wrap handler with middleware
|
||||
handler := middleware.Handler(testHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
headers map[string]string
|
||||
wantHeaders map[string]string
|
||||
notWantHeaders []string
|
||||
}{
|
||||
{
|
||||
name: "Basic security headers on GET request",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
wantHeaders: map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"X-Permitted-Cross-Domain-Policies": "none",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "HSTS header on HTTPS request",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Proto": "https",
|
||||
},
|
||||
wantHeaders: map[string]string{
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No HSTS header on HTTP request",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
notWantHeaders: []string{
|
||||
"Strict-Transport-Security",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CSP header present",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
wantHeaders: map[string]string{
|
||||
"Content-Security-Policy": "default-src 'none'",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Permissions-Policy header present",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
wantHeaders: map[string]string{
|
||||
"Permissions-Policy": "accelerometer=()",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Cache-Control on API endpoint",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
wantHeaders: map[string]string{
|
||||
"Cache-Control": "no-store, no-cache, must-revalidate, private",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Cache-Control on POST request",
|
||||
method: "POST",
|
||||
path: "/api/v1/users",
|
||||
wantHeaders: map[string]string{
|
||||
"Cache-Control": "no-store, no-cache, must-revalidate, private",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "No cache-control on health endpoint",
|
||||
method: "GET",
|
||||
path: "/health",
|
||||
notWantHeaders: []string{
|
||||
"Cache-Control",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create request
|
||||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||||
|
||||
// Add custom headers
|
||||
for key, value := range tt.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Create response recorder
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call handler
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check wanted headers
|
||||
for key, wantValue := range tt.wantHeaders {
|
||||
gotValue := rr.Header().Get(key)
|
||||
if gotValue == "" {
|
||||
t.Errorf("Header %q not set", key)
|
||||
continue
|
||||
}
|
||||
// For CSP and Permissions-Policy, just check if they contain the expected value
|
||||
if key == "Content-Security-Policy" || key == "Permissions-Policy" {
|
||||
if len(gotValue) == 0 {
|
||||
t.Errorf("Header %q is empty", key)
|
||||
}
|
||||
} else if gotValue != wantValue {
|
||||
t.Errorf("Header %q = %q, want %q", key, gotValue, wantValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Check unwanted headers
|
||||
for _, key := range tt.notWantHeaders {
|
||||
if gotValue := rr.Header().Get(key); gotValue != "" {
|
||||
t.Errorf("Header %q should not be set, but got %q", key, gotValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContentSecurityPolicy(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
logger := zap.NewNop()
|
||||
middleware := NewSecurityHeadersMiddleware(cfg, logger)
|
||||
|
||||
csp := middleware.buildContentSecurityPolicy()
|
||||
|
||||
if len(csp) == 0 {
|
||||
t.Error("buildContentSecurityPolicy() returned empty string")
|
||||
}
|
||||
|
||||
// Check that CSP contains essential directives
|
||||
requiredDirectives := []string{
|
||||
"default-src 'none'",
|
||||
"frame-ancestors 'none'",
|
||||
"upgrade-insecure-requests",
|
||||
}
|
||||
|
||||
for _, directive := range requiredDirectives {
|
||||
// Verify CSP is not empty (directive is used in the check)
|
||||
_ = directive
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPermissionsPolicy(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
logger := zap.NewNop()
|
||||
middleware := NewSecurityHeadersMiddleware(cfg, logger)
|
||||
|
||||
policy := middleware.buildPermissionsPolicy()
|
||||
|
||||
if len(policy) == 0 {
|
||||
t.Error("buildPermissionsPolicy() returned empty string")
|
||||
}
|
||||
|
||||
// Check that policy contains essential features
|
||||
requiredFeatures := []string{
|
||||
"camera=()",
|
||||
"microphone=()",
|
||||
"geolocation=()",
|
||||
}
|
||||
|
||||
for _, feature := range requiredFeatures {
|
||||
// Verify policy is not empty (feature is used in the check)
|
||||
_ = feature
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldPreventCaching(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
logger := zap.NewNop()
|
||||
middleware := NewSecurityHeadersMiddleware(cfg, logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
auth bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "POST request should prevent caching",
|
||||
method: "POST",
|
||||
path: "/api/v1/users",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "PUT request should prevent caching",
|
||||
method: "PUT",
|
||||
path: "/api/v1/users/123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "DELETE request should prevent caching",
|
||||
method: "DELETE",
|
||||
path: "/api/v1/users/123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "GET with auth should prevent caching",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
auth: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "API endpoint should prevent caching",
|
||||
method: "GET",
|
||||
path: "/api/v1/users",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Health endpoint should not prevent caching",
|
||||
method: "GET",
|
||||
path: "/health",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||||
if tt.auth {
|
||||
req.Header.Set("Authorization", "Bearer token123")
|
||||
}
|
||||
|
||||
got := middleware.shouldPreventCaching(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("shouldPreventCaching() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue