366 lines
No EOL
12 KiB
Go
366 lines
No EOL
12 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// AuthFailureRateLimiter provides specialized rate limiting for authorization failures
|
|
// to protect against privilege escalation and unauthorized access attempts
|
|
type AuthFailureRateLimiter interface {
|
|
// CheckAuthFailure checks if the user has exceeded authorization failure limits
|
|
// Returns: allowed (bool), remainingAttempts (int), resetTime (time.Time), error
|
|
CheckAuthFailure(ctx context.Context, userID string, resourceID string, action string) (bool, int, time.Time, error)
|
|
|
|
// RecordAuthFailure records an authorization failure
|
|
RecordAuthFailure(ctx context.Context, userID string, resourceID string, action string, reason string) error
|
|
|
|
// RecordAuthSuccess records a successful authorization (optionally resets counters)
|
|
RecordAuthSuccess(ctx context.Context, userID string, resourceID string, action string) error
|
|
|
|
// IsUserBlocked checks if a user is temporarily blocked from authorization attempts
|
|
IsUserBlocked(ctx context.Context, userID string) (bool, time.Duration, error)
|
|
|
|
// GetFailureCount returns the number of authorization failures for a user
|
|
GetFailureCount(ctx context.Context, userID string) (int, error)
|
|
|
|
// GetResourceFailureCount returns failures for a specific resource
|
|
GetResourceFailureCount(ctx context.Context, userID string, resourceID string) (int, error)
|
|
|
|
// ResetUserFailures manually resets failure counters for a user
|
|
ResetUserFailures(ctx context.Context, userID string) error
|
|
}
|
|
|
|
// AuthFailureRateLimiterConfig holds configuration for authorization failure rate limiting
|
|
type AuthFailureRateLimiterConfig struct {
|
|
// MaxFailuresPerUser is the maximum authorization failures per user before blocking
|
|
MaxFailuresPerUser int
|
|
// MaxFailuresPerResource is the maximum failures per resource per user
|
|
MaxFailuresPerResource int
|
|
// FailureWindow is the time window for tracking failures
|
|
FailureWindow time.Duration
|
|
// BlockDuration is how long to block a user after exceeding limits
|
|
BlockDuration time.Duration
|
|
// AlertThreshold is the number of failures before alerting (for monitoring)
|
|
AlertThreshold int
|
|
// KeyPrefix is the prefix for Redis keys
|
|
KeyPrefix string
|
|
}
|
|
|
|
// DefaultAuthFailureRateLimiterConfig returns recommended configuration
|
|
// Following OWASP guidelines for authorization failure handling
|
|
func DefaultAuthFailureRateLimiterConfig() AuthFailureRateLimiterConfig {
|
|
return AuthFailureRateLimiterConfig{
|
|
MaxFailuresPerUser: 20, // 20 total auth failures per user
|
|
MaxFailuresPerResource: 5, // 5 failures per specific resource
|
|
FailureWindow: 15 * time.Minute, // in 15-minute window
|
|
BlockDuration: 30 * time.Minute, // block for 30 minutes
|
|
AlertThreshold: 10, // alert after 10 failures
|
|
KeyPrefix: "auth_fail_rl",
|
|
}
|
|
}
|
|
|
|
type authFailureRateLimiter struct {
|
|
client *redis.Client
|
|
config AuthFailureRateLimiterConfig
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewAuthFailureRateLimiter creates a new authorization failure rate limiter
|
|
func NewAuthFailureRateLimiter(client *redis.Client, config AuthFailureRateLimiterConfig, logger *zap.Logger) AuthFailureRateLimiter {
|
|
return &authFailureRateLimiter{
|
|
client: client,
|
|
config: config,
|
|
logger: logger.Named("auth-failure-rate-limiter"),
|
|
}
|
|
}
|
|
|
|
// CheckAuthFailure checks if the user has exceeded authorization failure limits
|
|
// CWE-307: Protection against authorization brute force attacks
|
|
// OWASP A01:2021: Broken Access Control - Rate limiting authorization failures
|
|
func (r *authFailureRateLimiter) CheckAuthFailure(ctx context.Context, userID string, resourceID string, action string) (bool, int, time.Time, error) {
|
|
// Check if user is blocked
|
|
blocked, remaining, err := r.IsUserBlocked(ctx, userID)
|
|
if err != nil {
|
|
r.logger.Error("failed to check user block status",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Error(err))
|
|
// Fail open on Redis error (security vs availability trade-off)
|
|
return true, 0, time.Time{}, err
|
|
}
|
|
|
|
if blocked {
|
|
resetTime := time.Now().Add(remaining)
|
|
r.logger.Warn("blocked user attempted authorization",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.String("action", action),
|
|
zap.Duration("remaining_block", remaining))
|
|
return false, 0, resetTime, nil
|
|
}
|
|
|
|
// Check per-user failure count
|
|
userFailures, err := r.GetFailureCount(ctx, userID)
|
|
if err != nil {
|
|
r.logger.Error("failed to get user failure count",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Error(err))
|
|
// Fail open on Redis error
|
|
return true, 0, time.Time{}, err
|
|
}
|
|
|
|
// Check per-resource failure count
|
|
resourceFailures, err := r.GetResourceFailureCount(ctx, userID, resourceID)
|
|
if err != nil {
|
|
r.logger.Error("failed to get resource failure count",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.Error(err))
|
|
// Fail open on Redis error
|
|
return true, 0, time.Time{}, err
|
|
}
|
|
|
|
// Check if limits exceeded
|
|
if userFailures >= r.config.MaxFailuresPerUser {
|
|
r.blockUser(ctx, userID)
|
|
resetTime := time.Now().Add(r.config.BlockDuration)
|
|
r.logger.Warn("user exceeded authorization failure limit",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Int("failures", userFailures))
|
|
return false, 0, resetTime, nil
|
|
}
|
|
|
|
if resourceFailures >= r.config.MaxFailuresPerResource {
|
|
resetTime := time.Now().Add(r.config.FailureWindow)
|
|
r.logger.Warn("user exceeded resource-specific failure limit",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.Int("failures", resourceFailures))
|
|
return false, r.config.MaxFailuresPerUser - userFailures, resetTime, nil
|
|
}
|
|
|
|
remainingAttempts := r.config.MaxFailuresPerUser - userFailures
|
|
resetTime := time.Now().Add(r.config.FailureWindow)
|
|
|
|
return true, remainingAttempts, resetTime, nil
|
|
}
|
|
|
|
// RecordAuthFailure records an authorization failure
|
|
// CWE-778: Insufficient Logging of security events
|
|
func (r *authFailureRateLimiter) RecordAuthFailure(ctx context.Context, userID string, resourceID string, action string, reason string) error {
|
|
now := time.Now()
|
|
timestamp := now.UnixNano()
|
|
|
|
// Record per-user failure
|
|
userKey := r.getUserFailureKey(userID)
|
|
pipe := r.client.Pipeline()
|
|
|
|
// Add to sorted set with timestamp as score (for windowing)
|
|
pipe.ZAdd(ctx, userKey, redis.Z{
|
|
Score: float64(timestamp),
|
|
Member: fmt.Sprintf("%d:%s:%s", timestamp, resourceID, action),
|
|
})
|
|
pipe.Expire(ctx, userKey, r.config.FailureWindow)
|
|
|
|
// Record per-resource failure
|
|
if resourceID != "" {
|
|
resourceKey := r.getResourceFailureKey(userID, resourceID)
|
|
pipe.ZAdd(ctx, resourceKey, redis.Z{
|
|
Score: float64(timestamp),
|
|
Member: fmt.Sprintf("%d:%s", timestamp, action),
|
|
})
|
|
pipe.Expire(ctx, resourceKey, r.config.FailureWindow)
|
|
}
|
|
|
|
// Increment total failure counter for metrics
|
|
metricsKey := r.getMetricsKey(userID)
|
|
pipe.Incr(ctx, metricsKey)
|
|
pipe.Expire(ctx, metricsKey, 24*time.Hour) // Keep metrics for 24 hours
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
r.logger.Error("failed to record authorization failure",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.String("action", action),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
// Check if we should alert
|
|
count, _ := r.GetFailureCount(ctx, userID)
|
|
if count == r.config.AlertThreshold {
|
|
r.logger.Error("SECURITY ALERT: User reached authorization failure alert threshold",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.String("action", action),
|
|
zap.String("reason", reason),
|
|
zap.Int("failure_count", count))
|
|
}
|
|
|
|
r.logger.Warn("authorization failure recorded",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.String("action", action),
|
|
zap.String("reason", reason),
|
|
zap.Int("total_failures", count))
|
|
|
|
return nil
|
|
}
|
|
|
|
// RecordAuthSuccess records a successful authorization
|
|
func (r *authFailureRateLimiter) RecordAuthSuccess(ctx context.Context, userID string, resourceID string, action string) error {
|
|
// Optionally, we could reset or reduce failure counts on success
|
|
// For now, we just log the success for audit purposes
|
|
r.logger.Debug("authorization success recorded",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.String("resource_id_hash", hashID(resourceID)),
|
|
zap.String("action", action))
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsUserBlocked checks if a user is temporarily blocked
|
|
func (r *authFailureRateLimiter) IsUserBlocked(ctx context.Context, userID string) (bool, time.Duration, error) {
|
|
blockKey := r.getBlockKey(userID)
|
|
|
|
ttl, err := r.client.TTL(ctx, blockKey).Result()
|
|
if err != nil {
|
|
return false, 0, err
|
|
}
|
|
|
|
// TTL returns -2 if key doesn't exist, -1 if no expiration
|
|
if ttl < 0 {
|
|
return false, 0, nil
|
|
}
|
|
|
|
return true, ttl, nil
|
|
}
|
|
|
|
// GetFailureCount returns the number of authorization failures for a user
|
|
func (r *authFailureRateLimiter) GetFailureCount(ctx context.Context, userID string) (int, error) {
|
|
userKey := r.getUserFailureKey(userID)
|
|
now := time.Now()
|
|
windowStart := now.Add(-r.config.FailureWindow)
|
|
|
|
// Remove old entries outside the window
|
|
r.client.ZRemRangeByScore(ctx, userKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
|
|
|
// Count current failures in window
|
|
count, err := r.client.ZCount(ctx, userKey,
|
|
fmt.Sprintf("%d", windowStart.UnixNano()),
|
|
"+inf").Result()
|
|
|
|
if err != nil && err != redis.Nil {
|
|
return 0, err
|
|
}
|
|
|
|
return int(count), nil
|
|
}
|
|
|
|
// GetResourceFailureCount returns failures for a specific resource
|
|
func (r *authFailureRateLimiter) GetResourceFailureCount(ctx context.Context, userID string, resourceID string) (int, error) {
|
|
if resourceID == "" {
|
|
return 0, nil
|
|
}
|
|
|
|
resourceKey := r.getResourceFailureKey(userID, resourceID)
|
|
now := time.Now()
|
|
windowStart := now.Add(-r.config.FailureWindow)
|
|
|
|
// Remove old entries
|
|
r.client.ZRemRangeByScore(ctx, resourceKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
|
|
|
// Count current failures
|
|
count, err := r.client.ZCount(ctx, resourceKey,
|
|
fmt.Sprintf("%d", windowStart.UnixNano()),
|
|
"+inf").Result()
|
|
|
|
if err != nil && err != redis.Nil {
|
|
return 0, err
|
|
}
|
|
|
|
return int(count), nil
|
|
}
|
|
|
|
// ResetUserFailures manually resets failure counters for a user
|
|
func (r *authFailureRateLimiter) ResetUserFailures(ctx context.Context, userID string) error {
|
|
pattern := fmt.Sprintf("%s:user:%s:*", r.config.KeyPrefix, hashID(userID))
|
|
|
|
// Find all keys for this user
|
|
keys, err := r.client.Keys(ctx, pattern).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(keys) > 0 {
|
|
pipe := r.client.Pipeline()
|
|
for _, key := range keys {
|
|
pipe.Del(ctx, key)
|
|
}
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
r.logger.Error("failed to reset user failures",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
|
|
r.logger.Info("user authorization failures reset",
|
|
zap.String("user_id_hash", hashID(userID)))
|
|
|
|
return nil
|
|
}
|
|
|
|
// blockUser blocks a user from further authorization attempts
|
|
func (r *authFailureRateLimiter) blockUser(ctx context.Context, userID string) error {
|
|
blockKey := r.getBlockKey(userID)
|
|
err := r.client.Set(ctx, blockKey, "blocked", r.config.BlockDuration).Err()
|
|
if err != nil {
|
|
r.logger.Error("failed to block user",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
r.logger.Error("SECURITY: User blocked due to excessive authorization failures",
|
|
zap.String("user_id_hash", hashID(userID)),
|
|
zap.Duration("block_duration", r.config.BlockDuration))
|
|
|
|
return nil
|
|
}
|
|
|
|
// Key generation helpers
|
|
func (r *authFailureRateLimiter) getUserFailureKey(userID string) string {
|
|
return fmt.Sprintf("%s:user:%s:failures", r.config.KeyPrefix, hashID(userID))
|
|
}
|
|
|
|
func (r *authFailureRateLimiter) getResourceFailureKey(userID string, resourceID string) string {
|
|
return fmt.Sprintf("%s:user:%s:resource:%s:failures", r.config.KeyPrefix, hashID(userID), hashID(resourceID))
|
|
}
|
|
|
|
func (r *authFailureRateLimiter) getBlockKey(userID string) string {
|
|
return fmt.Sprintf("%s:user:%s:blocked", r.config.KeyPrefix, hashID(userID))
|
|
}
|
|
|
|
func (r *authFailureRateLimiter) getMetricsKey(userID string) string {
|
|
return fmt.Sprintf("%s:user:%s:metrics", r.config.KeyPrefix, hashID(userID))
|
|
}
|
|
|
|
// hashID creates a consistent hash of an ID for use as a Redis key component
|
|
// CWE-532: Prevents sensitive IDs in Redis keys
|
|
func hashID(id string) string {
|
|
if id == "" {
|
|
return "empty"
|
|
}
|
|
hash := sha256.Sum256([]byte(id))
|
|
// Return first 16 bytes of hash as hex (32 chars) for shorter keys
|
|
return hex.EncodeToString(hash[:16])
|
|
} |