monorepo/cloud/maplepress-backend/pkg/ratelimit/ratelimiter.go

172 lines
4.6 KiB
Go

package ratelimit
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// RateLimiter provides rate limiting functionality using Redis
type RateLimiter interface {
// Allow checks if a request should be allowed based on the key
// Returns true if allowed, false if rate limit exceeded
Allow(ctx context.Context, key string) (bool, error)
// AllowN checks if N requests should be allowed
AllowN(ctx context.Context, key string, n int) (bool, error)
// Reset resets the rate limit for a key
Reset(ctx context.Context, key string) error
// GetRemaining returns the number of remaining requests
GetRemaining(ctx context.Context, key string) (int, error)
}
// Config holds rate limiter configuration
type Config struct {
// MaxRequests is the maximum number of requests allowed
MaxRequests int
// Window is the time window for rate limiting
Window time.Duration
// KeyPrefix is the prefix for Redis keys
KeyPrefix string
}
type rateLimiter struct {
client *redis.Client
config Config
logger *zap.Logger
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(client *redis.Client, config Config, logger *zap.Logger) RateLimiter {
return &rateLimiter{
client: client,
config: config,
logger: logger.Named("rate-limiter"),
}
}
// Allow checks if a request should be allowed
func (r *rateLimiter) Allow(ctx context.Context, key string) (bool, error) {
return r.AllowN(ctx, key, 1)
}
// AllowN checks if N requests should be allowed using sliding window counter
func (r *rateLimiter) AllowN(ctx context.Context, key string, n int) (bool, error) {
redisKey := r.getRedisKey(key)
now := time.Now()
windowStart := now.Add(-r.config.Window)
// Use Redis transaction to ensure atomicity
pipe := r.client.Pipeline()
// Remove old entries outside the window
pipe.ZRemRangeByScore(ctx, redisKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
// Count current requests in window
countCmd := pipe.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf")
// Execute pipeline
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
r.logger.Error("failed to check rate limit",
zap.String("key", key),
zap.Error(err))
// Fail open: allow request if Redis is down
return true, err
}
currentCount := countCmd.Val()
// Check if adding N requests would exceed limit
if currentCount+int64(n) > int64(r.config.MaxRequests) {
r.logger.Warn("rate limit exceeded",
zap.String("key", key),
zap.Int64("current_count", currentCount),
zap.Int("max_requests", r.config.MaxRequests))
return false, nil
}
// Add the new request(s) to the sorted set
pipe2 := r.client.Pipeline()
for i := 0; i < n; i++ {
// Use nanosecond timestamp with incremental offset to ensure uniqueness
timestamp := now.Add(time.Duration(i) * time.Nanosecond).UnixNano()
pipe2.ZAdd(ctx, redisKey, redis.Z{
Score: float64(timestamp),
Member: fmt.Sprintf("%d-%d", timestamp, i),
})
}
// Set expiration on the key (window + buffer)
pipe2.Expire(ctx, redisKey, r.config.Window+time.Minute)
// Execute pipeline
_, err = pipe2.Exec(ctx)
if err != nil && err != redis.Nil {
r.logger.Error("failed to record request",
zap.String("key", key),
zap.Error(err))
// Already counted, so return true
return true, err
}
r.logger.Debug("rate limit check passed",
zap.String("key", key),
zap.Int64("current_count", currentCount),
zap.Int("max_requests", r.config.MaxRequests))
return true, nil
}
// Reset resets the rate limit for a key
func (r *rateLimiter) Reset(ctx context.Context, key string) error {
redisKey := r.getRedisKey(key)
err := r.client.Del(ctx, redisKey).Err()
if err != nil {
r.logger.Error("failed to reset rate limit",
zap.String("key", key),
zap.Error(err))
return err
}
r.logger.Info("rate limit reset",
zap.String("key", key))
return nil
}
// GetRemaining returns the number of remaining requests in the current window
func (r *rateLimiter) GetRemaining(ctx context.Context, key string) (int, error) {
redisKey := r.getRedisKey(key)
now := time.Now()
windowStart := now.Add(-r.config.Window)
// Count current requests in window
count, err := r.client.ZCount(ctx, redisKey,
fmt.Sprintf("%d", windowStart.UnixNano()),
"+inf").Result()
if err != nil && err != redis.Nil {
r.logger.Error("failed to get remaining requests",
zap.String("key", key),
zap.Error(err))
return 0, err
}
remaining := r.config.MaxRequests - int(count)
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// getRedisKey constructs the Redis key with prefix
func (r *rateLimiter) getRedisKey(key string) string {
return fmt.Sprintf("%s:%s", r.config.KeyPrefix, key)
}