monorepo/cloud/maplepress-backend/internal/service/session.go

258 lines
8.1 KiB
Go

package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/cache"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/logger"
)
const (
// SessionCachePrefix is the prefix for session cache keys
SessionCachePrefix = "session:"
// UserSessionsPrefix is the prefix for user session list keys (tracks all sessions for a user)
UserSessionsPrefix = "user_sessions:"
// DefaultSessionDuration is the default session expiration time
DefaultSessionDuration = 14 * 24 * time.Hour // 14 days
)
// SessionService handles session management operations
type SessionService interface {
CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error)
GetSession(ctx context.Context, sessionID string) (*domain.Session, error)
DeleteSession(ctx context.Context, sessionID string) error
// CWE-384: Session Fixation Prevention
InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error
GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error)
}
type sessionService struct {
cache cache.TwoTierCacher
logger *zap.Logger
}
// NewSessionService creates a new session service
func NewSessionService(cache cache.TwoTierCacher, logger *zap.Logger) SessionService {
return &sessionService{
cache: cache,
logger: logger.Named("session-service"),
}
}
// CreateSession creates a new session and stores it in the cache
// CWE-384: Tracks user sessions to enable invalidation on login (session fixation prevention)
func (s *sessionService) CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error) {
// Create new session
session := domain.NewSession(userID, userUUID, userEmail, userName, userRole, tenantID, DefaultSessionDuration)
// Serialize session to JSON
sessionData, err := json.Marshal(session)
if err != nil {
s.logger.Error("failed to marshal session",
zap.String("session_id", session.ID),
zap.Error(err),
)
return nil, fmt.Errorf("failed to marshal session: %w", err)
}
// Store in cache with expiry
cacheKey := SessionCachePrefix + session.ID
if err := s.cache.SetWithExpiry(ctx, cacheKey, sessionData, DefaultSessionDuration); err != nil {
s.logger.Error("failed to store session in cache",
zap.String("session_id", session.ID),
zap.Error(err),
)
return nil, fmt.Errorf("failed to store session: %w", err)
}
// CWE-384: Track session ID for this user (for session invalidation)
if err := s.addUserSession(ctx, userUUID, session.ID); err != nil {
// Log error but don't fail session creation
s.logger.Warn("failed to track user session (non-fatal)",
zap.String("session_id", session.ID),
zap.String("user_uuid", userUUID.String()),
zap.Error(err),
)
}
// CWE-532: Use redacted email for logging
s.logger.Info("session created",
zap.String("session_id", session.ID),
zap.Uint64("user_id", userID),
logger.EmailHash(userEmail),
logger.SafeEmail("email_redacted", userEmail),
)
return session, nil
}
// GetSession retrieves a session from the cache
func (s *sessionService) GetSession(ctx context.Context, sessionID string) (*domain.Session, error) {
cacheKey := SessionCachePrefix + sessionID
// Get from cache
sessionData, err := s.cache.Get(ctx, cacheKey)
if err != nil {
s.logger.Error("failed to get session from cache",
zap.String("session_id", sessionID),
zap.Error(err),
)
return nil, fmt.Errorf("failed to get session: %w", err)
}
if sessionData == nil {
s.logger.Debug("session not found",
zap.String("session_id", sessionID),
)
return nil, fmt.Errorf("session not found")
}
// Deserialize session from JSON
var session domain.Session
if err := json.Unmarshal(sessionData, &session); err != nil {
s.logger.Error("failed to unmarshal session",
zap.String("session_id", sessionID),
zap.Error(err),
)
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
// Check if session is expired
if session.IsExpired() {
s.logger.Info("session expired, deleting",
zap.String("session_id", sessionID),
)
_ = s.DeleteSession(ctx, sessionID) // Best effort cleanup
return nil, fmt.Errorf("session expired")
}
s.logger.Debug("session retrieved",
zap.String("session_id", sessionID),
zap.Uint64("user_id", session.UserID),
)
return &session, nil
}
// DeleteSession removes a session from the cache
func (s *sessionService) DeleteSession(ctx context.Context, sessionID string) error {
cacheKey := SessionCachePrefix + sessionID
if err := s.cache.Delete(ctx, cacheKey); err != nil {
s.logger.Error("failed to delete session from cache",
zap.String("session_id", sessionID),
zap.Error(err),
)
return fmt.Errorf("failed to delete session: %w", err)
}
s.logger.Info("session deleted",
zap.String("session_id", sessionID),
)
return nil
}
// InvalidateUserSessions invalidates all sessions for a given user
// CWE-384: This prevents session fixation attacks by ensuring old sessions are invalidated on login
func (s *sessionService) InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error {
s.logger.Info("invalidating all sessions for user",
zap.String("user_uuid", userUUID.String()))
// Get all session IDs for this user
sessionIDs, err := s.GetUserSessions(ctx, userUUID)
if err != nil {
s.logger.Error("failed to get user sessions for invalidation",
zap.String("user_uuid", userUUID.String()),
zap.Error(err),
)
return fmt.Errorf("failed to get user sessions: %w", err)
}
// Delete each session
for _, sessionID := range sessionIDs {
if err := s.DeleteSession(ctx, sessionID); err != nil {
// Log but continue - best effort cleanup
s.logger.Warn("failed to delete session during invalidation",
zap.String("session_id", sessionID),
zap.Error(err),
)
}
}
// Clear the user sessions list
userSessionsKey := UserSessionsPrefix + userUUID.String()
if err := s.cache.Delete(ctx, userSessionsKey); err != nil {
// Log but don't fail - this is cleanup
s.logger.Warn("failed to delete user sessions list",
zap.String("user_uuid", userUUID.String()),
zap.Error(err),
)
}
s.logger.Info("invalidated all sessions for user",
zap.String("user_uuid", userUUID.String()),
zap.Int("sessions_count", len(sessionIDs)),
)
return nil
}
// GetUserSessions retrieves all session IDs for a given user
func (s *sessionService) GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error) {
userSessionsKey := UserSessionsPrefix + userUUID.String()
// Get the session IDs list from cache
data, err := s.cache.Get(ctx, userSessionsKey)
if err != nil {
return nil, fmt.Errorf("failed to get user sessions: %w", err)
}
if data == nil {
// No sessions tracked for this user
return []string{}, nil
}
// Deserialize session IDs
var sessionIDs []string
if err := json.Unmarshal(data, &sessionIDs); err != nil {
return nil, fmt.Errorf("failed to unmarshal user sessions: %w", err)
}
return sessionIDs, nil
}
// addUserSession adds a session ID to the user's session list
// CWE-384: Helper method for tracking user sessions to enable invalidation
func (s *sessionService) addUserSession(ctx context.Context, userUUID uuid.UUID, sessionID string) error {
userSessionsKey := UserSessionsPrefix + userUUID.String()
// Get existing session IDs
sessionIDs, err := s.GetUserSessions(ctx, userUUID)
if err != nil && err.Error() != "failed to get user sessions: record not found" {
return fmt.Errorf("failed to get existing sessions: %w", err)
}
// Add new session ID
sessionIDs = append(sessionIDs, sessionID)
// Serialize and store
data, err := json.Marshal(sessionIDs)
if err != nil {
return fmt.Errorf("failed to marshal session IDs: %w", err)
}
// Store with same expiry as sessions
if err := s.cache.SetWithExpiry(ctx, userSessionsKey, data, DefaultSessionDuration); err != nil {
return fmt.Errorf("failed to store user sessions: %w", err)
}
return nil
}