125 lines
4.3 KiB
Go
125 lines
4.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
|
domainsite "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain/site"
|
|
siteservice "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service/site"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/usecase/site"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/httperror"
|
|
)
|
|
|
|
// APIKeyMiddleware validates API keys and populates site context
|
|
type APIKeyMiddleware struct {
|
|
siteService siteservice.AuthenticateAPIKeyService
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewAPIKeyMiddleware creates a new API key middleware
|
|
func NewAPIKeyMiddleware(siteService siteservice.AuthenticateAPIKeyService, logger *zap.Logger) *APIKeyMiddleware {
|
|
return &APIKeyMiddleware{
|
|
siteService: siteService,
|
|
logger: logger.Named("apikey-middleware"),
|
|
}
|
|
}
|
|
|
|
// Handler returns an HTTP middleware function that validates API keys
|
|
func (m *APIKeyMiddleware) Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
m.logger.Debug("no authorization header")
|
|
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Expected format: "Bearer {api_key}"
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
|
m.logger.Debug("invalid authorization header format",
|
|
zap.String("header", authHeader),
|
|
)
|
|
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
apiKey := parts[1]
|
|
|
|
// Validate API key format (live_sk_ or test_sk_)
|
|
if !strings.HasPrefix(apiKey, "live_sk_") && !strings.HasPrefix(apiKey, "test_sk_") {
|
|
m.logger.Debug("invalid API key format")
|
|
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Authenticate via Site service
|
|
siteOutput, err := m.siteService.AuthenticateByAPIKey(r.Context(), &site.AuthenticateAPIKeyInput{
|
|
APIKey: apiKey,
|
|
})
|
|
if err != nil {
|
|
m.logger.Debug("API key authentication failed", zap.Error(err))
|
|
|
|
// Provide specific error messages for different failure reasons
|
|
ctx := context.WithValue(r.Context(), constants.SiteIsAuthenticated, false)
|
|
|
|
// Check for specific error types and store in context for RequireAPIKey
|
|
if errors.Is(err, domainsite.ErrInvalidAPIKey) {
|
|
ctx = context.WithValue(ctx, "apikey_error", "Invalid API key")
|
|
} else if errors.Is(err, domainsite.ErrSiteNotActive) {
|
|
ctx = context.WithValue(ctx, "apikey_error", "Site is not active or has been suspended")
|
|
} else {
|
|
ctx = context.WithValue(ctx, "apikey_error", "API key authentication failed")
|
|
}
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
siteEntity := siteOutput.Site
|
|
|
|
// Populate context with site info
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, constants.SiteIsAuthenticated, true)
|
|
ctx = context.WithValue(ctx, constants.SiteID, siteEntity.ID.String())
|
|
ctx = context.WithValue(ctx, constants.SiteTenantID, siteEntity.TenantID.String())
|
|
ctx = context.WithValue(ctx, constants.SiteDomain, siteEntity.Domain)
|
|
|
|
m.logger.Debug("API key validated successfully",
|
|
zap.String("site_id", siteEntity.ID.String()),
|
|
zap.String("domain", siteEntity.Domain))
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireAPIKey is a middleware that requires API key authentication
|
|
func (m *APIKeyMiddleware) RequireAPIKey(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
isAuthenticated, ok := r.Context().Value(constants.SiteIsAuthenticated).(bool)
|
|
if !ok || !isAuthenticated {
|
|
m.logger.Debug("unauthorized API key access attempt",
|
|
zap.String("path", r.URL.Path),
|
|
)
|
|
|
|
// Get specific error message if available
|
|
errorMsg := "Valid API key required"
|
|
if errStr, ok := r.Context().Value("apikey_error").(string); ok {
|
|
errorMsg = errStr
|
|
}
|
|
|
|
httperror.Unauthorized(w, errorMsg)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|