monorepo/cloud/maplefile-backend/internal/service/auth/refresh_token.go

177 lines
6.8 KiB
Go

// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/internal/service/auth/refresh_token.go
package auth
import (
"context"
"fmt"
"time"
"github.com/gocql/gocql"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
uc_user "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/internal/usecase/user"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/auditlog"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/hash"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/jwt"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/cache/cassandracache"
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/transaction"
)
type RefreshTokenRequestDTO struct {
RefreshToken string `json:"value"`
}
type RefreshTokenResponseDTO struct {
Message string `json:"message"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccessTokenExpiryDate string `json:"access_token_expiry_date"`
RefreshTokenExpiryDate string `json:"refresh_token_expiry_date"`
Username string `json:"username"`
}
type RefreshTokenService interface {
Execute(ctx context.Context, req *RefreshTokenRequestDTO) (*RefreshTokenResponseDTO, error)
}
type refreshTokenServiceImpl struct {
config *config.Config
logger *zap.Logger
auditLogger auditlog.AuditLogger
cache cassandracache.CassandraCacher
jwtProvider jwt.JWTProvider
userGetByIDUC uc_user.UserGetByIDUseCase
}
func NewRefreshTokenService(
config *config.Config,
logger *zap.Logger,
auditLogger auditlog.AuditLogger,
cache cassandracache.CassandraCacher,
jwtProvider jwt.JWTProvider,
userGetByIDUC uc_user.UserGetByIDUseCase,
) RefreshTokenService {
return &refreshTokenServiceImpl{
config: config,
logger: logger.Named("RefreshTokenService"),
auditLogger: auditLogger,
cache: cache,
jwtProvider: jwtProvider,
userGetByIDUC: userGetByIDUC,
}
}
func (s *refreshTokenServiceImpl) Execute(ctx context.Context, req *RefreshTokenRequestDTO) (*RefreshTokenResponseDTO, error) {
// Create SAGA for token refresh workflow
saga := transaction.NewSaga("refresh-token", s.logger)
s.logger.Info("starting token refresh")
// Step 1: Validate refresh token JWT
userID, err := s.jwtProvider.ProcessJWTToken(req.RefreshToken)
if err != nil {
s.logger.Warn("Invalid refresh token JWT", zap.Error(err))
return nil, fmt.Errorf("invalid refresh token")
}
// Step 2: Check if refresh token exists in cache
// SECURITY: Hash refresh token to match how it was stored (prevents token leakage via cache keys)
refreshTokenHash := hash.HashToken(req.RefreshToken)
refreshKey := fmt.Sprintf("refresh:%s", refreshTokenHash)
cachedUserID, err := s.cache.Get(ctx, refreshKey)
if err != nil || cachedUserID == nil {
s.logger.Warn("Refresh token not found in cache", zap.String("user_id", userID))
return nil, fmt.Errorf("refresh token not found or expired")
}
// Step 3: Verify user IDs match
if string(cachedUserID) != userID {
s.logger.Warn("User ID mismatch", zap.String("jwt_user_id", userID), zap.String("cached_user_id", string(cachedUserID)))
return nil, fmt.Errorf("invalid refresh token")
}
// Step 4: Generate new token pair (token rotation for security)
newAccessToken, accessExpiry, newRefreshToken, refreshExpiry, err := s.jwtProvider.GenerateJWTTokenPair(
userID,
s.config.JWT.AccessTokenDuration,
s.config.JWT.RefreshTokenDuration,
)
if err != nil {
s.logger.Error("Failed to generate new tokens", zap.Error(err))
return nil, fmt.Errorf("failed to generate new tokens")
}
// Step 5: Store NEW refresh token FIRST (compensate: delete new token)
// CRITICAL: Store new token before deleting old token to prevent lockout
// SECURITY: Hash new refresh token to prevent token leakage via cache key inspection
newRefreshTokenHash := hash.HashToken(newRefreshToken)
newRefreshKey := fmt.Sprintf("refresh:%s", newRefreshTokenHash)
if err := s.cache.SetWithExpiry(ctx, newRefreshKey, []byte(userID), s.config.JWT.RefreshTokenDuration); err != nil {
s.logger.Error("Failed to store new refresh token", zap.Error(err))
return nil, fmt.Errorf("failed to store new refresh token")
}
// Register compensation: if deletion of old token fails, delete new token
newRefreshKeyCaptured := newRefreshKey
saga.AddCompensation(func(ctx context.Context) error {
s.logger.Info("compensating: deleting new refresh token",
zap.String("new_refresh_key", newRefreshKeyCaptured))
return s.cache.Delete(ctx, newRefreshKeyCaptured)
})
// Step 6: Delete old refresh token from cache (compensate: restore old token)
oldRefreshKeyCaptured := refreshKey
oldUserIDCaptured := userID
if err := s.cache.Delete(ctx, refreshKey); err != nil {
s.logger.Error("Failed to delete old refresh token",
zap.String("refresh_key", refreshKey),
zap.Error(err))
// Trigger compensation: Delete new token (restore consistency)
saga.Rollback(ctx)
return nil, fmt.Errorf("failed to delete old refresh token: %w", err)
}
// Register compensation: restore old token with reduced TTL (1 hour grace period)
saga.AddCompensation(func(ctx context.Context) error {
s.logger.Info("compensating: restoring old refresh token",
zap.String("old_refresh_key", oldRefreshKeyCaptured))
// Restore with reduced TTL (1 hour) to allow user retry without long-lived old token
return s.cache.SetWithExpiry(ctx, oldRefreshKeyCaptured, []byte(oldUserIDCaptured), 1*time.Hour)
})
// Step 7: Get user to retrieve username/email (read-only, no compensation needed)
userUUID, err := gocql.ParseUUID(userID)
if err != nil {
s.logger.Error("Invalid user ID", zap.Error(err))
// No rollback needed for UUID parsing error (tokens already rotated successfully)
return nil, fmt.Errorf("invalid user ID")
}
user, err := s.userGetByIDUC.Execute(ctx, userUUID)
if err != nil || user == nil {
s.logger.Error("User not found", zap.String("user_id", userID), zap.Error(err))
// No rollback needed for user lookup error (tokens already rotated successfully)
return nil, fmt.Errorf("user not found")
}
s.logger.Info("Token refreshed successfully",
zap.String("user_id", userID),
zap.String("new_refresh_token", newRefreshToken[:16]+"...")) // Log prefix only for security
// Audit log token refresh
s.auditLogger.LogAuth(ctx, auditlog.EventTypeTokenRefresh, auditlog.OutcomeSuccess,
"", "", map[string]string{
"user_id": userID,
})
return &RefreshTokenResponseDTO{
Message: "Token refreshed successfully",
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpiryDate: accessExpiry.Format(time.RFC3339),
RefreshTokenExpiryDate: refreshExpiry.Format(time.RFC3339),
Username: user.Email,
}, nil
}