package middleware import ( "net/http" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config" ) // SecurityHeadersMiddleware adds security headers to all HTTP responses // This addresses CWE-693 (Protection Mechanism Failure) and M-2 (Missing Security Headers) type SecurityHeadersMiddleware struct { config *config.Config logger *zap.Logger } // NewSecurityHeadersMiddleware creates a new security headers middleware func NewSecurityHeadersMiddleware(cfg *config.Config, logger *zap.Logger) *SecurityHeadersMiddleware { return &SecurityHeadersMiddleware{ config: cfg, logger: logger.Named("security-headers"), } } // Handler wraps an HTTP handler with security headers and CORS func (m *SecurityHeadersMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Add CORS headers m.addCORSHeaders(w, r) // Handle preflight requests if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // Add security headers before calling next handler m.addSecurityHeaders(w, r) // Call the next handler next.ServeHTTP(w, r) }) } // addCORSHeaders adds CORS headers for cross-origin requests func (m *SecurityHeadersMiddleware) addCORSHeaders(w http.ResponseWriter, r *http.Request) { // Allow requests from frontend development server and production origins origin := r.Header.Get("Origin") // Build allowed origins map allowedOrigins := make(map[string]bool) // In development, always allow localhost origins if m.config.App.Environment == "development" { allowedOrigins["http://localhost:5173"] = true // Vite dev server allowedOrigins["http://localhost:5174"] = true // Alternative Vite port allowedOrigins["http://localhost:3000"] = true // Common React port allowedOrigins["http://127.0.0.1:5173"] = true allowedOrigins["http://127.0.0.1:5174"] = true allowedOrigins["http://127.0.0.1:3000"] = true } // Add production origins from configuration for _, allowedOrigin := range m.config.Security.AllowedOrigins { if allowedOrigin != "" { allowedOrigins[allowedOrigin] = true } } // Check if the request origin is allowed if allowedOrigins[origin] { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Tenant-ID") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "3600") // Cache preflight for 1 hour m.logger.Debug("CORS headers added", zap.String("origin", origin), zap.String("path", r.URL.Path)) } else if origin != "" { // Log rejected origins for debugging m.logger.Warn("CORS request from disallowed origin", zap.String("origin", origin), zap.String("path", r.URL.Path), zap.Strings("allowed_origins", m.config.Security.AllowedOrigins)) } } // addSecurityHeaders adds all security headers to the response func (m *SecurityHeadersMiddleware) addSecurityHeaders(w http.ResponseWriter, r *http.Request) { // X-Content-Type-Options: Prevent MIME-sniffing // Prevents browsers from trying to guess the content type w.Header().Set("X-Content-Type-Options", "nosniff") // X-Frame-Options: Prevent clickjacking // Prevents the page from being embedded in an iframe w.Header().Set("X-Frame-Options", "DENY") // X-XSS-Protection: Enable browser XSS protection (legacy browsers) // Modern browsers use CSP, but this helps with older browsers w.Header().Set("X-XSS-Protection", "1; mode=block") // Strict-Transport-Security: Force HTTPS // Only send this header if request is over HTTPS if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { // max-age=31536000 (1 year), includeSubDomains, preload w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") } // Content-Security-Policy: Prevent XSS and injection attacks // This is a strict policy for an API backend csp := m.buildContentSecurityPolicy() w.Header().Set("Content-Security-Policy", csp) // Referrer-Policy: Control referrer information // "strict-origin-when-cross-origin" provides a good balance of security and functionality w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions-Policy: Control browser features // Disable features that an API doesn't need permissionsPolicy := m.buildPermissionsPolicy() w.Header().Set("Permissions-Policy", permissionsPolicy) // X-Permitted-Cross-Domain-Policies: Restrict cross-domain policies // Prevents Adobe Flash and PDF files from loading data from this domain w.Header().Set("X-Permitted-Cross-Domain-Policies", "none") // Cache-Control: Prevent caching of sensitive data // For API responses, we generally don't want caching if m.shouldPreventCaching(r) { w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, private") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "0") } // CORS headers (if needed) // Note: CORS is already handled by a separate middleware if configured // This just ensures we don't accidentally expose the API to all origins m.logger.Debug("security headers added", zap.String("path", r.URL.Path), zap.String("method", r.Method)) } // buildContentSecurityPolicy builds the Content-Security-Policy header value func (m *SecurityHeadersMiddleware) buildContentSecurityPolicy() string { // For an API backend, we want a very restrictive CSP // This prevents any content from being loaded except from the API itself policies := []string{ "default-src 'none'", // Block everything by default "img-src 'self'", // Allow images only from same origin (for potential future use) "font-src 'none'", // No fonts needed for API "style-src 'none'", // No styles needed for API "script-src 'none'", // No scripts needed for API "connect-src 'self'", // Allow API calls to self "frame-ancestors 'none'", // Prevent embedding (same as X-Frame-Options: DENY) "base-uri 'self'", // Restrict tag "form-action 'self'", // Restrict form submissions "upgrade-insecure-requests", // Upgrade HTTP to HTTPS } csp := "" for i, policy := range policies { if i > 0 { csp += "; " } csp += policy } return csp } // buildPermissionsPolicy builds the Permissions-Policy header value func (m *SecurityHeadersMiddleware) buildPermissionsPolicy() string { // Disable all features that an API doesn't need // This is the most restrictive policy features := []string{ "accelerometer=()", "ambient-light-sensor=()", "autoplay=()", "battery=()", "camera=()", "cross-origin-isolated=()", "display-capture=()", "document-domain=()", "encrypted-media=()", "execution-while-not-rendered=()", "execution-while-out-of-viewport=()", "fullscreen=()", "geolocation=()", "gyroscope=()", "keyboard-map=()", "magnetometer=()", "microphone=()", "midi=()", "navigation-override=()", "payment=()", "picture-in-picture=()", "publickey-credentials-get=()", "screen-wake-lock=()", "sync-xhr=()", "usb=()", "web-share=()", "xr-spatial-tracking=()", } policy := "" for i, feature := range features { if i > 0 { policy += ", " } policy += feature } return policy } // shouldPreventCaching determines if caching should be prevented for this request func (m *SecurityHeadersMiddleware) shouldPreventCaching(r *http.Request) bool { // Always prevent caching for: // 1. POST, PUT, DELETE, PATCH requests (mutations) // 2. Authenticated requests (contain sensitive data) // 3. API endpoints (contain sensitive data) // Check HTTP method if r.Method != "GET" && r.Method != "HEAD" { return true } // Check for authentication headers (JWT or API Key) if r.Header.Get("Authorization") != "" { return true } // Check if it's an API endpoint (all our endpoints start with /api/) if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" { return true } // Health check can be cached briefly if r.URL.Path == "/health" { return false } // Default: prevent caching for security return true }