package service import ( "context" "encoding/json" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/cache" "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/logger" ) const ( // SessionCachePrefix is the prefix for session cache keys SessionCachePrefix = "session:" // UserSessionsPrefix is the prefix for user session list keys (tracks all sessions for a user) UserSessionsPrefix = "user_sessions:" // DefaultSessionDuration is the default session expiration time DefaultSessionDuration = 14 * 24 * time.Hour // 14 days ) // SessionService handles session management operations type SessionService interface { CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error) GetSession(ctx context.Context, sessionID string) (*domain.Session, error) DeleteSession(ctx context.Context, sessionID string) error // CWE-384: Session Fixation Prevention InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error) } type sessionService struct { cache cache.TwoTierCacher logger *zap.Logger } // NewSessionService creates a new session service func NewSessionService(cache cache.TwoTierCacher, logger *zap.Logger) SessionService { return &sessionService{ cache: cache, logger: logger.Named("session-service"), } } // CreateSession creates a new session and stores it in the cache // CWE-384: Tracks user sessions to enable invalidation on login (session fixation prevention) func (s *sessionService) CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error) { // Create new session session := domain.NewSession(userID, userUUID, userEmail, userName, userRole, tenantID, DefaultSessionDuration) // Serialize session to JSON sessionData, err := json.Marshal(session) if err != nil { s.logger.Error("failed to marshal session", zap.String("session_id", session.ID), zap.Error(err), ) return nil, fmt.Errorf("failed to marshal session: %w", err) } // Store in cache with expiry cacheKey := SessionCachePrefix + session.ID if err := s.cache.SetWithExpiry(ctx, cacheKey, sessionData, DefaultSessionDuration); err != nil { s.logger.Error("failed to store session in cache", zap.String("session_id", session.ID), zap.Error(err), ) return nil, fmt.Errorf("failed to store session: %w", err) } // CWE-384: Track session ID for this user (for session invalidation) if err := s.addUserSession(ctx, userUUID, session.ID); err != nil { // Log error but don't fail session creation s.logger.Warn("failed to track user session (non-fatal)", zap.String("session_id", session.ID), zap.String("user_uuid", userUUID.String()), zap.Error(err), ) } // CWE-532: Use redacted email for logging s.logger.Info("session created", zap.String("session_id", session.ID), zap.Uint64("user_id", userID), logger.EmailHash(userEmail), logger.SafeEmail("email_redacted", userEmail), ) return session, nil } // GetSession retrieves a session from the cache func (s *sessionService) GetSession(ctx context.Context, sessionID string) (*domain.Session, error) { cacheKey := SessionCachePrefix + sessionID // Get from cache sessionData, err := s.cache.Get(ctx, cacheKey) if err != nil { s.logger.Error("failed to get session from cache", zap.String("session_id", sessionID), zap.Error(err), ) return nil, fmt.Errorf("failed to get session: %w", err) } if sessionData == nil { s.logger.Debug("session not found", zap.String("session_id", sessionID), ) return nil, fmt.Errorf("session not found") } // Deserialize session from JSON var session domain.Session if err := json.Unmarshal(sessionData, &session); err != nil { s.logger.Error("failed to unmarshal session", zap.String("session_id", sessionID), zap.Error(err), ) return nil, fmt.Errorf("failed to unmarshal session: %w", err) } // Check if session is expired if session.IsExpired() { s.logger.Info("session expired, deleting", zap.String("session_id", sessionID), ) _ = s.DeleteSession(ctx, sessionID) // Best effort cleanup return nil, fmt.Errorf("session expired") } s.logger.Debug("session retrieved", zap.String("session_id", sessionID), zap.Uint64("user_id", session.UserID), ) return &session, nil } // DeleteSession removes a session from the cache func (s *sessionService) DeleteSession(ctx context.Context, sessionID string) error { cacheKey := SessionCachePrefix + sessionID if err := s.cache.Delete(ctx, cacheKey); err != nil { s.logger.Error("failed to delete session from cache", zap.String("session_id", sessionID), zap.Error(err), ) return fmt.Errorf("failed to delete session: %w", err) } s.logger.Info("session deleted", zap.String("session_id", sessionID), ) return nil } // InvalidateUserSessions invalidates all sessions for a given user // CWE-384: This prevents session fixation attacks by ensuring old sessions are invalidated on login func (s *sessionService) InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error { s.logger.Info("invalidating all sessions for user", zap.String("user_uuid", userUUID.String())) // Get all session IDs for this user sessionIDs, err := s.GetUserSessions(ctx, userUUID) if err != nil { s.logger.Error("failed to get user sessions for invalidation", zap.String("user_uuid", userUUID.String()), zap.Error(err), ) return fmt.Errorf("failed to get user sessions: %w", err) } // Delete each session for _, sessionID := range sessionIDs { if err := s.DeleteSession(ctx, sessionID); err != nil { // Log but continue - best effort cleanup s.logger.Warn("failed to delete session during invalidation", zap.String("session_id", sessionID), zap.Error(err), ) } } // Clear the user sessions list userSessionsKey := UserSessionsPrefix + userUUID.String() if err := s.cache.Delete(ctx, userSessionsKey); err != nil { // Log but don't fail - this is cleanup s.logger.Warn("failed to delete user sessions list", zap.String("user_uuid", userUUID.String()), zap.Error(err), ) } s.logger.Info("invalidated all sessions for user", zap.String("user_uuid", userUUID.String()), zap.Int("sessions_count", len(sessionIDs)), ) return nil } // GetUserSessions retrieves all session IDs for a given user func (s *sessionService) GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error) { userSessionsKey := UserSessionsPrefix + userUUID.String() // Get the session IDs list from cache data, err := s.cache.Get(ctx, userSessionsKey) if err != nil { return nil, fmt.Errorf("failed to get user sessions: %w", err) } if data == nil { // No sessions tracked for this user return []string{}, nil } // Deserialize session IDs var sessionIDs []string if err := json.Unmarshal(data, &sessionIDs); err != nil { return nil, fmt.Errorf("failed to unmarshal user sessions: %w", err) } return sessionIDs, nil } // addUserSession adds a session ID to the user's session list // CWE-384: Helper method for tracking user sessions to enable invalidation func (s *sessionService) addUserSession(ctx context.Context, userUUID uuid.UUID, sessionID string) error { userSessionsKey := UserSessionsPrefix + userUUID.String() // Get existing session IDs sessionIDs, err := s.GetUserSessions(ctx, userUUID) if err != nil && err.Error() != "failed to get user sessions: record not found" { return fmt.Errorf("failed to get existing sessions: %w", err) } // Add new session ID sessionIDs = append(sessionIDs, sessionID) // Serialize and store data, err := json.Marshal(sessionIDs) if err != nil { return fmt.Errorf("failed to marshal session IDs: %w", err) } // Store with same expiry as sessions if err := s.cache.SetWithExpiry(ctx, userSessionsKey, data, DefaultSessionDuration); err != nil { return fmt.Errorf("failed to store user sessions: %w", err) } return nil }