258 lines
8.1 KiB
Go
258 lines
8.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/internal/domain"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/cache"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/logger"
|
|
)
|
|
|
|
const (
|
|
// SessionCachePrefix is the prefix for session cache keys
|
|
SessionCachePrefix = "session:"
|
|
// UserSessionsPrefix is the prefix for user session list keys (tracks all sessions for a user)
|
|
UserSessionsPrefix = "user_sessions:"
|
|
// DefaultSessionDuration is the default session expiration time
|
|
DefaultSessionDuration = 14 * 24 * time.Hour // 14 days
|
|
)
|
|
|
|
// SessionService handles session management operations
|
|
type SessionService interface {
|
|
CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error)
|
|
GetSession(ctx context.Context, sessionID string) (*domain.Session, error)
|
|
DeleteSession(ctx context.Context, sessionID string) error
|
|
// CWE-384: Session Fixation Prevention
|
|
InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error
|
|
GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error)
|
|
}
|
|
|
|
type sessionService struct {
|
|
cache cache.TwoTierCacher
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSessionService creates a new session service
|
|
func NewSessionService(cache cache.TwoTierCacher, logger *zap.Logger) SessionService {
|
|
return &sessionService{
|
|
cache: cache,
|
|
logger: logger.Named("session-service"),
|
|
}
|
|
}
|
|
|
|
// CreateSession creates a new session and stores it in the cache
|
|
// CWE-384: Tracks user sessions to enable invalidation on login (session fixation prevention)
|
|
func (s *sessionService) CreateSession(ctx context.Context, userID uint64, userUUID uuid.UUID, userEmail, userName, userRole string, tenantID uuid.UUID) (*domain.Session, error) {
|
|
// Create new session
|
|
session := domain.NewSession(userID, userUUID, userEmail, userName, userRole, tenantID, DefaultSessionDuration)
|
|
|
|
// Serialize session to JSON
|
|
sessionData, err := json.Marshal(session)
|
|
if err != nil {
|
|
s.logger.Error("failed to marshal session",
|
|
zap.String("session_id", session.ID),
|
|
zap.Error(err),
|
|
)
|
|
return nil, fmt.Errorf("failed to marshal session: %w", err)
|
|
}
|
|
|
|
// Store in cache with expiry
|
|
cacheKey := SessionCachePrefix + session.ID
|
|
if err := s.cache.SetWithExpiry(ctx, cacheKey, sessionData, DefaultSessionDuration); err != nil {
|
|
s.logger.Error("failed to store session in cache",
|
|
zap.String("session_id", session.ID),
|
|
zap.Error(err),
|
|
)
|
|
return nil, fmt.Errorf("failed to store session: %w", err)
|
|
}
|
|
|
|
// CWE-384: Track session ID for this user (for session invalidation)
|
|
if err := s.addUserSession(ctx, userUUID, session.ID); err != nil {
|
|
// Log error but don't fail session creation
|
|
s.logger.Warn("failed to track user session (non-fatal)",
|
|
zap.String("session_id", session.ID),
|
|
zap.String("user_uuid", userUUID.String()),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
|
|
// CWE-532: Use redacted email for logging
|
|
s.logger.Info("session created",
|
|
zap.String("session_id", session.ID),
|
|
zap.Uint64("user_id", userID),
|
|
logger.EmailHash(userEmail),
|
|
logger.SafeEmail("email_redacted", userEmail),
|
|
)
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// GetSession retrieves a session from the cache
|
|
func (s *sessionService) GetSession(ctx context.Context, sessionID string) (*domain.Session, error) {
|
|
cacheKey := SessionCachePrefix + sessionID
|
|
|
|
// Get from cache
|
|
sessionData, err := s.cache.Get(ctx, cacheKey)
|
|
if err != nil {
|
|
s.logger.Error("failed to get session from cache",
|
|
zap.String("session_id", sessionID),
|
|
zap.Error(err),
|
|
)
|
|
return nil, fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
if sessionData == nil {
|
|
s.logger.Debug("session not found",
|
|
zap.String("session_id", sessionID),
|
|
)
|
|
return nil, fmt.Errorf("session not found")
|
|
}
|
|
|
|
// Deserialize session from JSON
|
|
var session domain.Session
|
|
if err := json.Unmarshal(sessionData, &session); err != nil {
|
|
s.logger.Error("failed to unmarshal session",
|
|
zap.String("session_id", sessionID),
|
|
zap.Error(err),
|
|
)
|
|
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
|
}
|
|
|
|
// Check if session is expired
|
|
if session.IsExpired() {
|
|
s.logger.Info("session expired, deleting",
|
|
zap.String("session_id", sessionID),
|
|
)
|
|
_ = s.DeleteSession(ctx, sessionID) // Best effort cleanup
|
|
return nil, fmt.Errorf("session expired")
|
|
}
|
|
|
|
s.logger.Debug("session retrieved",
|
|
zap.String("session_id", sessionID),
|
|
zap.Uint64("user_id", session.UserID),
|
|
)
|
|
|
|
return &session, nil
|
|
}
|
|
|
|
// DeleteSession removes a session from the cache
|
|
func (s *sessionService) DeleteSession(ctx context.Context, sessionID string) error {
|
|
cacheKey := SessionCachePrefix + sessionID
|
|
|
|
if err := s.cache.Delete(ctx, cacheKey); err != nil {
|
|
s.logger.Error("failed to delete session from cache",
|
|
zap.String("session_id", sessionID),
|
|
zap.Error(err),
|
|
)
|
|
return fmt.Errorf("failed to delete session: %w", err)
|
|
}
|
|
|
|
s.logger.Info("session deleted",
|
|
zap.String("session_id", sessionID),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// InvalidateUserSessions invalidates all sessions for a given user
|
|
// CWE-384: This prevents session fixation attacks by ensuring old sessions are invalidated on login
|
|
func (s *sessionService) InvalidateUserSessions(ctx context.Context, userUUID uuid.UUID) error {
|
|
s.logger.Info("invalidating all sessions for user",
|
|
zap.String("user_uuid", userUUID.String()))
|
|
|
|
// Get all session IDs for this user
|
|
sessionIDs, err := s.GetUserSessions(ctx, userUUID)
|
|
if err != nil {
|
|
s.logger.Error("failed to get user sessions for invalidation",
|
|
zap.String("user_uuid", userUUID.String()),
|
|
zap.Error(err),
|
|
)
|
|
return fmt.Errorf("failed to get user sessions: %w", err)
|
|
}
|
|
|
|
// Delete each session
|
|
for _, sessionID := range sessionIDs {
|
|
if err := s.DeleteSession(ctx, sessionID); err != nil {
|
|
// Log but continue - best effort cleanup
|
|
s.logger.Warn("failed to delete session during invalidation",
|
|
zap.String("session_id", sessionID),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Clear the user sessions list
|
|
userSessionsKey := UserSessionsPrefix + userUUID.String()
|
|
if err := s.cache.Delete(ctx, userSessionsKey); err != nil {
|
|
// Log but don't fail - this is cleanup
|
|
s.logger.Warn("failed to delete user sessions list",
|
|
zap.String("user_uuid", userUUID.String()),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
|
|
s.logger.Info("invalidated all sessions for user",
|
|
zap.String("user_uuid", userUUID.String()),
|
|
zap.Int("sessions_count", len(sessionIDs)),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUserSessions retrieves all session IDs for a given user
|
|
func (s *sessionService) GetUserSessions(ctx context.Context, userUUID uuid.UUID) ([]string, error) {
|
|
userSessionsKey := UserSessionsPrefix + userUUID.String()
|
|
|
|
// Get the session IDs list from cache
|
|
data, err := s.cache.Get(ctx, userSessionsKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get user sessions: %w", err)
|
|
}
|
|
|
|
if data == nil {
|
|
// No sessions tracked for this user
|
|
return []string{}, nil
|
|
}
|
|
|
|
// Deserialize session IDs
|
|
var sessionIDs []string
|
|
if err := json.Unmarshal(data, &sessionIDs); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal user sessions: %w", err)
|
|
}
|
|
|
|
return sessionIDs, nil
|
|
}
|
|
|
|
// addUserSession adds a session ID to the user's session list
|
|
// CWE-384: Helper method for tracking user sessions to enable invalidation
|
|
func (s *sessionService) addUserSession(ctx context.Context, userUUID uuid.UUID, sessionID string) error {
|
|
userSessionsKey := UserSessionsPrefix + userUUID.String()
|
|
|
|
// Get existing session IDs
|
|
sessionIDs, err := s.GetUserSessions(ctx, userUUID)
|
|
if err != nil && err.Error() != "failed to get user sessions: record not found" {
|
|
return fmt.Errorf("failed to get existing sessions: %w", err)
|
|
}
|
|
|
|
// Add new session ID
|
|
sessionIDs = append(sessionIDs, sessionID)
|
|
|
|
// Serialize and store
|
|
data, err := json.Marshal(sessionIDs)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal session IDs: %w", err)
|
|
}
|
|
|
|
// Store with same expiry as sessions
|
|
if err := s.cache.SetWithExpiry(ctx, userSessionsKey, data, DefaultSessionDuration); err != nil {
|
|
return fmt.Errorf("failed to store user sessions: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|