monorepo/cloud/maplepress-backend/internal/http/middleware/jwt.go

113 lines
3.7 KiB
Go

package middleware
import (
"context"
"net/http"
"strings"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
)
// JWTMiddleware validates JWT tokens and populates session context
type JWTMiddleware struct {
jwtProvider jwt.Provider
sessionService service.SessionService
logger *zap.Logger
}
// NewJWTMiddleware creates a new JWT middleware
func NewJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
return &JWTMiddleware{
jwtProvider: jwtProvider,
sessionService: sessionService,
logger: logger.Named("jwt-middleware"),
}
}
// Handler returns an HTTP middleware function that validates JWT tokens
func (m *JWTMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get Authorization header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
m.logger.Debug("no authorization header")
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Expected format: "JWT <token>"
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "JWT" {
m.logger.Debug("invalid authorization header format",
zap.String("header", authHeader),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
token := parts[1]
// Validate token
sessionID, err := m.jwtProvider.ValidateToken(token)
if err != nil {
m.logger.Debug("invalid JWT token",
zap.Error(err),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Get session from cache
session, err := m.sessionService.GetSession(r.Context(), sessionID)
if err != nil {
m.logger.Debug("session not found or expired",
zap.String("session_id", sessionID),
zap.Error(err),
)
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Populate context with session data
ctx := r.Context()
ctx = context.WithValue(ctx, constants.SessionIsAuthorized, true)
ctx = context.WithValue(ctx, constants.SessionID, session.ID)
ctx = context.WithValue(ctx, constants.SessionUserID, session.UserID)
ctx = context.WithValue(ctx, constants.SessionUserUUID, session.UserUUID.String())
ctx = context.WithValue(ctx, constants.SessionUserEmail, session.UserEmail)
ctx = context.WithValue(ctx, constants.SessionUserName, session.UserName)
ctx = context.WithValue(ctx, constants.SessionUserRole, session.UserRole)
ctx = context.WithValue(ctx, constants.SessionTenantID, session.TenantID.String())
m.logger.Debug("JWT validated successfully",
zap.String("session_id", session.ID),
zap.Uint64("user_id", session.UserID),
zap.String("user_email", session.UserEmail),
)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireAuth is a middleware that requires authentication
func (m *JWTMiddleware) RequireAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
isAuthorized, ok := r.Context().Value(constants.SessionIsAuthorized).(bool)
if !ok || !isAuthorized {
m.logger.Debug("unauthorized access attempt",
zap.String("path", r.URL.Path),
)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}