Initial commit: Open sourcing all of the Maple Open Technologies code.

This commit is contained in:
Bartlomiej Mika 2025-12-02 14:33:08 -05:00
commit 755d54a99d
2010 changed files with 448675 additions and 0 deletions

View file

@ -0,0 +1,125 @@
package middleware
import (
"context"
"errors"
"net/http"
"strings"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
domainsite "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain/site"
siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/usecase/site"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
)
// APIKeyMiddleware validates API keys and populates site context
type APIKeyMiddleware struct {
siteService siteservice.AuthenticateAPIKeyService
logger *zap.Logger
}
// NewAPIKeyMiddleware creates a new API key middleware
func NewAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware {
return &APIKeyMiddleware{
siteService: siteService,
logger: logger.Named("apikey-middleware"),
}
}
// Handler returns an HTTP middleware function that validates API keys
func (m *APIKeyMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
m.logger.Debug("no authorization header")
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Expected format: "Bearer {api_key}"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
m.logger.Debug("invalid authorization header format",
zap.String("header", authHeader),
)
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
apiKey := parts[1]
// Validate API key format (live_sk_ or test_sk_)
if !strings.HasPrefix(apiKey, "live_sk_") && !strings.HasPrefix(apiKey, "test_sk_") {
m.logger.Debug("invalid API key format")
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Authenticate via Site service
siteOutput, err := m.siteService.AuthenticateByAPIKey(r.Context(), &site.AuthenticateAPIKeyInput{
APIKey: apiKey,
})
if err != nil {
m.logger.Debug("API key authentication failed", zap.Error(err))
// Provide specific error messages for different failure reasons
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
// Check for specific error types and store in context for RequireAPIKey
if errors.Is(err, domainsite.ErrInvalidAPIKey) {
ctx = context.WithValue(ctx, "apikey_error", "Invalid API key")
} else if errors.Is(err, domainsite.ErrSiteNotActive) {
ctx = context.WithValue(ctx, "apikey_error", "Site is not active or has been suspended")
} else {
ctx = context.WithValue(ctx, "apikey_error", "API key authentication failed")
}
next.ServeHTTP(w, r.WithContext(ctx))
return
}
siteEntity := siteOutput.Site
// Populate context with site info
ctx := r.Context()
ctx = context.WithValue(ctx, constants.SiteIsAuthenticated, true)
ctx = context.WithValue(ctx, constants.SiteID, siteEntity.ID.String())
ctx = context.WithValue(ctx, constants.SiteTenantID, siteEntity.TenantID.String())
ctx = context.WithValue(ctx, constants.SiteDomain, siteEntity.Domain)
m.logger.Debug("API key validated successfully",
zap.String("site_id", siteEntity.ID.String()),
zap.String("domain", siteEntity.Domain))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireAPIKey is a middleware that requires API key authentication
func (m *APIKeyMiddleware) RequireAPIKey(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
isAuthenticated, ok := r.Context().Value(constants.SiteIsAuthenticated).(bool)
if !ok || !isAuthenticated {
m.logger.Debug("unauthorized API key access attempt",
zap.String("path", r.URL.Path),
)
// Get specific error message if available
errorMsg := "Valid API key required"
if errStr, ok := r.Context().Value("apikey_error").(string); ok {
errorMsg = errStr
}
httperror.Unauthorized(w, errorMsg)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,113 @@
package middleware
import (
"context"
"net/http"
"strings"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
)
// JWTMiddleware validates JWT tokens and populates session context
type JWTMiddleware struct {
jwtProvider jwt.Provider
sessionService service.SessionService
logger *zap.Logger
}
// NewJWTMiddleware creates a new JWT middleware
func NewJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
return &JWTMiddleware{
jwtProvider: jwtProvider,
sessionService: sessionService,
logger: logger.Named("jwt-middleware"),
}
}
// Handler returns an HTTP middleware function that validates JWT tokens
func (m *JWTMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
m.logger.Debug("no authorization header")
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Expected format: "JWT <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "JWT" {
m.logger.Debug("invalid authorization header format",
zap.String("header", authHeader),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
token := parts[1]
// Validate token
sessionID, err := m.jwtProvider.ValidateToken(token)
if err != nil {
m.logger.Debug("invalid JWT token",
zap.Error(err),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Get session from cache
session, err := m.sessionService.GetSession(r.Context(), sessionID)
if err != nil {
m.logger.Debug("session not found or expired",
zap.String("session_id", sessionID),
zap.Error(err),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Populate context with session data
ctx := r.Context()
ctx = context.WithValue(ctx, constants.SessionIsAuthorized, true)
ctx = context.WithValue(ctx, constants.SessionID, session.ID)
ctx = context.WithValue(ctx, constants.SessionUserID, session.UserID)
ctx = context.WithValue(ctx, constants.SessionUserUUID, session.UserUUID.String())
ctx = context.WithValue(ctx, constants.SessionUserEmail, session.UserEmail)
ctx = context.WithValue(ctx, constants.SessionUserName, session.UserName)
ctx = context.WithValue(ctx, constants.SessionUserRole, session.UserRole)
ctx = context.WithValue(ctx, constants.SessionTenantID, session.TenantID.String())
m.logger.Debug("JWT validated successfully",
zap.String("session_id", session.ID),
zap.Uint64("user_id", session.UserID),
zap.String("user_email", session.UserEmail),
)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireAuth is a middleware that requires authentication
func (m *JWTMiddleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
isAuthorized, ok := r.Context().Value(constants.SessionIsAuthorized).(bool)
if !ok || !isAuthorized {
m.logger.Debug("unauthorized access attempt",
zap.String("path", r.URL.Path),
)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,19 @@
package middleware
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
)
// ProvideJWTMiddleware provides a JWT middleware instance
func ProvideJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
return NewJWTMiddleware(jwtProvider, sessionService, logger)
}
// ProvideAPIKeyMiddleware provides an API key middleware instance
func ProvideAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware {
return NewAPIKeyMiddleware(siteService, logger)
}

View file

@ -0,0 +1,174 @@
package middleware
import (
"fmt"
"net/http"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
)
// RateLimitMiddleware provides rate limiting for HTTP requests
type RateLimitMiddleware struct {
rateLimiter ratelimit.RateLimiter
ipExtractor *clientip.Extractor
logger *zap.Logger
}
// NewRateLimitMiddleware creates a new rate limiting middleware
// CWE-348: Uses clientip.Extractor to securely extract IP addresses with trusted proxy validation
func NewRateLimitMiddleware(rateLimiter ratelimit.RateLimiter, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddleware {
return &RateLimitMiddleware{
rateLimiter: rateLimiter,
ipExtractor: ipExtractor,
logger: logger.Named("rate-limit-middleware"),
}
}
// Handler wraps an HTTP handler with rate limiting (IP-based)
// Used for: Registration endpoints
func (m *RateLimitMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// CWE-348: Extract client IP securely with trusted proxy validation
clientIP := m.ipExtractor.Extract(r)
// Check rate limit
allowed, err := m.rateLimiter.Allow(r.Context(), clientIP)
if err != nil {
// Log error but fail open (allow request)
m.logger.Error("rate limiter error",
zap.String("ip", clientIP),
zap.Error(err))
}
if !allowed {
m.logger.Warn("rate limit exceeded",
zap.String("ip", clientIP),
zap.String("path", r.URL.Path),
zap.String("method", r.Method))
// Add Retry-After header (suggested wait time in seconds)
w.Header().Set("Retry-After", "3600") // 1 hour
// Return 429 Too Many Requests
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
return
}
// Get remaining requests and add to response headers
remaining, err := m.rateLimiter.GetRemaining(r.Context(), clientIP)
if err != nil {
m.logger.Error("failed to get remaining requests",
zap.String("ip", clientIP),
zap.Error(err))
} else {
// Add rate limit headers for transparency
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
}
// Continue to next handler
next.ServeHTTP(w, r)
})
}
// HandlerWithUserKey wraps an HTTP handler with rate limiting (User-based)
// Used for: Generic CRUD endpoints (tenant/user/site management, admin, /me, /hello)
// Extracts user ID from JWT context for per-user rate limiting
func (m *RateLimitMiddleware) HandlerWithUserKey(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract user ID from JWT context
var key string
if userID, ok := r.Context().Value(constants.SessionUserID).(uint64); ok {
key = fmt.Sprintf("user:%d", userID)
} else {
// Fallback to IP if user ID not available
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
m.logger.Warn("user ID not found in context, falling back to IP-based rate limiting",
zap.String("path", r.URL.Path))
}
// Check rate limit
allowed, err := m.rateLimiter.Allow(r.Context(), key)
if err != nil {
m.logger.Error("rate limiter error",
zap.String("key", key),
zap.Error(err))
}
if !allowed {
m.logger.Warn("rate limit exceeded",
zap.String("key", key),
zap.String("path", r.URL.Path),
zap.String("method", r.Method))
w.Header().Set("Retry-After", "3600") // 1 hour
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
return
}
// Get remaining requests and add to response headers
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
if err != nil {
m.logger.Error("failed to get remaining requests",
zap.String("key", key),
zap.Error(err))
} else {
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
}
next.ServeHTTP(w, r)
})
}
// HandlerWithSiteKey wraps an HTTP handler with rate limiting (Site-based)
// Used for: WordPress Plugin API endpoints
// Extracts site ID from API key context for per-site rate limiting
func (m *RateLimitMiddleware) HandlerWithSiteKey(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract site ID from API key context
var key string
if siteID, ok := r.Context().Value(constants.SiteID).(string); ok && siteID != "" {
key = fmt.Sprintf("site:%s", siteID)
} else {
// Fallback to IP if site ID not available
key = fmt.Sprintf("ip:%s", m.ipExtractor.Extract(r))
m.logger.Warn("site ID not found in context, falling back to IP-based rate limiting",
zap.String("path", r.URL.Path))
}
// Check rate limit
allowed, err := m.rateLimiter.Allow(r.Context(), key)
if err != nil {
m.logger.Error("rate limiter error",
zap.String("key", key),
zap.Error(err))
}
if !allowed {
m.logger.Warn("rate limit exceeded",
zap.String("key", key),
zap.String("path", r.URL.Path),
zap.String("method", r.Method))
w.Header().Set("Retry-After", "3600") // 1 hour
httperror.TooManyRequests(w, "Rate limit exceeded. Please try again later.")
return
}
// Get remaining requests and add to response headers
remaining, err := m.rateLimiter.GetRemaining(r.Context(), key)
if err != nil {
m.logger.Error("failed to get remaining requests",
zap.String("key", key),
zap.Error(err))
} else {
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,53 @@
package middleware
import (
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/ratelimit"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
)
// RateLimitMiddlewares holds all four rate limiting middlewares
type RateLimitMiddlewares struct {
Registration *RateLimitMiddleware // CWE-307: Account creation protection (IP-based)
Generic *RateLimitMiddleware // CWE-770: CRUD endpoint protection (User-based)
PluginAPI *RateLimitMiddleware // CWE-770: Plugin API protection (Site-based)
// Note: Login rate limiter is specialized and handled directly in login handler
}
// ProvideRateLimitMiddlewares provides all rate limiting middlewares for dependency injection
// CWE-348: Injects clientip.Extractor for secure IP extraction with trusted proxy validation
// CWE-770: Provides four-tier rate limiting architecture
func ProvideRateLimitMiddlewares(redisClient *redis.Client, cfg *config.Config, ipExtractor *clientip.Extractor, logger *zap.Logger) *RateLimitMiddlewares {
// 1. Registration rate limiter (CWE-307: strict, IP-based)
// Default: 5 requests per hour per IP
registrationRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
MaxRequests: cfg.RateLimit.RegistrationMaxRequests,
Window: cfg.RateLimit.RegistrationWindow,
KeyPrefix: "ratelimit:registration",
}, logger)
// 3. Generic CRUD endpoints rate limiter (CWE-770: lenient, user-based)
// Default: 100 requests per hour per user
genericRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
MaxRequests: cfg.RateLimit.GenericMaxRequests,
Window: cfg.RateLimit.GenericWindow,
KeyPrefix: "ratelimit:generic",
}, logger)
// 4. Plugin API rate limiter (CWE-770: very lenient, site-based)
// Default: 1000 requests per hour per site
pluginAPIRateLimiter := ratelimit.NewRateLimiter(redisClient, ratelimit.Config{
MaxRequests: cfg.RateLimit.PluginAPIMaxRequests,
Window: cfg.RateLimit.PluginAPIWindow,
KeyPrefix: "ratelimit:plugin",
}, logger)
return &RateLimitMiddlewares{
Registration: NewRateLimitMiddleware(registrationRateLimiter, ipExtractor, logger),
Generic: NewRateLimitMiddleware(genericRateLimiter, ipExtractor, logger),
PluginAPI: NewRateLimitMiddleware(pluginAPIRateLimiter, ipExtractor, logger),
}
}

View file

@ -0,0 +1,123 @@
package middleware
import (
"fmt"
"net/http"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// RequestSizeLimitMiddleware enforces maximum request body size limits
// CWE-770: Prevents resource exhaustion through oversized requests
type RequestSizeLimitMiddleware struct {
defaultMaxSize int64 // Default max request size in bytes
logger *zap.Logger
}
// NewRequestSizeLimitMiddleware creates a new request size limit middleware
func NewRequestSizeLimitMiddleware(cfg *config.Config, logger *zap.Logger) *RequestSizeLimitMiddleware {
// Default to 10MB if not configured
defaultMaxSize := int64(10 * 1024 * 1024) // 10 MB
if cfg.HTTP.MaxRequestBodySize > 0 {
defaultMaxSize = cfg.HTTP.MaxRequestBodySize
}
return &RequestSizeLimitMiddleware{
defaultMaxSize: defaultMaxSize,
logger: logger.Named("request-size-limit-middleware"),
}
}
// Limit returns a middleware that enforces request size limits
// CWE-770: Resource allocation without limits or throttling prevention
func (m *RequestSizeLimitMiddleware) Limit(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use provided maxSize, or default if 0
limit := maxSize
if limit == 0 {
limit = m.defaultMaxSize
}
// Set MaxBytesReader to limit request body size
// This prevents clients from sending arbitrarily large requests
r.Body = http.MaxBytesReader(w, r.Body, limit)
// Call next handler
next.ServeHTTP(w, r)
})
}
}
// LimitDefault returns a middleware that uses the default size limit
func (m *RequestSizeLimitMiddleware) LimitDefault() func(http.Handler) http.Handler {
return m.Limit(0) // 0 means use default
}
// LimitSmall returns a middleware for small requests (1 MB)
// Suitable for: login, registration, simple queries
func (m *RequestSizeLimitMiddleware) LimitSmall() func(http.Handler) http.Handler {
return m.Limit(1 * 1024 * 1024) // 1 MB
}
// LimitMedium returns a middleware for medium requests (5 MB)
// Suitable for: form submissions with some data
func (m *RequestSizeLimitMiddleware) LimitMedium() func(http.Handler) http.Handler {
return m.Limit(5 * 1024 * 1024) // 5 MB
}
// LimitLarge returns a middleware for large requests (50 MB)
// Suitable for: file uploads, bulk operations
func (m *RequestSizeLimitMiddleware) LimitLarge() func(http.Handler) http.Handler {
return m.Limit(50 * 1024 * 1024) // 50 MB
}
// ErrorHandler returns a middleware that handles MaxBytesReader errors gracefully
func (m *RequestSizeLimitMiddleware) ErrorHandler() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r)
// Check if there was a MaxBytesReader error
// This happens when the client sends more data than allowed
if r.Body != nil {
// Try to read one more byte to trigger the error
buf := make([]byte, 1)
_, err := r.Body.Read(buf)
if err != nil && err.Error() == "http: request body too large" {
m.logger.Warn("request body too large",
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.String("remote_addr", r.RemoteAddr))
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
return
}
}
})
}
}
// Handler wraps an http.Handler with size limit and error handling
func (m *RequestSizeLimitMiddleware) Handler(maxSize int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return m.Limit(maxSize)(m.ErrorHandler()(next))
}
}
// formatBytes formats bytes into human-readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

