172 lines
4.6 KiB
Go
172 lines
4.6 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// RateLimiter provides rate limiting functionality using Redis
|
|
type RateLimiter interface {
|
|
// Allow checks if a request should be allowed based on the key
|
|
// Returns true if allowed, false if rate limit exceeded
|
|
Allow(ctx context.Context, key string) (bool, error)
|
|
|
|
// AllowN checks if N requests should be allowed
|
|
AllowN(ctx context.Context, key string, n int) (bool, error)
|
|
|
|
// Reset resets the rate limit for a key
|
|
Reset(ctx context.Context, key string) error
|
|
|
|
// GetRemaining returns the number of remaining requests
|
|
GetRemaining(ctx context.Context, key string) (int, error)
|
|
}
|
|
|
|
// Config holds rate limiter configuration
|
|
type Config struct {
|
|
// MaxRequests is the maximum number of requests allowed
|
|
MaxRequests int
|
|
// Window is the time window for rate limiting
|
|
Window time.Duration
|
|
// KeyPrefix is the prefix for Redis keys
|
|
KeyPrefix string
|
|
}
|
|
|
|
type rateLimiter struct {
|
|
client *redis.Client
|
|
config Config
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter
|
|
func NewRateLimiter(client *redis.Client, config Config, logger *zap.Logger) RateLimiter {
|
|
return &rateLimiter{
|
|
client: client,
|
|
config: config,
|
|
logger: logger.Named("rate-limiter"),
|
|
}
|
|
}
|
|
|
|
// Allow checks if a request should be allowed
|
|
func (r *rateLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
|
return r.AllowN(ctx, key, 1)
|
|
}
|
|
|
|
// AllowN checks if N requests should be allowed using sliding window counter
|
|
func (r *rateLimiter) AllowN(ctx context.Context, key string, n int) (bool, error) {
|
|
redisKey := r.getRedisKey(key)
|
|
now := time.Now()
|
|
windowStart := now.Add(-r.config.Window)
|
|
|
|
// Use Redis transaction to ensure atomicity
|
|
pipe := r.client.Pipeline()
|
|
|
|
// Remove old entries outside the window
|
|
pipe.ZRemRangeByScore(ctx, redisKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
|
|
|
// Count current requests in window
|
|
countCmd := pipe.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf")
|
|
|
|
// Execute pipeline
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil && err != redis.Nil {
|
|
r.logger.Error("failed to check rate limit",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
// Fail open: allow request if Redis is down
|
|
return true, err
|
|
}
|
|
|
|
currentCount := countCmd.Val()
|
|
|
|
// Check if adding N requests would exceed limit
|
|
if currentCount+int64(n) > int64(r.config.MaxRequests) {
|
|
r.logger.Warn("rate limit exceeded",
|
|
zap.String("key", key),
|
|
zap.Int64("current_count", currentCount),
|
|
zap.Int("max_requests", r.config.MaxRequests))
|
|
return false, nil
|
|
}
|
|
|
|
// Add the new request(s) to the sorted set
|
|
pipe2 := r.client.Pipeline()
|
|
for i := 0; i < n; i++ {
|
|
// Use nanosecond timestamp with incremental offset to ensure uniqueness
|
|
timestamp := now.Add(time.Duration(i) * time.Nanosecond).UnixNano()
|
|
pipe2.ZAdd(ctx, redisKey, redis.Z{
|
|
Score: float64(timestamp),
|
|
Member: fmt.Sprintf("%d-%d", timestamp, i),
|
|
})
|
|
}
|
|
|
|
// Set expiration on the key (window + buffer)
|
|
pipe2.Expire(ctx, redisKey, r.config.Window+time.Minute)
|
|
|
|
// Execute pipeline
|
|
_, err = pipe2.Exec(ctx)
|
|
if err != nil && err != redis.Nil {
|
|
r.logger.Error("failed to record request",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
// Already counted, so return true
|
|
return true, err
|
|
}
|
|
|
|
r.logger.Debug("rate limit check passed",
|
|
zap.String("key", key),
|
|
zap.Int64("current_count", currentCount),
|
|
zap.Int("max_requests", r.config.MaxRequests))
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// Reset resets the rate limit for a key
|
|
func (r *rateLimiter) Reset(ctx context.Context, key string) error {
|
|
redisKey := r.getRedisKey(key)
|
|
err := r.client.Del(ctx, redisKey).Err()
|
|
if err != nil {
|
|
r.logger.Error("failed to reset rate limit",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
r.logger.Info("rate limit reset",
|
|
zap.String("key", key))
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetRemaining returns the number of remaining requests in the current window
|
|
func (r *rateLimiter) GetRemaining(ctx context.Context, key string) (int, error) {
|
|
redisKey := r.getRedisKey(key)
|
|
now := time.Now()
|
|
windowStart := now.Add(-r.config.Window)
|
|
|
|
// Count current requests in window
|
|
count, err := r.client.ZCount(ctx, redisKey,
|
|
fmt.Sprintf("%d", windowStart.UnixNano()),
|
|
"+inf").Result()
|
|
|
|
if err != nil && err != redis.Nil {
|
|
r.logger.Error("failed to get remaining requests",
|
|
zap.String("key", key),
|
|
zap.Error(err))
|
|
return 0, err
|
|
}
|
|
|
|
remaining := r.config.MaxRequests - int(count)
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
return remaining, nil
|
|
}
|
|
|
|
// getRedisKey constructs the Redis key with prefix
|
|
func (r *rateLimiter) getRedisKey(key string) string {
|
|
return fmt.Sprintf("%s:%s", r.config.KeyPrefix, key)
|
|
}
|