113 lines
3.7 KiB
Go
113 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config/constants"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/service"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
|
|
)
|
|
|
|
// JWTMiddleware validates JWT tokens and populates session context
|
|
type JWTMiddleware struct {
|
|
jwtProvider jwt.Provider
|
|
sessionService service.SessionService
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewJWTMiddleware creates a new JWT middleware
|
|
func NewJWTMiddleware(jwtProvider jwt.Provider, sessionService service.SessionService, logger *zap.Logger) *JWTMiddleware {
|
|
return &JWTMiddleware{
|
|
jwtProvider: jwtProvider,
|
|
sessionService: sessionService,
|
|
logger: logger.Named("jwt-middleware"),
|
|
}
|
|
}
|
|
|
|
// Handler returns an HTTP middleware function that validates JWT tokens
|
|
func (m *JWTMiddleware) Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Get Authorization header
|
|
authHeader := r.Header.Get("Authorization")
|
|
if authHeader == "" {
|
|
m.logger.Debug("no authorization header")
|
|
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Expected format: "JWT <token>"
|
|
parts := strings.Split(authHeader, " ")
|
|
if len(parts) != 2 || parts[0] != "JWT" {
|
|
m.logger.Debug("invalid authorization header format",
|
|
zap.String("header", authHeader),
|
|
)
|
|
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
token := parts[1]
|
|
|
|
// Validate token
|
|
sessionID, err := m.jwtProvider.ValidateToken(token)
|
|
if err != nil {
|
|
m.logger.Debug("invalid JWT token",
|
|
zap.Error(err),
|
|
)
|
|
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Get session from cache
|
|
session, err := m.sessionService.GetSession(r.Context(), sessionID)
|
|
if err != nil {
|
|
m.logger.Debug("session not found or expired",
|
|
zap.String("session_id", sessionID),
|
|
zap.Error(err),
|
|
)
|
|
ctx := context.WithValue(r.Context(), constants.SessionIsAuthorized, false)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
|
|
// Populate context with session data
|
|
ctx := r.Context()
|
|
ctx = context.WithValue(ctx, constants.SessionIsAuthorized, true)
|
|
ctx = context.WithValue(ctx, constants.SessionID, session.ID)
|
|
ctx = context.WithValue(ctx, constants.SessionUserID, session.UserID)
|
|
ctx = context.WithValue(ctx, constants.SessionUserUUID, session.UserUUID.String())
|
|
ctx = context.WithValue(ctx, constants.SessionUserEmail, session.UserEmail)
|
|
ctx = context.WithValue(ctx, constants.SessionUserName, session.UserName)
|
|
ctx = context.WithValue(ctx, constants.SessionUserRole, session.UserRole)
|
|
ctx = context.WithValue(ctx, constants.SessionTenantID, session.TenantID.String())
|
|
|
|
m.logger.Debug("JWT validated successfully",
|
|
zap.String("session_id", session.ID),
|
|
zap.Uint64("user_id", session.UserID),
|
|
zap.String("user_email", session.UserEmail),
|
|
)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
// RequireAuth is a middleware that requires authentication
|
|
func (m *JWTMiddleware) RequireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
isAuthorized, ok := r.Context().Value(constants.SessionIsAuthorized).(bool)
|
|
if !ok || !isAuthorized {
|
|
m.logger.Debug("unauthorized access attempt",
|
|
zap.String("path", r.URL.Path),
|
|
)
|
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|