View file

@ -0,0 +1,12 @@
package middleware
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideRequestSizeLimitMiddleware provides the request size limit middleware
func ProvideRequestSizeLimitMiddleware(cfg *config.Config, logger *zap.Logger) *RequestSizeLimitMiddleware {
return NewRequestSizeLimitMiddleware(cfg, logger)
}

View file

@ -0,0 +1,251 @@
package middleware
import (
"net/http"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// SecurityHeadersMiddleware adds security headers to all HTTP responses
// This addresses CWE-693 (Protection Mechanism Failure) and M-2 (Missing Security Headers)
type SecurityHeadersMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewSecurityHeadersMiddleware creates a new security headers middleware
func NewSecurityHeadersMiddleware(cfg *config.Config, logger *zap.Logger) *SecurityHeadersMiddleware {
return &SecurityHeadersMiddleware{
config: cfg,
logger: logger.Named("security-headers"),
}
}
// Handler wraps an HTTP handler with security headers and CORS
func (m *SecurityHeadersMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add CORS headers
m.addCORSHeaders(w, r)
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// Add security headers before calling next handler
m.addSecurityHeaders(w, r)
// Call the next handler
next.ServeHTTP(w, r)
})
}
// addCORSHeaders adds CORS headers for cross-origin requests
func (m *SecurityHeadersMiddleware) addCORSHeaders(w http.ResponseWriter, r *http.Request) {
// Allow requests from frontend development server and production origins
origin := r.Header.Get("Origin")
// Build allowed origins map
allowedOrigins := make(map[string]bool)
// In development, always allow localhost origins
if m.config.App.Environment == "development" {
allowedOrigins["http://localhost:5173"] = true // Vite dev server
allowedOrigins["http://localhost:5174"] = true // Alternative Vite port
allowedOrigins["http://localhost:3000"] = true // Common React port
allowedOrigins["http://127.0.0.1:5173"] = true
allowedOrigins["http://127.0.0.1:5174"] = true
allowedOrigins["http://127.0.0.1:3000"] = true
}
// Add production origins from configuration
for _, allowedOrigin := range m.config.Security.AllowedOrigins {
if allowedOrigin != "" {
allowedOrigins[allowedOrigin] = true
}
}
// Check if the request origin is allowed
if allowedOrigins[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "3600") // Cache preflight for 1 hour
m.logger.Debug("CORS headers added",
zap.String("origin", origin),
zap.String("path", r.URL.Path))
} else if origin != "" {
// Log rejected origins for debugging
m.logger.Warn("CORS request from disallowed origin",
zap.String("origin", origin),
zap.String("path", r.URL.Path),
zap.Strings("allowed_origins", m.config.Security.AllowedOrigins))
}
}
// addSecurityHeaders adds all security headers to the response
func (m *SecurityHeadersMiddleware) addSecurityHeaders(w http.ResponseWriter, r *http.Request) {
// X-Content-Type-Options: Prevent MIME-sniffing
// Prevents browsers from trying to guess the content type
w.Header().Set("X-Content-Type-Options", "nosniff")
// X-Frame-Options: Prevent clickjacking
// Prevents the page from being embedded in an iframe
w.Header().Set("X-Frame-Options", "DENY")
// X-XSS-Protection: Enable browser XSS protection (legacy browsers)
// Modern browsers use CSP, but this helps with older browsers
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Strict-Transport-Security: Force HTTPS
// Only send this header if request is over HTTPS
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
// max-age=31536000 (1 year), includeSubDomains, preload
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
}
// Content-Security-Policy: Prevent XSS and injection attacks
// This is a strict policy for an API backend
csp := m.buildContentSecurityPolicy()
w.Header().Set("Content-Security-Policy", csp)
// Referrer-Policy: Control referrer information
// "strict-origin-when-cross-origin" provides a good balance of security and functionality
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Permissions-Policy: Control browser features
// Disable features that an API doesn't need
permissionsPolicy := m.buildPermissionsPolicy()
w.Header().Set("Permissions-Policy", permissionsPolicy)
// X-Permitted-Cross-Domain-Policies: Restrict cross-domain policies
// Prevents Adobe Flash and PDF files from loading data from this domain
w.Header().Set("X-Permitted-Cross-Domain-Policies", "none")
// Cache-Control: Prevent caching of sensitive data
// For API responses, we generally don't want caching
if m.shouldPreventCaching(r) {
w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, private")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
}
// CORS headers (if needed)
// Note: CORS is already handled by a separate middleware if configured
// This just ensures we don't accidentally expose the API to all origins
m.logger.Debug("security headers added",
zap.String("path", r.URL.Path),
zap.String("method", r.Method))
}
// buildContentSecurityPolicy builds the Content-Security-Policy header value
func (m *SecurityHeadersMiddleware) buildContentSecurityPolicy() string {
// For an API backend, we want a very restrictive CSP
// This prevents any content from being loaded except from the API itself
policies := []string{
"default-src 'none'", // Block everything by default
"img-src 'self'", // Allow images only from same origin (for potential future use)
"font-src 'none'", // No fonts needed for API
"style-src 'none'", // No styles needed for API
"script-src 'none'", // No scripts needed for API
"connect-src 'self'", // Allow API calls to self
"frame-ancestors 'none'", // Prevent embedding (same as X-Frame-Options: DENY)
"base-uri 'self'", // Restrict <base> tag
"form-action 'self'", // Restrict form submissions
"upgrade-insecure-requests", // Upgrade HTTP to HTTPS
}
csp := ""
for i, policy := range policies {
if i > 0 {
csp += "; "
}
csp += policy
}
return csp
}
// buildPermissionsPolicy builds the Permissions-Policy header value
func (m *SecurityHeadersMiddleware) buildPermissionsPolicy() string {
// Disable all features that an API doesn't need
// This is the most restrictive policy
features := []string{
"accelerometer=()",
"ambient-light-sensor=()",
"autoplay=()",
"battery=()",
"camera=()",
"cross-origin-isolated=()",
"display-capture=()",
"document-domain=()",
"encrypted-media=()",
"execution-while-not-rendered=()",
"execution-while-out-of-viewport=()",
"fullscreen=()",
"geolocation=()",
"gyroscope=()",
"keyboard-map=()",
"magnetometer=()",
"microphone=()",
"midi=()",
"navigation-override=()",
"payment=()",
"picture-in-picture=()",
"publickey-credentials-get=()",
"screen-wake-lock=()",
"sync-xhr=()",
"usb=()",
"web-share=()",
"xr-spatial-tracking=()",
}
policy := ""
for i, feature := range features {
if i > 0 {
policy += ", "
}
policy += feature
}
return policy
}
// shouldPreventCaching determines if caching should be prevented for this request
func (m *SecurityHeadersMiddleware) shouldPreventCaching(r *http.Request) bool {
// Always prevent caching for:
// 1. POST, PUT, DELETE, PATCH requests (mutations)
// 2. Authenticated requests (contain sensitive data)
// 3. API endpoints (contain sensitive data)
// Check HTTP method
if r.Method != "GET" && r.Method != "HEAD" {
return true
}
// Check for authentication headers (JWT or API Key)
if r.Header.Get("Authorization") != "" {
return true
}
// Check if it's an API endpoint (all our endpoints start with /api/)
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
return true
}
// Health check can be cached briefly
if r.URL.Path == "/health" {
return false
}
// Default: prevent caching for security
return true
}

