package middleware import ( "context" "net/http" "strings" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt" ) // JWTMiddleware validates JWT tokens and populates session context type JWTMiddleware struct { jwtProvider jwt.Provider sessionService service.SessionService logger *zap.Logger } // NewJWTMiddleware creates a new JWT middleware func NewJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware { return &JWTMiddleware{ jwtProvider: jwtProvider, sessionService: sessionService, logger: logger.Named("jwt-middleware"), } } // Handler returns an HTTP middleware function that validates JWT tokens func (m *JWTMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get Authorization header authHeader := r.Header.Get("Authorization") if authHeader == "" { m.logger.Debug("no authorization header") ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Expected format: "JWT " parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "JWT" { m.logger.Debug("invalid authorization header format", zap.String("header", authHeader), ) ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false) next.ServeHTTP(w, r.WithContext(ctx)) return } token := parts[1] // Validate token sessionID, err := m.jwtProvider.ValidateToken(token) if err != nil { m.logger.Debug("invalid JWT token", zap.Error(err), ) ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Get session from cache session, err := m.sessionService.GetSession(r.Context(), sessionID) if err != nil { m.logger.Debug("session not found or expired", zap.String("session_id", sessionID), zap.Error(err), ) ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false) next.ServeHTTP(w, r.WithContext(ctx)) return } // Populate context with session data ctx := r.Context() ctx = context.WithValue(ctx, constants.SessionIsAuthorized, true) ctx = context.WithValue(ctx, constants.SessionID, session.ID) ctx = context.WithValue(ctx, constants.SessionUserID, session.UserID) ctx = context.WithValue(ctx, constants.SessionUserUUID, session.UserUUID.String()) ctx = context.WithValue(ctx, constants.SessionUserEmail, session.UserEmail) ctx = context.WithValue(ctx, constants.SessionUserName, session.UserName) ctx = context.WithValue(ctx, constants.SessionUserRole, session.UserRole) ctx = context.WithValue(ctx, constants.SessionTenantID, session.TenantID.String()) m.logger.Debug("JWT validated successfully", zap.String("session_id", session.ID), zap.Uint64("user_id", session.UserID), zap.String("user_email", session.UserEmail), ) next.ServeHTTP(w, r.WithContext(ctx)) }) } // RequireAuth is a middleware that requires authentication func (m *JWTMiddleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { isAuthorized, ok := r.Context().Value(constants.SessionIsAuthorized).(bool) if !ok || !isAuthorized { m.logger.Debug("unauthorized access attempt", zap.String("path", r.URL.Path), ) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) }