Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
109
cloud/maplepress-backend/pkg/cache/cassandra.go
vendored
Normal file
109
cloud/maplepress-backend/pkg/cache/cassandra.go
vendored
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CassandraCacher defines the interface for Cassandra cache operations
|
||||
type CassandraCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
type cassandraCache struct {
|
||||
session *gocql.Session
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCassandraCache creates a new Cassandra cache instance
|
||||
func NewCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
logger = logger.Named("cassandra-cache")
|
||||
logger.Info("✓ Cassandra cache layer initialized")
|
||||
return &cassandraCache{
|
||||
session: session,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down Cassandra cache")
|
||||
// Note: Don't close the session here as it's managed by the database layer
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
var value []byte
|
||||
var expiresAt time.Time
|
||||
|
||||
query := `SELECT value, expires_at FROM cache WHERE key = ?`
|
||||
err := c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Scan(&value, &expiresAt)
|
||||
|
||||
if err == gocql.ErrNotFound {
|
||||
// Key doesn't exist - this is not an error
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if expired in application code
|
||||
if time.Now().After(expiresAt) {
|
||||
// Entry is expired, delete it and return nil
|
||||
_ = c.Delete(ctx, key) // Clean up expired entry
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
expiresAt := time.Now().Add(24 * time.Hour) // Default 24 hour expiry
|
||||
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
|
||||
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
expiresAt := time.Now().Add(expiry)
|
||||
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
|
||||
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Delete(ctx context.Context, key string) error {
|
||||
query := `DELETE FROM cache WHERE key = ?`
|
||||
return c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) PurgeExpired(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Thanks to the index on expires_at, this query is efficient
|
||||
iter := c.session.Query(`SELECT key FROM cache WHERE expires_at < ? ALLOW FILTERING`, now).WithContext(ctx).Iter()
|
||||
|
||||
var expiredKeys []string
|
||||
var key string
|
||||
for iter.Scan(&key) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired keys in batch
|
||||
if len(expiredKeys) > 0 {
|
||||
batch := c.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
|
||||
for _, expiredKey := range expiredKeys {
|
||||
batch.Query(`DELETE FROM cache WHERE key = ?`, expiredKey)
|
||||
}
|
||||
return c.session.ExecuteBatch(batch)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
23
cloud/maplepress-backend/pkg/cache/provider.go
vendored
Normal file
23
cloud/maplepress-backend/pkg/cache/provider.go
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// ProvideRedisCache provides a Redis cache instance
|
||||
func ProvideRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
|
||||
return NewRedisCache(cfg, logger)
|
||||
}
|
||||
|
||||
// ProvideCassandraCache provides a Cassandra cache instance
|
||||
func ProvideCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
return NewCassandraCache(session, logger)
|
||||
}
|
||||
|
||||
// ProvideTwoTierCache provides a two-tier cache instance
|
||||
func ProvideTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
|
||||
return NewTwoTierCache(redisCache, cassandraCache, logger)
|
||||
}
|
||||
144
cloud/maplepress-backend/pkg/cache/redis.go
vendored
Normal file
144
cloud/maplepress-backend/pkg/cache/redis.go
vendored
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// silentRedisLogger filters out noisy "maintnotifications" warnings from go-redis
|
||||
// This warning occurs when the Redis client tries to use newer Redis 7.2+ features
|
||||
// that may not be fully supported by the current Redis version.
|
||||
// The client automatically falls back to compatible mode, so this is harmless.
|
||||
type silentRedisLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func (l *silentRedisLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Filter out harmless compatibility warnings
|
||||
if strings.Contains(msg, "maintnotifications disabled") ||
|
||||
strings.Contains(msg, "auto mode fallback") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other Redis messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// RedisCacher defines the interface for Redis cache operations
|
||||
type RedisCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
type redisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRedisCache creates a new Redis cache instance
|
||||
func NewRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
|
||||
logger = logger.Named("redis-cache")
|
||||
|
||||
logger.Info("⏳ Connecting to Redis...",
|
||||
zap.String("host", cfg.Cache.Host),
|
||||
zap.Int("port", cfg.Cache.Port))
|
||||
|
||||
// Build Redis URL from config
|
||||
redisURL := fmt.Sprintf("redis://:%s@%s:%d/%d",
|
||||
cfg.Cache.Password,
|
||||
cfg.Cache.Host,
|
||||
cfg.Cache.Port,
|
||||
cfg.Cache.DB,
|
||||
)
|
||||
|
||||
// If no password, use simpler URL format
|
||||
if cfg.Cache.Password == "" {
|
||||
redisURL = fmt.Sprintf("redis://%s:%d/%d",
|
||||
cfg.Cache.Host,
|
||||
cfg.Cache.Port,
|
||||
cfg.Cache.DB,
|
||||
)
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
|
||||
}
|
||||
|
||||
// Suppress noisy "maintnotifications" warnings from go-redis
|
||||
// Use a custom logger that filters out these harmless compatibility warnings
|
||||
redis.SetLogger(&silentRedisLogger{logger: logger.Named("redis-client")})
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err = client.Ping(ctx).Result(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✓ Redis connected",
|
||||
zap.String("host", cfg.Cache.Host),
|
||||
zap.Int("port", cfg.Cache.Port),
|
||||
zap.Int("db", cfg.Cache.DB))
|
||||
|
||||
return &redisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down Redis cache")
|
||||
if err := c.client.Close(); err != nil {
|
||||
c.logger.Error("error closing Redis connection", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *redisCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.client.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// Key doesn't exist - this is not an error
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis get failed: %w", err)
|
||||
}
|
||||
return []byte(val), nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
if err := c.client.Set(ctx, key, val, 0).Err(); err != nil {
|
||||
return fmt.Errorf("redis set failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *redisCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
if err := c.client.Set(ctx, key, val, expiry).Err(); err != nil {
|
||||
return fmt.Errorf("redis set with expiry failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Delete(ctx context.Context, key string) error {
|
||||
if err := c.client.Del(ctx, key).Err(); err != nil {
|
||||
return fmt.Errorf("redis delete failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
114
cloud/maplepress-backend/pkg/cache/twotier.go
vendored
Normal file
114
cloud/maplepress-backend/pkg/cache/twotier.go
vendored
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/cache/twotier.go
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TwoTierCacher defines the interface for two-tier cache operations
|
||||
type TwoTierCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
// twoTierCache implements a clean 2-layer (read-through write-through) cache
|
||||
//
|
||||
// L1: Redis (fast, in-memory)
|
||||
// L2: Cassandra (persistent)
|
||||
//
|
||||
// On Get: check Redis → then Cassandra → if found in Cassandra → populate Redis
|
||||
// On Set: write to both
|
||||
// On SetWithExpiry: write to both with expiry
|
||||
// On Delete: remove from both
|
||||
type twoTierCache struct {
|
||||
redisCache RedisCacher
|
||||
cassandraCache CassandraCacher
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTwoTierCache creates a new two-tier cache instance
|
||||
func NewTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
|
||||
logger = logger.Named("two-tier-cache")
|
||||
logger.Info("✓ Two-tier cache initialized (Redis L1 + Cassandra L2)")
|
||||
return &twoTierCache{
|
||||
redisCache: redisCache,
|
||||
cassandraCache: cassandraCache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
// Try L1 (Redis) first
|
||||
val, err := c.redisCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
c.logger.Debug("cache hit from Redis", zap.String("key", key))
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Not in Redis, try L2 (Cassandra)
|
||||
val, err = c.cassandraCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
// Found in Cassandra, populate Redis for future lookups
|
||||
c.logger.Debug("cache hit from Cassandra, writing back to Redis", zap.String("key", key))
|
||||
_ = c.redisCache.Set(ctx, key, val) // Best effort, don't fail if Redis write fails
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
// Write to both layers
|
||||
if err := c.redisCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
// Write to both layers with expiry
|
||||
if err := c.redisCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Delete(ctx context.Context, key string) error {
|
||||
// Remove from both layers
|
||||
if err := c.redisCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) PurgeExpired(ctx context.Context) error {
|
||||
// Only Cassandra needs purging (Redis handles TTL automatically)
|
||||
return c.cassandraCache.PurgeExpired(ctx)
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down two-tier cache")
|
||||
c.redisCache.Shutdown(ctx)
|
||||
c.cassandraCache.Shutdown(ctx)
|
||||
c.logger.Info("two-tier cache shutdown complete")
|
||||
}
|
||||
237
cloud/maplepress-backend/pkg/distributedmutex/README.md
Normal file
237
cloud/maplepress-backend/pkg/distributedmutex/README.md
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
# Distributed Mutex
|
||||
|
||||
A Redis-based distributed mutex implementation for coordinating access to shared resources across multiple application instances.
|
||||
|
||||
## Overview
|
||||
|
||||
This package provides a distributed locking mechanism using Redis as the coordination backend. It's built on top of the `redislock` library and provides a simple interface for acquiring and releasing locks across distributed systems.
|
||||
|
||||
## Features
|
||||
|
||||
- **Distributed Locking**: Coordinate access to shared resources across multiple application instances
|
||||
- **Automatic Retry**: Built-in retry logic with configurable backoff strategy
|
||||
- **Thread-Safe**: Safe for concurrent use within a single application
|
||||
- **Formatted Keys**: Support for formatted lock keys using `Acquiref` and `Releasef`
|
||||
- **Logging**: Integrated zap logging for debugging and monitoring
|
||||
|
||||
## Installation
|
||||
|
||||
The package is already included in the project. The required dependency (`github.com/bsm/redislock`) is automatically installed.
|
||||
|
||||
## Interface
|
||||
|
||||
```go
|
||||
type Adapter interface {
|
||||
Acquire(ctx context.Context, key string)
|
||||
Acquiref(ctx context.Context, format string, a ...any)
|
||||
Release(ctx context.Context, key string)
|
||||
Releasef(ctx context.Context, format string, a ...any)
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/distributedmutex"
|
||||
)
|
||||
|
||||
// Create Redis client
|
||||
redisClient := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
})
|
||||
|
||||
// Create logger
|
||||
logger, _ := zap.NewProduction()
|
||||
|
||||
// Create distributed mutex adapter
|
||||
mutex := distributedmutex.NewAdapter(logger, redisClient)
|
||||
|
||||
// Acquire a lock
|
||||
ctx := context.Background()
|
||||
mutex.Acquire(ctx, "my-resource-key")
|
||||
|
||||
// ... perform operations on the protected resource ...
|
||||
|
||||
// Release the lock
|
||||
mutex.Release(ctx, "my-resource-key")
|
||||
```
|
||||
|
||||
### Formatted Keys Example
|
||||
|
||||
```go
|
||||
// Acquire lock with formatted key
|
||||
tenantID := "tenant-123"
|
||||
resourceID := "resource-456"
|
||||
|
||||
mutex.Acquiref(ctx, "tenant:%s:resource:%s", tenantID, resourceID)
|
||||
|
||||
// ... perform operations ...
|
||||
|
||||
mutex.Releasef(ctx, "tenant:%s:resource:%s", tenantID, resourceID)
|
||||
```
|
||||
|
||||
### Integration with Dependency Injection (Wire)
|
||||
|
||||
```go
|
||||
// In your Wire provider set
|
||||
wire.NewSet(
|
||||
distributedmutex.ProvideDistributedMutexAdapter,
|
||||
// ... other providers
|
||||
)
|
||||
|
||||
// Use in your application
|
||||
func NewMyService(mutex distributedmutex.Adapter) *MyService {
|
||||
return &MyService{
|
||||
mutex: mutex,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Lock Duration
|
||||
|
||||
The default lock duration is **1 minute**. Locks are automatically released after this time to prevent deadlocks.
|
||||
|
||||
### Retry Strategy
|
||||
|
||||
- **Retry Interval**: 250ms
|
||||
- **Max Retries**: 20 attempts
|
||||
- **Total Max Wait Time**: ~5 seconds (20 × 250ms)
|
||||
|
||||
If a lock cannot be obtained after all retries, an error is logged and the `Acquire` method returns without blocking indefinitely.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always Release Locks**: Ensure locks are released even in error cases using `defer`
|
||||
```go
|
||||
mutex.Acquire(ctx, "my-key")
|
||||
defer mutex.Release(ctx, "my-key")
|
||||
```
|
||||
|
||||
2. **Use Descriptive Keys**: Use clear, hierarchical key names
|
||||
```go
|
||||
// Good
|
||||
mutex.Acquire(ctx, "tenant:123:user:456:update")
|
||||
|
||||
// Not ideal
|
||||
mutex.Acquire(ctx, "lock1")
|
||||
```
|
||||
|
||||
3. **Keep Critical Sections Short**: Minimize the time locks are held to improve concurrency
|
||||
|
||||
4. **Handle Timeouts**: Use context with timeout for critical operations
|
||||
```go
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
mutex.Acquire(ctx, "my-key")
|
||||
```
|
||||
|
||||
5. **Avoid Nested Locks**: Be careful with acquiring multiple locks to avoid deadlocks
|
||||
|
||||
## Logging
|
||||
|
||||
The adapter logs the following events:
|
||||
|
||||
- **Debug**: Lock acquisition and release operations
|
||||
- **Error**: Failed lock acquisitions, timeout errors, and release failures
|
||||
- **Warn**: Attempts to release non-existent locks
|
||||
|
||||
## Thread Safety
|
||||
|
||||
The adapter is safe for concurrent use within a single application instance. It uses an internal mutex to protect the lock instances map from concurrent access by multiple goroutines.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The current implementation logs errors but does not return them. Consider this when using the adapter:
|
||||
|
||||
- Lock acquisition failures are logged but don't panic
|
||||
- The application continues running even if locks fail
|
||||
- Check logs for lock-related issues in production
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Lock Duration**: Locks automatically expire after 1 minute
|
||||
2. **No Lock Extension**: Currently doesn't support extending lock duration
|
||||
3. **No Deadlock Detection**: Manual deadlock prevention is required
|
||||
4. **Redis Dependency**: Requires a running Redis instance
|
||||
|
||||
## Example Use Cases
|
||||
|
||||
### Preventing Duplicate Processing
|
||||
|
||||
```go
|
||||
func ProcessJob(ctx context.Context, jobID string, mutex distributedmutex.Adapter) {
|
||||
lockKey := fmt.Sprintf("job:processing:%s", jobID)
|
||||
|
||||
mutex.Acquire(ctx, lockKey)
|
||||
defer mutex.Release(ctx, lockKey)
|
||||
|
||||
// Process job - guaranteed only one instance processes this job
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### Coordinating Resource Updates
|
||||
|
||||
```go
|
||||
func UpdateTenantSettings(ctx context.Context, tenantID string, mutex distributedmutex.Adapter) error {
|
||||
mutex.Acquiref(ctx, "tenant:%s:settings:update", tenantID)
|
||||
defer mutex.Releasef(ctx, "tenant:%s:settings:update", tenantID)
|
||||
|
||||
// Safe to update tenant settings
|
||||
// ...
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Rate Limiting Operations
|
||||
|
||||
```go
|
||||
func RateLimitedOperation(ctx context.Context, userID string, mutex distributedmutex.Adapter) {
|
||||
lockKey := fmt.Sprintf("ratelimit:user:%s", userID)
|
||||
|
||||
mutex.Acquire(ctx, lockKey)
|
||||
defer mutex.Release(ctx, lockKey)
|
||||
|
||||
// Perform rate-limited operation
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Lock Not Acquired
|
||||
|
||||
**Problem**: Locks are not being acquired (error in logs)
|
||||
|
||||
**Solutions**:
|
||||
- Verify Redis is running and accessible
|
||||
- Check network connectivity to Redis
|
||||
- Ensure Redis has sufficient memory
|
||||
- Check for Redis errors in logs
|
||||
|
||||
### Lock Contention
|
||||
|
||||
**Problem**: Frequent lock acquisition failures due to contention
|
||||
|
||||
**Solutions**:
|
||||
- Reduce critical section duration
|
||||
- Use more specific lock keys to reduce contention
|
||||
- Consider increasing retry limits if appropriate
|
||||
- Review application architecture for excessive locking
|
||||
|
||||
### Memory Leaks
|
||||
|
||||
**Problem**: Lock instances accumulating in memory
|
||||
|
||||
**Solutions**:
|
||||
- Ensure all `Acquire` calls have corresponding `Release` calls
|
||||
- Use `defer` to guarantee lock release
|
||||
- Monitor lock instance map size in production
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
package distributedmutex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bsm/redislock"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Adapter provides interface for abstracting distributed mutex operations.
|
||||
// CWE-755: Methods now return errors to properly handle exceptional conditions
|
||||
type Adapter interface {
|
||||
Acquire(ctx context.Context, key string) error
|
||||
Acquiref(ctx context.Context, format string, a ...any) error
|
||||
Release(ctx context.Context, key string) error
|
||||
Releasef(ctx context.Context, format string, a ...any) error
|
||||
}
|
||||
|
||||
type distributedMutexAdapter struct {
|
||||
logger *zap.Logger
|
||||
redis redis.UniversalClient
|
||||
locker *redislock.Client
|
||||
lockInstances map[string]*redislock.Lock
|
||||
mutex *sync.Mutex // Mutex for synchronization with goroutines
|
||||
}
|
||||
|
||||
// NewAdapter constructor that returns the default distributed mutex adapter.
|
||||
func NewAdapter(logger *zap.Logger, redisClient redis.UniversalClient) Adapter {
|
||||
logger = logger.Named("distributed-mutex")
|
||||
|
||||
// Create a new lock client
|
||||
locker := redislock.New(redisClient)
|
||||
|
||||
logger.Info("✓ Distributed mutex initialized (Redis-backed)")
|
||||
|
||||
return &distributedMutexAdapter{
|
||||
logger: logger,
|
||||
redis: redisClient,
|
||||
locker: locker,
|
||||
lockInstances: make(map[string]*redislock.Lock),
|
||||
mutex: &sync.Mutex{}, // Initialize the mutex
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire function blocks the current thread if the lock key is currently locked.
|
||||
// CWE-755: Now returns error instead of silently failing
|
||||
func (a *distributedMutexAdapter) Acquire(ctx context.Context, key string) error {
|
||||
startDT := time.Now()
|
||||
a.logger.Debug("acquiring lock", zap.String("key", key))
|
||||
|
||||
// Retry every 250ms, for up to 20x
|
||||
backoff := redislock.LimitRetry(redislock.LinearBackoff(250*time.Millisecond), 20)
|
||||
|
||||
// Obtain lock with retry
|
||||
lock, err := a.locker.Obtain(ctx, key, time.Minute, &redislock.Options{
|
||||
RetryStrategy: backoff,
|
||||
})
|
||||
if err == redislock.ErrNotObtained {
|
||||
nowDT := time.Now()
|
||||
diff := nowDT.Sub(startDT)
|
||||
a.logger.Error("could not obtain lock after retries",
|
||||
zap.String("key", key),
|
||||
zap.Time("start_dt", startDT),
|
||||
zap.Time("now_dt", nowDT),
|
||||
zap.Duration("duration", diff),
|
||||
zap.Int("max_retries", 20))
|
||||
return fmt.Errorf("could not obtain lock after 20 retries (waited %s): %w", diff, err)
|
||||
} else if err != nil {
|
||||
a.logger.Error("failed obtaining lock",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to obtain lock: %w", err)
|
||||
}
|
||||
|
||||
// DEVELOPERS NOTE:
|
||||
// The `map` datastructure in Golang is not concurrently safe, therefore we
|
||||
// need to use mutex to coordinate access of our `lockInstances` map
|
||||
// resource between all the goroutines.
|
||||
a.mutex.Lock()
|
||||
defer a.mutex.Unlock()
|
||||
|
||||
if a.lockInstances != nil { // Defensive code
|
||||
a.lockInstances[key] = lock
|
||||
}
|
||||
|
||||
a.logger.Debug("lock acquired", zap.String("key", key))
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// Acquiref function blocks the current thread if the lock key is currently locked.
|
||||
// CWE-755: Now returns error from Acquire
|
||||
func (a *distributedMutexAdapter) Acquiref(ctx context.Context, format string, args ...any) error {
|
||||
key := fmt.Sprintf(format, args...)
|
||||
return a.Acquire(ctx, key)
|
||||
}
|
||||
|
||||
// Release function releases the lock for the given key.
|
||||
// CWE-755: Now returns error instead of silently failing
|
||||
func (a *distributedMutexAdapter) Release(ctx context.Context, key string) error {
|
||||
a.logger.Debug("releasing lock", zap.String("key", key))
|
||||
|
||||
// DEVELOPERS NOTE:
|
||||
// The `map` datastructure in Golang is not concurrently safe, therefore we
|
||||
// need to use mutex to coordinate access of our `lockInstances` map
|
||||
// resource between all the goroutines.
|
||||
a.mutex.Lock()
|
||||
lockInstance, ok := a.lockInstances[key]
|
||||
if ok {
|
||||
delete(a.lockInstances, key)
|
||||
}
|
||||
a.mutex.Unlock()
|
||||
|
||||
if ok {
|
||||
if err := lockInstance.Release(ctx); err != nil {
|
||||
a.logger.Error("failed to release lock",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to release lock: %w", err)
|
||||
}
|
||||
a.logger.Debug("lock released", zap.String("key", key))
|
||||
return nil // Success
|
||||
}
|
||||
|
||||
// Lock not found - this is a warning but not an error (may have already been released)
|
||||
a.logger.Warn("lock not found for release", zap.String("key", key))
|
||||
return nil // Not an error, just not found
|
||||
}
|
||||
|
||||
// Releasef function releases the lock for a formatted key.
|
||||
// CWE-755: Now returns error from Release
|
||||
func (a *distributedMutexAdapter) Releasef(ctx context.Context, format string, args ...any) error {
|
||||
key := fmt.Sprintf(format, args...)
|
||||
return a.Release(ctx, key)
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
package distributedmutex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// mockRedisClient implements minimal required methods for testing
|
||||
type mockRedisClient struct {
|
||||
redis.UniversalClient
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Get(ctx context.Context, key string) *redis.StringCmd {
|
||||
return redis.NewStringCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd {
|
||||
return redis.NewStatusCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...any) *redis.Cmd {
|
||||
return redis.NewCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) EvalSha(ctx context.Context, sha string, keys []string, args ...any) *redis.Cmd {
|
||||
return redis.NewCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) ScriptExists(ctx context.Context, scripts ...string) *redis.BoolSliceCmd {
|
||||
return redis.NewBoolSliceCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) ScriptLoad(ctx context.Context, script string) *redis.StringCmd {
|
||||
return redis.NewStringCmd(ctx)
|
||||
}
|
||||
|
||||
func TestNewAdapter(t *testing.T) {
|
||||
logger, _ := zap.NewDevelopment()
|
||||
adapter := NewAdapter(logger, &mockRedisClient{})
|
||||
if adapter == nil {
|
||||
t.Fatal("expected non-nil adapter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireAndRelease(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger, _ := zap.NewDevelopment()
|
||||
adapter := NewAdapter(logger, &mockRedisClient{})
|
||||
|
||||
// Test basic acquire/release
|
||||
adapter.Acquire(ctx, "test-key")
|
||||
adapter.Release(ctx, "test-key")
|
||||
|
||||
// Test formatted acquire/release
|
||||
adapter.Acquiref(ctx, "test-key-%d", 1)
|
||||
adapter.Releasef(ctx, "test-key-%d", 1)
|
||||
}
|
||||
|
||||
func TestReleaseNonExistentLock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger, _ := zap.NewDevelopment()
|
||||
adapter := NewAdapter(logger, &mockRedisClient{})
|
||||
|
||||
// This should not panic, just log a warning
|
||||
adapter.Release(ctx, "non-existent-key")
|
||||
}
|
||||
13
cloud/maplepress-backend/pkg/distributedmutex/provider.go
Normal file
13
cloud/maplepress-backend/pkg/distributedmutex/provider.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
package distributedmutex
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProvideDistributedMutexAdapter creates a new distributed mutex adapter instance.
|
||||
// Accepts *redis.Client which implements redis.UniversalClient interface
|
||||
func ProvideDistributedMutexAdapter(logger *zap.Logger, redisClient *redis.Client) Adapter {
|
||||
// redis.Client implements redis.UniversalClient, so we can pass it directly
|
||||
return NewAdapter(logger, redisClient)
|
||||
}
|
||||
113
cloud/maplepress-backend/pkg/dns/verifier.go
Normal file
113
cloud/maplepress-backend/pkg/dns/verifier.go
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Verifier handles DNS TXT record verification
|
||||
type Verifier struct {
|
||||
resolver *net.Resolver
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// ProvideVerifier creates a new DNS Verifier
|
||||
func ProvideVerifier(logger *zap.Logger) *Verifier {
|
||||
return &Verifier{
|
||||
resolver: &net.Resolver{
|
||||
PreferGo: true, // Use Go's DNS resolver
|
||||
},
|
||||
logger: logger.Named("dns-verifier"),
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyDomainOwnership checks if a domain has the correct TXT record
|
||||
// Expected format: "maplepress-verify=TOKEN"
|
||||
func (v *Verifier) VerifyDomainOwnership(ctx context.Context, domain string, expectedToken string) (bool, error) {
|
||||
v.logger.Info("verifying domain ownership via DNS",
|
||||
zap.String("domain", domain))
|
||||
|
||||
// Create context with timeout (10 seconds for DNS lookup)
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Look up TXT records for the domain
|
||||
txtRecords, err := v.resolver.LookupTXT(lookupCtx, domain)
|
||||
if err != nil {
|
||||
// Check if it's a timeout
|
||||
if lookupCtx.Err() == context.DeadlineExceeded {
|
||||
v.logger.Warn("DNS lookup timed out",
|
||||
zap.String("domain", domain))
|
||||
return false, fmt.Errorf("DNS lookup timed out after 10 seconds")
|
||||
}
|
||||
|
||||
// Check if domain doesn't exist
|
||||
if dnsErr, ok := err.(*net.DNSError); ok {
|
||||
if dnsErr.IsNotFound {
|
||||
v.logger.Warn("domain not found",
|
||||
zap.String("domain", domain))
|
||||
return false, fmt.Errorf("domain not found: %s", domain)
|
||||
}
|
||||
}
|
||||
|
||||
v.logger.Error("failed to lookup TXT records",
|
||||
zap.String("domain", domain),
|
||||
zap.Error(err))
|
||||
return false, fmt.Errorf("failed to lookup DNS TXT records: %w", err)
|
||||
}
|
||||
|
||||
// Expected verification record format
|
||||
expectedRecord := fmt.Sprintf("maplepress-verify=%s", expectedToken)
|
||||
|
||||
// Check each TXT record
|
||||
for _, record := range txtRecords {
|
||||
v.logger.Debug("checking TXT record",
|
||||
zap.String("domain", domain),
|
||||
zap.String("record", record))
|
||||
|
||||
// Normalize whitespace and compare
|
||||
normalizedRecord := strings.TrimSpace(record)
|
||||
if normalizedRecord == expectedRecord {
|
||||
v.logger.Info("domain ownership verified",
|
||||
zap.String("domain", domain))
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
v.logger.Warn("verification record not found",
|
||||
zap.String("domain", domain),
|
||||
zap.String("expected", expectedRecord),
|
||||
zap.Int("records_checked", len(txtRecords)))
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetVerificationRecord returns the TXT record format for a given token
|
||||
func GetVerificationRecord(token string) string {
|
||||
return fmt.Sprintf("maplepress-verify=%s", token)
|
||||
}
|
||||
|
||||
// GetVerificationInstructions returns user-friendly instructions
|
||||
func GetVerificationInstructions(domain string, token string) string {
|
||||
record := GetVerificationRecord(token)
|
||||
return fmt.Sprintf(`To verify ownership of %s, add this DNS TXT record:
|
||||
|
||||
Host/Name: %s
|
||||
Type: TXT
|
||||
Value: %s
|
||||
|
||||
Instructions:
|
||||
1. Log in to your domain registrar (GoDaddy, Namecheap, Cloudflare, etc.)
|
||||
2. Find DNS settings or DNS management
|
||||
3. Add a new TXT record with the values above
|
||||
4. Wait 5-10 minutes for DNS propagation
|
||||
5. Click "Verify Domain" in MaplePress
|
||||
|
||||
Note: DNS changes can take up to 48 hours to propagate globally, but usually complete within 10 minutes.`,
|
||||
domain, domain, record)
|
||||
}
|
||||
61
cloud/maplepress-backend/pkg/emailer/mailgun/config.go
Normal file
61
cloud/maplepress-backend/pkg/emailer/mailgun/config.go
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
package mailgun
|
||||
|
||||
type MailgunConfigurationProvider interface {
|
||||
GetSenderEmail() string
|
||||
GetDomainName() string // Deprecated
|
||||
GetBackendDomainName() string
|
||||
GetFrontendDomainName() string
|
||||
GetMaintenanceEmail() string
|
||||
GetAPIKey() string
|
||||
GetAPIBase() string
|
||||
}
|
||||
|
||||
type mailgunConfigurationProviderImpl struct {
|
||||
senderEmail string
|
||||
domain string
|
||||
apiBase string
|
||||
maintenanceEmail string
|
||||
frontendDomain string
|
||||
backendDomain string
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewMailgunConfigurationProvider(senderEmail, domain, apiBase, maintenanceEmail, frontendDomain, backendDomain, apiKey string) MailgunConfigurationProvider {
|
||||
return &mailgunConfigurationProviderImpl{
|
||||
senderEmail: senderEmail,
|
||||
domain: domain,
|
||||
apiBase: apiBase,
|
||||
maintenanceEmail: maintenanceEmail,
|
||||
frontendDomain: frontendDomain,
|
||||
backendDomain: backendDomain,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetDomainName() string {
|
||||
return me.domain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetSenderEmail() string {
|
||||
return me.senderEmail
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetBackendDomainName() string {
|
||||
return me.backendDomain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetFrontendDomainName() string {
|
||||
return me.frontendDomain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetMaintenanceEmail() string {
|
||||
return me.maintenanceEmail
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetAPIKey() string {
|
||||
return me.apiKey
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetAPIBase() string {
|
||||
return me.apiBase
|
||||
}
|
||||
12
cloud/maplepress-backend/pkg/emailer/mailgun/interface.go
Normal file
12
cloud/maplepress-backend/pkg/emailer/mailgun/interface.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package mailgun
|
||||
|
||||
import "context"
|
||||
|
||||
type Emailer interface {
|
||||
Send(ctx context.Context, sender, subject, recipient, htmlContent string) error
|
||||
GetSenderEmail() string
|
||||
GetDomainName() string // Deprecated
|
||||
GetBackendDomainName() string
|
||||
GetFrontendDomainName() string
|
||||
GetMaintenanceEmail() string
|
||||
}
|
||||
86
cloud/maplepress-backend/pkg/emailer/mailgun/mailgun.go
Normal file
86
cloud/maplepress-backend/pkg/emailer/mailgun/mailgun.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package mailgun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/mailgun-go/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type mailgunEmailer struct {
|
||||
config MailgunConfigurationProvider
|
||||
logger *zap.Logger
|
||||
Mailgun *mailgun.MailgunImpl
|
||||
}
|
||||
|
||||
func NewEmailer(config MailgunConfigurationProvider, logger *zap.Logger) Emailer {
|
||||
logger = logger.Named("mailgun-emailer")
|
||||
|
||||
// Initialize Mailgun client
|
||||
mg := mailgun.NewMailgun(config.GetDomainName(), config.GetAPIKey())
|
||||
mg.SetAPIBase(config.GetAPIBase()) // Override to support our custom email requirements.
|
||||
|
||||
logger.Info("✓ Mailgun emailer initialized",
|
||||
zap.String("domain", config.GetDomainName()),
|
||||
zap.String("api_base", config.GetAPIBase()))
|
||||
|
||||
return &mailgunEmailer{
|
||||
config: config,
|
||||
logger: logger,
|
||||
Mailgun: mg,
|
||||
}
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) Send(ctx context.Context, sender, subject, recipient, body string) error {
|
||||
me.logger.Debug("Sending email",
|
||||
zap.String("sender", sender),
|
||||
zap.String("recipient", recipient),
|
||||
zap.String("subject", subject))
|
||||
|
||||
message := me.Mailgun.NewMessage(sender, subject, "", recipient)
|
||||
message.SetHtml(body)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
// Send the message with a 10 second timeout
|
||||
resp, id, err := me.Mailgun.Send(ctx, message)
|
||||
|
||||
if err != nil {
|
||||
me.logger.Error("Failed to send email",
|
||||
zap.String("sender", sender),
|
||||
zap.String("recipient", recipient),
|
||||
zap.String("subject", subject),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
me.logger.Info("Email sent successfully",
|
||||
zap.String("recipient", recipient),
|
||||
zap.String("subject", subject),
|
||||
zap.String("message_id", id),
|
||||
zap.String("response", resp))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetDomainName() string {
|
||||
return me.config.GetDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetSenderEmail() string {
|
||||
return me.config.GetSenderEmail()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetBackendDomainName() string {
|
||||
return me.config.GetBackendDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetFrontendDomainName() string {
|
||||
return me.config.GetFrontendDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetMaintenanceEmail() string {
|
||||
return me.config.GetMaintenanceEmail()
|
||||
}
|
||||
26
cloud/maplepress-backend/pkg/emailer/mailgun/provider.go
Normal file
26
cloud/maplepress-backend/pkg/emailer/mailgun/provider.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/emailer/mailgun/provider.go
|
||||
package mailgun
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideMailgunConfigurationProvider creates a new Mailgun configuration provider from the application config.
|
||||
func ProvideMailgunConfigurationProvider(cfg *config.Config) MailgunConfigurationProvider {
|
||||
return NewMailgunConfigurationProvider(
|
||||
cfg.Mailgun.SenderEmail,
|
||||
cfg.Mailgun.Domain,
|
||||
cfg.Mailgun.APIBase,
|
||||
cfg.Mailgun.MaintenanceEmail,
|
||||
cfg.Mailgun.FrontendDomain,
|
||||
cfg.Mailgun.BackendDomain,
|
||||
cfg.Mailgun.APIKey,
|
||||
)
|
||||
}
|
||||
|
||||
// ProvideEmailer creates a new Mailgun emailer from the configuration provider.
|
||||
func ProvideEmailer(config MailgunConfigurationProvider, logger *zap.Logger) Emailer {
|
||||
return NewEmailer(config, logger)
|
||||
}
|
||||
187
cloud/maplepress-backend/pkg/httperror/error.go
Normal file
187
cloud/maplepress-backend/pkg/httperror/error.go
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
package httperror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorResponse represents an HTTP error response (legacy format)
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// ProblemDetail represents an RFC 9457 compliant error response
|
||||
// See: https://datatracker.ietf.org/doc/html/rfc9457
|
||||
type ProblemDetail struct {
|
||||
Type string `json:"type"` // URI reference identifying the problem type
|
||||
Title string `json:"title"` // Short, human-readable summary
|
||||
Status int `json:"status"` // HTTP status code
|
||||
Detail string `json:"detail,omitempty"` // Human-readable explanation
|
||||
Instance string `json:"instance,omitempty"` // URI reference identifying the specific occurrence
|
||||
Errors map[string][]string `json:"errors,omitempty"` // Validation errors (extension field)
|
||||
Extra map[string]interface{} `json:"-"` // Additional extension members
|
||||
}
|
||||
|
||||
// WriteError writes an error response with pretty printing (legacy format)
|
||||
func WriteError(w http.ResponseWriter, code int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
response := ErrorResponse{
|
||||
Error: http.StatusText(code),
|
||||
Message: message,
|
||||
Code: code,
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(w)
|
||||
encoder.SetIndent("", " ")
|
||||
encoder.Encode(response)
|
||||
}
|
||||
|
||||
// WriteProblemDetail writes an RFC 9457 compliant error response
|
||||
func WriteProblemDetail(w http.ResponseWriter, problem *ProblemDetail) {
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
w.WriteHeader(problem.Status)
|
||||
|
||||
encoder := json.NewEncoder(w)
|
||||
encoder.SetIndent("", " ")
|
||||
encoder.Encode(problem)
|
||||
}
|
||||
|
||||
// BadRequest writes a 400 error
|
||||
func BadRequest(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusBadRequest, message)
|
||||
}
|
||||
|
||||
// Unauthorized writes a 401 error
|
||||
func Unauthorized(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusUnauthorized, message)
|
||||
}
|
||||
|
||||
// Forbidden writes a 403 error
|
||||
func Forbidden(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusForbidden, message)
|
||||
}
|
||||
|
||||
// NotFound writes a 404 error
|
||||
func NotFound(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusNotFound, message)
|
||||
}
|
||||
|
||||
// Conflict writes a 409 error
|
||||
func Conflict(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusConflict, message)
|
||||
}
|
||||
|
||||
// TooManyRequests writes a 429 error
|
||||
func TooManyRequests(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusTooManyRequests, message)
|
||||
}
|
||||
|
||||
// InternalServerError writes a 500 error
|
||||
func InternalServerError(w http.ResponseWriter, message string) {
|
||||
WriteError(w, http.StatusInternalServerError, message)
|
||||
}
|
||||
|
||||
// ValidationError writes an RFC 9457 validation error response (400)
|
||||
func ValidationError(w http.ResponseWriter, errors map[string][]string, detail string) {
|
||||
if detail == "" {
|
||||
detail = "One or more validation errors occurred"
|
||||
}
|
||||
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank", // Using about:blank as per RFC 9457 when no specific problem type URI is defined
|
||||
Title: "Validation Error",
|
||||
Status: http.StatusBadRequest,
|
||||
Detail: detail,
|
||||
Errors: errors,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemBadRequest writes an RFC 9457 bad request error (400)
|
||||
func ProblemBadRequest(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Bad Request",
|
||||
Status: http.StatusBadRequest,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemUnauthorized writes an RFC 9457 unauthorized error (401)
|
||||
func ProblemUnauthorized(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Unauthorized",
|
||||
Status: http.StatusUnauthorized,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemForbidden writes an RFC 9457 forbidden error (403)
|
||||
func ProblemForbidden(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Forbidden",
|
||||
Status: http.StatusForbidden,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemNotFound writes an RFC 9457 not found error (404)
|
||||
func ProblemNotFound(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Not Found",
|
||||
Status: http.StatusNotFound,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemConflict writes an RFC 9457 conflict error (409)
|
||||
func ProblemConflict(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Conflict",
|
||||
Status: http.StatusConflict,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemTooManyRequests writes an RFC 9457 too many requests error (429)
|
||||
func ProblemTooManyRequests(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Too Many Requests",
|
||||
Status: http.StatusTooManyRequests,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
|
||||
// ProblemInternalServerError writes an RFC 9457 internal server error (500)
|
||||
func ProblemInternalServerError(w http.ResponseWriter, detail string) {
|
||||
problem := &ProblemDetail{
|
||||
Type: "about:blank",
|
||||
Title: "Internal Server Error",
|
||||
Status: http.StatusInternalServerError,
|
||||
Detail: detail,
|
||||
}
|
||||
|
||||
WriteProblemDetail(w, problem)
|
||||
}
|
||||
31
cloud/maplepress-backend/pkg/httpresponse/response.go
Normal file
31
cloud/maplepress-backend/pkg/httpresponse/response.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package httpresponse
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// JSON writes a JSON response with pretty printing (indented)
|
||||
func JSON(w http.ResponseWriter, code int, data interface{}) error {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
encoder := json.NewEncoder(w)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(data)
|
||||
}
|
||||
|
||||
// OK writes a 200 JSON response with pretty printing
|
||||
func OK(w http.ResponseWriter, data interface{}) error {
|
||||
return JSON(w, http.StatusOK, data)
|
||||
}
|
||||
|
||||
// Created writes a 201 JSON response with pretty printing
|
||||
func Created(w http.ResponseWriter, data interface{}) error {
|
||||
return JSON(w, http.StatusCreated, data)
|
||||
}
|
||||
|
||||
// NoContent writes a 204 No Content response
|
||||
func NoContent(w http.ResponseWriter) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
70
cloud/maplepress-backend/pkg/httpvalidation/content_type.go
Normal file
70
cloud/maplepress-backend/pkg/httpvalidation/content_type.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package httpvalidation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidContentType is returned when Content-Type header is not application/json
|
||||
ErrInvalidContentType = errors.New("Content-Type must be application/json")
|
||||
// ErrMissingContentType is returned when Content-Type header is missing
|
||||
ErrMissingContentType = errors.New("Content-Type header is required")
|
||||
)
|
||||
|
||||
// ValidateJSONContentType validates that the request has application/json Content-Type
|
||||
// CWE-436: Validates Content-Type before parsing to prevent interpretation conflicts
|
||||
// Accepts both "application/json" and "application/json; charset=utf-8"
|
||||
func ValidateJSONContentType(r *http.Request) error {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
// Accept empty Content-Type for backward compatibility (some clients don't set it)
|
||||
if contentType == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for exact match or charset variant
|
||||
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrInvalidContentType
|
||||
}
|
||||
|
||||
// RequireJSONContentType validates that the request has application/json Content-Type
|
||||
// CWE-436: Strict validation that requires Content-Type header
|
||||
// Use this for new endpoints where you want to enforce the header
|
||||
func RequireJSONContentType(r *http.Request) error {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
if contentType == "" {
|
||||
return ErrInvalidContentType
|
||||
}
|
||||
|
||||
// Check for exact match or charset variant
|
||||
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrInvalidContentType
|
||||
}
|
||||
|
||||
// ValidateJSONContentTypeStrict validates that the request has application/json Content-Type
|
||||
// CWE-16: Configuration - Enforces strict Content-Type validation
|
||||
// This version REQUIRES the Content-Type header and returns specific error for missing header
|
||||
func ValidateJSONContentTypeStrict(r *http.Request) error {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
// Require Content-Type header (no empty allowed)
|
||||
if contentType == "" {
|
||||
return ErrMissingContentType
|
||||
}
|
||||
|
||||
// Check for exact match or charset variant
|
||||
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrInvalidContentType
|
||||
}
|
||||
136
cloud/maplepress-backend/pkg/leaderelection/interface.go
Normal file
136
cloud/maplepress-backend/pkg/leaderelection/interface.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// Package leaderelection provides distributed leader election for multiple application instances.
|
||||
// It ensures only one instance acts as the leader at any given time, with automatic failover.
|
||||
package leaderelection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LeaderElection provides distributed leader election across multiple application instances.
|
||||
// It uses Redis to coordinate which instance is the current leader, with automatic failover
|
||||
// if the leader crashes or becomes unavailable.
|
||||
type LeaderElection interface {
|
||||
// Start begins participating in leader election.
|
||||
// This method blocks and runs the election loop until ctx is cancelled or an error occurs.
|
||||
// The instance will automatically attempt to become leader and maintain leadership.
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// IsLeader returns true if this instance is currently the leader.
|
||||
// This is a local check and does not require network communication.
|
||||
IsLeader() bool
|
||||
|
||||
// GetLeaderID returns the unique identifier of the current leader instance.
|
||||
// Returns empty string if no leader exists (should be rare).
|
||||
GetLeaderID() (string, error)
|
||||
|
||||
// GetLeaderInfo returns detailed information about the current leader.
|
||||
GetLeaderInfo() (*LeaderInfo, error)
|
||||
|
||||
// OnBecomeLeader registers a callback function that will be executed when
|
||||
// this instance becomes the leader. Multiple callbacks can be registered.
|
||||
OnBecomeLeader(callback func())
|
||||
|
||||
// OnLoseLeadership registers a callback function that will be executed when
|
||||
// this instance loses leadership (either voluntarily or due to failure).
|
||||
// Multiple callbacks can be registered.
|
||||
OnLoseLeadership(callback func())
|
||||
|
||||
// Stop gracefully stops leader election participation.
|
||||
// If this instance is the leader, it releases leadership allowing another instance to take over.
|
||||
// This should be called during application shutdown.
|
||||
Stop() error
|
||||
|
||||
// GetInstanceID returns the unique identifier for this instance.
|
||||
GetInstanceID() string
|
||||
}
|
||||
|
||||
// LeaderInfo contains information about the current leader.
|
||||
type LeaderInfo struct {
|
||||
// InstanceID is the unique identifier of the leader instance
|
||||
InstanceID string `json:"instance_id"`
|
||||
|
||||
// Hostname is the hostname of the leader instance
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// StartedAt is when this instance became the leader
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
|
||||
// LastHeartbeat is the last time the leader renewed its lock
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
}
|
||||
|
||||
// Config contains configuration for leader election.
|
||||
type Config struct {
|
||||
// RedisKeyName is the Redis key used for leader election.
|
||||
// Default: "maplefile:leader:lock"
|
||||
RedisKeyName string
|
||||
|
||||
// RedisInfoKeyName is the Redis key used to store leader information.
|
||||
// Default: "maplefile:leader:info"
|
||||
RedisInfoKeyName string
|
||||
|
||||
// LockTTL is how long the leader lock lasts before expiring.
|
||||
// The leader must renew the lock before this time expires.
|
||||
// Default: 10 seconds
|
||||
// Recommended: 10-30 seconds
|
||||
LockTTL time.Duration
|
||||
|
||||
// HeartbeatInterval is how often the leader renews its lock.
|
||||
// This should be significantly less than LockTTL (e.g., LockTTL / 3).
|
||||
// Default: 3 seconds
|
||||
// Recommended: LockTTL / 3
|
||||
HeartbeatInterval time.Duration
|
||||
|
||||
// RetryInterval is how often followers check for leadership opportunity.
|
||||
// Default: 2 seconds
|
||||
// Recommended: 1-5 seconds
|
||||
RetryInterval time.Duration
|
||||
|
||||
// InstanceID uniquely identifies this application instance.
|
||||
// If empty, will be auto-generated from hostname + random suffix.
|
||||
// Default: auto-generated
|
||||
InstanceID string
|
||||
|
||||
// Hostname is the hostname of this instance.
|
||||
// If empty, will be auto-detected.
|
||||
// Default: os.Hostname()
|
||||
Hostname string
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
RedisKeyName: "maplefile:leader:lock",
|
||||
RedisInfoKeyName: "maplefile:leader:info",
|
||||
LockTTL: 10 * time.Second,
|
||||
HeartbeatInterval: 3 * time.Second,
|
||||
RetryInterval: 2 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid and returns an error if not.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RedisKeyName == "" {
|
||||
c.RedisKeyName = "maplefile:leader:lock"
|
||||
}
|
||||
if c.RedisInfoKeyName == "" {
|
||||
c.RedisInfoKeyName = "maplefile:leader:info"
|
||||
}
|
||||
if c.LockTTL <= 0 {
|
||||
c.LockTTL = 10 * time.Second
|
||||
}
|
||||
if c.HeartbeatInterval <= 0 {
|
||||
c.HeartbeatInterval = 3 * time.Second
|
||||
}
|
||||
if c.RetryInterval <= 0 {
|
||||
c.RetryInterval = 2 * time.Second
|
||||
}
|
||||
|
||||
// HeartbeatInterval should be less than LockTTL
|
||||
if c.HeartbeatInterval >= c.LockTTL {
|
||||
c.HeartbeatInterval = c.LockTTL / 3
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
30
cloud/maplepress-backend/pkg/leaderelection/provider.go
Normal file
30
cloud/maplepress-backend/pkg/leaderelection/provider.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package leaderelection
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideLeaderElection provides a LeaderElection instance for Wire DI.
|
||||
func ProvideLeaderElection(
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
) (LeaderElection, error) {
|
||||
// Create configuration from app config
|
||||
// InstanceID and Hostname are auto-generated by NewRedisLeaderElection
|
||||
leConfig := &Config{
|
||||
RedisKeyName: "maplepress:leader:lock",
|
||||
RedisInfoKeyName: "maplepress:leader:info",
|
||||
LockTTL: cfg.LeaderElection.LockTTL,
|
||||
HeartbeatInterval: cfg.LeaderElection.HeartbeatInterval,
|
||||
RetryInterval: cfg.LeaderElection.RetryInterval,
|
||||
InstanceID: "", // Auto-generated from hostname + random suffix
|
||||
Hostname: "", // Auto-detected from os.Hostname()
|
||||
}
|
||||
|
||||
// redis.Client implements redis.UniversalClient interface
|
||||
return NewRedisLeaderElection(leConfig, redisClient, logger)
|
||||
}
|
||||
355
cloud/maplepress-backend/pkg/leaderelection/redis_leader.go
Normal file
355
cloud/maplepress-backend/pkg/leaderelection/redis_leader.go
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
package leaderelection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// redisLeaderElection implements LeaderElection using Redis.
|
||||
type redisLeaderElection struct {
|
||||
config *Config
|
||||
redis redis.UniversalClient
|
||||
logger *zap.Logger
|
||||
instanceID string
|
||||
hostname string
|
||||
isLeader bool
|
||||
leaderMutex sync.RWMutex
|
||||
becomeLeaderCbs []func()
|
||||
loseLeadershipCbs []func()
|
||||
callbackMutex sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
stoppedChan chan struct{}
|
||||
leaderStartTime time.Time
|
||||
lastHeartbeat time.Time
|
||||
lastHeartbeatMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRedisLeaderElection creates a new Redis-based leader election instance.
|
||||
func NewRedisLeaderElection(
|
||||
config *Config,
|
||||
redisClient redis.UniversalClient,
|
||||
logger *zap.Logger,
|
||||
) (LeaderElection, error) {
|
||||
logger = logger.Named("LeaderElection")
|
||||
|
||||
// Validate configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
// Generate instance ID if not provided
|
||||
instanceID := config.InstanceID
|
||||
if instanceID == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
}
|
||||
// Add random suffix to make it unique
|
||||
instanceID = fmt.Sprintf("%s-%d", hostname, rand.Intn(100000))
|
||||
logger.Info("Generated instance ID", zap.String("instance_id", instanceID))
|
||||
}
|
||||
|
||||
// Get hostname if not provided
|
||||
hostname := config.Hostname
|
||||
if hostname == "" {
|
||||
h, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
} else {
|
||||
hostname = h
|
||||
}
|
||||
}
|
||||
|
||||
return &redisLeaderElection{
|
||||
config: config,
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
instanceID: instanceID,
|
||||
hostname: hostname,
|
||||
isLeader: false,
|
||||
becomeLeaderCbs: make([]func(), 0),
|
||||
loseLeadershipCbs: make([]func(), 0),
|
||||
stopChan: make(chan struct{}),
|
||||
stoppedChan: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins participating in leader election.
|
||||
func (le *redisLeaderElection) Start(ctx context.Context) error {
|
||||
le.logger.Info("Starting leader election",
|
||||
zap.String("instance_id", le.instanceID),
|
||||
zap.String("hostname", le.hostname),
|
||||
zap.Duration("lock_ttl", le.config.LockTTL),
|
||||
zap.Duration("heartbeat_interval", le.config.HeartbeatInterval),
|
||||
)
|
||||
|
||||
defer close(le.stoppedChan)
|
||||
|
||||
// Main election loop
|
||||
ticker := time.NewTicker(le.config.RetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Try to become leader immediately on startup
|
||||
le.tryBecomeLeader(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
le.logger.Info("Context cancelled, stopping leader election")
|
||||
le.releaseLeadership(context.Background())
|
||||
return ctx.Err()
|
||||
|
||||
case <-le.stopChan:
|
||||
le.logger.Info("Stop signal received, stopping leader election")
|
||||
le.releaseLeadership(context.Background())
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
if le.IsLeader() {
|
||||
// If we're the leader, send heartbeat
|
||||
if err := le.sendHeartbeat(ctx); err != nil {
|
||||
le.logger.Error("Failed to send heartbeat, lost leadership",
|
||||
zap.Error(err))
|
||||
le.setLeaderStatus(false)
|
||||
le.executeCallbacks(le.loseLeadershipCbs)
|
||||
}
|
||||
} else {
|
||||
// If we're not the leader, try to become leader
|
||||
le.tryBecomeLeader(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryBecomeLeader attempts to acquire leadership.
|
||||
func (le *redisLeaderElection) tryBecomeLeader(ctx context.Context) {
|
||||
// Try to set the leader key with NX (only if not exists) and EX (expiry)
|
||||
success, err := le.redis.SetNX(ctx, le.config.RedisKeyName, le.instanceID, le.config.LockTTL).Result()
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to attempt leader election",
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if success {
|
||||
// We became the leader!
|
||||
le.logger.Info("🎉 Became the leader!",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
le.leaderStartTime = time.Now()
|
||||
le.setLeaderStatus(true)
|
||||
le.updateLeaderInfo(ctx)
|
||||
le.executeCallbacks(le.becomeLeaderCbs)
|
||||
} else {
|
||||
// Someone else is the leader
|
||||
if !le.IsLeader() {
|
||||
// Only log if we weren't already aware
|
||||
currentLeader, _ := le.GetLeaderID()
|
||||
le.logger.Debug("Another instance is the leader",
|
||||
zap.String("leader_id", currentLeader))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat renews the leader lock.
|
||||
func (le *redisLeaderElection) sendHeartbeat(ctx context.Context) error {
|
||||
// Verify we still hold the lock
|
||||
currentValue, err := le.redis.Get(ctx, le.config.RedisKeyName).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current lock value: %w", err)
|
||||
}
|
||||
|
||||
if currentValue != le.instanceID {
|
||||
return fmt.Errorf("lock held by different instance: %s", currentValue)
|
||||
}
|
||||
|
||||
// Renew the lock
|
||||
err = le.redis.Expire(ctx, le.config.RedisKeyName, le.config.LockTTL).Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to renew lock: %w", err)
|
||||
}
|
||||
|
||||
// Update heartbeat time
|
||||
le.setLastHeartbeat(time.Now())
|
||||
|
||||
// Update leader info
|
||||
le.updateLeaderInfo(ctx)
|
||||
|
||||
le.logger.Debug("Heartbeat sent",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateLeaderInfo updates the leader information in Redis.
|
||||
func (le *redisLeaderElection) updateLeaderInfo(ctx context.Context) {
|
||||
info := &LeaderInfo{
|
||||
InstanceID: le.instanceID,
|
||||
Hostname: le.hostname,
|
||||
StartedAt: le.leaderStartTime,
|
||||
LastHeartbeat: le.getLastHeartbeat(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to marshal leader info", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set with same TTL as lock
|
||||
err = le.redis.Set(ctx, le.config.RedisInfoKeyName, data, le.config.LockTTL).Err()
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to update leader info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// releaseLeadership voluntarily releases leadership.
|
||||
func (le *redisLeaderElection) releaseLeadership(ctx context.Context) {
|
||||
if !le.IsLeader() {
|
||||
return
|
||||
}
|
||||
|
||||
le.logger.Info("Releasing leadership voluntarily",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
// Only delete if we're still the owner
|
||||
script := `
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`
|
||||
|
||||
_, err := le.redis.Eval(ctx, script, []string{le.config.RedisKeyName}, le.instanceID).Result()
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to release leadership", zap.Error(err))
|
||||
}
|
||||
|
||||
// Delete leader info
|
||||
le.redis.Del(ctx, le.config.RedisInfoKeyName)
|
||||
|
||||
le.setLeaderStatus(false)
|
||||
le.executeCallbacks(le.loseLeadershipCbs)
|
||||
}
|
||||
|
||||
// IsLeader returns true if this instance is the leader.
|
||||
func (le *redisLeaderElection) IsLeader() bool {
|
||||
le.leaderMutex.RLock()
|
||||
defer le.leaderMutex.RUnlock()
|
||||
return le.isLeader
|
||||
}
|
||||
|
||||
// GetLeaderID returns the ID of the current leader.
|
||||
func (le *redisLeaderElection) GetLeaderID() (string, error) {
|
||||
ctx := context.Background()
|
||||
leaderID, err := le.redis.Get(ctx, le.config.RedisKeyName).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get leader ID: %w", err)
|
||||
}
|
||||
return leaderID, nil
|
||||
}
|
||||
|
||||
// GetLeaderInfo returns information about the current leader.
|
||||
func (le *redisLeaderElection) GetLeaderInfo() (*LeaderInfo, error) {
|
||||
ctx := context.Background()
|
||||
data, err := le.redis.Get(ctx, le.config.RedisInfoKeyName).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get leader info: %w", err)
|
||||
}
|
||||
|
||||
var info LeaderInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal leader info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// OnBecomeLeader registers a callback for when this instance becomes leader.
|
||||
func (le *redisLeaderElection) OnBecomeLeader(callback func()) {
|
||||
le.callbackMutex.Lock()
|
||||
defer le.callbackMutex.Unlock()
|
||||
le.becomeLeaderCbs = append(le.becomeLeaderCbs, callback)
|
||||
}
|
||||
|
||||
// OnLoseLeadership registers a callback for when this instance loses leadership.
|
||||
func (le *redisLeaderElection) OnLoseLeadership(callback func()) {
|
||||
le.callbackMutex.Lock()
|
||||
defer le.callbackMutex.Unlock()
|
||||
le.loseLeadershipCbs = append(le.loseLeadershipCbs, callback)
|
||||
}
|
||||
|
||||
// Stop gracefully stops leader election.
|
||||
func (le *redisLeaderElection) Stop() error {
|
||||
le.logger.Info("Stopping leader election")
|
||||
close(le.stopChan)
|
||||
|
||||
// Wait for the election loop to finish (with timeout)
|
||||
select {
|
||||
case <-le.stoppedChan:
|
||||
le.logger.Info("Leader election stopped successfully")
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
le.logger.Warn("Timeout waiting for leader election to stop")
|
||||
return fmt.Errorf("timeout waiting for leader election to stop")
|
||||
}
|
||||
}
|
||||
|
||||
// GetInstanceID returns this instance's unique identifier.
|
||||
func (le *redisLeaderElection) GetInstanceID() string {
|
||||
return le.instanceID
|
||||
}
|
||||
|
||||
// setLeaderStatus updates the leader status (thread-safe).
|
||||
func (le *redisLeaderElection) setLeaderStatus(isLeader bool) {
|
||||
le.leaderMutex.Lock()
|
||||
defer le.leaderMutex.Unlock()
|
||||
le.isLeader = isLeader
|
||||
}
|
||||
|
||||
// setLastHeartbeat updates the last heartbeat time (thread-safe).
|
||||
func (le *redisLeaderElection) setLastHeartbeat(t time.Time) {
|
||||
le.lastHeartbeatMutex.Lock()
|
||||
defer le.lastHeartbeatMutex.Unlock()
|
||||
le.lastHeartbeat = t
|
||||
}
|
||||
|
||||
// getLastHeartbeat gets the last heartbeat time (thread-safe).
|
||||
func (le *redisLeaderElection) getLastHeartbeat() time.Time {
|
||||
le.lastHeartbeatMutex.RLock()
|
||||
defer le.lastHeartbeatMutex.RUnlock()
|
||||
return le.lastHeartbeat
|
||||
}
|
||||
|
||||
// executeCallbacks executes a list of callbacks in separate goroutines.
|
||||
func (le *redisLeaderElection) executeCallbacks(callbacks []func()) {
|
||||
le.callbackMutex.RLock()
|
||||
defer le.callbackMutex.RUnlock()
|
||||
|
||||
for _, callback := range callbacks {
|
||||
go func(cb func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
le.logger.Error("Panic in leader election callback",
|
||||
zap.Any("panic", r))
|
||||
}
|
||||
}()
|
||||
cb()
|
||||
}(callback)
|
||||
}
|
||||
}
|
||||
120
cloud/maplepress-backend/pkg/logger/logger.go
Normal file
120
cloud/maplepress-backend/pkg/logger/logger.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// emojiCore wraps a zapcore.Core to add emoji icon field
|
||||
type emojiCore struct {
|
||||
zapcore.Core
|
||||
}
|
||||
|
||||
func (c *emojiCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
return &emojiCore{c.Core.With(fields)}
|
||||
}
|
||||
|
||||
func (c *emojiCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
if c.Enabled(entry.Level) {
|
||||
return ce.AddCore(entry, c)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
func (c *emojiCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only add emoji icon field for warnings and errors
|
||||
// Skip for info and debug to keep output clean
|
||||
var emoji string
|
||||
var addEmoji bool
|
||||
|
||||
switch entry.Level {
|
||||
case zapcore.WarnLevel:
|
||||
emoji = "🟡" // Yellow circle for warnings
|
||||
addEmoji = true
|
||||
case zapcore.ErrorLevel:
|
||||
emoji = "🔴" // Red circle for errors
|
||||
addEmoji = true
|
||||
case zapcore.DPanicLevel:
|
||||
emoji = "🔴" // Red circle for panic
|
||||
addEmoji = true
|
||||
case zapcore.PanicLevel:
|
||||
emoji = "🔴" // Red circle for panic
|
||||
addEmoji = true
|
||||
case zapcore.FatalLevel:
|
||||
emoji = "🔴" // Red circle for fatal
|
||||
addEmoji = true
|
||||
default:
|
||||
// No emoji for debug and info levels
|
||||
addEmoji = false
|
||||
}
|
||||
|
||||
// Only prepend emoji field if we're adding one
|
||||
if addEmoji {
|
||||
fieldsWithEmoji := make([]zapcore.Field, 0, len(fields)+1)
|
||||
fieldsWithEmoji = append(fieldsWithEmoji, zap.String("ico", emoji))
|
||||
fieldsWithEmoji = append(fieldsWithEmoji, fields...)
|
||||
return c.Core.Write(entry, fieldsWithEmoji)
|
||||
}
|
||||
|
||||
// For debug/info, write as-is without emoji
|
||||
return c.Core.Write(entry, fields)
|
||||
}
|
||||
|
||||
// ProvideLogger creates a new zap logger based on configuration
|
||||
func ProvideLogger(cfg *config.Config) (*zap.Logger, error) {
|
||||
var zapConfig zap.Config
|
||||
|
||||
// Set config based on environment
|
||||
if cfg.App.Environment == "production" {
|
||||
zapConfig = zap.NewProductionConfig()
|
||||
} else {
|
||||
zapConfig = zap.NewDevelopmentConfig()
|
||||
}
|
||||
|
||||
// Set log level
|
||||
level, err := zapcore.ParseLevel(cfg.Logger.Level)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid log level %s: %w", cfg.Logger.Level, err)
|
||||
}
|
||||
zapConfig.Level = zap.NewAtomicLevelAt(level)
|
||||
|
||||
// Set encoding format
|
||||
if cfg.Logger.Format == "console" {
|
||||
zapConfig.Encoding = "console"
|
||||
zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
} else {
|
||||
zapConfig.Encoding = "json"
|
||||
}
|
||||
|
||||
// Build logger with environment-specific options
|
||||
var loggerOptions []zap.Option
|
||||
|
||||
// Enable caller information in development for easier debugging
|
||||
if cfg.App.Environment != "production" {
|
||||
loggerOptions = append(loggerOptions, zap.AddCaller())
|
||||
loggerOptions = append(loggerOptions, zap.AddCallerSkip(0))
|
||||
}
|
||||
|
||||
// Add stack traces for error level and above
|
||||
loggerOptions = append(loggerOptions, zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
// Wrap core with emoji core to add icon field
|
||||
loggerOptions = append(loggerOptions, zap.WrapCore(func(core zapcore.Core) zapcore.Core {
|
||||
return &emojiCore{core}
|
||||
}))
|
||||
|
||||
logger, err := zapConfig.Build(loggerOptions...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build logger: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✓ Logger initialized",
|
||||
zap.String("level", cfg.Logger.Level),
|
||||
zap.String("format", cfg.Logger.Format),
|
||||
zap.String("environment", cfg.App.Environment))
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
231
cloud/maplepress-backend/pkg/logger/sanitizer.go
Normal file
231
cloud/maplepress-backend/pkg/logger/sanitizer.go
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// SensitiveFieldRedactor provides methods to redact sensitive data before logging
|
||||
// This addresses CWE-532 (Insertion of Sensitive Information into Log File)
|
||||
type SensitiveFieldRedactor struct {
|
||||
emailRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
// NewSensitiveFieldRedactor creates a new redactor for sensitive data
|
||||
func NewSensitiveFieldRedactor() *SensitiveFieldRedactor {
|
||||
return &SensitiveFieldRedactor{
|
||||
emailRegex: regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`),
|
||||
}
|
||||
}
|
||||
|
||||
// RedactEmail redacts an email address for logging
|
||||
// Example: "john.doe@example.com" -> "jo***@example.com"
|
||||
func (r *SensitiveFieldRedactor) RedactEmail(email string) string {
|
||||
if email == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
// Validate email format
|
||||
if !r.emailRegex.MatchString(email) {
|
||||
return "[invalid-email]"
|
||||
}
|
||||
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return "[invalid-email]"
|
||||
}
|
||||
|
||||
localPart := parts[0]
|
||||
domain := parts[1]
|
||||
|
||||
// Show first 2 characters of local part, redact the rest
|
||||
if len(localPart) <= 2 {
|
||||
return "**@" + domain
|
||||
}
|
||||
|
||||
return localPart[:2] + "***@" + domain
|
||||
}
|
||||
|
||||
// HashForLogging creates a consistent hash for unique identification without exposing the original value
|
||||
// This allows correlation across log entries without storing PII
|
||||
// Example: "john.doe@example.com" -> "a1b2c3d4"
|
||||
func (r *SensitiveFieldRedactor) HashForLogging(value string) string {
|
||||
if value == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
h := sha256.Sum256([]byte(value))
|
||||
// Return first 8 bytes (16 hex characters) for reasonable uniqueness
|
||||
return hex.EncodeToString(h[:8])
|
||||
}
|
||||
|
||||
// RedactTenantSlug redacts a tenant slug for logging
|
||||
// Example: "my-company" -> "my-***"
|
||||
func (r *SensitiveFieldRedactor) RedactTenantSlug(slug string) string {
|
||||
if slug == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
if len(slug) <= 3 {
|
||||
return "***"
|
||||
}
|
||||
|
||||
return slug[:2] + "***"
|
||||
}
|
||||
|
||||
// RedactAPIKey redacts an API key for logging
|
||||
// Shows only prefix and last 4 characters
|
||||
// Example: "live_sk_abc123def456ghi789" -> "live_sk_***i789"
|
||||
func (r *SensitiveFieldRedactor) RedactAPIKey(apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
// Show prefix (live_sk_ or test_sk_) and last 4 characters
|
||||
if strings.HasPrefix(apiKey, "live_sk_") || strings.HasPrefix(apiKey, "test_sk_") {
|
||||
prefix := apiKey[:8] // "live_sk_" or "test_sk_"
|
||||
if len(apiKey) > 12 {
|
||||
return prefix + "***" + apiKey[len(apiKey)-4:]
|
||||
}
|
||||
return prefix + "***"
|
||||
}
|
||||
|
||||
// For other formats, just show last 4 characters
|
||||
if len(apiKey) > 4 {
|
||||
return "***" + apiKey[len(apiKey)-4:]
|
||||
}
|
||||
|
||||
return "***"
|
||||
}
|
||||
|
||||
// RedactJWTToken redacts a JWT token for logging
|
||||
// Shows only first and last 8 characters
|
||||
func (r *SensitiveFieldRedactor) RedactJWTToken(token string) string {
|
||||
if token == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
if len(token) < 16 {
|
||||
return "***"
|
||||
}
|
||||
|
||||
return token[:8] + "..." + token[len(token)-8:]
|
||||
}
|
||||
|
||||
// RedactIPAddress partially redacts an IP address
|
||||
// IPv4: "192.168.1.100" -> "192.168.*.*"
|
||||
// IPv6: Redacts last 4 groups
|
||||
func (r *SensitiveFieldRedactor) RedactIPAddress(ip string) string {
|
||||
if ip == "" {
|
||||
return "[empty]"
|
||||
}
|
||||
|
||||
// IPv4
|
||||
if strings.Contains(ip, ".") {
|
||||
parts := strings.Split(ip, ".")
|
||||
if len(parts) == 4 {
|
||||
return parts[0] + "." + parts[1] + ".*.*"
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6
|
||||
if strings.Contains(ip, ":") {
|
||||
parts := strings.Split(ip, ":")
|
||||
if len(parts) >= 4 {
|
||||
return strings.Join(parts[:4], ":") + ":****"
|
||||
}
|
||||
}
|
||||
|
||||
return "***"
|
||||
}
|
||||
|
||||
// Zap Field Helpers - Provide convenient zap.Field constructors
|
||||
|
||||
// SafeEmail creates a zap field with redacted email
|
||||
func SafeEmail(key string, email string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String(key, redactor.RedactEmail(email))
|
||||
}
|
||||
|
||||
// EmailHash creates a zap field with hashed email for correlation
|
||||
func EmailHash(email string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String("email_hash", redactor.HashForLogging(email))
|
||||
}
|
||||
|
||||
// HashString hashes a string value for safe logging
|
||||
// Returns the hash string directly (not a zap.Field)
|
||||
func HashString(value string) string {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return redactor.HashForLogging(value)
|
||||
}
|
||||
|
||||
// SafeTenantSlug creates a zap field with redacted tenant slug
|
||||
func SafeTenantSlug(key string, slug string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String(key, redactor.RedactTenantSlug(slug))
|
||||
}
|
||||
|
||||
// TenantSlugHash creates a zap field with hashed tenant slug for correlation
|
||||
func TenantSlugHash(slug string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String("tenant_slug_hash", redactor.HashForLogging(slug))
|
||||
}
|
||||
|
||||
// SafeAPIKey creates a zap field with redacted API key
|
||||
func SafeAPIKey(key string, apiKey string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String(key, redactor.RedactAPIKey(apiKey))
|
||||
}
|
||||
|
||||
// SafeJWTToken creates a zap field with redacted JWT token
|
||||
func SafeJWTToken(key string, token string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String(key, redactor.RedactJWTToken(token))
|
||||
}
|
||||
|
||||
// SafeIPAddress creates a zap field with redacted IP address
|
||||
func SafeIPAddress(key string, ip string) zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return zap.String(key, redactor.RedactIPAddress(ip))
|
||||
}
|
||||
|
||||
// UserIdentifier creates safe identification fields for a user
|
||||
// Includes: user_id (safe), email_hash, email_redacted
|
||||
func UserIdentifier(userID string, email string) []zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return []zapcore.Field{
|
||||
zap.String("user_id", userID),
|
||||
zap.String("email_hash", redactor.HashForLogging(email)),
|
||||
zap.String("email_redacted", redactor.RedactEmail(email)),
|
||||
}
|
||||
}
|
||||
|
||||
// TenantIdentifier creates safe identification fields for a tenant
|
||||
// Includes: tenant_id (safe), slug_hash, slug_redacted
|
||||
func TenantIdentifier(tenantID string, slug string) []zapcore.Field {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
return []zapcore.Field{
|
||||
zap.String("tenant_id", tenantID),
|
||||
zap.String("tenant_slug_hash", redactor.HashForLogging(slug)),
|
||||
zap.String("tenant_slug_redacted", redactor.RedactTenantSlug(slug)),
|
||||
}
|
||||
}
|
||||
|
||||
// Constants for field names
|
||||
const (
|
||||
FieldUserID = "user_id"
|
||||
FieldEmailHash = "email_hash"
|
||||
FieldEmailRedacted = "email_redacted"
|
||||
FieldTenantID = "tenant_id"
|
||||
FieldTenantSlugHash = "tenant_slug_hash"
|
||||
FieldTenantSlugRedacted = "tenant_slug_redacted"
|
||||
FieldAPIKeyRedacted = "api_key_redacted"
|
||||
FieldJWTTokenRedacted = "jwt_token_redacted"
|
||||
FieldIPAddressRedacted = "ip_address_redacted"
|
||||
)
|
||||
345
cloud/maplepress-backend/pkg/logger/sanitizer_test.go
Normal file
345
cloud/maplepress-backend/pkg/logger/sanitizer_test.go
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedactEmail(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal email",
|
||||
input: "john.doe@example.com",
|
||||
expected: "jo***@example.com",
|
||||
},
|
||||
{
|
||||
name: "short local part",
|
||||
input: "ab@example.com",
|
||||
expected: "**@example.com",
|
||||
},
|
||||
{
|
||||
name: "single character local part",
|
||||
input: "a@example.com",
|
||||
expected: "**@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty email",
|
||||
input: "",
|
||||
expected: "[empty]",
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
input: "notanemail",
|
||||
expected: "[invalid-email]",
|
||||
},
|
||||
{
|
||||
name: "long email",
|
||||
input: "very.long.email.address@subdomain.example.com",
|
||||
expected: "ve***@subdomain.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := redactor.RedactEmail(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactEmail(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashForLogging(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "email",
|
||||
input: "john.doe@example.com",
|
||||
},
|
||||
{
|
||||
name: "tenant slug",
|
||||
input: "my-company",
|
||||
},
|
||||
{
|
||||
name: "another email",
|
||||
input: "jane.smith@test.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash1 := redactor.HashForLogging(tt.input)
|
||||
hash2 := redactor.HashForLogging(tt.input)
|
||||
|
||||
// Hash should be consistent
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("HashForLogging is not consistent: %q != %q", hash1, hash2)
|
||||
}
|
||||
|
||||
// Hash should be 16 characters (8 bytes in hex)
|
||||
if len(hash1) != 16 {
|
||||
t.Errorf("HashForLogging length = %d, want 16", len(hash1))
|
||||
}
|
||||
|
||||
// Hash should not contain original value
|
||||
if hash1 == tt.input {
|
||||
t.Errorf("HashForLogging returned original value")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Different inputs should produce different hashes
|
||||
hash1 := redactor.HashForLogging("john.doe@example.com")
|
||||
hash2 := redactor.HashForLogging("jane.smith@example.com")
|
||||
if hash1 == hash2 {
|
||||
t.Error("Different inputs produced same hash")
|
||||
}
|
||||
|
||||
// Empty string
|
||||
emptyHash := redactor.HashForLogging("")
|
||||
if emptyHash != "[empty]" {
|
||||
t.Errorf("HashForLogging(\"\") = %q, want [empty]", emptyHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactTenantSlug(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal slug",
|
||||
input: "my-company",
|
||||
expected: "my***",
|
||||
},
|
||||
{
|
||||
name: "short slug",
|
||||
input: "abc",
|
||||
expected: "***",
|
||||
},
|
||||
{
|
||||
name: "very short slug",
|
||||
input: "ab",
|
||||
expected: "***",
|
||||
},
|
||||
{
|
||||
name: "empty slug",
|
||||
input: "",
|
||||
expected: "[empty]",
|
||||
},
|
||||
{
|
||||
name: "long slug",
|
||||
input: "very-long-company-name",
|
||||
expected: "ve***",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := redactor.RedactTenantSlug(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactTenantSlug(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactAPIKey(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "live API key",
|
||||
input: "live_sk_abc123def456ghi789",
|
||||
expected: "live_sk_***i789",
|
||||
},
|
||||
{
|
||||
name: "test API key",
|
||||
input: "test_sk_xyz789uvw456rst123",
|
||||
expected: "test_sk_***t123",
|
||||
},
|
||||
{
|
||||
name: "short live key",
|
||||
input: "live_sk_abc",
|
||||
expected: "live_sk_***",
|
||||
},
|
||||
{
|
||||
name: "other format",
|
||||
input: "sk_abc123def456",
|
||||
expected: "***f456",
|
||||
},
|
||||
{
|
||||
name: "very short key",
|
||||
input: "abc",
|
||||
expected: "***",
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
input: "",
|
||||
expected: "[empty]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := redactor.RedactAPIKey(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactAPIKey(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactJWTToken(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal JWT",
|
||||
input: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U",
|
||||
expected: "eyJhbGci...P0THsR8U",
|
||||
},
|
||||
{
|
||||
name: "short token",
|
||||
input: "short",
|
||||
expected: "***",
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
input: "",
|
||||
expected: "[empty]",
|
||||
},
|
||||
{
|
||||
name: "minimum length token",
|
||||
input: "1234567890123456",
|
||||
expected: "12345678...90123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := redactor.RedactJWTToken(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactJWTToken(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactIPAddress(t *testing.T) {
|
||||
redactor := NewSensitiveFieldRedactor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
input: "192.168.1.100",
|
||||
expected: "192.168.*.*",
|
||||
},
|
||||
{
|
||||
name: "IPv4 public",
|
||||
input: "8.8.8.8",
|
||||
expected: "8.8.*.*",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
input: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
|
||||
expected: "2001:0db8:85a3:0000:****",
|
||||
},
|
||||
{
|
||||
name: "IPv6 shortened",
|
||||
input: "2001:db8::1",
|
||||
expected: "2001:db8::1:****",
|
||||
},
|
||||
{
|
||||
name: "empty IP",
|
||||
input: "",
|
||||
expected: "[empty]",
|
||||
},
|
||||
{
|
||||
name: "invalid IP",
|
||||
input: "notanip",
|
||||
expected: "***",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := redactor.RedactIPAddress(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactIPAddress(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIdentifier(t *testing.T) {
|
||||
userID := "user_123"
|
||||
email := "john.doe@example.com"
|
||||
|
||||
fields := UserIdentifier(userID, email)
|
||||
|
||||
if len(fields) != 3 {
|
||||
t.Errorf("UserIdentifier returned %d fields, want 3", len(fields))
|
||||
}
|
||||
|
||||
// Check that fields contain expected keys
|
||||
fieldKeys := make(map[string]bool)
|
||||
for _, field := range fields {
|
||||
fieldKeys[field.Key] = true
|
||||
}
|
||||
|
||||
expectedKeys := []string{"user_id", "email_hash", "email_redacted"}
|
||||
for _, key := range expectedKeys {
|
||||
if !fieldKeys[key] {
|
||||
t.Errorf("UserIdentifier missing key: %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTenantIdentifier(t *testing.T) {
|
||||
tenantID := "tenant_123"
|
||||
slug := "my-company"
|
||||
|
||||
fields := TenantIdentifier(tenantID, slug)
|
||||
|
||||
if len(fields) != 3 {
|
||||
t.Errorf("TenantIdentifier returned %d fields, want 3", len(fields))
|
||||
}
|
||||
|
||||
// Check that fields contain expected keys
|
||||
fieldKeys := make(map[string]bool)
|
||||
for _, field := range fields {
|
||||
fieldKeys[field.Key] = true
|
||||
}
|
||||
|
||||
expectedKeys := []string{"tenant_id", "tenant_slug_hash", "tenant_slug_redacted"}
|
||||
for _, key := range expectedKeys {
|
||||
if !fieldKeys[key] {
|
||||
t.Errorf("TenantIdentifier missing key: %s", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
327
cloud/maplepress-backend/pkg/ratelimit/login_ratelimiter.go
Normal file
327
cloud/maplepress-backend/pkg/ratelimit/login_ratelimiter.go
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LoginRateLimiter provides specialized rate limiting for login attempts
|
||||
// with account lockout functionality
|
||||
type LoginRateLimiter interface {
|
||||
// CheckAndRecordAttempt checks if login attempt is allowed and records it
|
||||
// Returns: allowed (bool), isLocked (bool), remainingAttempts (int), error
|
||||
CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error)
|
||||
|
||||
// RecordFailedAttempt records a failed login attempt
|
||||
RecordFailedAttempt(ctx context.Context, email string, clientIP string) error
|
||||
|
||||
// RecordSuccessfulLogin records a successful login and resets counters
|
||||
RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error
|
||||
|
||||
// IsAccountLocked checks if an account is locked due to too many failed attempts
|
||||
IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error)
|
||||
|
||||
// UnlockAccount manually unlocks an account (admin function)
|
||||
UnlockAccount(ctx context.Context, email string) error
|
||||
|
||||
// GetFailedAttempts returns the number of failed attempts for an email
|
||||
GetFailedAttempts(ctx context.Context, email string) (int, error)
|
||||
}
|
||||
|
||||
// LoginRateLimiterConfig holds configuration for login rate limiting
|
||||
type LoginRateLimiterConfig struct {
|
||||
// MaxAttemptsPerIP is the maximum login attempts per IP in the window
|
||||
MaxAttemptsPerIP int
|
||||
// IPWindow is the time window for IP-based rate limiting
|
||||
IPWindow time.Duration
|
||||
|
||||
// MaxFailedAttemptsPerAccount is the maximum failed attempts before account lockout
|
||||
MaxFailedAttemptsPerAccount int
|
||||
// AccountLockoutDuration is how long to lock an account after too many failures
|
||||
AccountLockoutDuration time.Duration
|
||||
|
||||
// KeyPrefix is the prefix for Redis keys
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
// DefaultLoginRateLimiterConfig returns recommended configuration
|
||||
func DefaultLoginRateLimiterConfig() LoginRateLimiterConfig {
|
||||
return LoginRateLimiterConfig{
|
||||
MaxAttemptsPerIP: 10, // 10 attempts per IP
|
||||
IPWindow: 15 * time.Minute, // in 15 minutes
|
||||
MaxFailedAttemptsPerAccount: 10, // 10 failed attempts per account
|
||||
AccountLockoutDuration: 30 * time.Minute, // lock for 30 minutes
|
||||
KeyPrefix: "login_rl",
|
||||
}
|
||||
}
|
||||
|
||||
type loginRateLimiter struct {
|
||||
client *redis.Client
|
||||
config LoginRateLimiterConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewLoginRateLimiter creates a new login rate limiter
|
||||
func NewLoginRateLimiter(client *redis.Client, config LoginRateLimiterConfig, logger *zap.Logger) LoginRateLimiter {
|
||||
return &loginRateLimiter{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger.Named("login-rate-limiter"),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAndRecordAttempt checks if login attempt is allowed
|
||||
// CWE-307: Implements protection against brute force attacks
|
||||
func (r *loginRateLimiter) CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error) {
|
||||
// Check account lockout first
|
||||
locked, remaining, err := r.IsAccountLocked(ctx, email)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check account lockout",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, false, 0, err
|
||||
}
|
||||
|
||||
if locked {
|
||||
r.logger.Warn("login attempt on locked account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Duration("remaining_lockout", remaining))
|
||||
return false, true, 0, nil
|
||||
}
|
||||
|
||||
// Check IP-based rate limit
|
||||
ipKey := r.getIPKey(clientIP)
|
||||
allowed, err := r.checkIPRateLimit(ctx, ipKey)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check IP rate limit",
|
||||
zap.String("ip", clientIP),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, false, 0, err
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
r.logger.Warn("IP rate limit exceeded",
|
||||
zap.String("ip", clientIP))
|
||||
return false, false, 0, nil
|
||||
}
|
||||
|
||||
// Record the attempt for IP
|
||||
if err := r.recordIPAttempt(ctx, ipKey); err != nil {
|
||||
r.logger.Error("failed to record IP attempt",
|
||||
zap.String("ip", clientIP),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// Get remaining attempts for account
|
||||
failedAttempts, err := r.GetFailedAttempts(ctx, email)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to get failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
remainingAttempts := r.config.MaxFailedAttemptsPerAccount - failedAttempts
|
||||
if remainingAttempts < 0 {
|
||||
remainingAttempts = 0
|
||||
}
|
||||
|
||||
r.logger.Debug("login attempt check passed",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Int("remaining_attempts", remainingAttempts))
|
||||
|
||||
return true, false, remainingAttempts, nil
|
||||
}
|
||||
|
||||
// RecordFailedAttempt records a failed login attempt
|
||||
// CWE-307: Tracks failed attempts to enable account lockout
|
||||
func (r *loginRateLimiter) RecordFailedAttempt(ctx context.Context, email string, clientIP string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
|
||||
// Increment failed attempt counter
|
||||
count, err := r.client.Incr(ctx, accountKey).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to increment failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Set expiration on first failed attempt
|
||||
if count == 1 {
|
||||
r.client.Expire(ctx, accountKey, r.config.AccountLockoutDuration)
|
||||
}
|
||||
|
||||
// Check if account should be locked
|
||||
if count >= int64(r.config.MaxFailedAttemptsPerAccount) {
|
||||
lockKey := r.getLockKey(email)
|
||||
err := r.client.Set(ctx, lockKey, "locked", r.config.AccountLockoutDuration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to lock account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Warn("account locked due to too many failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Int64("failed_attempts", count),
|
||||
zap.Duration("lockout_duration", r.config.AccountLockoutDuration))
|
||||
}
|
||||
|
||||
r.logger.Info("failed login attempt recorded",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Int64("total_failed_attempts", count))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordSuccessfulLogin records a successful login and resets counters
|
||||
func (r *loginRateLimiter) RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
// Delete failed attempt counter
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, accountKey)
|
||||
pipe.Del(ctx, lockKey)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("failed to reset login counters",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("successful login recorded, counters reset",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", clientIP))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAccountLocked checks if an account is locked
|
||||
func (r *loginRateLimiter) IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error) {
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
ttl, err := r.client.TTL(ctx, lockKey).Result()
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// TTL returns -2 if key doesn't exist, -1 if no expiration
|
||||
if ttl < 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
return true, ttl, nil
|
||||
}
|
||||
|
||||
// UnlockAccount manually unlocks an account
|
||||
func (r *loginRateLimiter) UnlockAccount(ctx context.Context, email string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, accountKey)
|
||||
pipe.Del(ctx, lockKey)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("failed to unlock account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("account unlocked",
|
||||
zap.String("email_hash", hashEmail(email)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFailedAttempts returns the number of failed attempts
|
||||
func (r *loginRateLimiter) GetFailedAttempts(ctx context.Context, email string) (int, error) {
|
||||
accountKey := r.getAccountKey(email)
|
||||
|
||||
count, err := r.client.Get(ctx, accountKey).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// checkIPRateLimit checks if IP has exceeded rate limit
|
||||
func (r *loginRateLimiter) checkIPRateLimit(ctx context.Context, ipKey string) (bool, error) {
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.IPWindow)
|
||||
|
||||
// Remove old entries
|
||||
r.client.ZRemRangeByScore(ctx, ipKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
||||
|
||||
// Count current attempts
|
||||
count, err := r.client.ZCount(ctx, ipKey,
|
||||
fmt.Sprintf("%d", windowStart.UnixNano()),
|
||||
"+inf").Result()
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count < int64(r.config.MaxAttemptsPerIP), nil
|
||||
}
|
||||
|
||||
// recordIPAttempt records an IP attempt
|
||||
func (r *loginRateLimiter) recordIPAttempt(ctx context.Context, ipKey string) error {
|
||||
now := time.Now()
|
||||
timestamp := now.UnixNano()
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.ZAdd(ctx, ipKey, redis.Z{
|
||||
Score: float64(timestamp),
|
||||
Member: fmt.Sprintf("%d", timestamp),
|
||||
})
|
||||
pipe.Expire(ctx, ipKey, r.config.IPWindow+time.Minute)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Key generation helpers
|
||||
func (r *loginRateLimiter) getIPKey(ip string) string {
|
||||
return fmt.Sprintf("%s:ip:%s", r.config.KeyPrefix, ip)
|
||||
}
|
||||
|
||||
func (r *loginRateLimiter) getAccountKey(email string) string {
|
||||
return fmt.Sprintf("%s:account:%s:attempts", r.config.KeyPrefix, hashEmail(email))
|
||||
}
|
||||
|
||||
func (r *loginRateLimiter) getLockKey(email string) string {
|
||||
return fmt.Sprintf("%s:account:%s:locked", r.config.KeyPrefix, hashEmail(email))
|
||||
}
|
||||
|
||||
// hashEmail creates a consistent hash of an email for use as a key
|
||||
// CWE-532: Prevents PII in Redis keys
|
||||
func hashEmail(email string) string {
|
||||
// Use a simple hash for key generation (not for security)
|
||||
// In production, consider using SHA-256
|
||||
hash := 0
|
||||
for _, c := range email {
|
||||
hash = (hash * 31) + int(c)
|
||||
}
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
45
cloud/maplepress-backend/pkg/ratelimit/provider.go
Normal file
45
cloud/maplepress-backend/pkg/ratelimit/provider.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideRateLimiter provides a rate limiter for dependency injection (registration endpoints)
|
||||
func ProvideRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
|
||||
rateLimitConfig := Config{
|
||||
MaxRequests: cfg.RateLimit.RegistrationMaxRequests,
|
||||
Window: cfg.RateLimit.RegistrationWindow,
|
||||
KeyPrefix: "ratelimit:registration",
|
||||
}
|
||||
|
||||
return NewRateLimiter(redisClient, rateLimitConfig, logger)
|
||||
}
|
||||
|
||||
// ProvideGenericRateLimiter provides a rate limiter for generic CRUD endpoints (CWE-770)
|
||||
// This is used for authenticated endpoints like tenant/user/site management, admin endpoints
|
||||
// Strategy: User-based limiting (authenticated user ID from JWT)
|
||||
func ProvideGenericRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
|
||||
rateLimitConfig := Config{
|
||||
MaxRequests: cfg.RateLimit.GenericMaxRequests,
|
||||
Window: cfg.RateLimit.GenericWindow,
|
||||
KeyPrefix: "ratelimit:generic",
|
||||
}
|
||||
|
||||
return NewRateLimiter(redisClient, rateLimitConfig, logger)
|
||||
}
|
||||
|
||||
// ProvidePluginAPIRateLimiter provides a rate limiter for WordPress plugin API endpoints (CWE-770)
|
||||
// This is used for plugin endpoints that are core business/revenue endpoints
|
||||
// Strategy: Site-based limiting (API key → site_id)
|
||||
func ProvidePluginAPIRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
|
||||
rateLimitConfig := Config{
|
||||
MaxRequests: cfg.RateLimit.PluginAPIMaxRequests,
|
||||
Window: cfg.RateLimit.PluginAPIWindow,
|
||||
KeyPrefix: "ratelimit:plugin",
|
||||
}
|
||||
|
||||
return NewRateLimiter(redisClient, rateLimitConfig, logger)
|
||||
}
|
||||
23
cloud/maplepress-backend/pkg/ratelimit/providers.go
Normal file
23
cloud/maplepress-backend/pkg/ratelimit/providers.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideLoginRateLimiter creates a LoginRateLimiter for dependency injection
|
||||
// CWE-307: Implements rate limiting and account lockout protection against brute force attacks
|
||||
func ProvideLoginRateLimiter(client *redis.Client, cfg *config.Config, logger *zap.Logger) LoginRateLimiter {
|
||||
// Use configuration from environment variables
|
||||
loginConfig := LoginRateLimiterConfig{
|
||||
MaxAttemptsPerIP: cfg.RateLimit.LoginMaxAttemptsPerIP,
|
||||
IPWindow: cfg.RateLimit.LoginIPWindow,
|
||||
MaxFailedAttemptsPerAccount: cfg.RateLimit.LoginMaxFailedAttemptsPerAccount,
|
||||
AccountLockoutDuration: cfg.RateLimit.LoginAccountLockoutDuration,
|
||||
KeyPrefix: "login_rl",
|
||||
}
|
||||
|
||||
return NewLoginRateLimiter(client, loginConfig, logger)
|
||||
}
|
||||
172
cloud/maplepress-backend/pkg/ratelimit/ratelimiter.go
Normal file
172
cloud/maplepress-backend/pkg/ratelimit/ratelimiter.go
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RateLimiter provides rate limiting functionality using Redis
|
||||
type RateLimiter interface {
|
||||
// Allow checks if a request should be allowed based on the key
|
||||
// Returns true if allowed, false if rate limit exceeded
|
||||
Allow(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// AllowN checks if N requests should be allowed
|
||||
AllowN(ctx context.Context, key string, n int) (bool, error)
|
||||
|
||||
// Reset resets the rate limit for a key
|
||||
Reset(ctx context.Context, key string) error
|
||||
|
||||
// GetRemaining returns the number of remaining requests
|
||||
GetRemaining(ctx context.Context, key string) (int, error)
|
||||
}
|
||||
|
||||
// Config holds rate limiter configuration
|
||||
type Config struct {
|
||||
// MaxRequests is the maximum number of requests allowed
|
||||
MaxRequests int
|
||||
// Window is the time window for rate limiting
|
||||
Window time.Duration
|
||||
// KeyPrefix is the prefix for Redis keys
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
client *redis.Client
|
||||
config Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(client *redis.Client, config Config, logger *zap.Logger) RateLimiter {
|
||||
return &rateLimiter{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger.Named("rate-limiter"),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed
|
||||
func (r *rateLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
||||
return r.AllowN(ctx, key, 1)
|
||||
}
|
||||
|
||||
// AllowN checks if N requests should be allowed using sliding window counter
|
||||
func (r *rateLimiter) AllowN(ctx context.Context, key string, n int) (bool, error) {
|
||||
redisKey := r.getRedisKey(key)
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.Window)
|
||||
|
||||
// Use Redis transaction to ensure atomicity
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Remove old entries outside the window
|
||||
pipe.ZRemRangeByScore(ctx, redisKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
||||
|
||||
// Count current requests in window
|
||||
countCmd := pipe.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf")
|
||||
|
||||
// Execute pipeline
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
r.logger.Error("failed to check rate limit",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
// Fail open: allow request if Redis is down
|
||||
return true, err
|
||||
}
|
||||
|
||||
currentCount := countCmd.Val()
|
||||
|
||||
// Check if adding N requests would exceed limit
|
||||
if currentCount+int64(n) > int64(r.config.MaxRequests) {
|
||||
r.logger.Warn("rate limit exceeded",
|
||||
zap.String("key", key),
|
||||
zap.Int64("current_count", currentCount),
|
||||
zap.Int("max_requests", r.config.MaxRequests))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Add the new request(s) to the sorted set
|
||||
pipe2 := r.client.Pipeline()
|
||||
for i := 0; i < n; i++ {
|
||||
// Use nanosecond timestamp with incremental offset to ensure uniqueness
|
||||
timestamp := now.Add(time.Duration(i) * time.Nanosecond).UnixNano()
|
||||
pipe2.ZAdd(ctx, redisKey, redis.Z{
|
||||
Score: float64(timestamp),
|
||||
Member: fmt.Sprintf("%d-%d", timestamp, i),
|
||||
})
|
||||
}
|
||||
|
||||
// Set expiration on the key (window + buffer)
|
||||
pipe2.Expire(ctx, redisKey, r.config.Window+time.Minute)
|
||||
|
||||
// Execute pipeline
|
||||
_, err = pipe2.Exec(ctx)
|
||||
if err != nil && err != redis.Nil {
|
||||
r.logger.Error("failed to record request",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
// Already counted, so return true
|
||||
return true, err
|
||||
}
|
||||
|
||||
r.logger.Debug("rate limit check passed",
|
||||
zap.String("key", key),
|
||||
zap.Int64("current_count", currentCount),
|
||||
zap.Int("max_requests", r.config.MaxRequests))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Reset resets the rate limit for a key
|
||||
func (r *rateLimiter) Reset(ctx context.Context, key string) error {
|
||||
redisKey := r.getRedisKey(key)
|
||||
err := r.client.Del(ctx, redisKey).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to reset rate limit",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("rate limit reset",
|
||||
zap.String("key", key))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemaining returns the number of remaining requests in the current window
|
||||
func (r *rateLimiter) GetRemaining(ctx context.Context, key string) (int, error) {
|
||||
redisKey := r.getRedisKey(key)
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.Window)
|
||||
|
||||
// Count current requests in window
|
||||
count, err := r.client.ZCount(ctx, redisKey,
|
||||
fmt.Sprintf("%d", windowStart.UnixNano()),
|
||||
"+inf").Result()
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
r.logger.Error("failed to get remaining requests",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
remaining := r.config.MaxRequests - int(count)
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// getRedisKey constructs the Redis key with prefix
|
||||
func (r *rateLimiter) getRedisKey(key string) string {
|
||||
return fmt.Sprintf("%s:%s", r.config.KeyPrefix, key)
|
||||
}
|
||||
18
cloud/maplepress-backend/pkg/search/config.go
Normal file
18
cloud/maplepress-backend/pkg/search/config.go
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/search/config.go
|
||||
package search
|
||||
|
||||
// Config holds Meilisearch configuration
|
||||
type Config struct {
|
||||
Host string
|
||||
APIKey string
|
||||
IndexPrefix string // e.g., "maplepress_" or "site_"
|
||||
}
|
||||
|
||||
// NewConfig creates a new Meilisearch configuration
|
||||
func NewConfig(host, apiKey, indexPrefix string) *Config {
|
||||
return &Config{
|
||||
Host: host,
|
||||
APIKey: apiKey,
|
||||
IndexPrefix: indexPrefix,
|
||||
}
|
||||
}
|
||||
216
cloud/maplepress-backend/pkg/search/index.go
Normal file
216
cloud/maplepress-backend/pkg/search/index.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/search/index.go
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/meilisearch/meilisearch-go"
|
||||
)
|
||||
|
||||
// PageDocument represents a document in the search index
|
||||
type PageDocument struct {
|
||||
ID string `json:"id"` // page_id
|
||||
SiteID string `json:"site_id"` // for filtering (though each site has its own index)
|
||||
TenantID string `json:"tenant_id"` // for additional isolation
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"` // HTML stripped
|
||||
Excerpt string `json:"excerpt"`
|
||||
URL string `json:"url"`
|
||||
Status string `json:"status"` // publish, draft, trash
|
||||
PostType string `json:"post_type"` // page, post
|
||||
Author string `json:"author"`
|
||||
PublishedAt int64 `json:"published_at"` // Unix timestamp for sorting
|
||||
ModifiedAt int64 `json:"modified_at"` // Unix timestamp for sorting
|
||||
}
|
||||
|
||||
// CreateIndex creates a new index for a site
|
||||
func (c *Client) CreateIndex(siteID string) error {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
|
||||
// Create index with site_id as primary key
|
||||
_, err := c.client.CreateIndex(&meilisearch.IndexConfig{
|
||||
Uid: indexName,
|
||||
PrimaryKey: "id", // page_id is the primary key
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
// Configure index settings
|
||||
return c.ConfigureIndex(siteID)
|
||||
}
|
||||
|
||||
// ConfigureIndex configures the index settings
|
||||
func (c *Client) ConfigureIndex(siteID string) error {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
// Set searchable attributes (in order of priority)
|
||||
searchableAttributes := []string{
|
||||
"title",
|
||||
"excerpt",
|
||||
"content",
|
||||
}
|
||||
|
||||
_, err := index.UpdateSearchableAttributes(&searchableAttributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set searchable attributes: %w", err)
|
||||
}
|
||||
|
||||
// Set filterable attributes
|
||||
filterableAttributes := []interface{}{
|
||||
"status",
|
||||
"post_type",
|
||||
"author",
|
||||
"published_at",
|
||||
}
|
||||
|
||||
_, err = index.UpdateFilterableAttributes(&filterableAttributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set filterable attributes: %w", err)
|
||||
}
|
||||
|
||||
// Set ranking rules
|
||||
rankingRules := []string{
|
||||
"words",
|
||||
"typo",
|
||||
"proximity",
|
||||
"attribute",
|
||||
"sort",
|
||||
"exactness",
|
||||
}
|
||||
|
||||
_, err = index.UpdateRankingRules(&rankingRules)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set ranking rules: %w", err)
|
||||
}
|
||||
|
||||
// Set displayed attributes (don't return full content in search results)
|
||||
displayedAttributes := []string{
|
||||
"id",
|
||||
"title",
|
||||
"excerpt",
|
||||
"url",
|
||||
"status",
|
||||
"post_type",
|
||||
"author",
|
||||
"published_at",
|
||||
"modified_at",
|
||||
}
|
||||
|
||||
_, err = index.UpdateDisplayedAttributes(&displayedAttributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set displayed attributes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexExists checks if an index exists
|
||||
func (c *Client) IndexExists(siteID string) (bool, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
|
||||
_, err := c.client.GetIndex(indexName)
|
||||
if err != nil {
|
||||
// Check if error is "index not found" (status code 404)
|
||||
if meiliErr, ok := err.(*meilisearch.Error); ok {
|
||||
if meiliErr.StatusCode == 404 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return false, fmt.Errorf("failed to check index existence: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// DeleteIndex deletes an index for a site
|
||||
func (c *Client) DeleteIndex(siteID string) error {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
|
||||
_, err := c.client.DeleteIndex(indexName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDocuments adds or updates documents in the index
|
||||
func (c *Client) AddDocuments(siteID string, documents []PageDocument) (*meilisearch.TaskInfo, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
taskInfo, err := index.AddDocuments(documents, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add documents to index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
// UpdateDocuments updates documents in the index
|
||||
func (c *Client) UpdateDocuments(siteID string, documents []PageDocument) (*meilisearch.TaskInfo, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
taskInfo, err := index.UpdateDocuments(documents, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update documents in index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
// DeleteDocument deletes a single document from the index
|
||||
func (c *Client) DeleteDocument(siteID string, documentID string) (*meilisearch.TaskInfo, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
taskInfo, err := index.DeleteDocument(documentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete document %s from index %s: %w", documentID, indexName, err)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
// DeleteDocuments deletes multiple documents from the index
|
||||
func (c *Client) DeleteDocuments(siteID string, documentIDs []string) (*meilisearch.TaskInfo, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
taskInfo, err := index.DeleteDocuments(documentIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete documents from index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
// DeleteAllDocuments deletes all documents from the index
|
||||
func (c *Client) DeleteAllDocuments(siteID string) (*meilisearch.TaskInfo, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
taskInfo, err := index.DeleteAllDocuments()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete all documents from index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return taskInfo, nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about an index
|
||||
func (c *Client) GetStats(siteID string) (*meilisearch.StatsIndex, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
index := c.client.Index(indexName)
|
||||
|
||||
stats, err := index.GetStats()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get stats for index %s: %w", indexName, err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
47
cloud/maplepress-backend/pkg/search/meilisearch.go
Normal file
47
cloud/maplepress-backend/pkg/search/meilisearch.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/search/meilisearch.go
|
||||
package search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/meilisearch/meilisearch-go"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Client wraps the Meilisearch client
|
||||
type Client struct {
|
||||
client meilisearch.ServiceManager
|
||||
config *Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewClient creates a new Meilisearch client
|
||||
func NewClient(config *Config, logger *zap.Logger) (*Client, error) {
|
||||
if config.Host == "" {
|
||||
return nil, fmt.Errorf("meilisearch host is required")
|
||||
}
|
||||
|
||||
client := meilisearch.New(config.Host, meilisearch.WithAPIKey(config.APIKey))
|
||||
|
||||
return &Client{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger.Named("meilisearch"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetIndexName returns the full index name for a site
|
||||
func (c *Client) GetIndexName(siteID string) string {
|
||||
return c.config.IndexPrefix + siteID
|
||||
}
|
||||
|
||||
// Health checks if Meilisearch is healthy
|
||||
func (c *Client) Health() error {
|
||||
_, err := c.client.Health()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetClient returns the underlying Meilisearch client (for advanced operations)
|
||||
func (c *Client) GetClient() meilisearch.ServiceManager {
|
||||
return c.client
|
||||
}
|
||||
22
cloud/maplepress-backend/pkg/search/provider.go
Normal file
22
cloud/maplepress-backend/pkg/search/provider.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package search
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProvideClient provides a Meilisearch client
|
||||
func ProvideClient(cfg *config.Config, logger *zap.Logger) (*Client, error) {
|
||||
searchConfig := NewConfig(
|
||||
cfg.Meilisearch.Host,
|
||||
cfg.Meilisearch.APIKey,
|
||||
cfg.Meilisearch.IndexPrefix,
|
||||
)
|
||||
|
||||
client, err := NewClient(searchConfig, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
155
cloud/maplepress-backend/pkg/search/search.go
Normal file
155
cloud/maplepress-backend/pkg/search/search.go
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/search/search.go
|
||||
package search
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SearchRequest represents a search request
|
||||
type SearchRequest struct {
|
||||
Query string
|
||||
Limit int64
|
||||
Offset int64
|
||||
Filter string // e.g., "status = publish"
|
||||
}
|
||||
|
||||
// SearchResult represents a search result
|
||||
type SearchResult struct {
|
||||
Hits []map[string]interface{} `json:"hits"`
|
||||
Query string `json:"query"`
|
||||
ProcessingTimeMs int64 `json:"processing_time_ms"`
|
||||
TotalHits int64 `json:"total_hits"`
|
||||
Limit int64 `json:"limit"`
|
||||
Offset int64 `json:"offset"`
|
||||
}
|
||||
|
||||
// Search performs a search query on the index
|
||||
func (c *Client) Search(siteID string, req SearchRequest) (*SearchResult, error) {
|
||||
indexName := c.GetIndexName(siteID)
|
||||
|
||||
c.logger.Info("initiating search",
|
||||
zap.String("site_id", siteID),
|
||||
zap.String("index_name", indexName),
|
||||
zap.String("query", req.Query),
|
||||
zap.Int64("limit", req.Limit),
|
||||
zap.Int64("offset", req.Offset),
|
||||
zap.String("filter", req.Filter))
|
||||
|
||||
// Build search request manually to avoid hybrid field
|
||||
searchBody := map[string]interface{}{
|
||||
"q": req.Query,
|
||||
"limit": req.Limit,
|
||||
"offset": req.Offset,
|
||||
"attributesToHighlight": []string{"title", "excerpt", "content"},
|
||||
}
|
||||
|
||||
// Add filter if provided
|
||||
if req.Filter != "" {
|
||||
searchBody["filter"] = req.Filter
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(searchBody)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to marshal search request", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to marshal search request: %w", err)
|
||||
}
|
||||
|
||||
// Uncomment for debugging: shows exact JSON payload sent to Meilisearch
|
||||
// c.logger.Debug("search request payload", zap.String("json", string(jsonData)))
|
||||
|
||||
// Build search URL
|
||||
searchURL := fmt.Sprintf("%s/indexes/%s/search", c.config.Host, indexName)
|
||||
// Uncomment for debugging: shows the Meilisearch API endpoint
|
||||
// c.logger.Debug("search URL", zap.String("url", searchURL))
|
||||
|
||||
// Create HTTP request
|
||||
httpReq, err := http.NewRequest("POST", searchURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
c.logger.Error("failed to create HTTP request", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to create search request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey))
|
||||
|
||||
// Execute request
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to execute HTTP request", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to execute search request: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
c.logger.Error("received nil response from search API")
|
||||
return nil, fmt.Errorf("received nil response from search API")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
c.logger.Info("received search response",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("status", resp.Status))
|
||||
|
||||
// Read response body for logging
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.logger.Error("failed to read response body", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Uncomment for debugging: shows full Meilisearch response
|
||||
// c.logger.Debug("search response body", zap.String("body", string(bodyBytes)))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
c.logger.Error("search request failed",
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.String("response_body", string(bodyBytes)))
|
||||
return nil, fmt.Errorf("search request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var searchResp struct {
|
||||
Hits []map[string]interface{} `json:"hits"`
|
||||
Query string `json:"query"`
|
||||
ProcessingTimeMs int64 `json:"processingTimeMs"`
|
||||
EstimatedTotalHits int `json:"estimatedTotalHits"`
|
||||
Limit int64 `json:"limit"`
|
||||
Offset int64 `json:"offset"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &searchResp); err != nil {
|
||||
c.logger.Error("failed to decode search response", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to decode search response: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("search completed successfully",
|
||||
zap.Int("hits_count", len(searchResp.Hits)),
|
||||
zap.Int("total_hits", searchResp.EstimatedTotalHits),
|
||||
zap.Int64("processing_time_ms", searchResp.ProcessingTimeMs))
|
||||
|
||||
// Convert response
|
||||
result := &SearchResult{
|
||||
Hits: searchResp.Hits,
|
||||
Query: searchResp.Query,
|
||||
ProcessingTimeMs: searchResp.ProcessingTimeMs,
|
||||
TotalHits: int64(searchResp.EstimatedTotalHits),
|
||||
Limit: req.Limit,
|
||||
Offset: req.Offset,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SearchWithHighlights performs a search with custom highlighting
|
||||
// Note: Currently uses same implementation as Search()
|
||||
func (c *Client) SearchWithHighlights(siteID string, req SearchRequest, highlightTags []string) (*SearchResult, error) {
|
||||
// For now, just use the regular Search method
|
||||
// TODO: Implement custom highlight tags if needed
|
||||
return c.Search(siteID, req)
|
||||
}
|
||||
520
cloud/maplepress-backend/pkg/security/README.md
Normal file
520
cloud/maplepress-backend/pkg/security/README.md
Normal file
|
|
@ -0,0 +1,520 @@
|
|||
# Security Package
|
||||
|
||||
This package provides secure password hashing and memory-safe storage for sensitive data.
|
||||
|
||||
## Packages
|
||||
|
||||
### Password (`pkg/security/password`)
|
||||
|
||||
Provides Argon2id-based password hashing and verification with secure default parameters following OWASP recommendations.
|
||||
|
||||
### SecureString (`pkg/security/securestring`)
|
||||
|
||||
Memory-safe string storage using `memguard` to protect sensitive data like passwords and API keys from memory dumps and swap files.
|
||||
|
||||
### SecureBytes (`pkg/security/securebytes`)
|
||||
|
||||
Memory-safe byte slice storage using `memguard` to protect sensitive binary data.
|
||||
|
||||
### IPCountryBlocker (`pkg/security/ipcountryblocker`)
|
||||
|
||||
GeoIP-based country blocking using MaxMind's GeoLite2 database to block requests from specific countries.
|
||||
|
||||
## Installation
|
||||
|
||||
The packages are included in the project. Required dependencies:
|
||||
- `github.com/awnumar/memguard` - For secure memory management
|
||||
- `golang.org/x/crypto/argon2` - For password hashing
|
||||
- `github.com/oschwald/geoip2-golang` - For GeoIP lookups
|
||||
|
||||
## Usage
|
||||
|
||||
### Password Hashing
|
||||
|
||||
```go
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/password"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// Create password provider
|
||||
passwordProvider := password.NewPasswordProvider()
|
||||
|
||||
// Hash a password
|
||||
plainPassword := "mySecurePassword123!"
|
||||
securePass, err := securestring.NewSecureString(plainPassword)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
defer securePass.Wipe() // Always wipe after use
|
||||
|
||||
hashedPassword, err := passwordProvider.GenerateHashFromPassword(securePass)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
|
||||
// Verify a password
|
||||
match, err := passwordProvider.ComparePasswordAndHash(securePass, hashedPassword)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
if match {
|
||||
// Password is correct
|
||||
}
|
||||
```
|
||||
|
||||
### Secure String Storage
|
||||
|
||||
```go
|
||||
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
|
||||
|
||||
// Store sensitive data securely
|
||||
apiKey := "secret-api-key-12345"
|
||||
secureKey, err := securestring.NewSecureString(apiKey)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
defer secureKey.Wipe() // Always wipe when done
|
||||
|
||||
// Use the secure string
|
||||
keyValue := secureKey.String() // Get the value when needed
|
||||
// ... use keyValue ...
|
||||
|
||||
// The original string should be cleared
|
||||
apiKey = ""
|
||||
```
|
||||
|
||||
### Secure Bytes Storage
|
||||
|
||||
```go
|
||||
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securebytes"
|
||||
|
||||
// Store sensitive binary data
|
||||
sensitiveData := []byte{0x01, 0x02, 0x03, 0x04}
|
||||
secureData, err := securebytes.NewSecureBytes(sensitiveData)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
defer secureData.Wipe()
|
||||
|
||||
// Use the secure bytes
|
||||
data := secureData.Bytes()
|
||||
// ... use data ...
|
||||
|
||||
// Clear the original slice
|
||||
for i := range sensitiveData {
|
||||
sensitiveData[i] = 0
|
||||
}
|
||||
```
|
||||
|
||||
### Generate Random Values
|
||||
|
||||
```go
|
||||
passwordProvider := password.NewPasswordProvider()
|
||||
|
||||
// Generate random bytes
|
||||
randomBytes, err := passwordProvider.GenerateSecureRandomBytes(32)
|
||||
|
||||
// Generate random hex string (length * 2 characters)
|
||||
randomString, err := passwordProvider.GenerateSecureRandomString(16)
|
||||
// Returns a 32-character hex string
|
||||
```
|
||||
|
||||
### IP Country Blocking
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/ipcountryblocker"
|
||||
)
|
||||
|
||||
// Create the blocker (typically done via dependency injection)
|
||||
cfg, _ := config.Load()
|
||||
blocker := ipcountryblocker.NewProvider(cfg, logger)
|
||||
defer blocker.Close()
|
||||
|
||||
// Check if an IP is blocked
|
||||
ip := net.ParseIP("192.0.2.1")
|
||||
if blocker.IsBlockedIP(context.Background(), ip) {
|
||||
// Handle blocked IP
|
||||
return errors.New("access denied: your country is blocked")
|
||||
}
|
||||
|
||||
// Check if a country code is blocked
|
||||
if blocker.IsBlockedCountry("CN") {
|
||||
// Country is in the blocked list
|
||||
}
|
||||
|
||||
// Get country code for an IP
|
||||
countryCode, err := blocker.GetCountryCode(context.Background(), ip)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
// countryCode will be ISO 3166-1 alpha-2 code like "US", "CA", "GB"
|
||||
```
|
||||
|
||||
**Configuration**:
|
||||
```bash
|
||||
# Environment variables
|
||||
APP_GEOLITE_DB_PATH=/path/to/GeoLite2-Country.mmdb
|
||||
APP_BANNED_COUNTRIES=CN,RU,KP # Comma-separated ISO 3166-1 alpha-2 codes
|
||||
```
|
||||
|
||||
## Password Hashing Details
|
||||
|
||||
### Algorithm: Argon2id
|
||||
|
||||
Argon2id is the recommended password hashing algorithm by OWASP. It combines:
|
||||
- Argon2i: Resistant to side-channel attacks
|
||||
- Argon2d: Resistant to GPU cracking attacks
|
||||
|
||||
### Default Parameters
|
||||
|
||||
```
|
||||
Memory: 64 MB (65536 KB)
|
||||
Iterations: 3
|
||||
Parallelism: 2 threads
|
||||
Salt Length: 16 bytes
|
||||
Key Length: 32 bytes
|
||||
```
|
||||
|
||||
These parameters provide strong security while maintaining reasonable performance for authentication systems.
|
||||
|
||||
### Hash Format
|
||||
|
||||
```
|
||||
$argon2id$v=19$m=65536,t=3,p=2$<base64-salt>$<base64-hash>
|
||||
```
|
||||
|
||||
Example:
|
||||
```
|
||||
$argon2id$v=19$m=65536,t=3,p=2$YWJjZGVmZ2hpamtsbW5vcA$9XJqrJ8fQvVrMz0FqJ7gBGqKvYLvLxC8HzPqKvYLvLxC
|
||||
```
|
||||
|
||||
The hash includes all parameters, so it can be verified even if you change the default parameters later.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. Always Wipe Sensitive Data
|
||||
|
||||
```go
|
||||
securePass, _ := securestring.NewSecureString(password)
|
||||
defer securePass.Wipe() // Ensures cleanup even on panic
|
||||
|
||||
// ... use securePass ...
|
||||
```
|
||||
|
||||
### 2. Clear Original Data
|
||||
|
||||
After creating a secure string/bytes, clear the original data:
|
||||
|
||||
```go
|
||||
password := "secret"
|
||||
securePass, _ := securestring.NewSecureString(password)
|
||||
password = "" // Clear the original string
|
||||
|
||||
// Even better for byte slices:
|
||||
data := []byte("secret")
|
||||
secureData, _ := securebytes.NewSecureBytes(data)
|
||||
for i := range data {
|
||||
data[i] = 0
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Minimize Exposure Time
|
||||
|
||||
Get values from secure storage only when needed:
|
||||
|
||||
```go
|
||||
// Bad - exposes value for too long
|
||||
value := secureString.String()
|
||||
// ... lots of code ...
|
||||
useValue(value)
|
||||
|
||||
// Good - get value right before use
|
||||
// ... lots of code ...
|
||||
useValue(secureString.String())
|
||||
```
|
||||
|
||||
### 4. Use Dependency Injection
|
||||
|
||||
```go
|
||||
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/password"
|
||||
|
||||
// In your Wire provider set
|
||||
wire.NewSet(
|
||||
password.ProvidePasswordProvider,
|
||||
// ... other providers
|
||||
)
|
||||
|
||||
// Use in your service
|
||||
type AuthService struct {
|
||||
passwordProvider password.PasswordProvider
|
||||
}
|
||||
|
||||
func NewAuthService(pp password.PasswordProvider) *AuthService {
|
||||
return &AuthService{passwordProvider: pp}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Handle Errors Properly
|
||||
|
||||
```go
|
||||
securePass, err := securestring.NewSecureString(password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create secure string: %w", err)
|
||||
}
|
||||
defer securePass.Wipe()
|
||||
```
|
||||
|
||||
### 6. Clean Up GeoIP Resources
|
||||
|
||||
```go
|
||||
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/ipcountryblocker"
|
||||
|
||||
// Always close the provider when done to release database resources
|
||||
blocker := ipcountryblocker.NewProvider(cfg, logger)
|
||||
defer blocker.Close()
|
||||
```
|
||||
|
||||
## Memory Safety
|
||||
|
||||
### How It Works
|
||||
|
||||
The `memguard` library provides:
|
||||
|
||||
1. **Locked Memory**: Prevents sensitive data from being swapped to disk
|
||||
2. **Guarded Heap**: Detects buffer overflows and underflows
|
||||
3. **Secure Wiping**: Overwrites memory with random data before freeing
|
||||
4. **Read Protection**: Makes memory pages read-only when not in use
|
||||
|
||||
### When to Use
|
||||
|
||||
Use secure storage for:
|
||||
- Passwords and password hashes (during verification)
|
||||
- API keys and tokens
|
||||
- Encryption keys
|
||||
- Private keys
|
||||
- Database credentials
|
||||
- OAuth secrets
|
||||
- JWT signing keys
|
||||
- Session tokens
|
||||
- Any sensitive user data
|
||||
|
||||
### When NOT to Use
|
||||
|
||||
Don't use for:
|
||||
- Public data
|
||||
- Non-sensitive configuration
|
||||
- Data that needs to be logged
|
||||
- Data that will be stored long-term in memory
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Password Hashing
|
||||
|
||||
Argon2id is intentionally slow to prevent brute-force attacks:
|
||||
- Expected time: ~50-100ms per hash on modern hardware
|
||||
- This is acceptable for authentication (login) operations
|
||||
- DO NOT use for high-throughput operations
|
||||
|
||||
### Memory Usage
|
||||
|
||||
SecureString/SecureBytes use locked memory:
|
||||
- Each instance locks a page in RAM (typically 4KB minimum)
|
||||
- Don't create thousands of instances
|
||||
- Reuse instances when possible
|
||||
- Always wipe when done
|
||||
|
||||
## Examples
|
||||
|
||||
### Complete Login Example
|
||||
|
||||
```go
|
||||
func (s *AuthService) Login(ctx context.Context, email, password string) (*User, error) {
|
||||
// Create secure string from password
|
||||
securePass, err := securestring.NewSecureString(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer securePass.Wipe()
|
||||
|
||||
// Clear the original password
|
||||
password = ""
|
||||
|
||||
// Get user from database
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify password
|
||||
match, err := s.passwordProvider.ComparePasswordAndHash(securePass, user.PasswordHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !match {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
```
|
||||
|
||||
### Complete Registration Example
|
||||
|
||||
```go
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (*User, error) {
|
||||
// Validate password strength first
|
||||
if len(password) < 8 {
|
||||
return nil, ErrWeakPassword
|
||||
}
|
||||
|
||||
// Create secure string from password
|
||||
securePass, err := securestring.NewSecureString(password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer securePass.Wipe()
|
||||
|
||||
// Clear the original password
|
||||
password = ""
|
||||
|
||||
// Hash the password
|
||||
hashedPassword, err := s.passwordProvider.GenerateHashFromPassword(securePass)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create user with hashed password
|
||||
user := &User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "failed to create buffer"
|
||||
|
||||
**Problem**: memguard couldn't allocate locked memory
|
||||
|
||||
**Solutions**:
|
||||
- Check system limits for locked memory (`ulimit -l`)
|
||||
- Reduce number of concurrent SecureString/SecureBytes instances
|
||||
- Ensure proper cleanup with `Wipe()`
|
||||
|
||||
### "buffer is not alive"
|
||||
|
||||
**Problem**: Trying to use a SecureString/SecureBytes after it was wiped
|
||||
|
||||
**Solutions**:
|
||||
- Don't use secure data after calling `Wipe()`
|
||||
- Check your defer ordering
|
||||
- Create new instances if you need the data again
|
||||
|
||||
### Slow Performance
|
||||
|
||||
**Problem**: Password hashing is too slow
|
||||
|
||||
**Solutions**:
|
||||
- This is by design for security
|
||||
- Don't hash passwords in high-throughput operations
|
||||
- Consider caching authentication results (with care)
|
||||
- Use async operations for registration/password changes
|
||||
|
||||
### "failed to open GeoLite2 DB"
|
||||
|
||||
**Problem**: Cannot open the GeoIP2 database
|
||||
|
||||
**Solutions**:
|
||||
- Verify APP_GEOLITE_DB_PATH points to a valid .mmdb file
|
||||
- Download the GeoLite2-Country database from MaxMind
|
||||
- Check file permissions
|
||||
- Ensure the database file is not corrupted
|
||||
|
||||
### "no country found for IP"
|
||||
|
||||
**Problem**: IP address lookup returns no country
|
||||
|
||||
**Solutions**:
|
||||
- This is normal for private IP ranges (10.x.x.x, 192.168.x.x, etc.)
|
||||
- The IP might not be in the GeoIP2 database
|
||||
- Update to a newer GeoLite2 database
|
||||
- By default, unknown IPs are allowed (returns false from IsBlockedIP)
|
||||
|
||||
## IP Country Blocking Details
|
||||
|
||||
### GeoLite2 Database
|
||||
|
||||
The IP country blocker uses MaxMind's GeoLite2-Country database for IP geolocation.
|
||||
|
||||
**How to Get the Database**:
|
||||
1. Create a free account at https://www.maxmind.com/en/geolite2/signup
|
||||
2. Generate a license key
|
||||
3. Download GeoLite2-Country database (.mmdb format)
|
||||
4. Set APP_GEOLITE_DB_PATH to the file location
|
||||
|
||||
**Database Updates**:
|
||||
- MaxMind updates GeoLite2 databases weekly
|
||||
- Set up automated updates for production systems
|
||||
- Database file is typically 5-10 MB
|
||||
|
||||
### Country Codes
|
||||
|
||||
Uses ISO 3166-1 alpha-2 country codes:
|
||||
- US - United States
|
||||
- CA - Canada
|
||||
- GB - United Kingdom
|
||||
- CN - China
|
||||
- RU - Russia
|
||||
- KP - North Korea
|
||||
- etc.
|
||||
|
||||
Full list: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
|
||||
|
||||
### Blocking Behavior
|
||||
|
||||
**Default Behavior**:
|
||||
- If IP lookup fails → Allow (returns false)
|
||||
- If country not found → Allow (returns false)
|
||||
- If country is blocked → Block (returns true)
|
||||
|
||||
**To block unknown IPs**, modify IsBlockedIP to return true on error (line 101 in ipcountryblocker.go).
|
||||
|
||||
### Thread Safety
|
||||
|
||||
The provider is thread-safe:
|
||||
- Uses sync.RWMutex for concurrent access to blocked countries map
|
||||
- GeoIP2 Reader is thread-safe by design
|
||||
- Safe to use in HTTP middleware and concurrent handlers
|
||||
|
||||
### Performance
|
||||
|
||||
**Lookup Speed**:
|
||||
- In-memory database lookups are very fast (~microseconds)
|
||||
- Database is memory-mapped for efficiency
|
||||
- Suitable for high-traffic applications
|
||||
|
||||
**Memory Usage**:
|
||||
- GeoLite2-Country database: ~5-10 MB in memory
|
||||
- Blocked countries map: negligible (few KB)
|
||||
|
||||
## References
|
||||
|
||||
- [Argon2 RFC](https://tools.ietf.org/html/rfc9106)
|
||||
- [OWASP Password Storage Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html)
|
||||
- [memguard Documentation](https://github.com/awnumar/memguard)
|
||||
- [Alex Edwards: How to Hash and Verify Passwords With Argon2 in Go](https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go)
|
||||
- [MaxMind GeoLite2 Free Geolocation Data](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data)
|
||||
- [ISO 3166-1 Country Codes](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2)
|
||||
96
cloud/maplepress-backend/pkg/security/apikey/generator.go
Normal file
96
cloud/maplepress-backend/pkg/security/apikey/generator.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrefixLive is the prefix for production API keys
|
||||
PrefixLive = "live_sk_"
|
||||
// PrefixTest is the prefix for test/sandbox API keys
|
||||
PrefixTest = "test_sk_"
|
||||
// KeyLength is the length of the random part (40 chars in base64url)
|
||||
KeyLength = 30 // 30 bytes = 40 base64url chars
|
||||
)
|
||||
|
||||
// Generator generates API keys
|
||||
type Generator interface {
|
||||
// Generate creates a new live API key
|
||||
Generate() (string, error)
|
||||
// GenerateTest creates a new test API key
|
||||
GenerateTest() (string, error)
|
||||
}
|
||||
|
||||
type generator struct{}
|
||||
|
||||
// NewGenerator creates a new API key generator
|
||||
func NewGenerator() Generator {
|
||||
return &generator{}
|
||||
}
|
||||
|
||||
// Generate creates a new live API key
|
||||
func (g *generator) Generate() (string, error) {
|
||||
return g.generateWithPrefix(PrefixLive)
|
||||
}
|
||||
|
||||
// GenerateTest creates a new test API key
|
||||
func (g *generator) GenerateTest() (string, error) {
|
||||
return g.generateWithPrefix(PrefixTest)
|
||||
}
|
||||
|
||||
func (g *generator) generateWithPrefix(prefix string) (string, error) {
|
||||
// Generate cryptographically secure random bytes
|
||||
b := make([]byte, KeyLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode to base64url (URL-safe, no padding)
|
||||
key := base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
// Remove any special chars and make lowercase for consistency
|
||||
key = strings.Map(func(r rune) rune {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
|
||||
return r
|
||||
}
|
||||
return -1 // Remove character
|
||||
}, key)
|
||||
|
||||
// Ensure we have at least 40 characters
|
||||
if len(key) < 40 {
|
||||
// Pad with additional random bytes if needed
|
||||
additional := make([]byte, 10)
|
||||
rand.Read(additional)
|
||||
extraKey := base64.RawURLEncoding.EncodeToString(additional)
|
||||
key += extraKey
|
||||
}
|
||||
|
||||
// Trim to exactly 40 characters
|
||||
key = key[:40]
|
||||
|
||||
return prefix + key, nil
|
||||
}
|
||||
|
||||
// ExtractPrefix extracts the prefix from an API key
|
||||
func ExtractPrefix(apiKey string) string {
|
||||
if len(apiKey) < 13 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[:13] // "live_sk_a1b2" or "test_sk_a1b2"
|
||||
}
|
||||
|
||||
// ExtractLastFour extracts the last 4 characters from an API key
|
||||
func ExtractLastFour(apiKey string) string {
|
||||
if len(apiKey) < 4 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[len(apiKey)-4:]
|
||||
}
|
||||
|
||||
// IsValid checks if an API key has a valid format
|
||||
func IsValid(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, PrefixLive) || strings.HasPrefix(apiKey, PrefixTest)
|
||||
}
|
||||
35
cloud/maplepress-backend/pkg/security/apikey/hasher.go
Normal file
35
cloud/maplepress-backend/pkg/security/apikey/hasher.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Hasher hashes and verifies API keys using SHA-256
|
||||
type Hasher interface {
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
Hash(apiKey string) string
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
Verify(apiKey string, hash string) bool
|
||||
}
|
||||
|
||||
type hasher struct{}
|
||||
|
||||
// NewHasher creates a new API key hasher
|
||||
func NewHasher() Hasher {
|
||||
return &hasher{}
|
||||
}
|
||||
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
func (h *hasher) Hash(apiKey string) string {
|
||||
hash := sha256.Sum256([]byte(apiKey))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
// This prevents timing attacks
|
||||
func (h *hasher) Verify(apiKey string, expectedHash string) bool {
|
||||
actualHash := h.Hash(apiKey)
|
||||
return subtle.ConstantTimeCompare([]byte(actualHash), []byte(expectedHash)) == 1
|
||||
}
|
||||
11
cloud/maplepress-backend/pkg/security/apikey/provider.go
Normal file
11
cloud/maplepress-backend/pkg/security/apikey/provider.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package apikey
|
||||
|
||||
// ProvideGenerator provides an API key generator for dependency injection
|
||||
func ProvideGenerator() Generator {
|
||||
return NewGenerator()
|
||||
}
|
||||
|
||||
// ProvideHasher provides an API key hasher for dependency injection
|
||||
func ProvideHasher() Hasher {
|
||||
return NewHasher()
|
||||
}
|
||||
168
cloud/maplepress-backend/pkg/security/clientip/extractor.go
Normal file
168
cloud/maplepress-backend/pkg/security/clientip/extractor.go
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Extractor provides secure client IP address extraction
|
||||
// CWE-348: Prevents X-Forwarded-For header spoofing by validating trusted proxies
|
||||
type Extractor struct {
|
||||
trustedProxies []*net.IPNet
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExtractor creates a new IP extractor with trusted proxy configuration
|
||||
// trustedProxyCIDRs should contain CIDR blocks of trusted reverse proxies
|
||||
// Example: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
|
||||
func NewExtractor(trustedProxyCIDRs []string, logger *zap.Logger) (*Extractor, error) {
|
||||
var trustedProxies []*net.IPNet
|
||||
|
||||
for _, cidr := range trustedProxyCIDRs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse trusted proxy CIDR",
|
||||
zap.String("cidr", cidr),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
trustedProxies = append(trustedProxies, ipNet)
|
||||
}
|
||||
|
||||
logger.Info("client IP extractor initialized",
|
||||
zap.Int("trusted_proxy_ranges", len(trustedProxies)))
|
||||
|
||||
return &Extractor{
|
||||
trustedProxies: trustedProxies,
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewDefaultExtractor creates an extractor with no trusted proxies
|
||||
// This is safe for direct connections but will ignore X-Forwarded-For headers
|
||||
func NewDefaultExtractor(logger *zap.Logger) *Extractor {
|
||||
logger.Warn("client IP extractor initialized with NO trusted proxies - X-Forwarded-For will be ignored")
|
||||
return &Extractor{
|
||||
trustedProxies: []*net.IPNet{},
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extracts the real client IP address from the HTTP request
|
||||
// CWE-348: Secure implementation that prevents header spoofing
|
||||
func (e *Extractor) Extract(r *http.Request) string {
|
||||
// Step 1: Get the immediate connection's remote address
|
||||
remoteAddr := r.RemoteAddr
|
||||
|
||||
// Remove port from RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||
remoteIP := e.stripPort(remoteAddr)
|
||||
|
||||
// Step 2: Parse the remote IP
|
||||
parsedRemoteIP := net.ParseIP(remoteIP)
|
||||
if parsedRemoteIP == nil {
|
||||
e.logger.Warn("failed to parse remote IP address",
|
||||
zap.String("remote_addr", remoteAddr))
|
||||
return remoteIP // Return as-is if we can't parse it
|
||||
}
|
||||
|
||||
// Step 3: Check if the immediate connection is from a trusted proxy
|
||||
if !e.isTrustedProxy(parsedRemoteIP) {
|
||||
// NOT from a trusted proxy - do NOT trust X-Forwarded-For header
|
||||
// This prevents clients from spoofing their IP by setting the header
|
||||
e.logger.Debug("remote IP is not a trusted proxy, using RemoteAddr",
|
||||
zap.String("remote_ip", remoteIP))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 4: Remote IP is trusted, check X-Forwarded-For header
|
||||
// Format: "client, proxy1, proxy2" (leftmost is original client)
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
// No X-Forwarded-For header, use RemoteAddr
|
||||
e.logger.Debug("no X-Forwarded-For header from trusted proxy",
|
||||
zap.String("remote_ip", remoteIP))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 5: Parse X-Forwarded-For header
|
||||
// Take the FIRST IP (leftmost) which should be the original client
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) == 0 {
|
||||
e.logger.Debug("empty X-Forwarded-For header",
|
||||
zap.String("remote_ip", remoteIP))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Get the first IP and trim whitespace
|
||||
clientIP := strings.TrimSpace(ips[0])
|
||||
|
||||
// Step 6: Validate the client IP
|
||||
parsedClientIP := net.ParseIP(clientIP)
|
||||
if parsedClientIP == nil {
|
||||
e.logger.Warn("invalid IP in X-Forwarded-For header",
|
||||
zap.String("xff", xff),
|
||||
zap.String("client_ip", clientIP))
|
||||
return remoteIP // Fall back to RemoteAddr
|
||||
}
|
||||
|
||||
e.logger.Debug("extracted client IP from X-Forwarded-For",
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("remote_proxy", remoteIP),
|
||||
zap.String("xff_chain", xff))
|
||||
|
||||
return clientIP
|
||||
}
|
||||
|
||||
// ExtractOrDefault extracts the client IP or returns a default value
|
||||
func (e *Extractor) ExtractOrDefault(r *http.Request, defaultIP string) string {
|
||||
ip := e.Extract(r)
|
||||
if ip == "" {
|
||||
return defaultIP
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isTrustedProxy checks if an IP is in the trusted proxy list
|
||||
func (e *Extractor) isTrustedProxy(ip net.IP) bool {
|
||||
for _, ipNet := range e.trustedProxies {
|
||||
if ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort removes the port from an address string
|
||||
// Handles both IPv4 (192.168.1.1:8080) and IPv6 ([::1]:8080) formats
|
||||
func (e *Extractor) stripPort(addr string) string {
|
||||
// For IPv6, check for bracket format [IP]:port
|
||||
if strings.HasPrefix(addr, "[") {
|
||||
// IPv6 format: [::1]:8080
|
||||
if idx := strings.LastIndex(addr, "]:"); idx != -1 {
|
||||
return addr[1:idx] // Extract IP between [ and ]
|
||||
}
|
||||
// Malformed IPv6 address
|
||||
return addr
|
||||
}
|
||||
|
||||
// For IPv4, split on last colon
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
}
|
||||
|
||||
// No port found
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetTrustedProxyCount returns the number of configured trusted proxy ranges
|
||||
func (e *Extractor) GetTrustedProxyCount() int {
|
||||
return len(e.trustedProxies)
|
||||
}
|
||||
|
||||
// HasTrustedProxies returns true if any trusted proxies are configured
|
||||
func (e *Extractor) HasTrustedProxies() bool {
|
||||
return len(e.trustedProxies) > 0
|
||||
}
|
||||
19
cloud/maplepress-backend/pkg/security/clientip/provider.go
Normal file
19
cloud/maplepress-backend/pkg/security/clientip/provider.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideExtractor provides a client IP extractor configured from the application config
|
||||
func ProvideExtractor(cfg *config.Config, logger *zap.Logger) (*Extractor, error) {
|
||||
// If no trusted proxies configured, use default (no X-Forwarded-For trust)
|
||||
if len(cfg.Security.TrustedProxies) == 0 {
|
||||
logger.Info("no trusted proxies configured - X-Forwarded-For headers will be ignored for security")
|
||||
return NewDefaultExtractor(logger), nil
|
||||
}
|
||||
|
||||
// Create extractor with trusted proxies
|
||||
return NewExtractor(cfg.Security.TrustedProxies, logger)
|
||||
}
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// Provider defines the interface for IP-based country blocking operations.
|
||||
// It provides methods to check if an IP or country is blocked and to retrieve
|
||||
// country codes for given IP addresses.
|
||||
type Provider interface {
|
||||
// IsBlockedCountry checks if a country is in the blocked list.
|
||||
// isoCode must be an ISO 3166-1 alpha-2 country code.
|
||||
IsBlockedCountry(isoCode string) bool
|
||||
|
||||
// IsBlockedIP determines if an IP address originates from a blocked country.
|
||||
// Returns false for nil IP addresses or if country lookup fails.
|
||||
IsBlockedIP(ctx context.Context, ip net.IP) bool
|
||||
|
||||
// GetCountryCode returns the ISO 3166-1 alpha-2 country code for an IP address.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
GetCountryCode(ctx context.Context, ip net.IP) (string, error)
|
||||
|
||||
// Close releases resources associated with the provider.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// provider implements the Provider interface using MaxMind's GeoIP2 database.
|
||||
type provider struct {
|
||||
db *geoip2.Reader
|
||||
blockedCountries map[string]struct{} // Uses empty struct to optimize memory
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex // Protects concurrent access to blockedCountries
|
||||
}
|
||||
|
||||
// NewProvider creates a new IP country blocking provider using the provided configuration.
|
||||
// It initializes the GeoIP2 database and sets up the blocked countries list.
|
||||
// Fatally crashes the entire application if the database cannot be opened.
|
||||
func NewProvider(cfg *config.Config, logger *zap.Logger) Provider {
|
||||
logger.Info("⏳ Loading GeoIP2 database...",
|
||||
zap.String("db_path", cfg.App.GeoLiteDBPath))
|
||||
|
||||
db, err := geoip2.Open(cfg.App.GeoLiteDBPath)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to open GeoLite2 database",
|
||||
zap.String("db_path", cfg.App.GeoLiteDBPath),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
blocked := make(map[string]struct{}, len(cfg.App.BannedCountries))
|
||||
for _, country := range cfg.App.BannedCountries {
|
||||
blocked[country] = struct{}{}
|
||||
}
|
||||
|
||||
logger.Info("✓ IP country blocker initialized",
|
||||
zap.Int("blocked_countries", len(cfg.App.BannedCountries)))
|
||||
|
||||
return &provider{
|
||||
db: db,
|
||||
blockedCountries: blocked,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedCountry checks if a country code exists in the blocked countries map.
|
||||
// Thread-safe through RLock.
|
||||
func (p *provider) IsBlockedCountry(isoCode string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
_, exists := p.blockedCountries[isoCode]
|
||||
return exists
|
||||
}
|
||||
|
||||
// IsBlockedIP performs a country lookup for the IP and checks if it's blocked.
|
||||
// Returns false for nil IPs or failed lookups to fail safely.
|
||||
func (p *provider) IsBlockedIP(ctx context.Context, ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
code, err := p.GetCountryCode(ctx, ip)
|
||||
if err != nil {
|
||||
// Developers Note:
|
||||
// Comment this console log as it contributes a `noisy` server log.
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
// p.logger.WarnContext(ctx, "failed to get country code",
|
||||
// zap.Any("ip", ip),
|
||||
// zap.Any("error", err))
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
// Developers Note:
|
||||
// If the country d.n.e. exist that means we will return with `false`
|
||||
// indicating this IP address is allowed to access our server. If this
|
||||
// is concerning then you might set this to `true` to block on all
|
||||
// IP address which are not categorized by country.
|
||||
return false
|
||||
}
|
||||
|
||||
return p.IsBlockedCountry(code)
|
||||
}
|
||||
|
||||
// GetCountryCode performs a GeoIP2 database lookup to determine an IP's country.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
func (p *provider) GetCountryCode(ctx context.Context, ip net.IP) (string, error) {
|
||||
record, err := p.db.Country(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup country: %w", err)
|
||||
}
|
||||
|
||||
if record == nil || record.Country.IsoCode == "" {
|
||||
return "", fmt.Errorf("no country found for IP: %v", ip)
|
||||
}
|
||||
|
||||
return record.Country.IsoCode, nil
|
||||
}
|
||||
|
||||
// Close cleanly shuts down the GeoIP2 database connection.
|
||||
func (p *provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideIPCountryBlocker creates a new IP country blocker provider instance.
|
||||
func ProvideIPCountryBlocker(cfg *config.Config, logger *zap.Logger) Provider {
|
||||
return NewProvider(cfg, logger)
|
||||
}
|
||||
221
cloud/maplepress-backend/pkg/security/ipcrypt/encryptor.go
Normal file
221
cloud/maplepress-backend/pkg/security/ipcrypt/encryptor.go
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// IPEncryptor provides secure IP address encryption for GDPR compliance
|
||||
// Uses AES-GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// Encrypts IP addresses before storage and provides expiration checking
|
||||
type IPEncryptor struct {
|
||||
gcm cipher.AEAD
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewIPEncryptor creates a new IP encryptor with the given encryption key
|
||||
// keyHex should be a 32-character hex string (16 bytes for AES-128)
|
||||
// or 64-character hex string (32 bytes for AES-256)
|
||||
// Example: "0123456789abcdef0123456789abcdef" (AES-128)
|
||||
// Recommended: Use AES-256 with 64-character hex key
|
||||
func NewIPEncryptor(keyHex string, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
// Decode hex key to bytes
|
||||
keyBytes, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid hex key: %w", err)
|
||||
}
|
||||
|
||||
// AES requires exactly 16, 24, or 32 bytes
|
||||
if len(keyBytes) != 16 && len(keyBytes) != 24 && len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("key must be 16, 24, or 32 bytes (32, 48, or 64 hex characters), got %d bytes", len(keyBytes))
|
||||
}
|
||||
|
||||
// Create AES cipher block
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// GCM provides both confidentiality and integrity
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("IP encryptor initialized with AES-GCM",
|
||||
zap.Int("key_length_bytes", len(keyBytes)),
|
||||
zap.Int("nonce_size", gcm.NonceSize()),
|
||||
zap.Int("overhead", gcm.Overhead()))
|
||||
|
||||
return &IPEncryptor{
|
||||
gcm: gcm,
|
||||
logger: logger.Named("ip-encryptor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts an IP address for secure storage using AES-GCM
|
||||
// Returns base64-encoded encrypted IP address with embedded nonce
|
||||
// Format: base64(nonce + ciphertext + auth_tag)
|
||||
// Supports both IPv4 and IPv6 addresses
|
||||
//
|
||||
// Security Properties:
|
||||
// - Semantic security: same IP address produces different ciphertext each time
|
||||
// - Authentication: tampering with ciphertext is detected
|
||||
// - Unique nonce per encryption prevents pattern analysis
|
||||
func (e *IPEncryptor) Encrypt(ipAddress string) (string, error) {
|
||||
if ipAddress == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Parse IP address to validate format
|
||||
ip := net.ParseIP(ipAddress)
|
||||
if ip == nil {
|
||||
e.logger.Warn("invalid IP address format",
|
||||
zap.String("ip", ipAddress))
|
||||
return "", fmt.Errorf("invalid IP address: %s", ipAddress)
|
||||
}
|
||||
|
||||
// Convert to 16-byte representation (IPv4 gets converted to IPv6 format)
|
||||
ipBytes := ip.To16()
|
||||
if ipBytes == nil {
|
||||
return "", fmt.Errorf("failed to convert IP to 16-byte format")
|
||||
}
|
||||
|
||||
// Generate a random nonce (number used once)
|
||||
// GCM requires a unique nonce for each encryption operation
|
||||
nonce := make([]byte, e.gcm.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
e.logger.Error("failed to generate nonce", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the IP bytes using AES-GCM
|
||||
// GCM appends the authentication tag to the ciphertext
|
||||
// nil additional data means no associated data
|
||||
ciphertext := e.gcm.Seal(nil, nonce, ipBytes, nil)
|
||||
|
||||
// Prepend nonce to ciphertext for storage
|
||||
// Format: nonce || ciphertext+tag
|
||||
encryptedData := append(nonce, ciphertext...)
|
||||
|
||||
// Encode to base64 for database storage (text-safe)
|
||||
encryptedBase64 := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
e.logger.Debug("IP address encrypted with AES-GCM",
|
||||
zap.Int("plaintext_length", len(ipBytes)),
|
||||
zap.Int("nonce_length", len(nonce)),
|
||||
zap.Int("ciphertext_length", len(ciphertext)),
|
||||
zap.Int("total_encrypted_length", len(encryptedData)),
|
||||
zap.Int("base64_length", len(encryptedBase64)))
|
||||
|
||||
return encryptedBase64, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts an encrypted IP address
|
||||
// Takes base64-encoded encrypted IP and returns original IP address string
|
||||
// Verifies authentication tag to detect tampering
|
||||
func (e *IPEncryptor) Decrypt(encryptedBase64 string) (string, error) {
|
||||
if encryptedBase64 == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Decode base64 to bytes
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
e.logger.Warn("invalid base64-encoded encrypted IP",
|
||||
zap.String("base64", encryptedBase64),
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
// Extract nonce from the beginning
|
||||
nonceSize := e.gcm.NonceSize()
|
||||
if len(encryptedData) < nonceSize {
|
||||
return "", fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", nonceSize, len(encryptedData))
|
||||
}
|
||||
|
||||
nonce := encryptedData[:nonceSize]
|
||||
ciphertext := encryptedData[nonceSize:]
|
||||
|
||||
// Decrypt and verify authentication tag using AES-GCM
|
||||
ipBytes, err := e.gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to decrypt IP address (authentication failed or corrupted data)",
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to IP address
|
||||
ip := net.IP(ipBytes)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("failed to parse decrypted IP bytes")
|
||||
}
|
||||
|
||||
// Convert to string
|
||||
ipString := ip.String()
|
||||
|
||||
e.logger.Debug("IP address decrypted with AES-GCM",
|
||||
zap.Int("encrypted_length", len(encryptedData)),
|
||||
zap.Int("decrypted_length", len(ipBytes)))
|
||||
|
||||
return ipString, nil
|
||||
}
|
||||
|
||||
// IsExpired checks if an IP address timestamp has expired (> 90 days old)
|
||||
// GDPR compliance: IP addresses must be deleted after 90 days
|
||||
func (e *IPEncryptor) IsExpired(timestamp time.Time) bool {
|
||||
if timestamp.IsZero() {
|
||||
return false // No timestamp means not expired (will be cleaned up later)
|
||||
}
|
||||
|
||||
// Calculate age in days
|
||||
age := time.Since(timestamp)
|
||||
ageInDays := int(age.Hours() / 24)
|
||||
|
||||
expired := ageInDays > 90
|
||||
|
||||
if expired {
|
||||
e.logger.Debug("IP timestamp expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Int("age_days", ageInDays))
|
||||
}
|
||||
|
||||
return expired
|
||||
}
|
||||
|
||||
// ShouldCleanup checks if an IP address should be cleaned up based on timestamp
|
||||
// Returns true if timestamp is older than 90 days OR if timestamp is zero (unset)
|
||||
func (e *IPEncryptor) ShouldCleanup(timestamp time.Time) bool {
|
||||
// Always cleanup if timestamp is not set (backwards compatibility)
|
||||
if timestamp.IsZero() {
|
||||
return false // Don't cleanup unset timestamps immediately
|
||||
}
|
||||
|
||||
return e.IsExpired(timestamp)
|
||||
}
|
||||
|
||||
// ValidateKey validates that a key is properly formatted for IP encryption
|
||||
// Returns true if key is valid 32-character hex string (AES-128) or 64-character (AES-256)
|
||||
func ValidateKey(keyHex string) error {
|
||||
// Check length (must be 16, 24, or 32 bytes = 32, 48, or 64 hex chars)
|
||||
if len(keyHex) != 32 && len(keyHex) != 48 && len(keyHex) != 64 {
|
||||
return fmt.Errorf("key must be 32, 48, or 64 hex characters, got %d characters", len(keyHex))
|
||||
}
|
||||
|
||||
// Check if valid hex
|
||||
_, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key must be valid hex string: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
cloud/maplepress-backend/pkg/security/ipcrypt/provider.go
Normal file
13
cloud/maplepress-backend/pkg/security/ipcrypt/provider.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideIPEncryptor provides an IP encryptor instance
|
||||
// CWE-359: GDPR compliance for IP address storage
|
||||
func ProvideIPEncryptor(cfg *config.Config, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
return NewIPEncryptor(cfg.Security.IPEncryptionKey, logger)
|
||||
}
|
||||
110
cloud/maplepress-backend/pkg/security/jwt/jwt.go
Normal file
110
cloud/maplepress-backend/pkg/security/jwt/jwt.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/validator"
|
||||
)
|
||||
|
||||
// Provider provides interface for JWT token generation and validation
|
||||
type Provider interface {
|
||||
GenerateToken(sessionID string, duration time.Duration) (string, time.Time, error)
|
||||
GenerateTokenPair(sessionID string, accessDuration time.Duration, refreshDuration time.Duration) (accessToken string, accessExpiry time.Time, refreshToken string, refreshExpiry time.Time, err error)
|
||||
ValidateToken(tokenString string) (sessionID string, err error)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
// NewProvider creates a new JWT provider with security validation
|
||||
func NewProvider(cfg *config.Config) Provider {
|
||||
// Validate JWT secret security before creating provider
|
||||
v := validator.NewCredentialValidator()
|
||||
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
|
||||
// Log detailed error with remediation steps
|
||||
log.Printf("[SECURITY ERROR] %s", err.Error())
|
||||
|
||||
// In production, this is a fatal error that should prevent startup
|
||||
if cfg.App.Environment == "production" {
|
||||
panic(fmt.Sprintf("SECURITY: Invalid JWT secret in production environment: %s", err.Error()))
|
||||
}
|
||||
|
||||
// In development, log warning but allow to continue
|
||||
log.Printf("[WARNING] Continuing with weak JWT secret in %s environment. This is NOT safe for production!", cfg.App.Environment)
|
||||
}
|
||||
|
||||
return &provider{
|
||||
secret: []byte(cfg.App.JWTSecret),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates a single JWT token
|
||||
func (p *provider) GenerateToken(sessionID string, duration time.Duration) (string, time.Time, error) {
|
||||
expiresAt := time.Now().Add(duration)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"session_id": sessionID,
|
||||
"exp": expiresAt.Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString(p.secret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, expiresAt, nil
|
||||
}
|
||||
|
||||
// GenerateTokenPair generates both access token and refresh token
|
||||
func (p *provider) GenerateTokenPair(sessionID string, accessDuration time.Duration, refreshDuration time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
// Generate access token
|
||||
accessToken, accessExpiry, err := p.GenerateToken(sessionID, accessDuration)
|
||||
if err != nil {
|
||||
return "", time.Time{}, "", time.Time{}, fmt.Errorf("failed to generate access token: %w", err)
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
refreshToken, refreshExpiry, err := p.GenerateToken(sessionID, refreshDuration)
|
||||
if err != nil {
|
||||
return "", time.Time{}, "", time.Time{}, fmt.Errorf("failed to generate refresh token: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, accessExpiry, refreshToken, refreshExpiry, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token and returns the session ID
|
||||
func (p *provider) ValidateToken(tokenString string) (string, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return p.secret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return "", fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
sessionID, ok := claims["session_id"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("session_id not found in token")
|
||||
}
|
||||
|
||||
return sessionID, nil
|
||||
}
|
||||
10
cloud/maplepress-backend/pkg/security/jwt/provider.go
Normal file
10
cloud/maplepress-backend/pkg/security/jwt/provider.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideProvider provides a JWT provider instance for Wire dependency injection
|
||||
func ProvideProvider(cfg *config.Config) Provider {
|
||||
return NewProvider(cfg)
|
||||
}
|
||||
149
cloud/maplepress-backend/pkg/security/password/breachcheck.go
Normal file
149
cloud/maplepress-backend/pkg/security/password/breachcheck.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/breachcheck.go
|
||||
package password
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPasswordBreached indicates the password has been found in known data breaches
|
||||
ErrPasswordBreached = fmt.Errorf("password has been found in data breaches")
|
||||
)
|
||||
|
||||
// BreachChecker checks if passwords have been compromised in known data breaches
|
||||
// using the Have I Been Pwned API's k-anonymity model
|
||||
type BreachChecker interface {
|
||||
// CheckPassword checks if a password has been breached
|
||||
// Returns the number of times the password was found in breaches (0 = safe)
|
||||
CheckPassword(ctx context.Context, password string) (int, error)
|
||||
|
||||
// IsPasswordBreached returns true if password has been found in breaches
|
||||
IsPasswordBreached(ctx context.Context, password string) (bool, error)
|
||||
}
|
||||
|
||||
type breachChecker struct {
|
||||
httpClient *http.Client
|
||||
apiURL string
|
||||
userAgent string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewBreachChecker creates a new password breach checker
|
||||
// CWE-521: Password breach checking using Have I Been Pwned API
|
||||
// Uses k-anonymity model - only sends first 5 characters of SHA-1 hash
|
||||
func NewBreachChecker(logger *zap.Logger) BreachChecker {
|
||||
return &breachChecker{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
apiURL: "https://api.pwnedpasswords.com/range/",
|
||||
userAgent: "MaplePress-Backend-Password-Checker",
|
||||
logger: logger.Named("breach-checker"),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckPassword checks if a password has been breached using HIBP k-anonymity API
|
||||
// Returns the number of times the password appears in breaches (0 = safe)
|
||||
// CWE-521: This implements password breach checking without sending the full password
|
||||
func (bc *breachChecker) CheckPassword(ctx context.Context, password string) (int, error) {
|
||||
// Step 1: SHA-1 hash the password
|
||||
hash := sha1.Sum([]byte(password))
|
||||
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
|
||||
|
||||
// Step 2: Take first 5 characters (k-anonymity prefix)
|
||||
prefix := hashStr[:5]
|
||||
suffix := hashStr[5:]
|
||||
|
||||
// Step 3: Query HIBP API with prefix only
|
||||
url := bc.apiURL + prefix
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
bc.logger.Error("failed to create HIBP request", zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set User-Agent as required by HIBP API
|
||||
req.Header.Set("User-Agent", bc.userAgent)
|
||||
req.Header.Set("Add-Padding", "true") // Request padding for additional privacy
|
||||
|
||||
bc.logger.Debug("checking password against HIBP",
|
||||
zap.String("prefix", prefix))
|
||||
|
||||
resp, err := bc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
bc.logger.Error("failed to query HIBP API", zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to query breach database: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
bc.logger.Error("received nil response from HIBP API")
|
||||
return 0, fmt.Errorf("received nil response from breach database")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bc.logger.Error("HIBP API returned non-OK status",
|
||||
zap.Int("status", resp.StatusCode))
|
||||
return 0, fmt.Errorf("breach database returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Step 4: Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
bc.logger.Error("failed to read HIBP response", zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Step 5: Parse response and look for our suffix
|
||||
// Response format: SUFFIX:COUNT\r\n for each hash
|
||||
lines := strings.Split(string(body), "\r\n")
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is our hash
|
||||
if parts[0] == suffix {
|
||||
count, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
bc.logger.Warn("failed to parse breach count",
|
||||
zap.String("line", line),
|
||||
zap.Error(err))
|
||||
return 0, fmt.Errorf("failed to parse breach count: %w", err)
|
||||
}
|
||||
|
||||
bc.logger.Warn("password found in data breaches",
|
||||
zap.Int("breach_count", count))
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Password not found in breaches
|
||||
bc.logger.Debug("password not found in breaches")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// IsPasswordBreached returns true if password has been found in data breaches
|
||||
// This is a convenience wrapper around CheckPassword
|
||||
func (bc *breachChecker) IsPasswordBreached(ctx context.Context, password string) (bool, error) {
|
||||
count, err := bc.CheckPassword(ctx, password)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
200
cloud/maplepress-backend/pkg/security/password/password.go
Normal file
200
cloud/maplepress-backend/pkg/security/password/password.go
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/password.go
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
|
||||
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
|
||||
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
|
||||
ErrPasswordTooLong = errors.New("password must not exceed 128 characters")
|
||||
|
||||
// Granular password strength errors (CWE-521: Weak Password Requirements)
|
||||
ErrPasswordNoUppercase = errors.New("password must contain at least one uppercase letter (A-Z)")
|
||||
ErrPasswordNoLowercase = errors.New("password must contain at least one lowercase letter (a-z)")
|
||||
ErrPasswordNoNumber = errors.New("password must contain at least one number (0-9)")
|
||||
ErrPasswordNoSpecialChar = errors.New("password must contain at least one special character (!@#$%^&*()_+-=[]{}; etc.)")
|
||||
ErrPasswordTooWeak = errors.New("password must contain uppercase, lowercase, number, and special character")
|
||||
)
|
||||
|
||||
// PasswordProvider provides secure password hashing and verification using Argon2id.
|
||||
type PasswordProvider interface {
|
||||
GenerateHashFromPassword(password *securestring.SecureString) (string, error)
|
||||
ComparePasswordAndHash(password *securestring.SecureString, hash string) (bool, error)
|
||||
AlgorithmName() string
|
||||
GenerateSecureRandomBytes(length int) ([]byte, error)
|
||||
GenerateSecureRandomString(length int) (string, error)
|
||||
}
|
||||
|
||||
type passwordProvider struct {
|
||||
memory uint32
|
||||
iterations uint32
|
||||
parallelism uint8
|
||||
saltLength uint32
|
||||
keyLength uint32
|
||||
}
|
||||
|
||||
// NewPasswordProvider creates a new password provider with secure default parameters.
|
||||
// The default parameters are based on OWASP recommendations for Argon2id:
|
||||
// - Memory: 64 MB
|
||||
// - Iterations: 3
|
||||
// - Parallelism: 2
|
||||
// - Salt length: 16 bytes
|
||||
// - Key length: 32 bytes
|
||||
func NewPasswordProvider() PasswordProvider {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
|
||||
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Establish the parameters to use for Argon2
|
||||
return &passwordProvider{
|
||||
memory: 64 * 1024, // 64 MB
|
||||
iterations: 3,
|
||||
parallelism: 2,
|
||||
saltLength: 16,
|
||||
keyLength: 32,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword takes a secure string and returns an Argon2id hashed string.
|
||||
// The returned hash string includes all parameters needed for verification:
|
||||
// Format: $argon2id$v=19$m=65536,t=3,p=2$<base64-salt>$<base64-hash>
|
||||
func (p *passwordProvider) GenerateHashFromPassword(password *securestring.SecureString) (string, error) {
|
||||
salt, err := generateRandomBytes(p.saltLength)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
passwordBytes := password.Bytes()
|
||||
|
||||
// Generate the hash using Argon2id
|
||||
hash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
|
||||
// Base64 encode the salt and hashed password
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return a string using the standard encoded hash representation
|
||||
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
|
||||
argon2.Version, p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
|
||||
|
||||
return encodedHash, nil
|
||||
}
|
||||
|
||||
// ComparePasswordAndHash verifies that a password matches the provided hash.
|
||||
// It uses constant-time comparison to prevent timing attacks.
|
||||
// Returns true if the password matches, false otherwise.
|
||||
func (p *passwordProvider) ComparePasswordAndHash(password *securestring.SecureString, encodedHash string) (match bool, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
|
||||
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Extract the parameters, salt and derived key from the encoded password hash
|
||||
params, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Derive the key from the password using the same parameters
|
||||
otherHash := argon2.IDKey(password.Bytes(), salt, params.iterations, params.memory, params.parallelism, params.keyLength)
|
||||
|
||||
// Check that the contents of the hashed passwords are identical
|
||||
// Using subtle.ConstantTimeCompare() to help prevent timing attacks
|
||||
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AlgorithmName returns the name of the hashing algorithm used.
|
||||
func (p *passwordProvider) AlgorithmName() string {
|
||||
return "argon2id"
|
||||
}
|
||||
|
||||
// GenerateSecureRandomBytes generates a cryptographically secure random byte slice.
|
||||
func (p *passwordProvider) GenerateSecureRandomBytes(length int) ([]byte, error) {
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate secure random bytes: %w", err)
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
// GenerateSecureRandomString generates a cryptographically secure random hex string.
|
||||
// The returned string will be twice the length parameter (2 hex chars per byte).
|
||||
func (p *passwordProvider) GenerateSecureRandomString(length int) (string, error) {
|
||||
bytes, err := p.GenerateSecureRandomBytes(length)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateRandomBytes generates cryptographically secure random bytes.
|
||||
func generateRandomBytes(n uint32) ([]byte, error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
|
||||
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// decodeHash extracts the parameters, salt, and hash from an encoded hash string.
|
||||
func decodeHash(encodedHash string) (p *passwordProvider, salt, hash []byte, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
|
||||
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
vals := strings.Split(encodedHash, "$")
|
||||
if len(vals) != 6 {
|
||||
return nil, nil, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, ErrIncompatibleVersion
|
||||
}
|
||||
|
||||
p = &passwordProvider{}
|
||||
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.saltLength = uint32(len(salt))
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.keyLength = uint32(len(hash))
|
||||
|
||||
return p, salt, hash, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package password
|
||||
|
||||
// ProvidePasswordProvider creates a new password provider instance.
|
||||
func ProvidePasswordProvider() PasswordProvider {
|
||||
return NewPasswordProvider()
|
||||
}
|
||||
44
cloud/maplepress-backend/pkg/security/password/timing.go
Normal file
44
cloud/maplepress-backend/pkg/security/password/timing.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/timing.go
|
||||
package password
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// DummyPasswordHash is a pre-computed valid Argon2id hash used for timing attack mitigation
|
||||
// This hash is computed with the same parameters as real password hashes
|
||||
// CWE-208: Observable Timing Discrepancy - Prevents user enumeration via timing attacks
|
||||
const DummyPasswordHash = "$argon2id$v=19$m=65536,t=3,p=2$c29tZXJhbmRvbXNhbHQxMjM0$kixiIQQ/y8E7dSH0j8p8KPBUlCMUGQOvH2kP7XYPkVs"
|
||||
|
||||
// ComparePasswordWithDummy performs password comparison but always uses a dummy hash
|
||||
// This is used when a user doesn't exist to maintain constant time behavior
|
||||
// CWE-208: Observable Timing Discrepancy - Mitigates timing-based user enumeration
|
||||
func (p *passwordProvider) ComparePasswordWithDummy(password *securestring.SecureString) error {
|
||||
// Perform the same expensive operation (Argon2 hashing) even for non-existent users
|
||||
// This ensures the timing is constant regardless of whether the user exists
|
||||
_, _ = p.ComparePasswordAndHash(password, DummyPasswordHash)
|
||||
|
||||
// Always return false (user doesn't exist, so authentication always fails)
|
||||
// The important part is that we spent the same amount of time
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimingSafeCompare performs a timing-safe password comparison
|
||||
// It always performs the password hashing operation regardless of whether
|
||||
// the user exists or the password matches
|
||||
// CWE-208: Observable Timing Discrepancy - Prevents timing attacks
|
||||
func TimingSafeCompare(provider PasswordProvider, password *securestring.SecureString, hash string, userExists bool) (bool, error) {
|
||||
if !userExists {
|
||||
// User doesn't exist - perform dummy hash comparison to maintain constant time
|
||||
if pp, ok := provider.(*passwordProvider); ok {
|
||||
_ = pp.ComparePasswordWithDummy(password)
|
||||
} else {
|
||||
// Fallback if type assertion fails
|
||||
_, _ = provider.ComparePasswordAndHash(password, DummyPasswordHash)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// User exists - perform real comparison
|
||||
return provider.ComparePasswordAndHash(password, hash)
|
||||
}
|
||||
90
cloud/maplepress-backend/pkg/security/password/validator.go
Normal file
90
cloud/maplepress-backend/pkg/security/password/validator.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
package password
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinPasswordLength is the minimum required password length
|
||||
MinPasswordLength = 8
|
||||
// MaxPasswordLength is the maximum allowed password length
|
||||
MaxPasswordLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
// Special characters allowed in passwords
|
||||
specialCharRegex = regexp.MustCompile(`[!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]`)
|
||||
)
|
||||
|
||||
// PasswordValidator provides password strength validation
|
||||
type PasswordValidator interface {
|
||||
ValidatePasswordStrength(password string) error
|
||||
}
|
||||
|
||||
type passwordValidator struct{}
|
||||
|
||||
// NewPasswordValidator creates a new password validator
|
||||
func NewPasswordValidator() PasswordValidator {
|
||||
return &passwordValidator{}
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength validates that a password meets strength requirements
|
||||
// Requirements:
|
||||
// - At least 8 characters long
|
||||
// - At most 128 characters long
|
||||
// - Contains at least one uppercase letter
|
||||
// - Contains at least one lowercase letter
|
||||
// - Contains at least one digit
|
||||
// - Contains at least one special character
|
||||
//
|
||||
// CWE-521: Returns granular error messages to help users create strong passwords
|
||||
func (v *passwordValidator) ValidatePasswordStrength(password string) error {
|
||||
// Check length first
|
||||
if len(password) < MinPasswordLength {
|
||||
return ErrPasswordTooShort
|
||||
}
|
||||
|
||||
if len(password) > MaxPasswordLength {
|
||||
return ErrPasswordTooLong
|
||||
}
|
||||
|
||||
// Check character type requirements
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasNumber bool
|
||||
hasSpecial bool
|
||||
)
|
||||
|
||||
for _, char := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsNumber(char):
|
||||
hasNumber = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for special characters
|
||||
hasSpecial = specialCharRegex.MatchString(password)
|
||||
|
||||
// Return granular error for the first missing requirement
|
||||
// This provides specific feedback to users about what's missing
|
||||
if !hasUpper {
|
||||
return ErrPasswordNoUppercase
|
||||
}
|
||||
if !hasLower {
|
||||
return ErrPasswordNoLowercase
|
||||
}
|
||||
if !hasNumber {
|
||||
return ErrPasswordNoNumber
|
||||
}
|
||||
if !hasSpecial {
|
||||
return ErrPasswordNoSpecialChar
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
20
cloud/maplepress-backend/pkg/security/provider.go
Normal file
20
cloud/maplepress-backend/pkg/security/provider.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package security
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
|
||||
)
|
||||
|
||||
// ProvideJWTProvider provides a JWT provider instance
|
||||
func ProvideJWTProvider(cfg *config.Config) jwt.Provider {
|
||||
return jwt.NewProvider(cfg)
|
||||
}
|
||||
|
||||
// ProvideClientIPExtractor provides a client IP extractor instance
|
||||
// CWE-348: Secure IP extraction with X-Forwarded-For validation
|
||||
func ProvideClientIPExtractor(cfg *config.Config, logger *zap.Logger) (*clientip.Extractor, error) {
|
||||
return clientip.ProvideExtractor(cfg, logger)
|
||||
}
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
package securebytes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureBytes is used to store a byte slice securely in memory.
|
||||
// It uses memguard to protect sensitive data from being exposed in memory dumps,
|
||||
// swap files, or other memory scanning attacks.
|
||||
type SecureBytes struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureBytes creates a new SecureBytes instance from the given byte slice.
|
||||
// The original byte slice should be wiped after creating SecureBytes to ensure
|
||||
// the sensitive data is only stored in the secure buffer.
|
||||
func NewSecureBytes(b []byte) (*SecureBytes, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, errors.New("byte slice cannot be empty")
|
||||
}
|
||||
|
||||
buffer := memguard.NewBuffer(len(b))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
copy(buffer.Bytes(), b)
|
||||
|
||||
return &SecureBytes{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// Bytes returns the securely stored byte slice.
|
||||
// WARNING: The returned bytes are still protected by memguard, but any copies
|
||||
// made from this slice will not be protected. Use with caution.
|
||||
func (sb *SecureBytes) Bytes() []byte {
|
||||
return sb.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the byte slice from memory and makes it unrecoverable.
|
||||
// After calling Wipe, the SecureBytes instance should not be used.
|
||||
func (sb *SecureBytes) Wipe() error {
|
||||
sb.buffer.Wipe()
|
||||
sb.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
package securestring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureString is used to store a string securely in memory.
|
||||
// It uses memguard to protect sensitive data like passwords, API keys, etc.
|
||||
// from being exposed in memory dumps, swap files, or other memory scanning attacks.
|
||||
type SecureString struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureString creates a new SecureString instance from the given string.
|
||||
// The original string should be cleared/wiped after creating SecureString to ensure
|
||||
// the sensitive data is only stored in the secure buffer.
|
||||
func NewSecureString(s string) (*SecureString, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, errors.New("string cannot be empty")
|
||||
}
|
||||
|
||||
// Use memguard's built-in method for creating from bytes
|
||||
buffer := memguard.NewBufferFromBytes([]byte(s))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
return &SecureString{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// String returns the securely stored string.
|
||||
// WARNING: The returned string is a copy and will not be protected by memguard.
|
||||
// Use this method carefully and wipe the string after use if possible.
|
||||
func (ss *SecureString) String() string {
|
||||
if ss.buffer == nil {
|
||||
return ""
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
return ""
|
||||
}
|
||||
return ss.buffer.String()
|
||||
}
|
||||
|
||||
// Bytes returns the byte representation of the securely stored string.
|
||||
// WARNING: The returned bytes are still protected by memguard, but any copies
|
||||
// made from this slice will not be protected. Use with caution.
|
||||
func (ss *SecureString) Bytes() []byte {
|
||||
if ss.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return ss.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the string from memory and makes it unrecoverable.
|
||||
// After calling Wipe, the SecureString instance should not be used.
|
||||
func (ss *SecureString) Wipe() error {
|
||||
if ss.buffer != nil {
|
||||
if ss.buffer.IsAlive() {
|
||||
ss.buffer.Destroy()
|
||||
}
|
||||
}
|
||||
ss.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinJWTSecretLength is the minimum required length for JWT secrets (256 bits)
|
||||
MinJWTSecretLength = 32
|
||||
|
||||
// RecommendedJWTSecretLength is the recommended length for JWT secrets (512 bits)
|
||||
RecommendedJWTSecretLength = 64
|
||||
|
||||
// MinEntropyBits is the minimum Shannon entropy in bits per character
|
||||
// For reference: random base64 has ~6 bits/char, we require minimum 4.0
|
||||
MinEntropyBits = 4.0
|
||||
|
||||
// MinProductionEntropyBits is the minimum entropy required for production
|
||||
MinProductionEntropyBits = 4.5
|
||||
|
||||
// MaxRepeatingCharacters is the maximum allowed consecutive repeating characters
|
||||
MaxRepeatingCharacters = 3
|
||||
)
|
||||
|
||||
// WeakSecrets contains common weak/default secrets that should never be used
|
||||
var WeakSecrets = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"changeme",
|
||||
"change-me",
|
||||
"change_me",
|
||||
"12345",
|
||||
"123456",
|
||||
"1234567",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"default",
|
||||
"test",
|
||||
"testing",
|
||||
"admin",
|
||||
"administrator",
|
||||
"root",
|
||||
"qwerty",
|
||||
"qwertyuiop",
|
||||
"letmein",
|
||||
"welcome",
|
||||
"monkey",
|
||||
"dragon",
|
||||
"master",
|
||||
"sunshine",
|
||||
"princess",
|
||||
"football",
|
||||
"starwars",
|
||||
"baseball",
|
||||
"superman",
|
||||
"iloveyou",
|
||||
"trustno1",
|
||||
"hello",
|
||||
"abc123",
|
||||
"password123",
|
||||
"admin123",
|
||||
"guest",
|
||||
"user",
|
||||
"demo",
|
||||
"sample",
|
||||
"example",
|
||||
}
|
||||
|
||||
// DangerousPatterns contains patterns that indicate a secret should be changed
|
||||
var DangerousPatterns = []string{
|
||||
"change",
|
||||
"replace",
|
||||
"update",
|
||||
"modify",
|
||||
"sample",
|
||||
"example",
|
||||
"todo",
|
||||
"fixme",
|
||||
"temp",
|
||||
"temporary",
|
||||
}
|
||||
|
||||
// CredentialValidator validates credentials and secrets for security issues
|
||||
type CredentialValidator interface {
|
||||
ValidateJWTSecret(secret string, environment string) error
|
||||
ValidateAllCredentials(cfg *config.Config) error
|
||||
}
|
||||
|
||||
type credentialValidator struct{}
|
||||
|
||||
// NewCredentialValidator creates a new credential validator
|
||||
func NewCredentialValidator() CredentialValidator {
|
||||
return &credentialValidator{}
|
||||
}
|
||||
|
||||
// ValidateJWTSecret validates JWT secret strength and security
|
||||
// CWE-798: Comprehensive validation to prevent hard-coded/weak credentials
|
||||
func (v *credentialValidator) ValidateJWTSecret(secret string, environment string) error {
|
||||
// Check minimum length
|
||||
if len(secret) < MinJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short (%d characters). Minimum required: %d characters (256 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
MinJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for common weak secrets (case-insensitive)
|
||||
secretLower := strings.ToLower(secret)
|
||||
for _, weak := range WeakSecrets {
|
||||
if secretLower == weak || strings.Contains(secretLower, weak) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret cannot contain common weak value: '%s'. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
weak,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for dangerous patterns indicating default/placeholder values
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(secretLower, pattern) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains suspicious pattern '%s' which suggests it's a placeholder. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
pattern,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for repeating character patterns (e.g., "aaaa", "1111")
|
||||
if err := checkRepeatingPatterns(secret); err != nil {
|
||||
return fmt.Errorf(
|
||||
"JWT secret validation failed: %s. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sequential patterns (e.g., "abcd", "1234")
|
||||
if hasSequentialPattern(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains sequential patterns (e.g., 'abcd', '1234') which reduces entropy. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Calculate Shannon entropy
|
||||
entropy := calculateShannonEntropy(secret)
|
||||
minEntropy := MinEntropyBits
|
||||
if environment == "production" {
|
||||
minEntropy = MinProductionEntropyBits
|
||||
}
|
||||
|
||||
if entropy < minEntropy {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient entropy: %.2f bits/char (minimum: %.1f bits/char for %s). "+
|
||||
"The secret appears to have low randomness. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
entropy,
|
||||
minEntropy,
|
||||
environment,
|
||||
)
|
||||
}
|
||||
|
||||
// In production, enforce stricter requirements
|
||||
if environment == "production" {
|
||||
// Check recommended length for production
|
||||
if len(secret) < RecommendedJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short for production environment (%d characters). "+
|
||||
"Recommended: %d characters (512 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
RecommendedJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sufficient character complexity
|
||||
if !hasSufficientComplexity(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient complexity for production. It should contain a mix of uppercase, lowercase, " +
|
||||
"digits, and special characters (at least 3 types). Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Validate base64-like characteristics (recommended generation method)
|
||||
if !looksLikeBase64(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret does not appear to be randomly generated (expected base64-like characteristics). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAllCredentials validates all credentials in the configuration
|
||||
func (v *credentialValidator) ValidateAllCredentials(cfg *config.Config) error {
|
||||
var errors []string
|
||||
|
||||
// Validate JWT Secret
|
||||
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("JWT Secret validation failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
// In production, ensure other critical configs are not using defaults/placeholders
|
||||
if cfg.App.Environment == "production" {
|
||||
// Check Meilisearch API key
|
||||
if cfg.Meilisearch.APIKey == "" {
|
||||
errors = append(errors, "Meilisearch API key must be set in production")
|
||||
} else if containsDangerousPattern(cfg.Meilisearch.APIKey) {
|
||||
errors = append(errors, "Meilisearch API key appears to be a placeholder/default value")
|
||||
}
|
||||
|
||||
// Check database hosts are not using localhost
|
||||
for _, host := range cfg.Database.Hosts {
|
||||
if strings.Contains(strings.ToLower(host), "localhost") || host == "127.0.0.1" {
|
||||
errors = append(errors, "Database hosts should not use localhost in production")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache host is not localhost
|
||||
if strings.Contains(strings.ToLower(cfg.Cache.Host), "localhost") || cfg.Cache.Host == "127.0.0.1" {
|
||||
errors = append(errors, "Cache host should not use localhost in production")
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("credential validation failed:\n - %s", strings.Join(errors, "\n - "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShannonEntropy calculates the Shannon entropy of a string in bits per character
|
||||
// Shannon entropy measures the randomness/unpredictability of data
|
||||
// Formula: H(X) = -Σ(p(x) * log2(p(x))) where p(x) is the probability of character x
|
||||
func calculateShannonEntropy(s string) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Count character frequencies
|
||||
frequencies := make(map[rune]int)
|
||||
for _, char := range s {
|
||||
frequencies[char]++
|
||||
}
|
||||
|
||||
// Calculate entropy
|
||||
var entropy float64
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range frequencies {
|
||||
probability := float64(count) / length
|
||||
entropy -= probability * math.Log2(probability)
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// hasSufficientComplexity checks if the secret has a good mix of character types
|
||||
// Requires at least 3 out of 4 character types for production
|
||||
func hasSufficientComplexity(secret string) bool {
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
hasSpecial bool
|
||||
)
|
||||
|
||||
for _, char := range secret {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(char):
|
||||
hasDigit = true
|
||||
default:
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
// Require at least 3 out of 4 character types
|
||||
count := 0
|
||||
if hasUpper {
|
||||
count++
|
||||
}
|
||||
if hasLower {
|
||||
count++
|
||||
}
|
||||
if hasDigit {
|
||||
count++
|
||||
}
|
||||
if hasSpecial {
|
||||
count++
|
||||
}
|
||||
|
||||
return count >= 3
|
||||
}
|
||||
|
||||
// checkRepeatingPatterns checks for excessive repeating characters
|
||||
func checkRepeatingPatterns(s string) error {
|
||||
if len(s) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repeatCount := 1
|
||||
lastChar := rune(s[0])
|
||||
|
||||
for _, char := range s[1:] {
|
||||
if char == lastChar {
|
||||
repeatCount++
|
||||
if repeatCount > MaxRepeatingCharacters {
|
||||
return fmt.Errorf(
|
||||
"contains %d consecutive repeating characters ('%c'), maximum allowed: %d",
|
||||
repeatCount,
|
||||
lastChar,
|
||||
MaxRepeatingCharacters,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
repeatCount = 1
|
||||
lastChar = char
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasSequentialPattern detects common sequential patterns
|
||||
func hasSequentialPattern(s string) bool {
|
||||
if len(s) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for at least 4 consecutive sequential characters
|
||||
for i := 0; i < len(s)-3; i++ {
|
||||
// Check ascending sequence (e.g., "abcd", "1234")
|
||||
if s[i+1] == s[i]+1 && s[i+2] == s[i]+2 && s[i+3] == s[i]+3 {
|
||||
return true
|
||||
}
|
||||
// Check descending sequence (e.g., "dcba", "4321")
|
||||
if s[i+1] == s[i]-1 && s[i+2] == s[i]-2 && s[i+3] == s[i]-3 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// looksLikeBase64 checks if the string has base64-like characteristics
|
||||
// Base64 uses: A-Z, a-z, 0-9, +, /, and = for padding
|
||||
func looksLikeBase64(s string) bool {
|
||||
if len(s) < MinJWTSecretLength {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
validChars int
|
||||
)
|
||||
|
||||
// Base64 valid characters
|
||||
for _, char := range s {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
validChars++
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
validChars++
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
validChars++
|
||||
case char == '+' || char == '/' || char == '=' || char == '-' || char == '_':
|
||||
validChars++
|
||||
default:
|
||||
// Invalid character for base64
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Should have good mix of character types typical of base64
|
||||
charTypesCount := 0
|
||||
if hasUpper {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasLower {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasDigit {
|
||||
charTypesCount++
|
||||
}
|
||||
|
||||
// Base64 typically has at least uppercase, lowercase, and digits
|
||||
// Also check that it doesn't look like a repeated pattern
|
||||
if charTypesCount < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for repeated patterns (e.g., "AbCd12!@" repeated)
|
||||
// If the string has low unique character count relative to its length, it's probably not random
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range s {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
|
||||
// Random base64 should have at least 50% unique characters for strings over 32 chars
|
||||
uniqueRatio := float64(len(uniqueChars)) / float64(len(s))
|
||||
return uniqueRatio >= 0.4 // At least 40% unique characters
|
||||
}
|
||||
|
||||
// containsDangerousPattern checks if a string contains any dangerous patterns
|
||||
func containsDangerousPattern(value string) bool {
|
||||
valueLower := strings.ToLower(value)
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(valueLower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Simplified comprehensive test for JWT secret validation
|
||||
func TestJWTSecretValidation(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
// Good secrets - these should pass
|
||||
goodSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
}{
|
||||
{
|
||||
name: "Good 32-char for dev",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "development",
|
||||
},
|
||||
{
|
||||
name: "Good 64-char for prod",
|
||||
secret: "1WDduocStecRuIv+Us1t/RnYDoW1ZcEEbU+H+WykJG+IT5WnijzBb8uUPzGKju+D",
|
||||
env: "production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range goodSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for valid secret, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Bad secrets - these should fail
|
||||
badSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
mustContain string
|
||||
}{
|
||||
{
|
||||
name: "Too short",
|
||||
secret: "short",
|
||||
env: "development",
|
||||
mustContain: "too short",
|
||||
},
|
||||
{
|
||||
name: "Common weak - password",
|
||||
secret: "password-is-not-secure-but-32char",
|
||||
env: "development",
|
||||
mustContain: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern",
|
||||
secret: "please-change-this-ima7xR+9nT0Yz",
|
||||
env: "development",
|
||||
mustContain: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
env: "development",
|
||||
mustContain: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
env: "development",
|
||||
mustContain: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy",
|
||||
secret: "abababababababababababababababab",
|
||||
env: "development",
|
||||
mustContain: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Prod too short",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "production",
|
||||
mustContain: "too short for production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range badSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got no error", tt.mustContain)
|
||||
} else if !contains(err.Error(), tt.mustContain) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.mustContain, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,535 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCalculateShannonEntropy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
minBits float64
|
||||
maxBits float64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have 0 entropy",
|
||||
},
|
||||
{
|
||||
name: "All same character",
|
||||
input: "aaaaaaaaaa",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have very low entropy",
|
||||
},
|
||||
{
|
||||
name: "Low entropy - repeated pattern",
|
||||
input: "abcabcabcabc",
|
||||
minBits: 1.5,
|
||||
maxBits: 2.0,
|
||||
expected: "should have low entropy",
|
||||
},
|
||||
{
|
||||
name: "Medium entropy - simple password",
|
||||
input: "Password123",
|
||||
minBits: 3.0,
|
||||
maxBits: 4.5,
|
||||
expected: "should have medium entropy",
|
||||
},
|
||||
{
|
||||
name: "High entropy - random base64",
|
||||
input: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
minBits: 4.0,
|
||||
maxBits: 6.0,
|
||||
expected: "should have high entropy",
|
||||
},
|
||||
{
|
||||
name: "Very high entropy - long random base64",
|
||||
input: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
minBits: 4.5,
|
||||
maxBits: 6.5,
|
||||
expected: "should have very high entropy",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
entropy := calculateShannonEntropy(tt.input)
|
||||
if entropy < tt.minBits || entropy > tt.maxBits {
|
||||
t.Errorf("%s: got %.2f bits/char, expected between %.1f and %.1f", tt.expected, entropy, tt.minBits, tt.maxBits)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSufficientComplexity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnop",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only uppercase",
|
||||
input: "ABCDEFGHIJKLMNOP",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only digits",
|
||||
input: "1234567890",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase",
|
||||
input: "AbCdEfGhIjKl",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits",
|
||||
input: "abc123def456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Uppercase + digits",
|
||||
input: "ABC123DEF456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + digits",
|
||||
input: "Abc123Def456",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + special",
|
||||
input: "AbC+DeF/GhI=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits + special",
|
||||
input: "abc123+def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "All four types",
|
||||
input: "Abc123+Def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 string",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b+xK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSufficientComplexity(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSufficientComplexity(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRepeatingPatterns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Single character",
|
||||
input: "a",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "No repeating",
|
||||
input: "abcdefgh",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Two repeating (ok)",
|
||||
input: "aabcdeef",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Three repeating (ok)",
|
||||
input: "aaabcdeee",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Four repeating (error)",
|
||||
input: "aaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Five repeating (error)",
|
||||
input: "aaaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple groups of three (ok)",
|
||||
input: "aaabbbccc",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Repeating in middle (error)",
|
||||
input: "abcdddddef",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Repeating at end (error)",
|
||||
input: "abcdefgggg",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkRepeatingPatterns(tt.input)
|
||||
if (err != nil) != tt.shouldErr {
|
||||
t.Errorf("checkRepeatingPatterns(%q) error = %v, shouldErr = %v", tt.input, err, tt.shouldErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSequentialPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No sequential",
|
||||
input: "acegikmo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Ascending sequence - abcd",
|
||||
input: "xyzabcdefg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending sequence - dcba",
|
||||
input: "xyzdcbafg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Ascending digits - 1234",
|
||||
input: "abc1234def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending digits - 4321",
|
||||
input: "abc4321def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Random characters",
|
||||
input: "xK8vN2mP9sQ4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Base64-like",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSequentialPattern(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSequentialPattern(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeBase64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Real base64",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 without padding",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 with URL-safe chars",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b-_",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Generated secret",
|
||||
input: "xK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Simple password",
|
||||
input: "Password123!Password123!Password123!",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := looksLikeBase64(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("looksLikeBase64(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
secret string
|
||||
environment string
|
||||
shouldErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Too short - 20 chars",
|
||||
secret: "12345678901234567890",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "too short",
|
||||
},
|
||||
{
|
||||
name: "Minimum length - 32 chars (acceptable for dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains password",
|
||||
secret: "my-password-is-secure-123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - secret",
|
||||
secret: "secretsecretsecretsecretsecretsec",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains 12345",
|
||||
secret: "abcd12345efghijklmnopqrstuvwxyz",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - change",
|
||||
secret: "please-change-this-j8EJm9ZKnuTYxcVK",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - sample",
|
||||
secret: "sample-secret-j8EJm9ZKnuTYxcVKQ",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - abcd",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - 1234",
|
||||
secret: "12345678901234567890123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy secret",
|
||||
secret: "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpP",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Good secret - base64 style (dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Good secret - longer (dev)",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - too short (32 chars)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "too short for production",
|
||||
},
|
||||
{
|
||||
name: "Production - insufficient complexity",
|
||||
secret: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient complexity",
|
||||
},
|
||||
{
|
||||
name: "Production - low entropy pattern",
|
||||
secret: strings.Repeat("AbCd12!@", 8), // 64 chars but repetitive
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Production - good secret",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - excellent secret with padding",
|
||||
secret: "7mK2nP8sR4wT6xZ3bA5cxK7mN1oQ9uS4vY2zA6bxK7mN1oQ9uS4vY2zA6b+W0E=",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.environment)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Errorf("ValidateJWTSecret() expected error containing %q, got no error", tt.errContains)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateJWTSecret() error = %q, should contain %q", err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("ValidateJWTSecret() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret_EdgeCases(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
t.Run("Secret with mixed weak patterns", func(t *testing.T) {
|
||||
secret := "password123admin" // Contains multiple weak patterns
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err == nil {
|
||||
t.Error("Expected error for secret containing weak patterns, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at minimum length", func(t *testing.T) {
|
||||
// 32 characters exactly
|
||||
secret := "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx"
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 32-char secret in development, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at recommended length", func(t *testing.T) {
|
||||
// 64 characters exactly - using real random base64
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFir"
|
||||
err := validator.ValidateJWTSecret(secret, "production")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 64-char secret in production, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure validation is performant
|
||||
func BenchmarkCalculateShannonEntropy(b *testing.B) {
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
calculateShannonEntropy(secret)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateJWTSecret(b *testing.B) {
|
||||
validator := NewCredentialValidator()
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateJWTSecret(secret, "production")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package validator
|
||||
|
||||
// ProvideCredentialValidator provides a credential validator for dependency injection
|
||||
func ProvideCredentialValidator() CredentialValidator {
|
||||
return NewCredentialValidator()
|
||||
}
|
||||
33
cloud/maplepress-backend/pkg/storage/cache/redis.go
vendored
Normal file
33
cloud/maplepress-backend/pkg/storage/cache/redis.go
vendored
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProvideRedisClient creates a new Redis client
|
||||
func ProvideRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
|
||||
logger.Info("connecting to Redis",
|
||||
zap.String("host", cfg.Cache.Host),
|
||||
zap.Int("port", cfg.Cache.Port))
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Host, cfg.Cache.Port),
|
||||
Password: cfg.Cache.Password,
|
||||
DB: cfg.Cache.DB,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx := context.Background()
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("successfully connected to Redis")
|
||||
|
||||
return client, nil
|
||||
}
|
||||
121
cloud/maplepress-backend/pkg/storage/database/cassandra.go
Normal file
121
cloud/maplepress-backend/pkg/storage/database/cassandra.go
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/storage/database/cassandra/cassandra.go
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// gocqlLogger wraps zap logger to filter out noisy gocql warnings
|
||||
type gocqlLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Print implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Print(v ...interface{}) {
|
||||
msg := fmt.Sprint(v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
// These are harmless and occur due to Docker networking
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// Printf implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Printf(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// Println implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Println(v ...interface{}) {
|
||||
msg := fmt.Sprintln(v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// ProvideCassandraSession creates a new Cassandra session
|
||||
func ProvideCassandraSession(cfg *config.Config, logger *zap.Logger) (*gocql.Session, error) {
|
||||
logger.Info("⏳ Connecting to Cassandra...",
|
||||
zap.Strings("hosts", cfg.Database.Hosts),
|
||||
zap.String("keyspace", cfg.Database.Keyspace))
|
||||
|
||||
// Create cluster configuration
|
||||
cluster := gocql.NewCluster(cfg.Database.Hosts...)
|
||||
cluster.Keyspace = cfg.Database.Keyspace
|
||||
cluster.Consistency = parseConsistency(cfg.Database.Consistency)
|
||||
cluster.ProtoVersion = 4
|
||||
cluster.ConnectTimeout = 10 * time.Second
|
||||
cluster.Timeout = 10 * time.Second
|
||||
cluster.NumConns = 2
|
||||
|
||||
// Set custom logger to filter out noisy warnings
|
||||
cluster.Logger = &gocqlLogger{logger: logger.Named("gocql")}
|
||||
|
||||
// Retry policy
|
||||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
||||
NumRetries: 3,
|
||||
Min: 1 * time.Second,
|
||||
Max: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Cassandra: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✓ Cassandra connected",
|
||||
zap.String("consistency", cfg.Database.Consistency),
|
||||
zap.Int("connections", cluster.NumConns))
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// parseConsistency converts string consistency level to gocql.Consistency
|
||||
func parseConsistency(consistency string) gocql.Consistency {
|
||||
switch consistency {
|
||||
case "ANY":
|
||||
return gocql.Any
|
||||
case "ONE":
|
||||
return gocql.One
|
||||
case "TWO":
|
||||
return gocql.Two
|
||||
case "THREE":
|
||||
return gocql.Three
|
||||
case "QUORUM":
|
||||
return gocql.Quorum
|
||||
case "ALL":
|
||||
return gocql.All
|
||||
case "LOCAL_QUORUM":
|
||||
return gocql.LocalQuorum
|
||||
case "EACH_QUORUM":
|
||||
return gocql.EachQuorum
|
||||
case "LOCAL_ONE":
|
||||
return gocql.LocalOne
|
||||
default:
|
||||
return gocql.Quorum // Default to QUORUM
|
||||
}
|
||||
}
|
||||
199
cloud/maplepress-backend/pkg/storage/database/migration.go
Normal file
199
cloud/maplepress-backend/pkg/storage/database/migration.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/cassandra"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// silentGocqlLogger filters out noisy "invalid peer" warnings from gocql
|
||||
type silentGocqlLogger struct{}
|
||||
|
||||
func (l *silentGocqlLogger) Print(v ...interface{}) {
|
||||
// Silently discard all gocql logs including "invalid peer" warnings
|
||||
}
|
||||
|
||||
func (l *silentGocqlLogger) Printf(format string, v ...interface{}) {
|
||||
// Silently discard all gocql logs including "invalid peer" warnings
|
||||
}
|
||||
|
||||
func (l *silentGocqlLogger) Println(v ...interface{}) {
|
||||
// Silently discard all gocql logs including "invalid peer" warnings
|
||||
}
|
||||
|
||||
// Migrator handles database schema migrations
|
||||
// This encapsulates all migration logic and makes it testable
|
||||
type Migrator struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMigrator creates a new migration manager
|
||||
func NewMigrator(cfg *config.Config, logger *zap.Logger) *Migrator {
|
||||
if logger == nil {
|
||||
// Create a no-op logger if none provided (for backward compatibility)
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &Migrator{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Up runs all pending migrations with dirty state recovery
|
||||
func (m *Migrator) Up() error {
|
||||
// Ensure keyspace exists before running migrations
|
||||
m.logger.Debug("Ensuring keyspace exists...")
|
||||
if err := m.ensureKeyspaceExists(); err != nil {
|
||||
return fmt.Errorf("failed to ensure keyspace exists: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Debug("Creating migrator...")
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
m.logger.Debug("Checking migration version...")
|
||||
version, dirty, err := migrateInstance.Version()
|
||||
if err != nil && err != migrate.ErrNilVersion {
|
||||
return fmt.Errorf("failed to get migration version: %w", err)
|
||||
}
|
||||
|
||||
if dirty {
|
||||
m.logger.Warn("Database is in dirty state, attempting to force clean state",
|
||||
zap.Uint("version", uint(version)))
|
||||
if err := migrateInstance.Force(int(version)); err != nil {
|
||||
return fmt.Errorf("failed to force clean migration state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := migrateInstance.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
// Get final version
|
||||
finalVersion, _, err := migrateInstance.Version()
|
||||
if err != nil && err != migrate.ErrNilVersion {
|
||||
m.logger.Warn("Could not get final migration version", zap.Error(err))
|
||||
} else if err != migrate.ErrNilVersion {
|
||||
m.logger.Debug("Database migrations completed successfully",
|
||||
zap.Uint("version", uint(finalVersion)))
|
||||
} else {
|
||||
m.logger.Debug("Database migrations completed successfully (no migrations applied)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Down rolls back the last migration
|
||||
// Useful for development and rollback scenarios
|
||||
func (m *Migrator) Down() error {
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
if err := migrateInstance.Steps(-1); err != nil {
|
||||
return fmt.Errorf("failed to rollback migration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Version returns the current migration version
|
||||
func (m *Migrator) Version() (uint, bool, error) {
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
return migrateInstance.Version()
|
||||
}
|
||||
|
||||
// ForceVersion forces the migration version (useful for fixing dirty states)
|
||||
func (m *Migrator) ForceVersion(version int) error {
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
if err := migrateInstance.Force(version); err != nil {
|
||||
return fmt.Errorf("failed to force version %d: %w", version, err)
|
||||
}
|
||||
|
||||
m.logger.Info("Successfully forced migration version", zap.Int("version", version))
|
||||
return nil
|
||||
}
|
||||
|
||||
// createMigrate creates a migrate instance with proper configuration
|
||||
func (m *Migrator) createMigrate() (*migrate.Migrate, error) {
|
||||
// Set global gocql logger to suppress "invalid peer" warnings
|
||||
// This affects the internal gocql connections used by golang-migrate
|
||||
gocql.Logger = &silentGocqlLogger{}
|
||||
|
||||
// Build Cassandra connection string
|
||||
// Format: cassandra://host:port/keyspace?consistency=level
|
||||
databaseURL := fmt.Sprintf("cassandra://%s/%s?consistency=%s",
|
||||
m.config.Database.Hosts[0], // Use first host for migrations
|
||||
m.config.Database.Keyspace,
|
||||
m.config.Database.Consistency,
|
||||
)
|
||||
|
||||
// Create migrate instance
|
||||
migrateInstance, err := migrate.New(m.config.Database.MigrationsPath, databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize migrate: %w", err)
|
||||
}
|
||||
|
||||
return migrateInstance, nil
|
||||
}
|
||||
|
||||
// ensureKeyspaceExists creates the keyspace if it doesn't exist
|
||||
// This must be done before running migrations since golang-migrate requires the keyspace to exist
|
||||
func (m *Migrator) ensureKeyspaceExists() error {
|
||||
// Create cluster configuration without keyspace
|
||||
cluster := gocql.NewCluster(m.config.Database.Hosts...)
|
||||
cluster.Port = 9042
|
||||
cluster.Consistency = gocql.Quorum
|
||||
cluster.ProtoVersion = 4
|
||||
|
||||
// Suppress noisy "invalid peer" warnings from gocql
|
||||
// Use a minimal logger that discards these harmless Docker networking warnings
|
||||
cluster.Logger = &silentGocqlLogger{}
|
||||
|
||||
// Create session to system keyspace
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Cassandra: %w", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Create keyspace if it doesn't exist
|
||||
replicationFactor := m.config.Database.Replication
|
||||
createKeyspaceQuery := fmt.Sprintf(`
|
||||
CREATE KEYSPACE IF NOT EXISTS %s
|
||||
WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': %d}
|
||||
AND durable_writes = true
|
||||
`, m.config.Database.Keyspace, replicationFactor)
|
||||
|
||||
m.logger.Debug("Creating keyspace if it doesn't exist",
|
||||
zap.String("keyspace", m.config.Database.Keyspace))
|
||||
if err := session.Query(createKeyspaceQuery).Exec(); err != nil {
|
||||
return fmt.Errorf("failed to create keyspace: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Debug("Keyspace is ready", zap.String("keyspace", m.config.Database.Keyspace))
|
||||
return nil
|
||||
}
|
||||
54
cloud/maplepress-backend/pkg/storage/object/s3/config.go
Normal file
54
cloud/maplepress-backend/pkg/storage/object/s3/config.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package s3
|
||||
|
||||
type S3ObjectStorageConfigurationProvider interface {
|
||||
GetAccessKey() string
|
||||
GetSecretKey() string
|
||||
GetEndpoint() string
|
||||
GetRegion() string
|
||||
GetBucketName() string
|
||||
GetIsPublicBucket() bool
|
||||
}
|
||||
|
||||
type s3ObjectStorageConfigurationProviderImpl struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
endpoint string
|
||||
region string
|
||||
bucketName string
|
||||
isPublicBucket bool
|
||||
}
|
||||
|
||||
func NewS3ObjectStorageConfigurationProvider(accessKey, secretKey, endpoint, region, bucketName string, isPublicBucket bool) S3ObjectStorageConfigurationProvider {
|
||||
return &s3ObjectStorageConfigurationProviderImpl{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
endpoint: endpoint,
|
||||
region: region,
|
||||
bucketName: bucketName,
|
||||
isPublicBucket: isPublicBucket,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetAccessKey() string {
|
||||
return s.accessKey
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetSecretKey() string {
|
||||
return s.secretKey
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetEndpoint() string {
|
||||
return s.endpoint
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetRegion() string {
|
||||
return s.region
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetBucketName() string {
|
||||
return s.bucketName
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorageConfigurationProviderImpl) GetIsPublicBucket() bool {
|
||||
return s.isPublicBucket
|
||||
}
|
||||
23
cloud/maplepress-backend/pkg/storage/object/s3/provider.go
Normal file
23
cloud/maplepress-backend/pkg/storage/object/s3/provider.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideS3ObjectStorage provides an S3 object storage instance
|
||||
func ProvideS3ObjectStorage(cfg *config.Config, logger *zap.Logger) S3ObjectStorage {
|
||||
// Create configuration provider
|
||||
configProvider := NewS3ObjectStorageConfigurationProvider(
|
||||
cfg.AWS.AccessKey,
|
||||
cfg.AWS.SecretKey,
|
||||
cfg.AWS.Endpoint,
|
||||
cfg.AWS.Region,
|
||||
cfg.AWS.BucketName,
|
||||
false, // Default to private bucket
|
||||
)
|
||||
|
||||
// Return new S3 storage instance
|
||||
return NewObjectStorage(configProvider, logger)
|
||||
}
|
||||
508
cloud/maplepress-backend/pkg/storage/object/s3/s3.go
Normal file
508
cloud/maplepress-backend/pkg/storage/object/s3/s3.go
Normal file
|
|
@ -0,0 +1,508 @@
|
|||
// monorepo/cloud/maplefileapps-backend/pkg/storage/object/s3/s3.go
|
||||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/aws/smithy-go"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ACL constants for public and private objects
|
||||
const (
|
||||
ACLPrivate = "private"
|
||||
ACLPublicRead = "public-read"
|
||||
)
|
||||
|
||||
type S3ObjectStorage interface {
|
||||
UploadContent(ctx context.Context, objectKey string, content []byte) error
|
||||
UploadContentWithVisibility(ctx context.Context, objectKey string, content []byte, isPublic bool) error
|
||||
UploadContentFromMulipart(ctx context.Context, objectKey string, file multipart.File) error
|
||||
UploadContentFromMulipartWithVisibility(ctx context.Context, objectKey string, file multipart.File, isPublic bool) error
|
||||
BucketExists(ctx context.Context, bucketName string) (bool, error)
|
||||
DeleteByKeys(ctx context.Context, key []string) error
|
||||
Cut(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error
|
||||
CutWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error
|
||||
Copy(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error
|
||||
CopyWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error
|
||||
GetBinaryData(ctx context.Context, objectKey string) (io.ReadCloser, error)
|
||||
DownloadToLocalfile(ctx context.Context, objectKey string, filePath string) (string, error)
|
||||
ListAllObjects(ctx context.Context) (*s3.ListObjectsOutput, error)
|
||||
FindMatchingObjectKey(s3Objects *s3.ListObjectsOutput, partialKey string) string
|
||||
IsPublicBucket() bool
|
||||
// GeneratePresignedUploadURL creates a presigned URL for uploading objects
|
||||
GeneratePresignedUploadURL(ctx context.Context, key string, duration time.Duration) (string, error)
|
||||
GetDownloadablePresignedURL(ctx context.Context, key string, duration time.Duration) (string, error)
|
||||
ObjectExists(ctx context.Context, key string) (bool, error)
|
||||
GetObjectSize(ctx context.Context, key string) (int64, error)
|
||||
}
|
||||
|
||||
type s3ObjectStorage struct {
|
||||
S3Client *s3.Client
|
||||
PresignClient *s3.PresignClient
|
||||
Logger *zap.Logger
|
||||
BucketName string
|
||||
IsPublic bool
|
||||
}
|
||||
|
||||
// NewObjectStorage connects to a specific S3 bucket instance and returns a connected
|
||||
// instance structure.
|
||||
func NewObjectStorage(s3Config S3ObjectStorageConfigurationProvider, logger *zap.Logger) S3ObjectStorage {
|
||||
logger = logger.Named("s3-object-storage")
|
||||
|
||||
// DEVELOPERS NOTE:
|
||||
// How can I use the AWS SDK v2 for Go with DigitalOcean Spaces? via https://stackoverflow.com/a/74284205
|
||||
logger.Info("⏳ Connecting to S3-compatible storage...",
|
||||
zap.String("endpoint", s3Config.GetEndpoint()),
|
||||
zap.String("bucket", s3Config.GetBucketName()),
|
||||
zap.String("region", s3Config.GetRegion()))
|
||||
|
||||
// STEP 1: initialize the custom `endpoint` we will connect to.
|
||||
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...any) (aws.Endpoint, error) {
|
||||
return aws.Endpoint{
|
||||
URL: s3Config.GetEndpoint(),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// STEP 2: Configure.
|
||||
sdkConfig, err := config.LoadDefaultConfig(
|
||||
context.TODO(), config.WithRegion(s3Config.GetRegion()),
|
||||
config.WithEndpointResolverWithOptions(customResolver),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.GetAccessKey(), s3Config.GetSecretKey(), "")),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Fatal("S3ObjectStorage failed loading default config", zap.Error(err)) // We need to crash the program at start to satisfy google wire requirement of having no errors.
|
||||
}
|
||||
|
||||
// STEP 3\: Load up s3 instance.
|
||||
s3Client := s3.NewFromConfig(sdkConfig)
|
||||
|
||||
// Create our storage handler.
|
||||
s3Storage := &s3ObjectStorage{
|
||||
S3Client: s3Client,
|
||||
PresignClient: s3.NewPresignClient(s3Client),
|
||||
Logger: logger,
|
||||
BucketName: s3Config.GetBucketName(),
|
||||
IsPublic: s3Config.GetIsPublicBucket(),
|
||||
}
|
||||
|
||||
logger.Debug("Verifying bucket exists...")
|
||||
|
||||
// STEP 4: Connect to the s3 bucket instance and confirm that bucket exists.
|
||||
doesExist, err := s3Storage.BucketExists(context.TODO(), s3Config.GetBucketName())
|
||||
if err != nil {
|
||||
logger.Fatal("S3ObjectStorage failed checking if bucket exists",
|
||||
zap.String("bucket", s3Config.GetBucketName()),
|
||||
zap.Error(err)) // We need to crash the program at start to satisfy google wire requirement of having no errors.
|
||||
}
|
||||
if !doesExist {
|
||||
logger.Fatal("S3ObjectStorage failed - bucket does not exist",
|
||||
zap.String("bucket", s3Config.GetBucketName())) // We need to crash the program at start to satisfy google wire requirement of having no errors.
|
||||
}
|
||||
|
||||
logger.Info("✓ S3-compatible storage connected",
|
||||
zap.String("bucket", s3Config.GetBucketName()),
|
||||
zap.Bool("public", s3Config.GetIsPublicBucket()))
|
||||
|
||||
// Return our s3 storage handler.
|
||||
return s3Storage
|
||||
}
|
||||
|
||||
// IsPublicBucket returns whether the bucket is configured as public by default
|
||||
func (s *s3ObjectStorage) IsPublicBucket() bool {
|
||||
return s.IsPublic
|
||||
}
|
||||
|
||||
// UploadContent uploads content using the default bucket visibility setting
|
||||
func (s *s3ObjectStorage) UploadContent(ctx context.Context, objectKey string, content []byte) error {
|
||||
return s.UploadContentWithVisibility(ctx, objectKey, content, s.IsPublic)
|
||||
}
|
||||
|
||||
// UploadContentWithVisibility uploads content with specified visibility (public or private)
|
||||
func (s *s3ObjectStorage) UploadContentWithVisibility(ctx context.Context, objectKey string, content []byte, isPublic bool) error {
|
||||
acl := ACLPrivate
|
||||
if isPublic {
|
||||
acl = ACLPublicRead
|
||||
}
|
||||
|
||||
s.Logger.Debug("Uploading content with visibility",
|
||||
zap.String("objectKey", objectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.String("acl", acl))
|
||||
|
||||
_, err := s.S3Client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(objectKey),
|
||||
Body: bytes.NewReader(content),
|
||||
ACL: types.ObjectCannedACL(acl),
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error("Failed to upload content",
|
||||
zap.String("objectKey", objectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.Any("error", err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadContentFromMulipart uploads file using the default bucket visibility setting
|
||||
func (s *s3ObjectStorage) UploadContentFromMulipart(ctx context.Context, objectKey string, file multipart.File) error {
|
||||
return s.UploadContentFromMulipartWithVisibility(ctx, objectKey, file, s.IsPublic)
|
||||
}
|
||||
|
||||
// UploadContentFromMulipartWithVisibility uploads a multipart file with specified visibility
|
||||
func (s *s3ObjectStorage) UploadContentFromMulipartWithVisibility(ctx context.Context, objectKey string, file multipart.File, isPublic bool) error {
|
||||
acl := ACLPrivate
|
||||
if isPublic {
|
||||
acl = ACLPublicRead
|
||||
}
|
||||
|
||||
s.Logger.Debug("Uploading multipart file with visibility",
|
||||
zap.String("objectKey", objectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.String("acl", acl))
|
||||
|
||||
// Create the S3 upload input parameters
|
||||
params := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(objectKey),
|
||||
Body: file,
|
||||
ACL: types.ObjectCannedACL(acl),
|
||||
}
|
||||
|
||||
// Perform the file upload to S3
|
||||
_, err := s.S3Client.PutObject(ctx, params)
|
||||
if err != nil {
|
||||
s.Logger.Error("Failed to upload multipart file",
|
||||
zap.String("objectKey", objectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.Any("error", err))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorage) BucketExists(ctx context.Context, bucketName string) (bool, error) {
|
||||
// Note: https://docs.aws.amazon.com/code-library/latest/ug/go_2_s3_code_examples.html#actions
|
||||
|
||||
_, err := s.S3Client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: aws.String(bucketName),
|
||||
})
|
||||
exists := true
|
||||
if err != nil {
|
||||
var apiError smithy.APIError
|
||||
if errors.As(err, &apiError) {
|
||||
switch apiError.(type) {
|
||||
case *types.NotFound:
|
||||
s.Logger.Debug("Bucket is available", zap.String("bucket", bucketName))
|
||||
exists = false
|
||||
err = nil
|
||||
default:
|
||||
s.Logger.Error("Either you don't have access to bucket or another error occurred",
|
||||
zap.String("bucket", bucketName),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorage) GetDownloadablePresignedURL(ctx context.Context, key string, duration time.Duration) (string, error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// AWS S3 Bucket — presigned URL APIs with Go (2022) via https://ronen-niv.medium.com/aws-s3-handling-presigned-urls-2718ab247d57
|
||||
|
||||
presignedUrl, err := s.PresignClient.PresignGetObject(context.Background(),
|
||||
&s3.GetObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(key),
|
||||
ResponseContentDisposition: aws.String("attachment"), // This field allows the file to download it directly from your browser
|
||||
},
|
||||
s3.WithPresignExpires(duration))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return presignedUrl.URL, nil
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorage) DeleteByKeys(ctx context.Context, objectKeys []string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var objectIds []types.ObjectIdentifier
|
||||
for _, key := range objectKeys {
|
||||
objectIds = append(objectIds, types.ObjectIdentifier{Key: aws.String(key)})
|
||||
}
|
||||
_, err := s.S3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Delete: &types.Delete{Objects: objectIds},
|
||||
})
|
||||
if err != nil {
|
||||
s.Logger.Error("Couldn't delete objects from bucket",
|
||||
zap.String("bucket", s.BucketName),
|
||||
zap.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Cut moves a file using the default bucket visibility setting
|
||||
func (s *s3ObjectStorage) Cut(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error {
|
||||
return s.CutWithVisibility(ctx, sourceObjectKey, destinationObjectKey, s.IsPublic)
|
||||
}
|
||||
|
||||
// CutWithVisibility moves a file with specified visibility
|
||||
func (s *s3ObjectStorage) CutWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 60*time.Second) // Increase timout so it runs longer then usual to handle this unique case.
|
||||
defer cancel()
|
||||
|
||||
// First copy the object with the desired visibility
|
||||
if err := s.CopyWithVisibility(ctx, sourceObjectKey, destinationObjectKey, isPublic); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete the original object
|
||||
_, deleteErr := s.S3Client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(sourceObjectKey),
|
||||
})
|
||||
if deleteErr != nil {
|
||||
s.Logger.Error("Failed to delete original object:", zap.Any("deleteErr", deleteErr))
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
s.Logger.Debug("Original object deleted.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy copies a file using the default bucket visibility setting
|
||||
func (s *s3ObjectStorage) Copy(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error {
|
||||
return s.CopyWithVisibility(ctx, sourceObjectKey, destinationObjectKey, s.IsPublic)
|
||||
}
|
||||
|
||||
// CopyWithVisibility copies a file with specified visibility
|
||||
func (s *s3ObjectStorage) CopyWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 60*time.Second) // Increase timout so it runs longer then usual to handle this unique case.
|
||||
defer cancel()
|
||||
|
||||
acl := ACLPrivate
|
||||
if isPublic {
|
||||
acl = ACLPublicRead
|
||||
}
|
||||
|
||||
s.Logger.Debug("Copying object with visibility",
|
||||
zap.String("sourceKey", sourceObjectKey),
|
||||
zap.String("destinationKey", destinationObjectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.String("acl", acl))
|
||||
|
||||
_, copyErr := s.S3Client.CopyObject(ctx, &s3.CopyObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
CopySource: aws.String(s.BucketName + "/" + sourceObjectKey),
|
||||
Key: aws.String(destinationObjectKey),
|
||||
ACL: types.ObjectCannedACL(acl),
|
||||
})
|
||||
if copyErr != nil {
|
||||
s.Logger.Error("Failed to copy object:",
|
||||
zap.String("sourceKey", sourceObjectKey),
|
||||
zap.String("destinationKey", destinationObjectKey),
|
||||
zap.Bool("isPublic", isPublic),
|
||||
zap.Any("copyErr", copyErr))
|
||||
return copyErr
|
||||
}
|
||||
|
||||
s.Logger.Debug("Object copied successfully.")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBinaryData function will return the binary data for the particular key.
|
||||
func (s *s3ObjectStorage) GetBinaryData(ctx context.Context, objectKey string) (io.ReadCloser, error) {
|
||||
input := &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(objectKey),
|
||||
}
|
||||
|
||||
s3object, err := s.S3Client.GetObject(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s3object.Body, nil
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorage) DownloadToLocalfile(ctx context.Context, objectKey string, filePath string) (string, error) {
|
||||
responseBin, err := s.GetBinaryData(ctx, objectKey)
|
||||
if err != nil {
|
||||
return filePath, err
|
||||
}
|
||||
out, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return filePath, err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, responseBin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filePath, err
|
||||
}
|
||||
|
||||
func (s *s3ObjectStorage) ListAllObjects(ctx context.Context) (*s3.ListObjectsOutput, error) {
|
||||
input := &s3.ListObjectsInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
}
|
||||
|
||||
objects, err := s.S3Client.ListObjects(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Function will iterate over all the s3 objects to match the partial key with
|
||||
// the actual key found in the S3 bucket.
|
||||
func (s *s3ObjectStorage) FindMatchingObjectKey(s3Objects *s3.ListObjectsOutput, partialKey string) string {
|
||||
for _, obj := range s3Objects.Contents {
|
||||
|
||||
match := strings.Contains(*obj.Key, partialKey)
|
||||
|
||||
// If a match happens then it means we have found the ACTUAL KEY in the
|
||||
// s3 objects inside the bucket.
|
||||
if match == true {
|
||||
return *obj.Key
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GeneratePresignedUploadURL creates a presigned URL for uploading objects to S3
|
||||
func (s *s3ObjectStorage) GeneratePresignedUploadURL(ctx context.Context, key string, duration time.Duration) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create PutObjectInput without ACL to avoid requiring x-amz-acl header
|
||||
putObjectInput := &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(key),
|
||||
// Removed ACL field - files inherit bucket's default privacy settings.
|
||||
}
|
||||
|
||||
presignedUrl, err := s.PresignClient.PresignPutObject(ctx, putObjectInput, s3.WithPresignExpires(duration))
|
||||
if err != nil {
|
||||
s.Logger.Error("Failed to generate presigned upload URL",
|
||||
zap.String("key", key),
|
||||
zap.Duration("duration", duration),
|
||||
zap.Error(err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.Logger.Debug("Generated presigned upload URL",
|
||||
zap.String("key", key),
|
||||
zap.Duration("duration", duration))
|
||||
|
||||
return presignedUrl.URL, nil
|
||||
}
|
||||
|
||||
// ObjectExists checks if an object exists at the given key using HeadObject
|
||||
func (s *s3ObjectStorage) ObjectExists(ctx context.Context, key string) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := s.S3Client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
var apiError smithy.APIError
|
||||
if errors.As(err, &apiError) {
|
||||
switch apiError.(type) {
|
||||
case *types.NotFound:
|
||||
// Object doesn't exist
|
||||
s.Logger.Debug("Object does not exist",
|
||||
zap.String("key", key))
|
||||
return false, nil
|
||||
case *types.NoSuchKey:
|
||||
// Object doesn't exist
|
||||
s.Logger.Debug("Object does not exist (NoSuchKey)",
|
||||
zap.String("key", key))
|
||||
return false, nil
|
||||
default:
|
||||
// Some other error occurred
|
||||
s.Logger.Error("Error checking object existence",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
// Non-API error
|
||||
s.Logger.Error("Error checking object existence",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
s.Logger.Debug("Object exists",
|
||||
zap.String("key", key))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetObjectSize returns the size of an object at the given key using HeadObject
|
||||
func (s *s3ObjectStorage) GetObjectSize(ctx context.Context, key string) (int64, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.S3Client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.BucketName),
|
||||
Key: aws.String(key),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
var apiError smithy.APIError
|
||||
if errors.As(err, &apiError) {
|
||||
switch apiError.(type) {
|
||||
case *types.NotFound:
|
||||
s.Logger.Debug("Object not found when getting size",
|
||||
zap.String("key", key))
|
||||
return 0, errors.New("object not found")
|
||||
case *types.NoSuchKey:
|
||||
s.Logger.Debug("Object not found when getting size (NoSuchKey)",
|
||||
zap.String("key", key))
|
||||
return 0, errors.New("object not found")
|
||||
default:
|
||||
s.Logger.Error("Error getting object size",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
s.Logger.Error("Error getting object size",
|
||||
zap.String("key", key),
|
||||
zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Let's use aws.ToInt64 which handles both pointer and non-pointer cases
|
||||
size := aws.ToInt64(result.ContentLength)
|
||||
|
||||
s.Logger.Debug("Retrieved object size",
|
||||
zap.String("key", key),
|
||||
zap.Int64("size", size))
|
||||
|
||||
return size, nil
|
||||
}
|
||||
516
cloud/maplepress-backend/pkg/transaction/saga.go
Normal file
516
cloud/maplepress-backend/pkg/transaction/saga.go
Normal file
|
|
@ -0,0 +1,516 @@
|
|||
package transaction
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Package transaction provides a SAGA pattern implementation for managing distributed transactions.
|
||||
//
|
||||
// # What is SAGA Pattern?
|
||||
//
|
||||
// SAGA is a pattern for managing distributed transactions through a sequence of local transactions,
|
||||
// each with a corresponding compensating transaction that undoes its effects if a later step fails.
|
||||
//
|
||||
// # When to Use SAGA
|
||||
//
|
||||
// Use SAGA when you have multiple database operations that need to succeed or fail together,
|
||||
// but you can't use traditional ACID transactions (e.g., with Cassandra, distributed services,
|
||||
// or operations across multiple bounded contexts).
|
||||
//
|
||||
// # Key Concepts
|
||||
//
|
||||
// - Forward Transaction: A database write operation (e.g., CreateTenant)
|
||||
// - Compensating Transaction: An undo operation (e.g., DeleteTenant)
|
||||
// - LIFO Execution: Compensations execute in reverse order (Last In, First Out)
|
||||
//
|
||||
// # Example Usage: User Registration Flow
|
||||
//
|
||||
// Problem: When registering a user, we create a tenant, then create a user.
|
||||
// If user creation fails, the tenant becomes orphaned in the database.
|
||||
//
|
||||
// Solution: Use SAGA to automatically delete the tenant if user creation fails.
|
||||
//
|
||||
// func (s *RegisterService) Register(ctx context.Context, input *RegisterInput) (*RegisterResponse, error) {
|
||||
// // Step 1: Create SAGA instance
|
||||
// saga := transaction.NewSaga("user-registration", s.logger)
|
||||
//
|
||||
// // Step 2: Validate input (no DB writes, no compensation needed)
|
||||
// if err := s.validateInputUC.Execute(input); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // Step 3: Create tenant (FIRST DB WRITE - register compensation)
|
||||
// tenantOutput, err := s.createTenantUC.Execute(ctx, input)
|
||||
// if err != nil {
|
||||
// return nil, err // No rollback needed - tenant creation failed
|
||||
// }
|
||||
//
|
||||
// // Register compensation: if anything fails later, delete this tenant
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// s.logger.Warn("compensating: deleting tenant",
|
||||
// zap.String("tenant_id", tenantOutput.ID))
|
||||
// return s.deleteTenantUC.Execute(ctx, tenantOutput.ID)
|
||||
// })
|
||||
//
|
||||
// // Step 4: Create user (SECOND DB WRITE)
|
||||
// userOutput, err := s.createUserUC.Execute(ctx, tenantOutput.ID, input)
|
||||
// if err != nil {
|
||||
// s.logger.Error("user creation failed - rolling back tenant",
|
||||
// zap.Error(err))
|
||||
//
|
||||
// // Execute SAGA rollback - this will delete the tenant
|
||||
// saga.Rollback(ctx)
|
||||
//
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // Success! Both tenant and user created, no rollback needed
|
||||
// return &RegisterResponse{
|
||||
// TenantID: tenantOutput.ID,
|
||||
// UserID: userOutput.ID,
|
||||
// }, nil
|
||||
// }
|
||||
//
|
||||
// # Example Usage: Multi-Step Saga
|
||||
//
|
||||
// For operations with many steps, register multiple compensations:
|
||||
//
|
||||
// func (uc *ComplexOperationUseCase) Execute(ctx context.Context) error {
|
||||
// saga := transaction.NewSaga("complex-operation", uc.logger)
|
||||
//
|
||||
// // Step 1: Create resource A
|
||||
// resourceA, err := uc.createResourceA(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteResourceA(ctx, resourceA.ID)
|
||||
// })
|
||||
//
|
||||
// // Step 2: Create resource B
|
||||
// resourceB, err := uc.createResourceB(ctx)
|
||||
// if err != nil {
|
||||
// saga.Rollback(ctx) // Deletes A
|
||||
// return err
|
||||
// }
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteResourceB(ctx, resourceB.ID)
|
||||
// })
|
||||
//
|
||||
// // Step 3: Create resource C
|
||||
// resourceC, err := uc.createResourceC(ctx)
|
||||
// if err != nil {
|
||||
// saga.Rollback(ctx) // Deletes B, then A (LIFO order)
|
||||
// return err
|
||||
// }
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteResourceC(ctx, resourceC.ID)
|
||||
// })
|
||||
//
|
||||
// // All steps succeeded - no rollback needed
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// # Important Notes for Junior Developers
|
||||
//
|
||||
// 1. LIFO Order: Compensations execute in REVERSE order of registration
|
||||
// If you create: Tenant → User → Email
|
||||
// Rollback deletes: Email → User → Tenant
|
||||
//
|
||||
// 2. Idempotency: Compensating operations should be idempotent (safe to call multiple times)
|
||||
// Your DeleteTenant should not error if tenant is already deleted
|
||||
//
|
||||
// 3. Failures Continue: If one compensation fails, others still execute
|
||||
// This ensures maximum cleanup even if some operations fail
|
||||
//
|
||||
// 4. Logging: All operations are logged with emoji icons (🔴 for errors, 🟡 for warnings)
|
||||
// Monitor logs for "saga rollback had failures" - indicates manual intervention needed
|
||||
//
|
||||
// 5. When NOT to Use SAGA:
|
||||
// - Single database operation (no need for compensation)
|
||||
// - Read-only operations (no state changes to rollback)
|
||||
// - Operations where compensation isn't possible (e.g., sending an email can't be unsent)
|
||||
//
|
||||
// 6. Testing: Always test your rollback scenarios!
|
||||
// Mock the second operation to fail and verify the first is rolled back
|
||||
//
|
||||
// # Common Pitfalls to Avoid
|
||||
//
|
||||
// - DON'T register compensations before the operation succeeds
|
||||
// - DON'T forget to call saga.Rollback(ctx) when an operation fails
|
||||
// - DON'T assume compensations will always succeed (they might fail too)
|
||||
// - DON'T use SAGA for operations that can use database transactions
|
||||
// - DO make your compensating operations idempotent
|
||||
// - DO log all compensation failures for investigation
|
||||
//
|
||||
// # See Also
|
||||
//
|
||||
// For real-world examples, see:
|
||||
// - internal/service/gateway/register.go (user registration with SAGA)
|
||||
// - internal/usecase/tenant/delete.go (compensating transaction example)
|
||||
// - internal/usecase/user/delete.go (compensating transaction example)
|
||||
|
||||
// Compensator defines a function that undoes a previously executed operation.
|
||||
//
|
||||
// A compensator is the "undo" function for a database write operation.
|
||||
// For example:
|
||||
// - Forward operation: CreateTenant
|
||||
// - Compensator: DeleteTenant
|
||||
//
|
||||
// Compensators must:
|
||||
// - Accept a context (for cancellation/timeouts)
|
||||
// - Return an error if compensation fails
|
||||
// - Be idempotent (safe to call multiple times)
|
||||
// - Clean up the exact resources created by the forward operation
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// // Forward operation: Create tenant
|
||||
// tenantID := "tenant-123"
|
||||
// err := tenantRepo.Create(ctx, tenant)
|
||||
//
|
||||
// // Compensator: Delete tenant
|
||||
// compensator := func(ctx context.Context) error {
|
||||
// return tenantRepo.Delete(ctx, tenantID)
|
||||
// }
|
||||
//
|
||||
// saga.AddCompensation(compensator)
|
||||
type Compensator func(ctx context.Context) error
|
||||
|
||||
// Saga manages a sequence of operations with compensating transactions.
|
||||
//
|
||||
// A Saga coordinates a multi-step workflow where each step that performs a database
|
||||
// write registers a compensating transaction. If any step fails, all registered
|
||||
// compensations are executed in reverse order (LIFO) to undo previous changes.
|
||||
//
|
||||
// # How it Works
|
||||
//
|
||||
// 1. Create a Saga instance with NewSaga()
|
||||
// 2. Execute your operations in sequence
|
||||
// 3. After each successful write, call AddCompensation() with the undo operation
|
||||
// 4. If any operation fails, call Rollback() to undo all previous changes
|
||||
// 5. If all operations succeed, no action needed (compensations are never called)
|
||||
//
|
||||
// # Thread Safety
|
||||
//
|
||||
// Saga is NOT thread-safe. Do not share a single Saga instance across goroutines.
|
||||
// Each workflow execution should create its own Saga instance.
|
||||
//
|
||||
// # Fields
|
||||
//
|
||||
// - name: Human-readable name for logging (e.g., "user-registration")
|
||||
// - compensators: Stack of undo functions, executed in LIFO order
|
||||
// - logger: Structured logger for tracking saga execution and failures
|
||||
type Saga struct {
|
||||
name string // Name of the saga (for logging)
|
||||
compensators []Compensator // Stack of compensating transactions (LIFO)
|
||||
logger *zap.Logger // Logger for tracking saga execution
|
||||
}
|
||||
|
||||
// NewSaga creates a new SAGA instance with the given name.
|
||||
//
|
||||
// The name parameter should be a descriptive identifier for the workflow
|
||||
// (e.g., "user-registration", "order-processing", "account-setup").
|
||||
// This name appears in all log messages for easy tracking and debugging.
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - name: A descriptive name for this saga workflow (used in logging)
|
||||
// - logger: A zap logger instance (will be enhanced with saga-specific fields)
|
||||
//
|
||||
// # Returns
|
||||
//
|
||||
// A new Saga instance ready to coordinate multi-step operations.
|
||||
//
|
||||
// # Example
|
||||
//
|
||||
// // In your use case
|
||||
// func (uc *RegisterUseCase) Execute(ctx context.Context, input *Input) error {
|
||||
// // Create a new saga for this registration workflow
|
||||
// saga := transaction.NewSaga("user-registration", uc.logger)
|
||||
//
|
||||
// // ... use saga for your operations ...
|
||||
// }
|
||||
//
|
||||
// # Important
|
||||
//
|
||||
// Each workflow execution should create its own Saga instance.
|
||||
// Do NOT reuse a Saga instance across multiple workflow executions.
|
||||
func NewSaga(name string, logger *zap.Logger) *Saga {
|
||||
return &Saga{
|
||||
name: name,
|
||||
compensators: make([]Compensator, 0),
|
||||
logger: logger.Named("saga").With(zap.String("saga_name", name)),
|
||||
}
|
||||
}
|
||||
|
||||
// AddCompensation registers a compensating transaction for rollback.
|
||||
//
|
||||
// Call this method IMMEDIATELY AFTER a successful database write operation
|
||||
// to register the corresponding undo operation.
|
||||
//
|
||||
// # Execution Order: LIFO (Last In, First Out)
|
||||
//
|
||||
// Compensations are executed in REVERSE order of registration during rollback.
|
||||
// This ensures proper cleanup order:
|
||||
// - If you create: Tenant → User → Subscription
|
||||
// - Rollback deletes: Subscription → User → Tenant
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - compensate: A function that undoes the operation (e.g., DeleteTenant)
|
||||
//
|
||||
// # When to Call
|
||||
//
|
||||
// // ✅ CORRECT: Register compensation AFTER operation succeeds
|
||||
// tenantOutput, err := uc.createTenantUC.Execute(ctx, input)
|
||||
// if err != nil {
|
||||
// return nil, err // Operation failed - no compensation needed
|
||||
// }
|
||||
// // Operation succeeded - NOW register the undo operation
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteTenantUC.Execute(ctx, tenantOutput.ID)
|
||||
// })
|
||||
//
|
||||
// // ❌ WRONG: Don't register compensation BEFORE operation
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteTenantUC.Execute(ctx, tenantOutput.ID)
|
||||
// })
|
||||
// tenantOutput, err := uc.createTenantUC.Execute(ctx, input) // Might fail!
|
||||
//
|
||||
// # Example: Basic Usage
|
||||
//
|
||||
// // Step 1: Create tenant
|
||||
// tenant, err := uc.createTenantUC.Execute(ctx, input)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // Step 2: Register compensation for tenant
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// uc.logger.Warn("rolling back: deleting tenant",
|
||||
// zap.String("tenant_id", tenant.ID))
|
||||
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
|
||||
// })
|
||||
//
|
||||
// # Example: Capturing Variables in Closure
|
||||
//
|
||||
// // Be careful with variable scope in closures!
|
||||
// for _, item := range items {
|
||||
// created, err := uc.createItem(ctx, item)
|
||||
// if err != nil {
|
||||
// saga.Rollback(ctx)
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// // ✅ CORRECT: Capture the variable value
|
||||
// itemID := created.ID // Capture in local variable
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteItem(ctx, itemID) // Use captured value
|
||||
// })
|
||||
//
|
||||
// // ❌ WRONG: Variable will have wrong value at rollback time
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteItem(ctx, created.ID) // 'created' may change!
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// # Tips for Writing Good Compensators
|
||||
//
|
||||
// 1. Make them idempotent (safe to call multiple times)
|
||||
// 2. Log what you're compensating for easier debugging
|
||||
// 3. Capture all necessary IDs before the closure
|
||||
// 4. Handle "not found" errors gracefully (resource may already be deleted)
|
||||
// 5. Return errors if compensation truly fails (logged but doesn't stop other compensations)
|
||||
func (s *Saga) AddCompensation(compensate Compensator) {
|
||||
s.compensators = append(s.compensators, compensate)
|
||||
s.logger.Debug("compensation registered",
|
||||
zap.Int("total_compensations", len(s.compensators)))
|
||||
}
|
||||
|
||||
// Rollback executes all registered compensating transactions in reverse order (LIFO).
|
||||
//
|
||||
// Call this method when any operation in your workflow fails AFTER you've started
|
||||
// registering compensations. This will undo all previously successful operations
|
||||
// by executing their compensating transactions in reverse order.
|
||||
//
|
||||
// # When to Call
|
||||
//
|
||||
// tenant, err := uc.createTenantUC.Execute(ctx, input)
|
||||
// if err != nil {
|
||||
// return nil, err // No compensations registered yet - no rollback needed
|
||||
// }
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
|
||||
// })
|
||||
//
|
||||
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
|
||||
// if err != nil {
|
||||
// // Compensations ARE registered - MUST call rollback!
|
||||
// saga.Rollback(ctx)
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// # Execution Behavior
|
||||
//
|
||||
// 1. LIFO Order: Compensations execute in REVERSE order of registration
|
||||
// - If you registered: [DeleteTenant, DeleteUser, DeleteSubscription]
|
||||
// - Rollback executes: DeleteSubscription → DeleteUser → DeleteTenant
|
||||
//
|
||||
// 2. Best Effort: If a compensation fails, it's logged but others still execute
|
||||
// - This maximizes cleanup even if some operations fail
|
||||
// - Failed compensations are logged with 🔴 emoji for investigation
|
||||
//
|
||||
// 3. No Panic: Rollback never panics, even if all compensations fail
|
||||
// - Failures are logged for manual intervention
|
||||
// - Returns without error (compensation failures are logged, not returned)
|
||||
//
|
||||
// # Example: Basic Rollback
|
||||
//
|
||||
// func (uc *RegisterUseCase) Execute(ctx context.Context, input *Input) error {
|
||||
// saga := transaction.NewSaga("user-registration", uc.logger)
|
||||
//
|
||||
// // Step 1: Create tenant
|
||||
// tenant, err := uc.createTenantUC.Execute(ctx, input)
|
||||
// if err != nil {
|
||||
// return err // No rollback needed
|
||||
// }
|
||||
// saga.AddCompensation(func(ctx context.Context) error {
|
||||
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
|
||||
// })
|
||||
//
|
||||
// // Step 2: Create user
|
||||
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
|
||||
// if err != nil {
|
||||
// uc.logger.Error("user creation failed", zap.Error(err))
|
||||
// saga.Rollback(ctx) // ← Deletes tenant
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// // Both operations succeeded - no rollback needed
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// # Log Output Example
|
||||
//
|
||||
// Successful rollback:
|
||||
//
|
||||
// WARN 🟡 executing saga rollback {"saga_name": "user-registration", "compensation_count": 1}
|
||||
// INFO executing compensation {"step": 1, "index": 0}
|
||||
// INFO deleting tenant {"tenant_id": "tenant-123"}
|
||||
// INFO tenant deleted successfully {"tenant_id": "tenant-123"}
|
||||
// INFO compensation succeeded {"step": 1}
|
||||
// WARN 🟡 saga rollback completed {"total_compensations": 1, "successes": 1, "failures": 0}
|
||||
//
|
||||
// Failed compensation:
|
||||
//
|
||||
// WARN 🟡 executing saga rollback
|
||||
// INFO executing compensation
|
||||
// ERROR 🔴 failed to delete tenant {"error": "connection lost"}
|
||||
// ERROR 🔴 compensation failed {"step": 1, "error": "..."}
|
||||
// WARN 🟡 saga rollback completed {"successes": 0, "failures": 1}
|
||||
// ERROR 🔴 saga rollback had failures - manual intervention may be required
|
||||
//
|
||||
// # Important Notes
|
||||
//
|
||||
// 1. Always call Rollback if you've registered ANY compensations and a later step fails
|
||||
// 2. Don't call Rollback if no compensations have been registered yet
|
||||
// 3. Rollback is safe to call multiple times (idempotent) but wasteful
|
||||
// 4. Monitor logs for "saga rollback had failures" - indicates manual cleanup needed
|
||||
// 5. Context cancellation is respected - compensations will see cancelled context
|
||||
//
|
||||
// # Parameters
|
||||
//
|
||||
// - ctx: Context for cancellation/timeout (passed to each compensating function)
|
||||
//
|
||||
// # What Gets Logged
|
||||
//
|
||||
// - Start of rollback (warning level with 🟡 emoji)
|
||||
// - Each compensation execution attempt
|
||||
// - Success or failure of each compensation
|
||||
// - Summary of rollback results
|
||||
// - Alert if any compensations failed (error level with 🔴 emoji)
|
||||
func (s *Saga) Rollback(ctx context.Context) {
|
||||
if len(s.compensators) == 0 {
|
||||
s.logger.Info("no compensations to execute")
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Warn("executing saga rollback",
|
||||
zap.Int("compensation_count", len(s.compensators)))
|
||||
|
||||
successCount := 0
|
||||
failureCount := 0
|
||||
|
||||
// Execute in reverse order (LIFO - Last In, First Out)
|
||||
for i := len(s.compensators) - 1; i >= 0; i-- {
|
||||
compensationStep := len(s.compensators) - i
|
||||
|
||||
s.logger.Info("executing compensation",
|
||||
zap.Int("step", compensationStep),
|
||||
zap.Int("index", i))
|
||||
|
||||
if err := s.compensators[i](ctx); err != nil {
|
||||
failureCount++
|
||||
// Log with error level (automatically adds emoji)
|
||||
s.logger.Error("compensation failed",
|
||||
zap.Int("step", compensationStep),
|
||||
zap.Int("index", i),
|
||||
zap.Error(err))
|
||||
// Continue with other compensations even if one fails
|
||||
} else {
|
||||
successCount++
|
||||
s.logger.Info("compensation succeeded",
|
||||
zap.Int("step", compensationStep),
|
||||
zap.Int("index", i))
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Warn("saga rollback completed",
|
||||
zap.Int("total_compensations", len(s.compensators)),
|
||||
zap.Int("successes", successCount),
|
||||
zap.Int("failures", failureCount))
|
||||
|
||||
// If any compensations failed, this indicates a serious issue
|
||||
// The operations team should be alerted to investigate
|
||||
if failureCount > 0 {
|
||||
s.logger.Error("saga rollback had failures - manual intervention may be required",
|
||||
zap.Int("failed_compensations", failureCount))
|
||||
}
|
||||
}
|
||||
|
||||
// MustRollback is a convenience method that executes rollback.
|
||||
//
|
||||
// This method currently has the same behavior as Rollback() - it executes
|
||||
// all compensating transactions but does NOT panic on failure.
|
||||
//
|
||||
// # When to Use
|
||||
//
|
||||
// Use this method when you want to make it explicit in your code that rollback
|
||||
// is critical and must be executed, even though the actual behavior is the same
|
||||
// as Rollback().
|
||||
//
|
||||
// # Example
|
||||
//
|
||||
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
|
||||
// if err != nil {
|
||||
// // Make it explicit that rollback is critical
|
||||
// saga.MustRollback(ctx)
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// # Note for Junior Developers
|
||||
//
|
||||
// Despite the name "MustRollback", this method does NOT panic if compensations fail.
|
||||
// Compensation failures are logged for manual intervention, but the method returns normally.
|
||||
//
|
||||
// The name "Must" indicates that YOU must call this method if compensations are registered,
|
||||
// not that the rollback itself must succeed.
|
||||
//
|
||||
// If you need actual panic behavior on compensation failure, you would need to check
|
||||
// logs or implement custom panic logic.
|
||||
func (s *Saga) MustRollback(ctx context.Context) {
|
||||
s.Rollback(ctx)
|
||||
}
|
||||
275
cloud/maplepress-backend/pkg/validation/email.go
Normal file
275
cloud/maplepress-backend/pkg/validation/email.go
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/validation/email.go
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EmailValidator provides comprehensive email validation and normalization
|
||||
// CWE-20: Improper Input Validation - Ensures email addresses are properly validated
|
||||
type EmailValidator struct {
|
||||
validator *Validator
|
||||
}
|
||||
|
||||
// NewEmailValidator creates a new email validator
|
||||
func NewEmailValidator() *EmailValidator {
|
||||
return &EmailValidator{
|
||||
validator: NewValidator(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAndNormalize validates and normalizes an email address
|
||||
// Returns the normalized email and any validation error
|
||||
func (ev *EmailValidator) ValidateAndNormalize(email, fieldName string) (string, error) {
|
||||
// Step 1: Basic validation using existing validator
|
||||
if err := ev.validator.ValidateEmail(email, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Step 2: Normalize the email
|
||||
normalized := ev.Normalize(email)
|
||||
|
||||
// Step 3: Additional security checks
|
||||
if err := ev.ValidateSecurityConstraints(normalized, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// Normalize normalizes an email address for consistent storage and comparison
|
||||
// CWE-180: Incorrect Behavior Order: Validate Before Canonicalize
|
||||
func (ev *EmailValidator) Normalize(email string) string {
|
||||
// Trim whitespace
|
||||
email = strings.TrimSpace(email)
|
||||
|
||||
// Convert to lowercase (email local parts are case-sensitive per RFC 5321,
|
||||
// but most providers treat them as case-insensitive for better UX)
|
||||
email = strings.ToLower(email)
|
||||
|
||||
// Remove any null bytes
|
||||
email = strings.ReplaceAll(email, "\x00", "")
|
||||
|
||||
// Gmail-specific normalization (optional - commented out by default)
|
||||
// This removes dots and plus-aliases from Gmail addresses
|
||||
// Uncomment if you want to prevent abuse via Gmail aliases
|
||||
// email = ev.normalizeGmail(email)
|
||||
|
||||
return email
|
||||
}
|
||||
|
||||
// ValidateSecurityConstraints performs additional security validation
|
||||
func (ev *EmailValidator) ValidateSecurityConstraints(email, fieldName string) error {
|
||||
// Check for suspicious patterns
|
||||
|
||||
// 1. Detect emails with excessive special characters (potential obfuscation)
|
||||
specialCharCount := 0
|
||||
for _, ch := range email {
|
||||
if ch == '+' || ch == '.' || ch == '_' || ch == '-' || ch == '%' {
|
||||
specialCharCount++
|
||||
}
|
||||
}
|
||||
if specialCharCount > 10 {
|
||||
return fmt.Errorf("%s: contains too many special characters", fieldName)
|
||||
}
|
||||
|
||||
// 2. Detect potentially disposable email patterns
|
||||
if ev.isLikelyDisposable(email) {
|
||||
// Note: This is a warning-level check. In production, you might want to
|
||||
// either reject these or flag them for review.
|
||||
// For now, we'll allow them but this can be configured.
|
||||
}
|
||||
|
||||
// 3. Check for common typos in popular domains
|
||||
if typo := ev.detectCommonDomainTypo(email); typo != "" {
|
||||
return fmt.Errorf("%s: possible typo detected, did you mean %s?", fieldName, typo)
|
||||
}
|
||||
|
||||
// 4. Prevent IP-based email addresses
|
||||
if ev.hasIPAddress(email) {
|
||||
return fmt.Errorf("%s: IP-based email addresses are not allowed", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLikelyDisposable checks if email is from a known disposable email provider
|
||||
// This is a basic implementation - in production, use a service like:
|
||||
// - https://github.com/disposable/disposable-email-domains
|
||||
// - or an API service
|
||||
func (ev *EmailValidator) isLikelyDisposable(email string) bool {
|
||||
// Extract domain
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Common disposable email patterns
|
||||
disposablePatterns := []string{
|
||||
"temp",
|
||||
"disposable",
|
||||
"throwaway",
|
||||
"guerrilla",
|
||||
"mailinator",
|
||||
"10minute",
|
||||
"trashmail",
|
||||
"yopmail",
|
||||
"fakeinbox",
|
||||
}
|
||||
|
||||
for _, pattern := range disposablePatterns {
|
||||
if strings.Contains(domain, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Known disposable domains (small sample - expand as needed)
|
||||
disposableDomains := map[string]bool{
|
||||
"mailinator.com": true,
|
||||
"guerrillamail.com": true,
|
||||
"10minutemail.com": true,
|
||||
"tempmailaddress.com": true,
|
||||
"yopmail.com": true,
|
||||
"fakeinbox.com": true,
|
||||
"trashmail.com": true,
|
||||
"throwaway.email": true,
|
||||
}
|
||||
|
||||
return disposableDomains[domain]
|
||||
}
|
||||
|
||||
// detectCommonDomainTypo checks for common typos in popular email domains
|
||||
func (ev *EmailValidator) detectCommonDomainTypo(email string) string {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
localPart := parts[0]
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Common typos map: typo -> correct
|
||||
typos := map[string]string{
|
||||
"gmial.com": "gmail.com",
|
||||
"gmai.com": "gmail.com",
|
||||
"gmil.com": "gmail.com",
|
||||
"yahooo.com": "yahoo.com",
|
||||
"yaho.com": "yahoo.com",
|
||||
"hotmial.com": "hotmail.com",
|
||||
"hotmal.com": "hotmail.com",
|
||||
"outlok.com": "outlook.com",
|
||||
"outloo.com": "outlook.com",
|
||||
"iclodu.com": "icloud.com",
|
||||
"iclod.com": "icloud.com",
|
||||
"protonmai.com": "protonmail.com",
|
||||
"protonmal.com": "protonmail.com",
|
||||
}
|
||||
|
||||
if correct, found := typos[domain]; found {
|
||||
return localPart + "@" + correct
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// hasIPAddress checks if email domain is an IP address
|
||||
func (ev *EmailValidator) hasIPAddress(email string) bool {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := parts[1]
|
||||
|
||||
// Check for IPv4 pattern: [192.168.1.1]
|
||||
if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unbracketed IP patterns (less common but possible)
|
||||
// Simple heuristic: contains only digits and dots
|
||||
hasOnlyDigitsAndDots := true
|
||||
for _, ch := range domain {
|
||||
if ch != '.' && (ch < '0' || ch > '9') {
|
||||
hasOnlyDigitsAndDots = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return hasOnlyDigitsAndDots && strings.Count(domain, ".") >= 3
|
||||
}
|
||||
|
||||
// normalizeGmail normalizes Gmail addresses by removing dots and plus-aliases
|
||||
// Gmail ignores dots in the local part and treats everything after + as an alias
|
||||
// Example: john.doe+test@gmail.com -> johndoe@gmail.com
|
||||
func (ev *EmailValidator) normalizeGmail(email string) string {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return email
|
||||
}
|
||||
|
||||
localPart := parts[0]
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Only normalize for Gmail and Googlemail
|
||||
if domain != "gmail.com" && domain != "googlemail.com" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Remove dots from local part
|
||||
localPart = strings.ReplaceAll(localPart, ".", "")
|
||||
|
||||
// Remove everything after + (plus-alias)
|
||||
if plusIndex := strings.Index(localPart, "+"); plusIndex != -1 {
|
||||
localPart = localPart[:plusIndex]
|
||||
}
|
||||
|
||||
return localPart + "@" + domain
|
||||
}
|
||||
|
||||
// ValidateEmailList validates a list of email addresses
|
||||
// Returns the first error encountered, or nil if all are valid
|
||||
func (ev *EmailValidator) ValidateEmailList(emails []string, fieldName string) ([]string, error) {
|
||||
normalized := make([]string, 0, len(emails))
|
||||
|
||||
for i, email := range emails {
|
||||
norm, err := ev.ValidateAndNormalize(email, fmt.Sprintf("%s[%d]", fieldName, i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, norm)
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// IsValidEmailDomain checks if a domain is likely valid (has proper structure)
|
||||
// This is a lightweight check - for production, consider DNS MX record validation
|
||||
func (ev *EmailValidator) IsValidEmailDomain(email string) bool {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Must have at least one dot
|
||||
if !strings.Contains(domain, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
// TLD must be at least 2 characters
|
||||
tldParts := strings.Split(domain, ".")
|
||||
if len(tldParts) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
tld := tldParts[len(tldParts)-1]
|
||||
if len(tld) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
120
cloud/maplepress-backend/pkg/validation/helpers.go
Normal file
120
cloud/maplepress-backend/pkg/validation/helpers.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ValidatePathUUID validates a UUID path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathUUID(r *http.Request, paramName string) (string, error) {
|
||||
value := r.PathValue(paramName)
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
validator := NewValidator()
|
||||
if err := validator.ValidateUUID(value, paramName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePathSlug validates a slug path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathSlug(r *http.Request, paramName string) (string, error) {
|
||||
value := r.PathValue(paramName)
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
validator := NewValidator()
|
||||
if err := validator.ValidateSlug(value, paramName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePathInt validates an integer path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathInt(r *http.Request, paramName string) (int64, error) {
|
||||
valueStr := r.PathValue(paramName)
|
||||
if valueStr == "" {
|
||||
return 0, fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
value, err := strconv.ParseInt(valueStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s must be a valid integer", paramName)
|
||||
}
|
||||
|
||||
if value <= 0 {
|
||||
return 0, fmt.Errorf("%s must be greater than 0", paramName)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePagination validates pagination query parameters
|
||||
// Returns limit and offset with defaults and bounds checking
|
||||
func ValidatePagination(r *http.Request, defaultLimit int) (limit int, offset int, err error) {
|
||||
limit = defaultLimit
|
||||
offset = 0
|
||||
|
||||
// Validate limit
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
parsedLimit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || parsedLimit <= 0 || parsedLimit > 100 {
|
||||
return 0, 0, fmt.Errorf("limit must be between 1 and 100")
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
|
||||
// Validate offset
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
parsedOffset, err := strconv.Atoi(offsetStr)
|
||||
if err != nil || parsedOffset < 0 {
|
||||
return 0, 0, fmt.Errorf("offset must be >= 0")
|
||||
}
|
||||
offset = parsedOffset
|
||||
}
|
||||
|
||||
return limit, offset, nil
|
||||
}
|
||||
|
||||
// ValidateSortField validates sort field against whitelist
|
||||
// CWE-89: SQL Injection prevention via whitelist
|
||||
func ValidateSortField(r *http.Request, allowedFields []string) (string, error) {
|
||||
sortBy := r.URL.Query().Get("sort_by")
|
||||
if sortBy == "" {
|
||||
return "", nil // Optional field
|
||||
}
|
||||
|
||||
for _, allowed := range allowedFields {
|
||||
if sortBy == allowed {
|
||||
return sortBy, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid sort_by field (allowed: %v)", allowedFields)
|
||||
}
|
||||
|
||||
// ValidateQueryEmail validates an email query parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidateQueryEmail(r *http.Request, paramName string) (string, error) {
|
||||
email := r.URL.Query().Get(paramName)
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
emailValidator := NewEmailValidator()
|
||||
normalizedEmail, err := emailValidator.ValidateAndNormalize(email, paramName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return normalizedEmail, nil
|
||||
}
|
||||
6
cloud/maplepress-backend/pkg/validation/provider.go
Normal file
6
cloud/maplepress-backend/pkg/validation/provider.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
package validation
|
||||
|
||||
// ProvideValidator provides a Validator instance
|
||||
func ProvideValidator() *Validator {
|
||||
return NewValidator()
|
||||
}
|
||||
498
cloud/maplepress-backend/pkg/validation/validator.go
Normal file
498
cloud/maplepress-backend/pkg/validation/validator.go
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Common validation errors
|
||||
var (
|
||||
ErrRequired = fmt.Errorf("field is required")
|
||||
ErrInvalidEmail = fmt.Errorf("invalid email format")
|
||||
ErrInvalidURL = fmt.Errorf("invalid URL format")
|
||||
ErrInvalidDomain = fmt.Errorf("invalid domain format")
|
||||
ErrTooShort = fmt.Errorf("value is too short")
|
||||
ErrTooLong = fmt.Errorf("value is too long")
|
||||
ErrInvalidCharacters = fmt.Errorf("contains invalid characters")
|
||||
ErrInvalidFormat = fmt.Errorf("invalid format")
|
||||
ErrInvalidValue = fmt.Errorf("invalid value")
|
||||
ErrWhitespaceOnly = fmt.Errorf("cannot contain only whitespace")
|
||||
ErrContainsHTML = fmt.Errorf("cannot contain HTML tags")
|
||||
ErrInvalidSlug = fmt.Errorf("invalid slug format")
|
||||
)
|
||||
|
||||
// Regex patterns for validation
|
||||
var (
|
||||
// Email validation: RFC 5322 compliant
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._%+\-]*[a-zA-Z0-9]@[a-zA-Z0-9][a-zA-Z0-9.\-]*[a-zA-Z0-9]\.[a-zA-Z]{2,}$`)
|
||||
|
||||
// Domain validation: alphanumeric with dots and hyphens
|
||||
domainRegex = regexp.MustCompile(`^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
|
||||
|
||||
// Slug validation: lowercase alphanumeric with hyphens
|
||||
slugRegex = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
|
||||
// HTML tag detection
|
||||
htmlTagRegex = regexp.MustCompile(`<[^>]+>`)
|
||||
|
||||
// UUID validation (version 4)
|
||||
uuidRegex = regexp.MustCompile(`^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$`)
|
||||
|
||||
// Alphanumeric only
|
||||
alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
|
||||
)
|
||||
|
||||
// Reserved slugs that cannot be used for tenant names
|
||||
var ReservedSlugs = map[string]bool{
|
||||
"api": true,
|
||||
"admin": true,
|
||||
"www": true,
|
||||
"mail": true,
|
||||
"email": true,
|
||||
"health": true,
|
||||
"status": true,
|
||||
"metrics": true,
|
||||
"static": true,
|
||||
"cdn": true,
|
||||
"assets": true,
|
||||
"blog": true,
|
||||
"docs": true,
|
||||
"help": true,
|
||||
"support": true,
|
||||
"login": true,
|
||||
"logout": true,
|
||||
"signup": true,
|
||||
"register": true,
|
||||
"app": true,
|
||||
"dashboard": true,
|
||||
"settings": true,
|
||||
"account": true,
|
||||
"profile": true,
|
||||
"root": true,
|
||||
"system": true,
|
||||
"public": true,
|
||||
"private": true,
|
||||
}
|
||||
|
||||
// Validator provides input validation utilities
|
||||
type Validator struct{}
|
||||
|
||||
// NewValidator creates a new validator instance
|
||||
func NewValidator() *Validator {
|
||||
return &Validator{}
|
||||
}
|
||||
|
||||
// ==================== String Validation ====================
|
||||
|
||||
// ValidateRequired checks if a string is not empty
|
||||
func (v *Validator) ValidateRequired(value, fieldName string) error {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateLength checks if string length is within range
|
||||
func (v *Validator) ValidateLength(value, fieldName string, min, max int) error {
|
||||
length := len(strings.TrimSpace(value))
|
||||
|
||||
if length < min {
|
||||
return fmt.Errorf("%s: %w (minimum %d characters)", fieldName, ErrTooShort, min)
|
||||
}
|
||||
|
||||
if max > 0 && length > max {
|
||||
return fmt.Errorf("%s: %w (maximum %d characters)", fieldName, ErrTooLong, max)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNotWhitespaceOnly ensures the string contains non-whitespace characters
|
||||
func (v *Validator) ValidateNotWhitespaceOnly(value, fieldName string) error {
|
||||
if len(strings.TrimSpace(value)) == 0 && len(value) > 0 {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrWhitespaceOnly)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNoHTML checks that the string doesn't contain HTML tags
|
||||
func (v *Validator) ValidateNoHTML(value, fieldName string) error {
|
||||
if htmlTagRegex.MatchString(value) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrContainsHTML)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAlphanumeric checks if string contains only alphanumeric characters
|
||||
func (v *Validator) ValidateAlphanumeric(value, fieldName string) error {
|
||||
if !alphanumericRegex.MatchString(value) {
|
||||
return fmt.Errorf("%s: %w (only letters and numbers allowed)", fieldName, ErrInvalidCharacters)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePrintable ensures string contains only printable characters
|
||||
func (v *Validator) ValidatePrintable(value, fieldName string) error {
|
||||
for _, r := range value {
|
||||
if !unicode.IsPrint(r) && !unicode.IsSpace(r) {
|
||||
return fmt.Errorf("%s: %w (contains non-printable characters)", fieldName, ErrInvalidCharacters)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Email Validation ====================
|
||||
|
||||
// ValidateEmail validates email format using RFC 5322 compliant regex
|
||||
func (v *Validator) ValidateEmail(email, fieldName string) error {
|
||||
email = strings.TrimSpace(email)
|
||||
|
||||
// Check required
|
||||
if email == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (RFC 5321: max 320 chars)
|
||||
if len(email) > 320 {
|
||||
return fmt.Errorf("%s: %w (maximum 320 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Validate using regex
|
||||
if !emailRegex.MatchString(email) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Additional validation using net/mail package
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Check for consecutive dots
|
||||
if strings.Contains(email, "..") {
|
||||
return fmt.Errorf("%s: %w (consecutive dots not allowed)", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Check for leading/trailing dots in local part
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) == 2 {
|
||||
if strings.HasPrefix(parts[0], ".") || strings.HasSuffix(parts[0], ".") {
|
||||
return fmt.Errorf("%s: %w (local part cannot start or end with dot)", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== URL Validation ====================
|
||||
|
||||
// ValidateURL validates URL format and ensures it has a valid scheme
|
||||
func (v *Validator) ValidateURL(urlStr, fieldName string) error {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
|
||||
// Check required
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (max 2048 chars for URL)
|
||||
if len(urlStr) > 2048 {
|
||||
return fmt.Errorf("%s: %w (maximum 2048 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Parse URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Ensure scheme is present and valid
|
||||
if parsedURL.Scheme == "" {
|
||||
return fmt.Errorf("%s: %w (missing scheme)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Only allow http and https
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("%s: %w (only http and https schemes allowed)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Ensure host is present
|
||||
if parsedURL.Host == "" {
|
||||
return fmt.Errorf("%s: %w (missing host)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateHTTPSURL validates URL and ensures it uses HTTPS
|
||||
func (v *Validator) ValidateHTTPSURL(urlStr, fieldName string) error {
|
||||
if err := v.ValidateURL(urlStr, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: invalid URL format", fieldName)
|
||||
}
|
||||
if parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("%s: must use HTTPS protocol", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Domain Validation ====================
|
||||
|
||||
// ValidateDomain validates domain name format
|
||||
// Supports standard domains (example.com) and localhost with ports (localhost:8081) for development
|
||||
func (v *Validator) ValidateDomain(domain, fieldName string) error {
|
||||
domain = strings.TrimSpace(strings.ToLower(domain))
|
||||
|
||||
// Check required
|
||||
if domain == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (max 253 chars per RFC 1035)
|
||||
if len(domain) > 253 {
|
||||
return fmt.Errorf("%s: %w (maximum 253 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(domain) < 4 {
|
||||
return fmt.Errorf("%s: %w (minimum 4 characters)", fieldName, ErrTooShort)
|
||||
}
|
||||
|
||||
// Allow localhost with optional port for development
|
||||
// Examples: localhost, localhost:8080, localhost:3000
|
||||
if strings.HasPrefix(domain, "localhost") {
|
||||
// If it has a port, validate the port format
|
||||
if strings.Contains(domain, ":") {
|
||||
parts := strings.Split(domain, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("%s: %w (invalid localhost format)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
// Port should be numeric
|
||||
if parts[1] == "" {
|
||||
return fmt.Errorf("%s: %w (missing port number)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
// Basic port validation (could be more strict)
|
||||
for _, c := range parts[1] {
|
||||
if c < '0' || c > '9' {
|
||||
return fmt.Errorf("%s: %w (port must be numeric)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Allow 127.0.0.1 and other local IPs with optional port for development
|
||||
if strings.HasPrefix(domain, "127.") || strings.HasPrefix(domain, "192.168.") || strings.HasPrefix(domain, "10.") {
|
||||
// If it has a port, just verify format (IP:port)
|
||||
if strings.Contains(domain, ":") {
|
||||
parts := strings.Split(domain, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("%s: %w (invalid IP format)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate standard domain format (example.com)
|
||||
if !domainRegex.MatchString(domain) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
|
||||
// Check each label length (max 63 chars per RFC 1035)
|
||||
labels := strings.Split(domain, ".")
|
||||
for _, label := range labels {
|
||||
if len(label) > 63 {
|
||||
return fmt.Errorf("%s: %w (label exceeds 63 characters)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Slug Validation ====================
|
||||
|
||||
// ValidateSlug validates slug format (lowercase alphanumeric with hyphens)
|
||||
func (v *Validator) ValidateSlug(slug, fieldName string) error {
|
||||
slug = strings.TrimSpace(strings.ToLower(slug))
|
||||
|
||||
// Check required
|
||||
if slug == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (3-63 chars)
|
||||
if len(slug) < 3 {
|
||||
return fmt.Errorf("%s: %w (minimum 3 characters)", fieldName, ErrTooShort)
|
||||
}
|
||||
|
||||
if len(slug) > 63 {
|
||||
return fmt.Errorf("%s: %w (maximum 63 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Validate format
|
||||
if !slugRegex.MatchString(slug) {
|
||||
return fmt.Errorf("%s: %w (only lowercase letters, numbers, and hyphens allowed)", fieldName, ErrInvalidSlug)
|
||||
}
|
||||
|
||||
// Check for reserved slugs
|
||||
if ReservedSlugs[slug] {
|
||||
return fmt.Errorf("%s: '%s' is a reserved slug and cannot be used", fieldName, slug)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSlug generates a URL-friendly slug from a name
|
||||
// Converts to lowercase, replaces spaces and special chars with hyphens
|
||||
// Ensures the slug matches the slug validation regex
|
||||
func (v *Validator) GenerateSlug(name string) string {
|
||||
// Convert to lowercase and trim spaces
|
||||
slug := strings.TrimSpace(strings.ToLower(name))
|
||||
|
||||
// Replace any non-alphanumeric characters (except hyphens) with hyphens
|
||||
var result strings.Builder
|
||||
prevWasHyphen := false
|
||||
|
||||
for _, char := range slug {
|
||||
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') {
|
||||
result.WriteRune(char)
|
||||
prevWasHyphen = false
|
||||
} else if !prevWasHyphen {
|
||||
// Replace any non-alphanumeric character with a hyphen
|
||||
// But don't add consecutive hyphens
|
||||
result.WriteRune('-')
|
||||
prevWasHyphen = true
|
||||
}
|
||||
}
|
||||
|
||||
slug = result.String()
|
||||
|
||||
// Remove leading and trailing hyphens
|
||||
slug = strings.Trim(slug, "-")
|
||||
|
||||
// Enforce length constraints (3-63 chars)
|
||||
if len(slug) < 3 {
|
||||
// If too short, pad with random suffix
|
||||
slug = slug + "-" + strings.ToLower(fmt.Sprintf("%d", time.Now().UnixNano()%10000))
|
||||
}
|
||||
|
||||
if len(slug) > 63 {
|
||||
// Truncate to 63 chars
|
||||
slug = slug[:63]
|
||||
// Remove trailing hyphen if any
|
||||
slug = strings.TrimRight(slug, "-")
|
||||
}
|
||||
|
||||
return slug
|
||||
}
|
||||
|
||||
// ==================== UUID Validation ====================
|
||||
|
||||
// ValidateUUID validates UUID format (version 4)
|
||||
func (v *Validator) ValidateUUID(id, fieldName string) error {
|
||||
id = strings.TrimSpace(strings.ToLower(id))
|
||||
|
||||
// Check required
|
||||
if id == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Validate format
|
||||
if !uuidRegex.MatchString(id) {
|
||||
return fmt.Errorf("%s: %w (must be a valid UUID v4)", fieldName, ErrInvalidFormat)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Enum Validation ====================
|
||||
|
||||
// ValidateEnum checks if value is in the allowed list (whitelist validation)
|
||||
func (v *Validator) ValidateEnum(value, fieldName string, allowedValues []string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Check required
|
||||
if value == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check if value is in allowed list
|
||||
for _, allowed := range allowedValues {
|
||||
if value == allowed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: %w (allowed values: %s)", fieldName, ErrInvalidValue, strings.Join(allowedValues, ", "))
|
||||
}
|
||||
|
||||
// ==================== Number Validation ====================
|
||||
|
||||
// ValidateRange checks if a number is within the specified range
|
||||
func (v *Validator) ValidateRange(value int, fieldName string, min, max int) error {
|
||||
if value < min {
|
||||
return fmt.Errorf("%s: value must be at least %d", fieldName, min)
|
||||
}
|
||||
|
||||
if max > 0 && value > max {
|
||||
return fmt.Errorf("%s: value must be at most %d", fieldName, max)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Sanitization ====================
|
||||
|
||||
// SanitizeString removes potentially dangerous characters and trims whitespace
|
||||
func (v *Validator) SanitizeString(value string) string {
|
||||
// Trim whitespace
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Remove null bytes
|
||||
value = strings.ReplaceAll(value, "\x00", "")
|
||||
|
||||
// Normalize Unicode
|
||||
// Note: For production, consider using golang.org/x/text/unicode/norm
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// StripHTML removes all HTML tags from a string
|
||||
func (v *Validator) StripHTML(value string) string {
|
||||
return htmlTagRegex.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
// ==================== Combined Validations ====================
|
||||
|
||||
// ValidateAndSanitizeString performs validation and sanitization
|
||||
func (v *Validator) ValidateAndSanitizeString(value, fieldName string, minLen, maxLen int) (string, error) {
|
||||
// Sanitize first
|
||||
value = v.SanitizeString(value)
|
||||
|
||||
// Validate required
|
||||
if err := v.ValidateRequired(value, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate length
|
||||
if err := v.ValidateLength(value, fieldName, minLen, maxLen); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate printable characters
|
||||
if err := v.ValidatePrintable(value, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
472
cloud/maplepress-backend/pkg/validation/validator_test.go
Normal file
472
cloud/maplepress-backend/pkg/validation/validator_test.go
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateRequired(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid non-empty string", "test", false},
|
||||
{"Empty string", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Tab only", "\t", true},
|
||||
{"Newline only", "\n", true},
|
||||
{"Valid with spaces", "hello world", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateRequired(tt.value, "test_field")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateRequired() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLength(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
min int
|
||||
max int
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid length", "hello", 3, 10, false},
|
||||
{"Too short", "ab", 3, 10, true},
|
||||
{"Too long", "hello world this is too long", 3, 10, true},
|
||||
{"Exact minimum", "abc", 3, 10, false},
|
||||
{"Exact maximum", "0123456789", 3, 10, false},
|
||||
{"No maximum (0)", "very long string here", 3, 0, false},
|
||||
{"Whitespace counted correctly", " test ", 4, 10, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateLength(tt.value, "test_field", tt.min, tt.max)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateLength() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid emails
|
||||
{"Valid email", "user@example.com", false},
|
||||
{"Valid email with plus", "user+tag@example.com", false},
|
||||
{"Valid email with dot", "first.last@example.com", false},
|
||||
{"Valid email with hyphen", "user-name@example-domain.com", false},
|
||||
{"Valid email with numbers", "user123@example456.com", false},
|
||||
{"Valid email with subdomain", "user@sub.example.com", false},
|
||||
|
||||
// Invalid emails
|
||||
{"Empty email", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Missing @", "userexample.com", true},
|
||||
{"Missing domain", "user@", true},
|
||||
{"Missing local part", "@example.com", true},
|
||||
{"No TLD", "user@localhost", true},
|
||||
{"Consecutive dots in local", "user..name@example.com", true},
|
||||
{"Leading dot in local", ".user@example.com", true},
|
||||
{"Trailing dot in local", "user.@example.com", true},
|
||||
{"Double @", "user@@example.com", true},
|
||||
{"Spaces in email", "user name@example.com", true},
|
||||
{"Invalid characters", "user<>@example.com", true},
|
||||
{"Too long", strings.Repeat("a", 320) + "@example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateEmail(tt.email, "email")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateEmail() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid URLs
|
||||
{"Valid HTTP URL", "http://example.com", false},
|
||||
{"Valid HTTPS URL", "https://example.com", false},
|
||||
{"Valid URL with path", "https://example.com/path/to/resource", false},
|
||||
{"Valid URL with query", "https://example.com?param=value", false},
|
||||
{"Valid URL with port", "https://example.com:8080", false},
|
||||
{"Valid URL with subdomain", "https://sub.example.com", false},
|
||||
|
||||
// Invalid URLs
|
||||
{"Empty URL", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Missing scheme", "example.com", true},
|
||||
{"Invalid scheme", "ftp://example.com", true},
|
||||
{"Missing host", "https://", true},
|
||||
{"Invalid characters", "https://exam ple.com", true},
|
||||
{"Too long", "https://" + strings.Repeat("a", 2048) + ".com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateURL(tt.url, "url")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateURL() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHTTPSURL(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid HTTPS URL", "https://example.com", false},
|
||||
{"HTTP URL (should fail)", "http://example.com", true},
|
||||
{"FTP URL (should fail)", "ftp://example.com", true},
|
||||
{"Invalid URL", "not-a-url", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateHTTPSURL(tt.url, "url")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateHTTPSURL() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid domains
|
||||
{"Valid domain", "example.com", false},
|
||||
{"Valid subdomain", "sub.example.com", false},
|
||||
{"Valid deep subdomain", "deep.sub.example.com", false},
|
||||
{"Valid with hyphen", "my-site.example.com", false},
|
||||
{"Valid with numbers", "site123.example456.com", false},
|
||||
|
||||
// Invalid domains
|
||||
{"Empty domain", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Too short", "a.b", true},
|
||||
{"Too long", strings.Repeat("a", 254) + ".com", true},
|
||||
{"Label too long", strings.Repeat("a", 64) + ".example.com", true},
|
||||
{"No TLD", "localhost", true},
|
||||
{"Leading hyphen", "-example.com", true},
|
||||
{"Trailing hyphen", "example-.com", true},
|
||||
{"Double dot", "example..com", true},
|
||||
{"Leading dot", ".example.com", true},
|
||||
{"Trailing dot", "example.com.", true},
|
||||
{"Underscore", "my_site.example.com", true},
|
||||
{"Spaces", "my site.example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateDomain(tt.domain, "domain")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateDomain() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSlug(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slug string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid slugs
|
||||
{"Valid slug", "my-company", false},
|
||||
{"Valid slug with numbers", "company123", false},
|
||||
{"Valid slug all lowercase", "testcompany", false},
|
||||
{"Valid slug with multiple hyphens", "my-test-company", false},
|
||||
|
||||
// Invalid slugs
|
||||
{"Empty slug", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Too short", "ab", true},
|
||||
{"Too long", strings.Repeat("a", 64), true},
|
||||
{"Uppercase letters", "MyCompany", true},
|
||||
{"Leading hyphen", "-company", true},
|
||||
{"Trailing hyphen", "company-", true},
|
||||
{"Double hyphen", "my--company", true},
|
||||
{"Underscore", "my_company", true},
|
||||
{"Spaces", "my company", true},
|
||||
{"Special characters", "my@company", true},
|
||||
|
||||
// Reserved slugs
|
||||
{"Reserved: api", "api", true},
|
||||
{"Reserved: admin", "admin", true},
|
||||
{"Reserved: www", "www", true},
|
||||
{"Reserved: login", "login", true},
|
||||
{"Reserved: register", "register", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateSlug(tt.slug, "slug")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateSlug() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUUID(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid UUID v4", "550e8400-e29b-41d4-a716-446655440000", false},
|
||||
{"Valid UUID v4 lowercase", "123e4567-e89b-42d3-a456-426614174000", false},
|
||||
{"Empty UUID", "", true},
|
||||
{"Invalid format", "not-a-uuid", true},
|
||||
{"Invalid version", "550e8400-e29b-21d4-a716-446655440000", true},
|
||||
{"Missing hyphens", "550e8400e29b41d4a716446655440000", true},
|
||||
{"Too short", "550e8400-e29b-41d4-a716", true},
|
||||
{"With uppercase", "550E8400-E29B-41D4-A716-446655440000", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateUUID(tt.uuid, "id")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateUUID() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEnum(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
allowedValues := []string{"free", "basic", "pro", "enterprise"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid: free", "free", false},
|
||||
{"Valid: basic", "basic", false},
|
||||
{"Valid: pro", "pro", false},
|
||||
{"Valid: enterprise", "enterprise", false},
|
||||
{"Invalid: premium", "premium", true},
|
||||
{"Invalid: empty", "", true},
|
||||
{"Invalid: wrong case", "FREE", true},
|
||||
{"Invalid: typo", "basi", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateEnum(tt.value, "plan_tier", allowedValues)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateEnum() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRange(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value int
|
||||
min int
|
||||
max int
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid within range", 5, 1, 10, false},
|
||||
{"Valid at minimum", 1, 1, 10, false},
|
||||
{"Valid at maximum", 10, 1, 10, false},
|
||||
{"Below minimum", 0, 1, 10, true},
|
||||
{"Above maximum", 11, 1, 10, true},
|
||||
{"No maximum (0)", 1000, 1, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateRange(tt.value, "count", tt.min, tt.max)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateRange() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNoHTML(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Plain text", "Hello world", false},
|
||||
{"Text with punctuation", "Hello, world!", false},
|
||||
{"HTML tag <script>", "<script>alert('xss')</script>", true},
|
||||
{"HTML tag <img>", "<img src='x'>", true},
|
||||
{"HTML tag <div>", "<div>content</div>", true},
|
||||
{"HTML tag <a>", "<a href='#'>link</a>", true},
|
||||
{"Less than symbol", "5 < 10", false},
|
||||
{"Greater than symbol", "10 > 5", false},
|
||||
{"Both symbols", "5 < x < 10", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateNoHTML(tt.value, "content")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateNoHTML() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeString(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Trim whitespace", " hello ", "hello"},
|
||||
{"Remove null bytes", "hello\x00world", "helloworld"},
|
||||
{"Already clean", "hello", "hello"},
|
||||
{"Empty string", "", ""},
|
||||
{"Only whitespace", " ", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := v.SanitizeString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeString() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripHTML(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Remove script tag", "<script>alert('xss')</script>", "alert('xss')"},
|
||||
{"Remove div tag", "<div>content</div>", "content"},
|
||||
{"Remove multiple tags", "<p>Hello <b>world</b></p>", "Hello world"},
|
||||
{"No tags", "plain text", "plain text"},
|
||||
{"Empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := v.StripHTML(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("StripHTML() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAndSanitizeString(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
minLen int
|
||||
maxLen int
|
||||
wantValue string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid and clean", "hello", 3, 10, "hello", false},
|
||||
{"Trim and validate", " hello ", 3, 10, "hello", false},
|
||||
{"Too short after trim", " a ", 3, 10, "", true},
|
||||
{"Too long", "hello world this is too long", 3, 10, "", true},
|
||||
{"Empty after trim", " ", 3, 10, "", true},
|
||||
{"Valid with null byte removed", "hel\x00lo", 3, 10, "hello", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := v.ValidateAndSanitizeString(tt.input, "test_field", tt.minLen, tt.maxLen)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateAndSanitizeString() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
if !tt.wantError && result != tt.wantValue {
|
||||
t.Errorf("ValidateAndSanitizeString() = %q, want %q", result, tt.wantValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrintable(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"All printable", "Hello World 123!", false},
|
||||
{"With tabs and newlines", "Hello\tWorld\n", false},
|
||||
{"With control character", "Hello\x01World", true},
|
||||
{"With bell character", "Hello\x07", true},
|
||||
{"Empty string", "", false},
|
||||
{"Unicode printable", "Hello 世界", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidatePrintable(tt.value, "test_field")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidatePrintable() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue