package ratelimit import ( "context" "fmt" "time" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // RateLimiter provides rate limiting functionality using Redis type RateLimiter interface { // Allow checks if a request should be allowed based on the key // Returns true if allowed, false if rate limit exceeded Allow(ctx context.Context, key string) (bool, error) // AllowN checks if N requests should be allowed AllowN(ctx context.Context, key string, n int) (bool, error) // Reset resets the rate limit for a key Reset(ctx context.Context, key string) error // GetRemaining returns the number of remaining requests GetRemaining(ctx context.Context, key string) (int, error) } // Config holds rate limiter configuration type Config struct { // MaxRequests is the maximum number of requests allowed MaxRequests int // Window is the time window for rate limiting Window time.Duration // KeyPrefix is the prefix for Redis keys KeyPrefix string } type rateLimiter struct { client *redis.Client config Config logger *zap.Logger } // NewRateLimiter creates a new rate limiter func NewRateLimiter(client *redis.Client, config Config, logger *zap.Logger) RateLimiter { return &rateLimiter{ client: client, config: config, logger: logger.Named("rate-limiter"), } } // Allow checks if a request should be allowed func (r *rateLimiter) Allow(ctx context.Context, key string) (bool, error) { return r.AllowN(ctx, key, 1) } // AllowN checks if N requests should be allowed using sliding window counter func (r *rateLimiter) AllowN(ctx context.Context, key string, n int) (bool, error) { redisKey := r.getRedisKey(key) now := time.Now() windowStart := now.Add(-r.config.Window) // Use Redis transaction to ensure atomicity pipe := r.client.Pipeline() // Remove old entries outside the window pipe.ZRemRangeByScore(ctx, redisKey, "0", fmt.Sprintf("%d", windowStart.UnixNano())) // Count current requests in window countCmd := pipe.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf") // Execute pipeline _, err := pipe.Exec(ctx) if err != nil && err != redis.Nil { r.logger.Error("failed to check rate limit", zap.String("key", key), zap.Error(err)) // Fail open: allow request if Redis is down return true, err } currentCount := countCmd.Val() // Check if adding N requests would exceed limit if currentCount+int64(n) > int64(r.config.MaxRequests) { r.logger.Warn("rate limit exceeded", zap.String("key", key), zap.Int64("current_count", currentCount), zap.Int("max_requests", r.config.MaxRequests)) return false, nil } // Add the new request(s) to the sorted set pipe2 := r.client.Pipeline() for i := 0; i < n; i++ { // Use nanosecond timestamp with incremental offset to ensure uniqueness timestamp := now.Add(time.Duration(i) * time.Nanosecond).UnixNano() pipe2.ZAdd(ctx, redisKey, redis.Z{ Score: float64(timestamp), Member: fmt.Sprintf("%d-%d", timestamp, i), }) } // Set expiration on the key (window + buffer) pipe2.Expire(ctx, redisKey, r.config.Window+time.Minute) // Execute pipeline _, err = pipe2.Exec(ctx) if err != nil && err != redis.Nil { r.logger.Error("failed to record request", zap.String("key", key), zap.Error(err)) // Already counted, so return true return true, err } r.logger.Debug("rate limit check passed", zap.String("key", key), zap.Int64("current_count", currentCount), zap.Int("max_requests", r.config.MaxRequests)) return true, nil } // Reset resets the rate limit for a key func (r *rateLimiter) Reset(ctx context.Context, key string) error { redisKey := r.getRedisKey(key) err := r.client.Del(ctx, redisKey).Err() if err != nil { r.logger.Error("failed to reset rate limit", zap.String("key", key), zap.Error(err)) return err } r.logger.Info("rate limit reset", zap.String("key", key)) return nil } // GetRemaining returns the number of remaining requests in the current window func (r *rateLimiter) GetRemaining(ctx context.Context, key string) (int, error) { redisKey := r.getRedisKey(key) now := time.Now() windowStart := now.Add(-r.config.Window) // Count current requests in window count, err := r.client.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf").Result() if err != nil && err != redis.Nil { r.logger.Error("failed to get remaining requests", zap.String("key", key), zap.Error(err)) return 0, err } remaining := r.config.MaxRequests - int(count) if remaining < 0 { remaining = 0 } return remaining, nil } // getRedisKey constructs the Redis key with prefix func (r *rateLimiter) getRedisKey(key string) string { return fmt.Sprintf("%s:%s", r.config.KeyPrefix, key) }