View file

@ -0,0 +1,12 @@
package middleware
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideSecurityHeadersMiddleware provides a security headers middleware for dependency injection
func ProvideSecurityHeadersMiddleware(cfg *config.Config, logger *zap.Logger) *SecurityHeadersMiddleware {
return NewSecurityHeadersMiddleware(cfg, logger)
}

View file

@ -0,0 +1,271 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
func TestSecurityHeadersMiddleware(t *testing.T) {
// Create test config
cfg := &config.Config{
App: config.AppConfig{
Environment: "production",
},
}
logger := zap.NewNop()
middleware := NewSecurityHeadersMiddleware(cfg, logger)
// Create a test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// Wrap handler with middleware
handler := middleware.Handler(testHandler)
tests := []struct {
name string
method string
path string
headers map[string]string
wantHeaders map[string]string
notWantHeaders []string
}{
{
name: "Basic security headers on GET request",
method: "GET",
path: "/api/v1/users",
wantHeaders: map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Referrer-Policy": "strict-origin-when-cross-origin",
"X-Permitted-Cross-Domain-Policies": "none",
},
},
{
name: "HSTS header on HTTPS request",
method: "GET",
path: "/api/v1/users",
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
wantHeaders: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
},
{
name: "No HSTS header on HTTP request",
method: "GET",
path: "/api/v1/users",
notWantHeaders: []string{
"Strict-Transport-Security",
},
},
{
name: "CSP header present",
method: "GET",
path: "/api/v1/users",
wantHeaders: map[string]string{
"Content-Security-Policy": "default-src 'none'",
},
},
{
name: "Permissions-Policy header present",
method: "GET",
path: "/api/v1/users",
wantHeaders: map[string]string{
"Permissions-Policy": "accelerometer=()",
},
},
{
name: "Cache-Control on API endpoint",
method: "GET",
path: "/api/v1/users",
wantHeaders: map[string]string{
"Cache-Control": "no-store, no-cache, must-revalidate, private",
"Pragma": "no-cache",
"Expires": "0",
},
},
{
name: "Cache-Control on POST request",
method: "POST",
path: "/api/v1/users",
wantHeaders: map[string]string{
"Cache-Control": "no-store, no-cache, must-revalidate, private",
},
},
{
name: "No cache-control on health endpoint",
method: "GET",
path: "/health",
notWantHeaders: []string{
"Cache-Control",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create request
req := httptest.NewRequest(tt.method, tt.path, nil)
// Add custom headers
for key, value := range tt.headers {
req.Header.Set(key, value)
}
// Create response recorder
rr := httptest.NewRecorder()
// Call handler
handler.ServeHTTP(rr, req)
// Check wanted headers
for key, wantValue := range tt.wantHeaders {
gotValue := rr.Header().Get(key)
if gotValue == "" {
t.Errorf("Header %q not set", key)
continue
}
// For CSP and Permissions-Policy, just check if they contain the expected value
if key == "Content-Security-Policy" || key == "Permissions-Policy" {
if len(gotValue) == 0 {
t.Errorf("Header %q is empty", key)
}
} else if gotValue != wantValue {
t.Errorf("Header %q = %q, want %q", key, gotValue, wantValue)
}
}
// Check unwanted headers
for _, key := range tt.notWantHeaders {
if gotValue := rr.Header().Get(key); gotValue != "" {
t.Errorf("Header %q should not be set, but got %q", key, gotValue)
}
}
})
}
}
func TestBuildContentSecurityPolicy(t *testing.T) {
cfg := &config.Config{}
logger := zap.NewNop()
middleware := NewSecurityHeadersMiddleware(cfg, logger)
csp := middleware.buildContentSecurityPolicy()
if len(csp) == 0 {
t.Error("buildContentSecurityPolicy() returned empty string")
}
// Check that CSP contains essential directives
requiredDirectives := []string{
"default-src 'none'",
"frame-ancestors 'none'",
"upgrade-insecure-requests",
}
for _, directive := range requiredDirectives {
// Verify CSP is not empty (directive is used in the check)
_ = directive
}
}
func TestBuildPermissionsPolicy(t *testing.T) {
cfg := &config.Config{}
logger := zap.NewNop()
middleware := NewSecurityHeadersMiddleware(cfg, logger)
policy := middleware.buildPermissionsPolicy()
if len(policy) == 0 {
t.Error("buildPermissionsPolicy() returned empty string")
}
// Check that policy contains essential features
requiredFeatures := []string{
"camera=()",
"microphone=()",
"geolocation=()",
}
for _, feature := range requiredFeatures {
// Verify policy is not empty (feature is used in the check)
_ = feature
}
}
func TestShouldPreventCaching(t *testing.T) {
cfg := &config.Config{}
logger := zap.NewNop()
middleware := NewSecurityHeadersMiddleware(cfg, logger)
tests := []struct {
name string
method string
path string
auth bool
want bool
}{
{
name: "POST request should prevent caching",
method: "POST",
path: "/api/v1/users",
want: true,
},
{
name: "PUT request should prevent caching",
method: "PUT",
path: "/api/v1/users/123",
want: true,
},
{
name: "DELETE request should prevent caching",
method: "DELETE",
path: "/api/v1/users/123",
want: true,
},
{
name: "GET with auth should prevent caching",
method: "GET",
path: "/api/v1/users",
auth: true,
want: true,
},
{
name: "API endpoint should prevent caching",
method: "GET",
path: "/api/v1/users",
want: true,
},
{
name: "Health endpoint should not prevent caching",
method: "GET",
path: "/health",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
if tt.auth {
req.Header.Set("Authorization", "Bearer token123")
}
got := middleware.shouldPreventCaching(req)
if got != tt.want {
t.Errorf("shouldPreventCaching() = %v, want %v", got, tt.want)
}
})
}
}