177 lines
6.8 KiB
Go
177 lines
6.8 KiB
Go
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/internal/service/auth/refresh_token.go
|
|
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/gocql/gocql"
|
|
"go.uber.org/zap"
|
|
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
|
uc_user "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/internal/usecase/user"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/auditlog"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/hash"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/jwt"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/cache/cassandracache"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/transaction"
|
|
)
|
|
|
|
type RefreshTokenRequestDTO struct {
|
|
RefreshToken string `json:"value"`
|
|
}
|
|
|
|
type RefreshTokenResponseDTO struct {
|
|
Message string `json:"message"`
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
AccessTokenExpiryDate string `json:"access_token_expiry_date"`
|
|
RefreshTokenExpiryDate string `json:"refresh_token_expiry_date"`
|
|
Username string `json:"username"`
|
|
}
|
|
|
|
type RefreshTokenService interface {
|
|
Execute(ctx context.Context, req *RefreshTokenRequestDTO) (*RefreshTokenResponseDTO, error)
|
|
}
|
|
|
|
type refreshTokenServiceImpl struct {
|
|
config *config.Config
|
|
logger *zap.Logger
|
|
auditLogger auditlog.AuditLogger
|
|
cache cassandracache.CassandraCacher
|
|
jwtProvider jwt.JWTProvider
|
|
userGetByIDUC uc_user.UserGetByIDUseCase
|
|
}
|
|
|
|
func NewRefreshTokenService(
|
|
config *config.Config,
|
|
logger *zap.Logger,
|
|
auditLogger auditlog.AuditLogger,
|
|
cache cassandracache.CassandraCacher,
|
|
jwtProvider jwt.JWTProvider,
|
|
userGetByIDUC uc_user.UserGetByIDUseCase,
|
|
) RefreshTokenService {
|
|
return &refreshTokenServiceImpl{
|
|
config: config,
|
|
logger: logger.Named("RefreshTokenService"),
|
|
auditLogger: auditLogger,
|
|
cache: cache,
|
|
jwtProvider: jwtProvider,
|
|
userGetByIDUC: userGetByIDUC,
|
|
}
|
|
}
|
|
|
|
func (s *refreshTokenServiceImpl) Execute(ctx context.Context, req *RefreshTokenRequestDTO) (*RefreshTokenResponseDTO, error) {
|
|
// Create SAGA for token refresh workflow
|
|
saga := transaction.NewSaga("refresh-token", s.logger)
|
|
|
|
s.logger.Info("starting token refresh")
|
|
|
|
// Step 1: Validate refresh token JWT
|
|
userID, err := s.jwtProvider.ProcessJWTToken(req.RefreshToken)
|
|
if err != nil {
|
|
s.logger.Warn("Invalid refresh token JWT", zap.Error(err))
|
|
return nil, fmt.Errorf("invalid refresh token")
|
|
}
|
|
|
|
// Step 2: Check if refresh token exists in cache
|
|
// SECURITY: Hash refresh token to match how it was stored (prevents token leakage via cache keys)
|
|
refreshTokenHash := hash.HashToken(req.RefreshToken)
|
|
refreshKey := fmt.Sprintf("refresh:%s", refreshTokenHash)
|
|
cachedUserID, err := s.cache.Get(ctx, refreshKey)
|
|
if err != nil || cachedUserID == nil {
|
|
s.logger.Warn("Refresh token not found in cache", zap.String("user_id", userID))
|
|
return nil, fmt.Errorf("refresh token not found or expired")
|
|
}
|
|
|
|
// Step 3: Verify user IDs match
|
|
if string(cachedUserID) != userID {
|
|
s.logger.Warn("User ID mismatch", zap.String("jwt_user_id", userID), zap.String("cached_user_id", string(cachedUserID)))
|
|
return nil, fmt.Errorf("invalid refresh token")
|
|
}
|
|
|
|
// Step 4: Generate new token pair (token rotation for security)
|
|
newAccessToken, accessExpiry, newRefreshToken, refreshExpiry, err := s.jwtProvider.GenerateJWTTokenPair(
|
|
userID,
|
|
s.config.JWT.AccessTokenDuration,
|
|
s.config.JWT.RefreshTokenDuration,
|
|
)
|
|
if err != nil {
|
|
s.logger.Error("Failed to generate new tokens", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to generate new tokens")
|
|
}
|
|
|
|
// Step 5: Store NEW refresh token FIRST (compensate: delete new token)
|
|
// CRITICAL: Store new token before deleting old token to prevent lockout
|
|
// SECURITY: Hash new refresh token to prevent token leakage via cache key inspection
|
|
newRefreshTokenHash := hash.HashToken(newRefreshToken)
|
|
newRefreshKey := fmt.Sprintf("refresh:%s", newRefreshTokenHash)
|
|
if err := s.cache.SetWithExpiry(ctx, newRefreshKey, []byte(userID), s.config.JWT.RefreshTokenDuration); err != nil {
|
|
s.logger.Error("Failed to store new refresh token", zap.Error(err))
|
|
return nil, fmt.Errorf("failed to store new refresh token")
|
|
}
|
|
|
|
// Register compensation: if deletion of old token fails, delete new token
|
|
newRefreshKeyCaptured := newRefreshKey
|
|
saga.AddCompensation(func(ctx context.Context) error {
|
|
s.logger.Info("compensating: deleting new refresh token",
|
|
zap.String("new_refresh_key", newRefreshKeyCaptured))
|
|
return s.cache.Delete(ctx, newRefreshKeyCaptured)
|
|
})
|
|
|
|
// Step 6: Delete old refresh token from cache (compensate: restore old token)
|
|
oldRefreshKeyCaptured := refreshKey
|
|
oldUserIDCaptured := userID
|
|
if err := s.cache.Delete(ctx, refreshKey); err != nil {
|
|
s.logger.Error("Failed to delete old refresh token",
|
|
zap.String("refresh_key", refreshKey),
|
|
zap.Error(err))
|
|
|
|
// Trigger compensation: Delete new token (restore consistency)
|
|
saga.Rollback(ctx)
|
|
return nil, fmt.Errorf("failed to delete old refresh token: %w", err)
|
|
}
|
|
|
|
// Register compensation: restore old token with reduced TTL (1 hour grace period)
|
|
saga.AddCompensation(func(ctx context.Context) error {
|
|
s.logger.Info("compensating: restoring old refresh token",
|
|
zap.String("old_refresh_key", oldRefreshKeyCaptured))
|
|
// Restore with reduced TTL (1 hour) to allow user retry without long-lived old token
|
|
return s.cache.SetWithExpiry(ctx, oldRefreshKeyCaptured, []byte(oldUserIDCaptured), 1*time.Hour)
|
|
})
|
|
|
|
// Step 7: Get user to retrieve username/email (read-only, no compensation needed)
|
|
userUUID, err := gocql.ParseUUID(userID)
|
|
if err != nil {
|
|
s.logger.Error("Invalid user ID", zap.Error(err))
|
|
// No rollback needed for UUID parsing error (tokens already rotated successfully)
|
|
return nil, fmt.Errorf("invalid user ID")
|
|
}
|
|
|
|
user, err := s.userGetByIDUC.Execute(ctx, userUUID)
|
|
if err != nil || user == nil {
|
|
s.logger.Error("User not found", zap.String("user_id", userID), zap.Error(err))
|
|
// No rollback needed for user lookup error (tokens already rotated successfully)
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
s.logger.Info("Token refreshed successfully",
|
|
zap.String("user_id", userID),
|
|
zap.String("new_refresh_token", newRefreshToken[:16]+"...")) // Log prefix only for security
|
|
|
|
// Audit log token refresh
|
|
s.auditLogger.LogAuth(ctx, auditlog.EventTypeTokenRefresh, auditlog.OutcomeSuccess,
|
|
"", "", map[string]string{
|
|
"user_id": userID,
|
|
})
|
|
|
|
return &RefreshTokenResponseDTO{
|
|
Message: "Token refreshed successfully",
|
|
AccessToken: newAccessToken,
|
|
RefreshToken: newRefreshToken,
|
|
AccessTokenExpiryDate: accessExpiry.Format(time.RFC3339),
|
|
RefreshTokenExpiryDate: refreshExpiry.Format(time.RFC3339),
|
|
Username: user.Email,
|
|
}, nil
|
|
}
|