package middleware import ( "context" "errors" "net/http" "strings" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants" domainsite "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain/site" siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/usecase/site" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror" ) // APIKeyMiddleware validates API keys and populates site context type APIKeyMiddleware struct { siteService siteservice.AuthenticateAPIKeyService logger *zap.Logger } // NewAPIKeyMiddleware creates a new API key middleware func NewAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware { return &APIKeyMiddleware{ siteService: siteService, logger: logger.Named("apikey-middleware"), } } // Handler returns an HTTP middleware function that validates API keys func (m *APIKeyMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { m.logger.Debug("no authorization header") ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Expected format: "Bearer {api_key}" parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { m.logger.Debug("invalid authorization header format", zap.String("header", authHeader), ) ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false) next.ServeHTTP(w, r.WithContext(ctx)) return } apiKey := parts[1] // Validate API key format (live_sk_ or test_sk_) if !strings.HasPrefix(apiKey, "live_sk_") && !strings.HasPrefix(apiKey, "test_sk_") { m.logger.Debug("invalid API key format") ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Authenticate via Site service siteOutput, err := m.siteService.AuthenticateByAPIKey(r.Context(), &site.AuthenticateAPIKeyInput{ APIKey: apiKey, }) if err != nil { m.logger.Debug("API key authentication failed", zap.Error(err)) // Provide specific error messages for different failure reasons ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false) // Check for specific error types and store in context for RequireAPIKey if errors.Is(err, domainsite.ErrInvalidAPIKey) { ctx = context.WithValue(ctx, "apikey_error", "Invalid API key") } else if errors.Is(err, domainsite.ErrSiteNotActive) { ctx = context.WithValue(ctx, "apikey_error", "Site is not active or has been suspended") } else { ctx = context.WithValue(ctx, "apikey_error", "API key authentication failed") } next.ServeHTTP(w, r.WithContext(ctx)) return } siteEntity := siteOutput.Site // Populate context with site info ctx := r.Context() ctx = context.WithValue(ctx, constants.SiteIsAuthenticated, true) ctx = context.WithValue(ctx, constants.SiteID, siteEntity.ID.String()) ctx = context.WithValue(ctx, constants.SiteTenantID, siteEntity.TenantID.String()) ctx = context.WithValue(ctx, constants.SiteDomain, siteEntity.Domain) m.logger.Debug("API key validated successfully", zap.String("site_id", siteEntity.ID.String()), zap.String("domain", siteEntity.Domain)) next.ServeHTTP(w, r.WithContext(ctx)) }) } // RequireAPIKey is a middleware that requires API key authentication func (m *APIKeyMiddleware) RequireAPIKey(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { isAuthenticated, ok := r.Context().Value(constants.SiteIsAuthenticated).(bool) if !ok || !isAuthenticated { m.logger.Debug("unauthorized API key access attempt", zap.String("path", r.URL.Path), ) // Get specific error message if available errorMsg := "Valid API key required" if errStr, ok := r.Context().Value("apikey_error").(string); ok { errorMsg = errStr } httperror.Unauthorized(w, errorMsg) return } next.ServeHTTP(w, r) }) }