Initial commit: Open sourcing all of the Maple Open Technologies code.

This commit is contained in:
Bartlomiej Mika 2025-12-02 14:33:08 -05:00
commit 755d54a99d
2010 changed files with 448675 additions and 0 deletions

View file

@ -0,0 +1,109 @@
package cache
import (
"context"
"time"
"github.com/gocql/gocql"
"go.uber.org/zap"
)
// CassandraCacher defines the interface for Cassandra cache operations
type CassandraCacher interface {
Shutdown(ctx context.Context)
Get(ctx context.Context, key string) ([]byte, error)
Set(ctx context.Context, key string, val []byte) error
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
Delete(ctx context.Context, key string) error
PurgeExpired(ctx context.Context) error
}
type cassandraCache struct {
session *gocql.Session
logger *zap.Logger
}
// NewCassandraCache creates a new Cassandra cache instance
func NewCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
logger = logger.Named("cassandra-cache")
logger.Info("✓ Cassandra cache layer initialized")
return &cassandraCache{
session: session,
logger: logger,
}
}
func (c *cassandraCache) Shutdown(ctx context.Context) {
c.logger.Info("shutting down Cassandra cache")
// Note: Don't close the session here as it's managed by the database layer
}
func (c *cassandraCache) Get(ctx context.Context, key string) ([]byte, error) {
var value []byte
var expiresAt time.Time
query := `SELECT value, expires_at FROM cache WHERE key = ?`
err := c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Scan(&value, &expiresAt)
if err == gocql.ErrNotFound {
// Key doesn't exist - this is not an error
return nil, nil
}
if err != nil {
return nil, err
}
// Check if expired in application code
if time.Now().After(expiresAt) {
// Entry is expired, delete it and return nil
_ = c.Delete(ctx, key) // Clean up expired entry
return nil, nil
}
return value, nil
}
func (c *cassandraCache) Set(ctx context.Context, key string, val []byte) error {
expiresAt := time.Now().Add(24 * time.Hour) // Default 24 hour expiry
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
}
func (c *cassandraCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
expiresAt := time.Now().Add(expiry)
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
}
func (c *cassandraCache) Delete(ctx context.Context, key string) error {
query := `DELETE FROM cache WHERE key = ?`
return c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
}
func (c *cassandraCache) PurgeExpired(ctx context.Context) error {
now := time.Now()
// Thanks to the index on expires_at, this query is efficient
iter := c.session.Query(`SELECT key FROM cache WHERE expires_at < ? ALLOW FILTERING`, now).WithContext(ctx).Iter()
var expiredKeys []string
var key string
for iter.Scan(&key) {
expiredKeys = append(expiredKeys, key)
}
if err := iter.Close(); err != nil {
return err
}
// Delete expired keys in batch
if len(expiredKeys) > 0 {
batch := c.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
for _, expiredKey := range expiredKeys {
batch.Query(`DELETE FROM cache WHERE key = ?`, expiredKey)
}
return c.session.ExecuteBatch(batch)
}
return nil
}

View file

@ -0,0 +1,23 @@
package cache
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"github.com/gocql/gocql"
)
// ProvideRedisCache provides a Redis cache instance
func ProvideRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
return NewRedisCache(cfg, logger)
}
// ProvideCassandraCache provides a Cassandra cache instance
func ProvideCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
return NewCassandraCache(session, logger)
}
// ProvideTwoTierCache provides a two-tier cache instance
func ProvideTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
return NewTwoTierCache(redisCache, cassandraCache, logger)
}

View file

@ -0,0 +1,144 @@
package cache
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// silentRedisLogger filters out noisy "maintnotifications" warnings from go-redis
// This warning occurs when the Redis client tries to use newer Redis 7.2+ features
// that may not be fully supported by the current Redis version.
// The client automatically falls back to compatible mode, so this is harmless.
type silentRedisLogger struct {
logger *zap.Logger
}
func (l *silentRedisLogger) Printf(ctx context.Context, format string, v ...interface{}) {
msg := fmt.Sprintf(format, v...)
// Filter out harmless compatibility warnings
if strings.Contains(msg, "maintnotifications disabled") ||
strings.Contains(msg, "auto mode fallback") {
return
}
// Log other Redis messages at debug level
l.logger.Debug(msg)
}
// RedisCacher defines the interface for Redis cache operations
type RedisCacher interface {
Shutdown(ctx context.Context)
Get(ctx context.Context, key string) ([]byte, error)
Set(ctx context.Context, key string, val []byte) error
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
Delete(ctx context.Context, key string) error
}
type redisCache struct {
client *redis.Client
logger *zap.Logger
}
// NewRedisCache creates a new Redis cache instance
func NewRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
logger = logger.Named("redis-cache")
logger.Info("⏳ Connecting to Redis...",
zap.String("host", cfg.Cache.Host),
zap.Int("port", cfg.Cache.Port))
// Build Redis URL from config
redisURL := fmt.Sprintf("redis://:%s@%s:%d/%d",
cfg.Cache.Password,
cfg.Cache.Host,
cfg.Cache.Port,
cfg.Cache.DB,
)
// If no password, use simpler URL format
if cfg.Cache.Password == "" {
redisURL = fmt.Sprintf("redis://%s:%d/%d",
cfg.Cache.Host,
cfg.Cache.Port,
cfg.Cache.DB,
)
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
}
// Suppress noisy "maintnotifications" warnings from go-redis
// Use a custom logger that filters out these harmless compatibility warnings
redis.SetLogger(&silentRedisLogger{logger: logger.Named("redis-client")})
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err = client.Ping(ctx).Result(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
logger.Info("✓ Redis connected",
zap.String("host", cfg.Cache.Host),
zap.Int("port", cfg.Cache.Port),
zap.Int("db", cfg.Cache.DB))
return &redisCache{
client: client,
logger: logger,
}, nil
}
func (c *redisCache) Shutdown(ctx context.Context) {
c.logger.Info("shutting down Redis cache")
if err := c.client.Close(); err != nil {
c.logger.Error("error closing Redis connection", zap.Error(err))
}
}
func (c *redisCache) Get(ctx context.Context, key string) ([]byte, error) {
val, err := c.client.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
// Key doesn't exist - this is not an error
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("redis get failed: %w", err)
}
return []byte(val), nil
}
func (c *redisCache) Set(ctx context.Context, key string, val []byte) error {
if err := c.client.Set(ctx, key, val, 0).Err(); err != nil {
return fmt.Errorf("redis set failed: %w", err)
}
return nil
}
func (c *redisCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
if err := c.client.Set(ctx, key, val, expiry).Err(); err != nil {
return fmt.Errorf("redis set with expiry failed: %w", err)
}
return nil
}
func (c *redisCache) Delete(ctx context.Context, key string) error {
if err := c.client.Del(ctx, key).Err(); err != nil {
return fmt.Errorf("redis delete failed: %w", err)
}
return nil
}

View file

@ -0,0 +1,114 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/cache/twotier.go
package cache
import (
"context"
"time"
"go.uber.org/zap"
)
// TwoTierCacher defines the interface for two-tier cache operations
type TwoTierCacher interface {
Shutdown(ctx context.Context)
Get(ctx context.Context, key string) ([]byte, error)
Set(ctx context.Context, key string, val []byte) error
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
Delete(ctx context.Context, key string) error
PurgeExpired(ctx context.Context) error
}
// twoTierCache implements a clean 2-layer (read-through write-through) cache
//
// L1: Redis (fast, in-memory)
// L2: Cassandra (persistent)
//
// On Get: check Redis → then Cassandra → if found in Cassandra → populate Redis
// On Set: write to both
// On SetWithExpiry: write to both with expiry
// On Delete: remove from both
type twoTierCache struct {
redisCache RedisCacher
cassandraCache CassandraCacher
logger *zap.Logger
}
// NewTwoTierCache creates a new two-tier cache instance
func NewTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
logger = logger.Named("two-tier-cache")
logger.Info("✓ Two-tier cache initialized (Redis L1 + Cassandra L2)")
return &twoTierCache{
redisCache: redisCache,
cassandraCache: cassandraCache,
logger: logger,
}
}
func (c *twoTierCache) Get(ctx context.Context, key string) ([]byte, error) {
// Try L1 (Redis) first
val, err := c.redisCache.Get(ctx, key)
if err != nil {
return nil, err
}
if val != nil {
c.logger.Debug("cache hit from Redis", zap.String("key", key))
return val, nil
}
// Not in Redis, try L2 (Cassandra)
val, err = c.cassandraCache.Get(ctx, key)
if err != nil {
return nil, err
}
if val != nil {
// Found in Cassandra, populate Redis for future lookups
c.logger.Debug("cache hit from Cassandra, writing back to Redis", zap.String("key", key))
_ = c.redisCache.Set(ctx, key, val) // Best effort, don't fail if Redis write fails
}
return val, nil
}
func (c *twoTierCache) Set(ctx context.Context, key string, val []byte) error {
// Write to both layers
if err := c.redisCache.Set(ctx, key, val); err != nil {
return err
}
if err := c.cassandraCache.Set(ctx, key, val); err != nil {
return err
}
return nil
}
func (c *twoTierCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
// Write to both layers with expiry
if err := c.redisCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
return err
}
if err := c.cassandraCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
return err
}
return nil
}
func (c *twoTierCache) Delete(ctx context.Context, key string) error {
// Remove from both layers
if err := c.redisCache.Delete(ctx, key); err != nil {
return err
}
if err := c.cassandraCache.Delete(ctx, key); err != nil {
return err
}
return nil
}
func (c *twoTierCache) PurgeExpired(ctx context.Context) error {
// Only Cassandra needs purging (Redis handles TTL automatically)
return c.cassandraCache.PurgeExpired(ctx)
}
func (c *twoTierCache) Shutdown(ctx context.Context) {
c.logger.Info("shutting down two-tier cache")
c.redisCache.Shutdown(ctx)
c.cassandraCache.Shutdown(ctx)
c.logger.Info("two-tier cache shutdown complete")
}

View file

@ -0,0 +1,237 @@
# Distributed Mutex
A Redis-based distributed mutex implementation for coordinating access to shared resources across multiple application instances.
## Overview
This package provides a distributed locking mechanism using Redis as the coordination backend. It's built on top of the `redislock` library and provides a simple interface for acquiring and releasing locks across distributed systems.
## Features
- **Distributed Locking**: Coordinate access to shared resources across multiple application instances
- **Automatic Retry**: Built-in retry logic with configurable backoff strategy
- **Thread-Safe**: Safe for concurrent use within a single application
- **Formatted Keys**: Support for formatted lock keys using `Acquiref` and `Releasef`
- **Logging**: Integrated zap logging for debugging and monitoring
## Installation
The package is already included in the project. The required dependency (`github.com/bsm/redislock`) is automatically installed.
## Interface
```go
type Adapter interface {
Acquire(ctx context.Context, key string)
Acquiref(ctx context.Context, format string, a ...any)
Release(ctx context.Context, key string)
Releasef(ctx context.Context, format string, a ...any)
}
```
## Usage
### Basic Example
```go
import (
"context"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/distributedmutex"
)
// Create Redis client
redisClient := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
// Create logger
logger, _ := zap.NewProduction()
// Create distributed mutex adapter
mutex := distributedmutex.NewAdapter(logger, redisClient)
// Acquire a lock
ctx := context.Background()
mutex.Acquire(ctx, "my-resource-key")
// ... perform operations on the protected resource ...
// Release the lock
mutex.Release(ctx, "my-resource-key")
```
### Formatted Keys Example
```go
// Acquire lock with formatted key
tenantID := "tenant-123"
resourceID := "resource-456"
mutex.Acquiref(ctx, "tenant:%s:resource:%s", tenantID, resourceID)
// ... perform operations ...
mutex.Releasef(ctx, "tenant:%s:resource:%s", tenantID, resourceID)
```
### Integration with Dependency Injection (Wire)
```go
// In your Wire provider set
wire.NewSet(
distributedmutex.ProvideDistributedMutexAdapter,
// ... other providers
)
// Use in your application
func NewMyService(mutex distributedmutex.Adapter) *MyService {
return &MyService{
mutex: mutex,
}
}
```
## Configuration
### Lock Duration
The default lock duration is **1 minute**. Locks are automatically released after this time to prevent deadlocks.
### Retry Strategy
- **Retry Interval**: 250ms
- **Max Retries**: 20 attempts
- **Total Max Wait Time**: ~5 seconds (20 × 250ms)
If a lock cannot be obtained after all retries, an error is logged and the `Acquire` method returns without blocking indefinitely.
## Best Practices
1. **Always Release Locks**: Ensure locks are released even in error cases using `defer`
```go
mutex.Acquire(ctx, "my-key")
defer mutex.Release(ctx, "my-key")
```
2. **Use Descriptive Keys**: Use clear, hierarchical key names
```go
// Good
mutex.Acquire(ctx, "tenant:123:user:456:update")
// Not ideal
mutex.Acquire(ctx, "lock1")
```
3. **Keep Critical Sections Short**: Minimize the time locks are held to improve concurrency
4. **Handle Timeouts**: Use context with timeout for critical operations
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mutex.Acquire(ctx, "my-key")
```
5. **Avoid Nested Locks**: Be careful with acquiring multiple locks to avoid deadlocks
## Logging
The adapter logs the following events:
- **Debug**: Lock acquisition and release operations
- **Error**: Failed lock acquisitions, timeout errors, and release failures
- **Warn**: Attempts to release non-existent locks
## Thread Safety
The adapter is safe for concurrent use within a single application instance. It uses an internal mutex to protect the lock instances map from concurrent access by multiple goroutines.
## Error Handling
The current implementation logs errors but does not return them. Consider this when using the adapter:
- Lock acquisition failures are logged but don't panic
- The application continues running even if locks fail
- Check logs for lock-related issues in production
## Limitations
1. **Lock Duration**: Locks automatically expire after 1 minute
2. **No Lock Extension**: Currently doesn't support extending lock duration
3. **No Deadlock Detection**: Manual deadlock prevention is required
4. **Redis Dependency**: Requires a running Redis instance
## Example Use Cases
### Preventing Duplicate Processing
```go
func ProcessJob(ctx context.Context, jobID string, mutex distributedmutex.Adapter) {
lockKey := fmt.Sprintf("job:processing:%s", jobID)
mutex.Acquire(ctx, lockKey)
defer mutex.Release(ctx, lockKey)
// Process job - guaranteed only one instance processes this job
// ...
}
```
### Coordinating Resource Updates
```go
func UpdateTenantSettings(ctx context.Context, tenantID string, mutex distributedmutex.Adapter) error {
mutex.Acquiref(ctx, "tenant:%s:settings:update", tenantID)
defer mutex.Releasef(ctx, "tenant:%s:settings:update", tenantID)
// Safe to update tenant settings
// ...
return nil
}
```
### Rate Limiting Operations
```go
func RateLimitedOperation(ctx context.Context, userID string, mutex distributedmutex.Adapter) {
lockKey := fmt.Sprintf("ratelimit:user:%s", userID)
mutex.Acquire(ctx, lockKey)
defer mutex.Release(ctx, lockKey)
// Perform rate-limited operation
// ...
}
```
## Troubleshooting
### Lock Not Acquired
**Problem**: Locks are not being acquired (error in logs)
**Solutions**:
- Verify Redis is running and accessible
- Check network connectivity to Redis
- Ensure Redis has sufficient memory
- Check for Redis errors in logs
### Lock Contention
**Problem**: Frequent lock acquisition failures due to contention
**Solutions**:
- Reduce critical section duration
- Use more specific lock keys to reduce contention
- Consider increasing retry limits if appropriate
- Review application architecture for excessive locking
### Memory Leaks
**Problem**: Lock instances accumulating in memory
**Solutions**:
- Ensure all `Acquire` calls have corresponding `Release` calls
- Use `defer` to guarantee lock release
- Monitor lock instance map size in production

View file

@ -0,0 +1,138 @@
package distributedmutex
import (
"context"
"fmt"
"sync"
"time"
"github.com/bsm/redislock"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// Adapter provides interface for abstracting distributed mutex operations.
// CWE-755: Methods now return errors to properly handle exceptional conditions
type Adapter interface {
Acquire(ctx context.Context, key string) error
Acquiref(ctx context.Context, format string, a ...any) error
Release(ctx context.Context, key string) error
Releasef(ctx context.Context, format string, a ...any) error
}
type distributedMutexAdapter struct {
logger *zap.Logger
redis redis.UniversalClient
locker *redislock.Client
lockInstances map[string]*redislock.Lock
mutex *sync.Mutex // Mutex for synchronization with goroutines
}
// NewAdapter constructor that returns the default distributed mutex adapter.
func NewAdapter(logger *zap.Logger, redisClient redis.UniversalClient) Adapter {
logger = logger.Named("distributed-mutex")
// Create a new lock client
locker := redislock.New(redisClient)
logger.Info("✓ Distributed mutex initialized (Redis-backed)")
return &distributedMutexAdapter{
logger: logger,
redis: redisClient,
locker: locker,
lockInstances: make(map[string]*redislock.Lock),
mutex: &sync.Mutex{}, // Initialize the mutex
}
}
// Acquire function blocks the current thread if the lock key is currently locked.
// CWE-755: Now returns error instead of silently failing
func (a *distributedMutexAdapter) Acquire(ctx context.Context, key string) error {
startDT := time.Now()
a.logger.Debug("acquiring lock", zap.String("key", key))
// Retry every 250ms, for up to 20x
backoff := redislock.LimitRetry(redislock.LinearBackoff(250*time.Millisecond), 20)
// Obtain lock with retry
lock, err := a.locker.Obtain(ctx, key, time.Minute, &redislock.Options{
RetryStrategy: backoff,
})
if err == redislock.ErrNotObtained {
nowDT := time.Now()
diff := nowDT.Sub(startDT)
a.logger.Error("could not obtain lock after retries",
zap.String("key", key),
zap.Time("start_dt", startDT),
zap.Time("now_dt", nowDT),
zap.Duration("duration", diff),
zap.Int("max_retries", 20))
return fmt.Errorf("could not obtain lock after 20 retries (waited %s): %w", diff, err)
} else if err != nil {
a.logger.Error("failed obtaining lock",
zap.String("key", key),
zap.Error(err))
return fmt.Errorf("failed to obtain lock: %w", err)
}
// DEVELOPERS NOTE:
// The `map` datastructure in Golang is not concurrently safe, therefore we
// need to use mutex to coordinate access of our `lockInstances` map
// resource between all the goroutines.
a.mutex.Lock()
defer a.mutex.Unlock()
if a.lockInstances != nil { // Defensive code
a.lockInstances[key] = lock
}
a.logger.Debug("lock acquired", zap.String("key", key))
return nil // Success
}
// Acquiref function blocks the current thread if the lock key is currently locked.
// CWE-755: Now returns error from Acquire
func (a *distributedMutexAdapter) Acquiref(ctx context.Context, format string, args ...any) error {
key := fmt.Sprintf(format, args...)
return a.Acquire(ctx, key)
}
// Release function releases the lock for the given key.
// CWE-755: Now returns error instead of silently failing
func (a *distributedMutexAdapter) Release(ctx context.Context, key string) error {
a.logger.Debug("releasing lock", zap.String("key", key))
// DEVELOPERS NOTE:
// The `map` datastructure in Golang is not concurrently safe, therefore we
// need to use mutex to coordinate access of our `lockInstances` map
// resource between all the goroutines.
a.mutex.Lock()
lockInstance, ok := a.lockInstances[key]
if ok {
delete(a.lockInstances, key)
}
a.mutex.Unlock()
if ok {
if err := lockInstance.Release(ctx); err != nil {
a.logger.Error("failed to release lock",
zap.String("key", key),
zap.Error(err))
return fmt.Errorf("failed to release lock: %w", err)
}
a.logger.Debug("lock released", zap.String("key", key))
return nil // Success
}
// Lock not found - this is a warning but not an error (may have already been released)
a.logger.Warn("lock not found for release", zap.String("key", key))
return nil // Not an error, just not found
}
// Releasef function releases the lock for a formatted key.
// CWE-755: Now returns error from Release
func (a *distributedMutexAdapter) Releasef(ctx context.Context, format string, args ...any) error {
key := fmt.Sprintf(format, args...)
return a.Release(ctx, key)
}

View file

@ -0,0 +1,70 @@
package distributedmutex
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// mockRedisClient implements minimal required methods for testing
type mockRedisClient struct {
redis.UniversalClient
}
func (m *mockRedisClient) Get(ctx context.Context, key string) *redis.StringCmd {
return redis.NewStringCmd(ctx)
}
func (m *mockRedisClient) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd {
return redis.NewStatusCmd(ctx)
}
func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...any) *redis.Cmd {
return redis.NewCmd(ctx)
}
func (m *mockRedisClient) EvalSha(ctx context.Context, sha string, keys []string, args ...any) *redis.Cmd {
return redis.NewCmd(ctx)
}
func (m *mockRedisClient) ScriptExists(ctx context.Context, scripts ...string) *redis.BoolSliceCmd {
return redis.NewBoolSliceCmd(ctx)
}
func (m *mockRedisClient) ScriptLoad(ctx context.Context, script string) *redis.StringCmd {
return redis.NewStringCmd(ctx)
}
func TestNewAdapter(t *testing.T) {
logger, _ := zap.NewDevelopment()
adapter := NewAdapter(logger, &mockRedisClient{})
if adapter == nil {
t.Fatal("expected non-nil adapter")
}
}
func TestAcquireAndRelease(t *testing.T) {
ctx := context.Background()
logger, _ := zap.NewDevelopment()
adapter := NewAdapter(logger, &mockRedisClient{})
// Test basic acquire/release
adapter.Acquire(ctx, "test-key")
adapter.Release(ctx, "test-key")
// Test formatted acquire/release
adapter.Acquiref(ctx, "test-key-%d", 1)
adapter.Releasef(ctx, "test-key-%d", 1)
}
func TestReleaseNonExistentLock(t *testing.T) {
ctx := context.Background()
logger, _ := zap.NewDevelopment()
adapter := NewAdapter(logger, &mockRedisClient{})
// This should not panic, just log a warning
adapter.Release(ctx, "non-existent-key")
}

View file

@ -0,0 +1,13 @@
package distributedmutex
import (
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// ProvideDistributedMutexAdapter creates a new distributed mutex adapter instance.
// Accepts *redis.Client which implements redis.UniversalClient interface
func ProvideDistributedMutexAdapter(logger *zap.Logger, redisClient *redis.Client) Adapter {
// redis.Client implements redis.UniversalClient, so we can pass it directly
return NewAdapter(logger, redisClient)
}

View file

@ -0,0 +1,113 @@
package dns
import (
"context"
"fmt"
"net"
"strings"
"time"
"go.uber.org/zap"
)
// Verifier handles DNS TXT record verification
type Verifier struct {
resolver *net.Resolver
logger *zap.Logger
}
// ProvideVerifier creates a new DNS Verifier
func ProvideVerifier(logger *zap.Logger) *Verifier {
return &Verifier{
resolver: &net.Resolver{
PreferGo: true, // Use Go's DNS resolver
},
logger: logger.Named("dns-verifier"),
}
}
// VerifyDomainOwnership checks if a domain has the correct TXT record
// Expected format: "maplepress-verify=TOKEN"
func (v *Verifier) VerifyDomainOwnership(ctx context.Context, domain string, expectedToken string) (bool, error) {
v.logger.Info("verifying domain ownership via DNS",
zap.String("domain", domain))
// Create context with timeout (10 seconds for DNS lookup)
lookupCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Look up TXT records for the domain
txtRecords, err := v.resolver.LookupTXT(lookupCtx, domain)
if err != nil {
// Check if it's a timeout
if lookupCtx.Err() == context.DeadlineExceeded {
v.logger.Warn("DNS lookup timed out",
zap.String("domain", domain))
return false, fmt.Errorf("DNS lookup timed out after 10 seconds")
}
// Check if domain doesn't exist
if dnsErr, ok := err.(*net.DNSError); ok {
if dnsErr.IsNotFound {
v.logger.Warn("domain not found",
zap.String("domain", domain))
return false, fmt.Errorf("domain not found: %s", domain)
}
}
v.logger.Error("failed to lookup TXT records",
zap.String("domain", domain),
zap.Error(err))
return false, fmt.Errorf("failed to lookup DNS TXT records: %w", err)
}
// Expected verification record format
expectedRecord := fmt.Sprintf("maplepress-verify=%s", expectedToken)
// Check each TXT record
for _, record := range txtRecords {
v.logger.Debug("checking TXT record",
zap.String("domain", domain),
zap.String("record", record))
// Normalize whitespace and compare
normalizedRecord := strings.TrimSpace(record)
if normalizedRecord == expectedRecord {
v.logger.Info("domain ownership verified",
zap.String("domain", domain))
return true, nil
}
}
v.logger.Warn("verification record not found",
zap.String("domain", domain),
zap.String("expected", expectedRecord),
zap.Int("records_checked", len(txtRecords)))
return false, nil
}
// GetVerificationRecord returns the TXT record format for a given token
func GetVerificationRecord(token string) string {
return fmt.Sprintf("maplepress-verify=%s", token)
}
// GetVerificationInstructions returns user-friendly instructions
func GetVerificationInstructions(domain string, token string) string {
record := GetVerificationRecord(token)
return fmt.Sprintf(`To verify ownership of %s, add this DNS TXT record:
Host/Name: %s
Type: TXT
Value: %s
Instructions:
1. Log in to your domain registrar (GoDaddy, Namecheap, Cloudflare, etc.)
2. Find DNS settings or DNS management
3. Add a new TXT record with the values above
4. Wait 5-10 minutes for DNS propagation
5. Click "Verify Domain" in MaplePress
Note: DNS changes can take up to 48 hours to propagate globally, but usually complete within 10 minutes.`,
domain, domain, record)
}

View file

@ -0,0 +1,61 @@
package mailgun
type MailgunConfigurationProvider interface {
GetSenderEmail() string
GetDomainName() string // Deprecated
GetBackendDomainName() string
GetFrontendDomainName() string
GetMaintenanceEmail() string
GetAPIKey() string
GetAPIBase() string
}
type mailgunConfigurationProviderImpl struct {
senderEmail string
domain string
apiBase string
maintenanceEmail string
frontendDomain string
backendDomain string
apiKey string
}
func NewMailgunConfigurationProvider(senderEmail, domain, apiBase, maintenanceEmail, frontendDomain, backendDomain, apiKey string) MailgunConfigurationProvider {
return &mailgunConfigurationProviderImpl{
senderEmail: senderEmail,
domain: domain,
apiBase: apiBase,
maintenanceEmail: maintenanceEmail,
frontendDomain: frontendDomain,
backendDomain: backendDomain,
apiKey: apiKey,
}
}
func (me *mailgunConfigurationProviderImpl) GetDomainName() string {
return me.domain
}
func (me *mailgunConfigurationProviderImpl) GetSenderEmail() string {
return me.senderEmail
}
func (me *mailgunConfigurationProviderImpl) GetBackendDomainName() string {
return me.backendDomain
}
func (me *mailgunConfigurationProviderImpl) GetFrontendDomainName() string {
return me.frontendDomain
}
func (me *mailgunConfigurationProviderImpl) GetMaintenanceEmail() string {
return me.maintenanceEmail
}
func (me *mailgunConfigurationProviderImpl) GetAPIKey() string {
return me.apiKey
}
func (me *mailgunConfigurationProviderImpl) GetAPIBase() string {
return me.apiBase
}

View file

@ -0,0 +1,12 @@
package mailgun
import "context"
type Emailer interface {
Send(ctx context.Context, sender, subject, recipient, htmlContent string) error
GetSenderEmail() string
GetDomainName() string // Deprecated
GetBackendDomainName() string
GetFrontendDomainName() string
GetMaintenanceEmail() string
}

View file

@ -0,0 +1,86 @@
package mailgun
import (
"context"
"time"
"github.com/mailgun/mailgun-go/v4"
"go.uber.org/zap"
)
type mailgunEmailer struct {
config MailgunConfigurationProvider
logger *zap.Logger
Mailgun *mailgun.MailgunImpl
}
func NewEmailer(config MailgunConfigurationProvider, logger *zap.Logger) Emailer {
logger = logger.Named("mailgun-emailer")
// Initialize Mailgun client
mg := mailgun.NewMailgun(config.GetDomainName(), config.GetAPIKey())
mg.SetAPIBase(config.GetAPIBase()) // Override to support our custom email requirements.
logger.Info("✓ Mailgun emailer initialized",
zap.String("domain", config.GetDomainName()),
zap.String("api_base", config.GetAPIBase()))
return &mailgunEmailer{
config: config,
logger: logger,
Mailgun: mg,
}
}
func (me *mailgunEmailer) Send(ctx context.Context, sender, subject, recipient, body string) error {
me.logger.Debug("Sending email",
zap.String("sender", sender),
zap.String("recipient", recipient),
zap.String("subject", subject))
message := me.Mailgun.NewMessage(sender, subject, "", recipient)
message.SetHtml(body)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
// Send the message with a 10 second timeout
resp, id, err := me.Mailgun.Send(ctx, message)
if err != nil {
me.logger.Error("Failed to send email",
zap.String("sender", sender),
zap.String("recipient", recipient),
zap.String("subject", subject),
zap.Error(err))
return err
}
me.logger.Info("Email sent successfully",
zap.String("recipient", recipient),
zap.String("subject", subject),
zap.String("message_id", id),
zap.String("response", resp))
return nil
}
func (me *mailgunEmailer) GetDomainName() string {
return me.config.GetDomainName()
}
func (me *mailgunEmailer) GetSenderEmail() string {
return me.config.GetSenderEmail()
}
func (me *mailgunEmailer) GetBackendDomainName() string {
return me.config.GetBackendDomainName()
}
func (me *mailgunEmailer) GetFrontendDomainName() string {
return me.config.GetFrontendDomainName()
}
func (me *mailgunEmailer) GetMaintenanceEmail() string {
return me.config.GetMaintenanceEmail()
}

View file

@ -0,0 +1,26 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/emailer/mailgun/provider.go
package mailgun
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideMailgunConfigurationProvider creates a new Mailgun configuration provider from the application config.
func ProvideMailgunConfigurationProvider(cfg *config.Config) MailgunConfigurationProvider {
return NewMailgunConfigurationProvider(
cfg.Mailgun.SenderEmail,
cfg.Mailgun.Domain,
cfg.Mailgun.APIBase,
cfg.Mailgun.MaintenanceEmail,
cfg.Mailgun.FrontendDomain,
cfg.Mailgun.BackendDomain,
cfg.Mailgun.APIKey,
)
}
// ProvideEmailer creates a new Mailgun emailer from the configuration provider.
func ProvideEmailer(config MailgunConfigurationProvider, logger *zap.Logger) Emailer {
return NewEmailer(config, logger)
}

View file

@ -0,0 +1,187 @@
package httperror
import (
"encoding/json"
"net/http"
)
// ErrorResponse represents an HTTP error response (legacy format)
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Code int `json:"code"`
}
// ProblemDetail represents an RFC 9457 compliant error response
// See: https://datatracker.ietf.org/doc/html/rfc9457
type ProblemDetail struct {
Type string `json:"type"` // URI reference identifying the problem type
Title string `json:"title"` // Short, human-readable summary
Status int `json:"status"` // HTTP status code
Detail string `json:"detail,omitempty"` // Human-readable explanation
Instance string `json:"instance,omitempty"` // URI reference identifying the specific occurrence
Errors map[string][]string `json:"errors,omitempty"` // Validation errors (extension field)
Extra map[string]interface{} `json:"-"` // Additional extension members
}
// WriteError writes an error response with pretty printing (legacy format)
func WriteError(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
response := ErrorResponse{
Error: http.StatusText(code),
Message: message,
Code: code,
}
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
encoder.Encode(response)
}
// WriteProblemDetail writes an RFC 9457 compliant error response
func WriteProblemDetail(w http.ResponseWriter, problem *ProblemDetail) {
w.Header().Set("Content-Type", "application/problem+json")
w.WriteHeader(problem.Status)
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
encoder.Encode(problem)
}
// BadRequest writes a 400 error
func BadRequest(w http.ResponseWriter, message string) {
WriteError(w, http.StatusBadRequest, message)
}
// Unauthorized writes a 401 error
func Unauthorized(w http.ResponseWriter, message string) {
WriteError(w, http.StatusUnauthorized, message)
}
// Forbidden writes a 403 error
func Forbidden(w http.ResponseWriter, message string) {
WriteError(w, http.StatusForbidden, message)
}
// NotFound writes a 404 error
func NotFound(w http.ResponseWriter, message string) {
WriteError(w, http.StatusNotFound, message)
}
// Conflict writes a 409 error
func Conflict(w http.ResponseWriter, message string) {
WriteError(w, http.StatusConflict, message)
}
// TooManyRequests writes a 429 error
func TooManyRequests(w http.ResponseWriter, message string) {
WriteError(w, http.StatusTooManyRequests, message)
}
// InternalServerError writes a 500 error
func InternalServerError(w http.ResponseWriter, message string) {
WriteError(w, http.StatusInternalServerError, message)
}
// ValidationError writes an RFC 9457 validation error response (400)
func ValidationError(w http.ResponseWriter, errors map[string][]string, detail string) {
if detail == "" {
detail = "One or more validation errors occurred"
}
problem := &ProblemDetail{
Type: "about:blank", // Using about:blank as per RFC 9457 when no specific problem type URI is defined
Title: "Validation Error",
Status: http.StatusBadRequest,
Detail: detail,
Errors: errors,
}
WriteProblemDetail(w, problem)
}
// ProblemBadRequest writes an RFC 9457 bad request error (400)
func ProblemBadRequest(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Bad Request",
Status: http.StatusBadRequest,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemUnauthorized writes an RFC 9457 unauthorized error (401)
func ProblemUnauthorized(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Unauthorized",
Status: http.StatusUnauthorized,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemForbidden writes an RFC 9457 forbidden error (403)
func ProblemForbidden(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Forbidden",
Status: http.StatusForbidden,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemNotFound writes an RFC 9457 not found error (404)
func ProblemNotFound(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Not Found",
Status: http.StatusNotFound,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemConflict writes an RFC 9457 conflict error (409)
func ProblemConflict(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Conflict",
Status: http.StatusConflict,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemTooManyRequests writes an RFC 9457 too many requests error (429)
func ProblemTooManyRequests(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Too Many Requests",
Status: http.StatusTooManyRequests,
Detail: detail,
}
WriteProblemDetail(w, problem)
}
// ProblemInternalServerError writes an RFC 9457 internal server error (500)
func ProblemInternalServerError(w http.ResponseWriter, detail string) {
problem := &ProblemDetail{
Type: "about:blank",
Title: "Internal Server Error",
Status: http.StatusInternalServerError,
Detail: detail,
}
WriteProblemDetail(w, problem)
}

View file

@ -0,0 +1,31 @@
package httpresponse
import (
"encoding/json"
"net/http"
)
// JSON writes a JSON response with pretty printing (indented)
func JSON(w http.ResponseWriter, code int, data interface{}) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
return encoder.Encode(data)
}
// OK writes a 200 JSON response with pretty printing
func OK(w http.ResponseWriter, data interface{}) error {
return JSON(w, http.StatusOK, data)
}
// Created writes a 201 JSON response with pretty printing
func Created(w http.ResponseWriter, data interface{}) error {
return JSON(w, http.StatusCreated, data)
}
// NoContent writes a 204 No Content response
func NoContent(w http.ResponseWriter) {
w.WriteHeader(http.StatusNoContent)
}

View file

@ -0,0 +1,70 @@
package httpvalidation
import (
"errors"
"net/http"
"strings"
)
var (
// ErrInvalidContentType is returned when Content-Type header is not application/json
ErrInvalidContentType = errors.New("Content-Type must be application/json")
// ErrMissingContentType is returned when Content-Type header is missing
ErrMissingContentType = errors.New("Content-Type header is required")
)
// ValidateJSONContentType validates that the request has application/json Content-Type
// CWE-436: Validates Content-Type before parsing to prevent interpretation conflicts
// Accepts both "application/json" and "application/json; charset=utf-8"
func ValidateJSONContentType(r *http.Request) error {
contentType := r.Header.Get("Content-Type")
// Accept empty Content-Type for backward compatibility (some clients don't set it)
if contentType == "" {
return nil
}
// Check for exact match or charset variant
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
return nil
}
return ErrInvalidContentType
}
// RequireJSONContentType validates that the request has application/json Content-Type
// CWE-436: Strict validation that requires Content-Type header
// Use this for new endpoints where you want to enforce the header
func RequireJSONContentType(r *http.Request) error {
contentType := r.Header.Get("Content-Type")
if contentType == "" {
return ErrInvalidContentType
}
// Check for exact match or charset variant
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
return nil
}
return ErrInvalidContentType
}
// ValidateJSONContentTypeStrict validates that the request has application/json Content-Type
// CWE-16: Configuration - Enforces strict Content-Type validation
// This version REQUIRES the Content-Type header and returns specific error for missing header
func ValidateJSONContentTypeStrict(r *http.Request) error {
contentType := r.Header.Get("Content-Type")
// Require Content-Type header (no empty allowed)
if contentType == "" {
return ErrMissingContentType
}
// Check for exact match or charset variant
if contentType == "application/json" || strings.HasPrefix(contentType, "application/json;") {
return nil
}
return ErrInvalidContentType
}

View file

@ -0,0 +1,136 @@
// Package leaderelection provides distributed leader election for multiple application instances.
// It ensures only one instance acts as the leader at any given time, with automatic failover.
package leaderelection
import (
"context"
"time"
)
// LeaderElection provides distributed leader election across multiple application instances.
// It uses Redis to coordinate which instance is the current leader, with automatic failover
// if the leader crashes or becomes unavailable.
type LeaderElection interface {
// Start begins participating in leader election.
// This method blocks and runs the election loop until ctx is cancelled or an error occurs.
// The instance will automatically attempt to become leader and maintain leadership.
Start(ctx context.Context) error
// IsLeader returns true if this instance is currently the leader.
// This is a local check and does not require network communication.
IsLeader() bool
// GetLeaderID returns the unique identifier of the current leader instance.
// Returns empty string if no leader exists (should be rare).
GetLeaderID() (string, error)
// GetLeaderInfo returns detailed information about the current leader.
GetLeaderInfo() (*LeaderInfo, error)
// OnBecomeLeader registers a callback function that will be executed when
// this instance becomes the leader. Multiple callbacks can be registered.
OnBecomeLeader(callback func())
// OnLoseLeadership registers a callback function that will be executed when
// this instance loses leadership (either voluntarily or due to failure).
// Multiple callbacks can be registered.
OnLoseLeadership(callback func())
// Stop gracefully stops leader election participation.
// If this instance is the leader, it releases leadership allowing another instance to take over.
// This should be called during application shutdown.
Stop() error
// GetInstanceID returns the unique identifier for this instance.
GetInstanceID() string
}
// LeaderInfo contains information about the current leader.
type LeaderInfo struct {
// InstanceID is the unique identifier of the leader instance
InstanceID string `json:"instance_id"`
// Hostname is the hostname of the leader instance
Hostname string `json:"hostname"`
// StartedAt is when this instance became the leader
StartedAt time.Time `json:"started_at"`
// LastHeartbeat is the last time the leader renewed its lock
LastHeartbeat time.Time `json:"last_heartbeat"`
}
// Config contains configuration for leader election.
type Config struct {
// RedisKeyName is the Redis key used for leader election.
// Default: "maplefile:leader:lock"
RedisKeyName string
// RedisInfoKeyName is the Redis key used to store leader information.
// Default: "maplefile:leader:info"
RedisInfoKeyName string
// LockTTL is how long the leader lock lasts before expiring.
// The leader must renew the lock before this time expires.
// Default: 10 seconds
// Recommended: 10-30 seconds
LockTTL time.Duration
// HeartbeatInterval is how often the leader renews its lock.
// This should be significantly less than LockTTL (e.g., LockTTL / 3).
// Default: 3 seconds
// Recommended: LockTTL / 3
HeartbeatInterval time.Duration
// RetryInterval is how often followers check for leadership opportunity.
// Default: 2 seconds
// Recommended: 1-5 seconds
RetryInterval time.Duration
// InstanceID uniquely identifies this application instance.
// If empty, will be auto-generated from hostname + random suffix.
// Default: auto-generated
InstanceID string
// Hostname is the hostname of this instance.
// If empty, will be auto-detected.
// Default: os.Hostname()
Hostname string
}
// DefaultConfig returns a Config with sensible defaults.
func DefaultConfig() *Config {
return &Config{
RedisKeyName: "maplefile:leader:lock",
RedisInfoKeyName: "maplefile:leader:info",
LockTTL: 10 * time.Second,
HeartbeatInterval: 3 * time.Second,
RetryInterval: 2 * time.Second,
}
}
// Validate checks if the configuration is valid and returns an error if not.
func (c *Config) Validate() error {
if c.RedisKeyName == "" {
c.RedisKeyName = "maplefile:leader:lock"
}
if c.RedisInfoKeyName == "" {
c.RedisInfoKeyName = "maplefile:leader:info"
}
if c.LockTTL <= 0 {
c.LockTTL = 10 * time.Second
}
if c.HeartbeatInterval <= 0 {
c.HeartbeatInterval = 3 * time.Second
}
if c.RetryInterval <= 0 {
c.RetryInterval = 2 * time.Second
}
// HeartbeatInterval should be less than LockTTL
if c.HeartbeatInterval >= c.LockTTL {
c.HeartbeatInterval = c.LockTTL / 3
}
return nil
}

View file

@ -0,0 +1,30 @@
package leaderelection
import (
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideLeaderElection provides a LeaderElection instance for Wire DI.
func ProvideLeaderElection(
cfg *config.Config,
redisClient *redis.Client,
logger *zap.Logger,
) (LeaderElection, error) {
// Create configuration from app config
// InstanceID and Hostname are auto-generated by NewRedisLeaderElection
leConfig := &Config{
RedisKeyName: "maplepress:leader:lock",
RedisInfoKeyName: "maplepress:leader:info",
LockTTL: cfg.LeaderElection.LockTTL,
HeartbeatInterval: cfg.LeaderElection.HeartbeatInterval,
RetryInterval: cfg.LeaderElection.RetryInterval,
InstanceID: "", // Auto-generated from hostname + random suffix
Hostname: "", // Auto-detected from os.Hostname()
}
// redis.Client implements redis.UniversalClient interface
return NewRedisLeaderElection(leConfig, redisClient, logger)
}

View file

@ -0,0 +1,355 @@
package leaderelection
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"os"
"sync"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// redisLeaderElection implements LeaderElection using Redis.
type redisLeaderElection struct {
config *Config
redis redis.UniversalClient
logger *zap.Logger
instanceID string
hostname string
isLeader bool
leaderMutex sync.RWMutex
becomeLeaderCbs []func()
loseLeadershipCbs []func()
callbackMutex sync.RWMutex
stopChan chan struct{}
stoppedChan chan struct{}
leaderStartTime time.Time
lastHeartbeat time.Time
lastHeartbeatMutex sync.RWMutex
}
// NewRedisLeaderElection creates a new Redis-based leader election instance.
func NewRedisLeaderElection(
config *Config,
redisClient redis.UniversalClient,
logger *zap.Logger,
) (LeaderElection, error) {
logger = logger.Named("LeaderElection")
// Validate configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid configuration: %w", err)
}
// Generate instance ID if not provided
instanceID := config.InstanceID
if instanceID == "" {
hostname, err := os.Hostname()
if err != nil {
hostname = "unknown"
}
// Add random suffix to make it unique
instanceID = fmt.Sprintf("%s-%d", hostname, rand.Intn(100000))
logger.Info("Generated instance ID", zap.String("instance_id", instanceID))
}
// Get hostname if not provided
hostname := config.Hostname
if hostname == "" {
h, err := os.Hostname()
if err != nil {
hostname = "unknown"
} else {
hostname = h
}
}
return &redisLeaderElection{
config: config,
redis: redisClient,
logger: logger,
instanceID: instanceID,
hostname: hostname,
isLeader: false,
becomeLeaderCbs: make([]func(), 0),
loseLeadershipCbs: make([]func(), 0),
stopChan: make(chan struct{}),
stoppedChan: make(chan struct{}),
}, nil
}
// Start begins participating in leader election.
func (le *redisLeaderElection) Start(ctx context.Context) error {
le.logger.Info("Starting leader election",
zap.String("instance_id", le.instanceID),
zap.String("hostname", le.hostname),
zap.Duration("lock_ttl", le.config.LockTTL),
zap.Duration("heartbeat_interval", le.config.HeartbeatInterval),
)
defer close(le.stoppedChan)
// Main election loop
ticker := time.NewTicker(le.config.RetryInterval)
defer ticker.Stop()
// Try to become leader immediately on startup
le.tryBecomeLeader(ctx)
for {
select {
case <-ctx.Done():
le.logger.Info("Context cancelled, stopping leader election")
le.releaseLeadership(context.Background())
return ctx.Err()
case <-le.stopChan:
le.logger.Info("Stop signal received, stopping leader election")
le.releaseLeadership(context.Background())
return nil
case <-ticker.C:
if le.IsLeader() {
// If we're the leader, send heartbeat
if err := le.sendHeartbeat(ctx); err != nil {
le.logger.Error("Failed to send heartbeat, lost leadership",
zap.Error(err))
le.setLeaderStatus(false)
le.executeCallbacks(le.loseLeadershipCbs)
}
} else {
// If we're not the leader, try to become leader
le.tryBecomeLeader(ctx)
}
}
}
}
// tryBecomeLeader attempts to acquire leadership.
func (le *redisLeaderElection) tryBecomeLeader(ctx context.Context) {
// Try to set the leader key with NX (only if not exists) and EX (expiry)
success, err := le.redis.SetNX(ctx, le.config.RedisKeyName, le.instanceID, le.config.LockTTL).Result()
if err != nil {
le.logger.Error("Failed to attempt leader election",
zap.Error(err))
return
}
if success {
// We became the leader!
le.logger.Info("🎉 Became the leader!",
zap.String("instance_id", le.instanceID))
le.leaderStartTime = time.Now()
le.setLeaderStatus(true)
le.updateLeaderInfo(ctx)
le.executeCallbacks(le.becomeLeaderCbs)
} else {
// Someone else is the leader
if !le.IsLeader() {
// Only log if we weren't already aware
currentLeader, _ := le.GetLeaderID()
le.logger.Debug("Another instance is the leader",
zap.String("leader_id", currentLeader))
}
}
}
// sendHeartbeat renews the leader lock.
func (le *redisLeaderElection) sendHeartbeat(ctx context.Context) error {
// Verify we still hold the lock
currentValue, err := le.redis.Get(ctx, le.config.RedisKeyName).Result()
if err != nil {
return fmt.Errorf("failed to get current lock value: %w", err)
}
if currentValue != le.instanceID {
return fmt.Errorf("lock held by different instance: %s", currentValue)
}
// Renew the lock
err = le.redis.Expire(ctx, le.config.RedisKeyName, le.config.LockTTL).Err()
if err != nil {
return fmt.Errorf("failed to renew lock: %w", err)
}
// Update heartbeat time
le.setLastHeartbeat(time.Now())
// Update leader info
le.updateLeaderInfo(ctx)
le.logger.Debug("Heartbeat sent",
zap.String("instance_id", le.instanceID))
return nil
}
// updateLeaderInfo updates the leader information in Redis.
func (le *redisLeaderElection) updateLeaderInfo(ctx context.Context) {
info := &LeaderInfo{
InstanceID: le.instanceID,
Hostname: le.hostname,
StartedAt: le.leaderStartTime,
LastHeartbeat: le.getLastHeartbeat(),
}
data, err := json.Marshal(info)
if err != nil {
le.logger.Error("Failed to marshal leader info", zap.Error(err))
return
}
// Set with same TTL as lock
err = le.redis.Set(ctx, le.config.RedisInfoKeyName, data, le.config.LockTTL).Err()
if err != nil {
le.logger.Error("Failed to update leader info", zap.Error(err))
}
}
// releaseLeadership voluntarily releases leadership.
func (le *redisLeaderElection) releaseLeadership(ctx context.Context) {
if !le.IsLeader() {
return
}
le.logger.Info("Releasing leadership voluntarily",
zap.String("instance_id", le.instanceID))
// Only delete if we're still the owner
script := `
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
`
_, err := le.redis.Eval(ctx, script, []string{le.config.RedisKeyName}, le.instanceID).Result()
if err != nil {
le.logger.Error("Failed to release leadership", zap.Error(err))
}
// Delete leader info
le.redis.Del(ctx, le.config.RedisInfoKeyName)
le.setLeaderStatus(false)
le.executeCallbacks(le.loseLeadershipCbs)
}
// IsLeader returns true if this instance is the leader.
func (le *redisLeaderElection) IsLeader() bool {
le.leaderMutex.RLock()
defer le.leaderMutex.RUnlock()
return le.isLeader
}
// GetLeaderID returns the ID of the current leader.
func (le *redisLeaderElection) GetLeaderID() (string, error) {
ctx := context.Background()
leaderID, err := le.redis.Get(ctx, le.config.RedisKeyName).Result()
if err == redis.Nil {
return "", nil
}
if err != nil {
return "", fmt.Errorf("failed to get leader ID: %w", err)
}
return leaderID, nil
}
// GetLeaderInfo returns information about the current leader.
func (le *redisLeaderElection) GetLeaderInfo() (*LeaderInfo, error) {
ctx := context.Background()
data, err := le.redis.Get(ctx, le.config.RedisInfoKeyName).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get leader info: %w", err)
}
var info LeaderInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
return nil, fmt.Errorf("failed to unmarshal leader info: %w", err)
}
return &info, nil
}
// OnBecomeLeader registers a callback for when this instance becomes leader.
func (le *redisLeaderElection) OnBecomeLeader(callback func()) {
le.callbackMutex.Lock()
defer le.callbackMutex.Unlock()
le.becomeLeaderCbs = append(le.becomeLeaderCbs, callback)
}
// OnLoseLeadership registers a callback for when this instance loses leadership.
func (le *redisLeaderElection) OnLoseLeadership(callback func()) {
le.callbackMutex.Lock()
defer le.callbackMutex.Unlock()
le.loseLeadershipCbs = append(le.loseLeadershipCbs, callback)
}
// Stop gracefully stops leader election.
func (le *redisLeaderElection) Stop() error {
le.logger.Info("Stopping leader election")
close(le.stopChan)
// Wait for the election loop to finish (with timeout)
select {
case <-le.stoppedChan:
le.logger.Info("Leader election stopped successfully")
return nil
case <-time.After(5 * time.Second):
le.logger.Warn("Timeout waiting for leader election to stop")
return fmt.Errorf("timeout waiting for leader election to stop")
}
}
// GetInstanceID returns this instance's unique identifier.
func (le *redisLeaderElection) GetInstanceID() string {
return le.instanceID
}
// setLeaderStatus updates the leader status (thread-safe).
func (le *redisLeaderElection) setLeaderStatus(isLeader bool) {
le.leaderMutex.Lock()
defer le.leaderMutex.Unlock()
le.isLeader = isLeader
}
// setLastHeartbeat updates the last heartbeat time (thread-safe).
func (le *redisLeaderElection) setLastHeartbeat(t time.Time) {
le.lastHeartbeatMutex.Lock()
defer le.lastHeartbeatMutex.Unlock()
le.lastHeartbeat = t
}
// getLastHeartbeat gets the last heartbeat time (thread-safe).
func (le *redisLeaderElection) getLastHeartbeat() time.Time {
le.lastHeartbeatMutex.RLock()
defer le.lastHeartbeatMutex.RUnlock()
return le.lastHeartbeat
}
// executeCallbacks executes a list of callbacks in separate goroutines.
func (le *redisLeaderElection) executeCallbacks(callbacks []func()) {
le.callbackMutex.RLock()
defer le.callbackMutex.RUnlock()
for _, callback := range callbacks {
go func(cb func()) {
defer func() {
if r := recover(); r != nil {
le.logger.Error("Panic in leader election callback",
zap.Any("panic", r))
}
}()
cb()
}(callback)
}
}

View file

@ -0,0 +1,120 @@
package logger
import (
"fmt"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// emojiCore wraps a zapcore.Core to add emoji icon field
type emojiCore struct {
zapcore.Core
}
func (c *emojiCore) With(fields []zapcore.Field) zapcore.Core {
return &emojiCore{c.Core.With(fields)}
}
func (c *emojiCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if c.Enabled(entry.Level) {
return ce.AddCore(entry, c)
}
return ce
}
func (c *emojiCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only add emoji icon field for warnings and errors
// Skip for info and debug to keep output clean
var emoji string
var addEmoji bool
switch entry.Level {
case zapcore.WarnLevel:
emoji = "🟡" // Yellow circle for warnings
addEmoji = true
case zapcore.ErrorLevel:
emoji = "🔴" // Red circle for errors
addEmoji = true
case zapcore.DPanicLevel:
emoji = "🔴" // Red circle for panic
addEmoji = true
case zapcore.PanicLevel:
emoji = "🔴" // Red circle for panic
addEmoji = true
case zapcore.FatalLevel:
emoji = "🔴" // Red circle for fatal
addEmoji = true
default:
// No emoji for debug and info levels
addEmoji = false
}
// Only prepend emoji field if we're adding one
if addEmoji {
fieldsWithEmoji := make([]zapcore.Field, 0, len(fields)+1)
fieldsWithEmoji = append(fieldsWithEmoji, zap.String("ico", emoji))
fieldsWithEmoji = append(fieldsWithEmoji, fields...)
return c.Core.Write(entry, fieldsWithEmoji)
}
// For debug/info, write as-is without emoji
return c.Core.Write(entry, fields)
}
// ProvideLogger creates a new zap logger based on configuration
func ProvideLogger(cfg *config.Config) (*zap.Logger, error) {
var zapConfig zap.Config
// Set config based on environment
if cfg.App.Environment == "production" {
zapConfig = zap.NewProductionConfig()
} else {
zapConfig = zap.NewDevelopmentConfig()
}
// Set log level
level, err := zapcore.ParseLevel(cfg.Logger.Level)
if err != nil {
return nil, fmt.Errorf("invalid log level %s: %w", cfg.Logger.Level, err)
}
zapConfig.Level = zap.NewAtomicLevelAt(level)
// Set encoding format
if cfg.Logger.Format == "console" {
zapConfig.Encoding = "console"
zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
} else {
zapConfig.Encoding = "json"
}
// Build logger with environment-specific options
var loggerOptions []zap.Option
// Enable caller information in development for easier debugging
if cfg.App.Environment != "production" {
loggerOptions = append(loggerOptions, zap.AddCaller())
loggerOptions = append(loggerOptions, zap.AddCallerSkip(0))
}
// Add stack traces for error level and above
loggerOptions = append(loggerOptions, zap.AddStacktrace(zapcore.ErrorLevel))
// Wrap core with emoji core to add icon field
loggerOptions = append(loggerOptions, zap.WrapCore(func(core zapcore.Core) zapcore.Core {
return &emojiCore{core}
}))
logger, err := zapConfig.Build(loggerOptions...)
if err != nil {
return nil, fmt.Errorf("failed to build logger: %w", err)
}
logger.Info("✓ Logger initialized",
zap.String("level", cfg.Logger.Level),
zap.String("format", cfg.Logger.Format),
zap.String("environment", cfg.App.Environment))
return logger, nil
}

View file

@ -0,0 +1,231 @@
package logger
import (
"crypto/sha256"
"encoding/hex"
"regexp"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// SensitiveFieldRedactor provides methods to redact sensitive data before logging
// This addresses CWE-532 (Insertion of Sensitive Information into Log File)
type SensitiveFieldRedactor struct {
emailRegex *regexp.Regexp
}
// NewSensitiveFieldRedactor creates a new redactor for sensitive data
func NewSensitiveFieldRedactor() *SensitiveFieldRedactor {
return &SensitiveFieldRedactor{
emailRegex: regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`),
}
}
// RedactEmail redacts an email address for logging
// Example: "john.doe@example.com" -> "jo***@example.com"
func (r *SensitiveFieldRedactor) RedactEmail(email string) string {
if email == "" {
return "[empty]"
}
// Validate email format
if !r.emailRegex.MatchString(email) {
return "[invalid-email]"
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return "[invalid-email]"
}
localPart := parts[0]
domain := parts[1]
// Show first 2 characters of local part, redact the rest
if len(localPart) <= 2 {
return "**@" + domain
}
return localPart[:2] + "***@" + domain
}
// HashForLogging creates a consistent hash for unique identification without exposing the original value
// This allows correlation across log entries without storing PII
// Example: "john.doe@example.com" -> "a1b2c3d4"
func (r *SensitiveFieldRedactor) HashForLogging(value string) string {
if value == "" {
return "[empty]"
}
h := sha256.Sum256([]byte(value))
// Return first 8 bytes (16 hex characters) for reasonable uniqueness
return hex.EncodeToString(h[:8])
}
// RedactTenantSlug redacts a tenant slug for logging
// Example: "my-company" -> "my-***"
func (r *SensitiveFieldRedactor) RedactTenantSlug(slug string) string {
if slug == "" {
return "[empty]"
}
if len(slug) <= 3 {
return "***"
}
return slug[:2] + "***"
}
// RedactAPIKey redacts an API key for logging
// Shows only prefix and last 4 characters
// Example: "live_sk_abc123def456ghi789" -> "live_sk_***i789"
func (r *SensitiveFieldRedactor) RedactAPIKey(apiKey string) string {
if apiKey == "" {
return "[empty]"
}
// Show prefix (live_sk_ or test_sk_) and last 4 characters
if strings.HasPrefix(apiKey, "live_sk_") || strings.HasPrefix(apiKey, "test_sk_") {
prefix := apiKey[:8] // "live_sk_" or "test_sk_"
if len(apiKey) > 12 {
return prefix + "***" + apiKey[len(apiKey)-4:]
}
return prefix + "***"
}
// For other formats, just show last 4 characters
if len(apiKey) > 4 {
return "***" + apiKey[len(apiKey)-4:]
}
return "***"
}
// RedactJWTToken redacts a JWT token for logging
// Shows only first and last 8 characters
func (r *SensitiveFieldRedactor) RedactJWTToken(token string) string {
if token == "" {
return "[empty]"
}
if len(token) < 16 {
return "***"
}
return token[:8] + "..." + token[len(token)-8:]
}
// RedactIPAddress partially redacts an IP address
// IPv4: "192.168.1.100" -> "192.168.*.*"
// IPv6: Redacts last 4 groups
func (r *SensitiveFieldRedactor) RedactIPAddress(ip string) string {
if ip == "" {
return "[empty]"
}
// IPv4
if strings.Contains(ip, ".") {
parts := strings.Split(ip, ".")
if len(parts) == 4 {
return parts[0] + "." + parts[1] + ".*.*"
}
}
// IPv6
if strings.Contains(ip, ":") {
parts := strings.Split(ip, ":")
if len(parts) >= 4 {
return strings.Join(parts[:4], ":") + ":****"
}
}
return "***"
}
// Zap Field Helpers - Provide convenient zap.Field constructors
// SafeEmail creates a zap field with redacted email
func SafeEmail(key string, email string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String(key, redactor.RedactEmail(email))
}
// EmailHash creates a zap field with hashed email for correlation
func EmailHash(email string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String("email_hash", redactor.HashForLogging(email))
}
// HashString hashes a string value for safe logging
// Returns the hash string directly (not a zap.Field)
func HashString(value string) string {
redactor := NewSensitiveFieldRedactor()
return redactor.HashForLogging(value)
}
// SafeTenantSlug creates a zap field with redacted tenant slug
func SafeTenantSlug(key string, slug string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String(key, redactor.RedactTenantSlug(slug))
}
// TenantSlugHash creates a zap field with hashed tenant slug for correlation
func TenantSlugHash(slug string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String("tenant_slug_hash", redactor.HashForLogging(slug))
}
// SafeAPIKey creates a zap field with redacted API key
func SafeAPIKey(key string, apiKey string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String(key, redactor.RedactAPIKey(apiKey))
}
// SafeJWTToken creates a zap field with redacted JWT token
func SafeJWTToken(key string, token string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String(key, redactor.RedactJWTToken(token))
}
// SafeIPAddress creates a zap field with redacted IP address
func SafeIPAddress(key string, ip string) zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return zap.String(key, redactor.RedactIPAddress(ip))
}
// UserIdentifier creates safe identification fields for a user
// Includes: user_id (safe), email_hash, email_redacted
func UserIdentifier(userID string, email string) []zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return []zapcore.Field{
zap.String("user_id", userID),
zap.String("email_hash", redactor.HashForLogging(email)),
zap.String("email_redacted", redactor.RedactEmail(email)),
}
}
// TenantIdentifier creates safe identification fields for a tenant
// Includes: tenant_id (safe), slug_hash, slug_redacted
func TenantIdentifier(tenantID string, slug string) []zapcore.Field {
redactor := NewSensitiveFieldRedactor()
return []zapcore.Field{
zap.String("tenant_id", tenantID),
zap.String("tenant_slug_hash", redactor.HashForLogging(slug)),
zap.String("tenant_slug_redacted", redactor.RedactTenantSlug(slug)),
}
}
// Constants for field names
const (
FieldUserID = "user_id"
FieldEmailHash = "email_hash"
FieldEmailRedacted = "email_redacted"
FieldTenantID = "tenant_id"
FieldTenantSlugHash = "tenant_slug_hash"
FieldTenantSlugRedacted = "tenant_slug_redacted"
FieldAPIKeyRedacted = "api_key_redacted"
FieldJWTTokenRedacted = "jwt_token_redacted"
FieldIPAddressRedacted = "ip_address_redacted"
)

View file

@ -0,0 +1,345 @@
package logger
import (
"testing"
)
func TestRedactEmail(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
expected string
}{
{
name: "normal email",
input: "john.doe@example.com",
expected: "jo***@example.com",
},
{
name: "short local part",
input: "ab@example.com",
expected: "**@example.com",
},
{
name: "single character local part",
input: "a@example.com",
expected: "**@example.com",
},
{
name: "empty email",
input: "",
expected: "[empty]",
},
{
name: "invalid email",
input: "notanemail",
expected: "[invalid-email]",
},
{
name: "long email",
input: "very.long.email.address@subdomain.example.com",
expected: "ve***@subdomain.example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := redactor.RedactEmail(tt.input)
if result != tt.expected {
t.Errorf("RedactEmail(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestHashForLogging(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
}{
{
name: "email",
input: "john.doe@example.com",
},
{
name: "tenant slug",
input: "my-company",
},
{
name: "another email",
input: "jane.smith@test.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash1 := redactor.HashForLogging(tt.input)
hash2 := redactor.HashForLogging(tt.input)
// Hash should be consistent
if hash1 != hash2 {
t.Errorf("HashForLogging is not consistent: %q != %q", hash1, hash2)
}
// Hash should be 16 characters (8 bytes in hex)
if len(hash1) != 16 {
t.Errorf("HashForLogging length = %d, want 16", len(hash1))
}
// Hash should not contain original value
if hash1 == tt.input {
t.Errorf("HashForLogging returned original value")
}
})
}
// Different inputs should produce different hashes
hash1 := redactor.HashForLogging("john.doe@example.com")
hash2 := redactor.HashForLogging("jane.smith@example.com")
if hash1 == hash2 {
t.Error("Different inputs produced same hash")
}
// Empty string
emptyHash := redactor.HashForLogging("")
if emptyHash != "[empty]" {
t.Errorf("HashForLogging(\"\") = %q, want [empty]", emptyHash)
}
}
func TestRedactTenantSlug(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
expected string
}{
{
name: "normal slug",
input: "my-company",
expected: "my***",
},
{
name: "short slug",
input: "abc",
expected: "***",
},
{
name: "very short slug",
input: "ab",
expected: "***",
},
{
name: "empty slug",
input: "",
expected: "[empty]",
},
{
name: "long slug",
input: "very-long-company-name",
expected: "ve***",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := redactor.RedactTenantSlug(tt.input)
if result != tt.expected {
t.Errorf("RedactTenantSlug(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestRedactAPIKey(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
expected string
}{
{
name: "live API key",
input: "live_sk_abc123def456ghi789",
expected: "live_sk_***i789",
},
{
name: "test API key",
input: "test_sk_xyz789uvw456rst123",
expected: "test_sk_***t123",
},
{
name: "short live key",
input: "live_sk_abc",
expected: "live_sk_***",
},
{
name: "other format",
input: "sk_abc123def456",
expected: "***f456",
},
{
name: "very short key",
input: "abc",
expected: "***",
},
{
name: "empty key",
input: "",
expected: "[empty]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := redactor.RedactAPIKey(tt.input)
if result != tt.expected {
t.Errorf("RedactAPIKey(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestRedactJWTToken(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
expected string
}{
{
name: "normal JWT",
input: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U",
expected: "eyJhbGci...P0THsR8U",
},
{
name: "short token",
input: "short",
expected: "***",
},
{
name: "empty token",
input: "",
expected: "[empty]",
},
{
name: "minimum length token",
input: "1234567890123456",
expected: "12345678...90123456",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := redactor.RedactJWTToken(tt.input)
if result != tt.expected {
t.Errorf("RedactJWTToken(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestRedactIPAddress(t *testing.T) {
redactor := NewSensitiveFieldRedactor()
tests := []struct {
name string
input string
expected string
}{
{
name: "IPv4 address",
input: "192.168.1.100",
expected: "192.168.*.*",
},
{
name: "IPv4 public",
input: "8.8.8.8",
expected: "8.8.*.*",
},
{
name: "IPv6 address",
input: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
expected: "2001:0db8:85a3:0000:****",
},
{
name: "IPv6 shortened",
input: "2001:db8::1",
expected: "2001:db8::1:****",
},
{
name: "empty IP",
input: "",
expected: "[empty]",
},
{
name: "invalid IP",
input: "notanip",
expected: "***",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := redactor.RedactIPAddress(tt.input)
if result != tt.expected {
t.Errorf("RedactIPAddress(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestUserIdentifier(t *testing.T) {
userID := "user_123"
email := "john.doe@example.com"
fields := UserIdentifier(userID, email)
if len(fields) != 3 {
t.Errorf("UserIdentifier returned %d fields, want 3", len(fields))
}
// Check that fields contain expected keys
fieldKeys := make(map[string]bool)
for _, field := range fields {
fieldKeys[field.Key] = true
}
expectedKeys := []string{"user_id", "email_hash", "email_redacted"}
for _, key := range expectedKeys {
if !fieldKeys[key] {
t.Errorf("UserIdentifier missing key: %s", key)
}
}
}
func TestTenantIdentifier(t *testing.T) {
tenantID := "tenant_123"
slug := "my-company"
fields := TenantIdentifier(tenantID, slug)
if len(fields) != 3 {
t.Errorf("TenantIdentifier returned %d fields, want 3", len(fields))
}
// Check that fields contain expected keys
fieldKeys := make(map[string]bool)
for _, field := range fields {
fieldKeys[field.Key] = true
}
expectedKeys := []string{"tenant_id", "tenant_slug_hash", "tenant_slug_redacted"}
for _, key := range expectedKeys {
if !fieldKeys[key] {
t.Errorf("TenantIdentifier missing key: %s", key)
}
}
}

View file

@ -0,0 +1,327 @@
package ratelimit
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// LoginRateLimiter provides specialized rate limiting for login attempts
// with account lockout functionality
type LoginRateLimiter interface {
// CheckAndRecordAttempt checks if login attempt is allowed and records it
// Returns: allowed (bool), isLocked (bool), remainingAttempts (int), error
CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error)
// RecordFailedAttempt records a failed login attempt
RecordFailedAttempt(ctx context.Context, email string, clientIP string) error
// RecordSuccessfulLogin records a successful login and resets counters
RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error
// IsAccountLocked checks if an account is locked due to too many failed attempts
IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error)
// UnlockAccount manually unlocks an account (admin function)
UnlockAccount(ctx context.Context, email string) error
// GetFailedAttempts returns the number of failed attempts for an email
GetFailedAttempts(ctx context.Context, email string) (int, error)
}
// LoginRateLimiterConfig holds configuration for login rate limiting
type LoginRateLimiterConfig struct {
// MaxAttemptsPerIP is the maximum login attempts per IP in the window
MaxAttemptsPerIP int
// IPWindow is the time window for IP-based rate limiting
IPWindow time.Duration
// MaxFailedAttemptsPerAccount is the maximum failed attempts before account lockout
MaxFailedAttemptsPerAccount int
// AccountLockoutDuration is how long to lock an account after too many failures
AccountLockoutDuration time.Duration
// KeyPrefix is the prefix for Redis keys
KeyPrefix string
}
// DefaultLoginRateLimiterConfig returns recommended configuration
func DefaultLoginRateLimiterConfig() LoginRateLimiterConfig {
return LoginRateLimiterConfig{
MaxAttemptsPerIP: 10, // 10 attempts per IP
IPWindow: 15 * time.Minute, // in 15 minutes
MaxFailedAttemptsPerAccount: 10, // 10 failed attempts per account
AccountLockoutDuration: 30 * time.Minute, // lock for 30 minutes
KeyPrefix: "login_rl",
}
}
type loginRateLimiter struct {
client *redis.Client
config LoginRateLimiterConfig
logger *zap.Logger
}
// NewLoginRateLimiter creates a new login rate limiter
func NewLoginRateLimiter(client *redis.Client, config LoginRateLimiterConfig, logger *zap.Logger) LoginRateLimiter {
return &loginRateLimiter{
client: client,
config: config,
logger: logger.Named("login-rate-limiter"),
}
}
// CheckAndRecordAttempt checks if login attempt is allowed
// CWE-307: Implements protection against brute force attacks
func (r *loginRateLimiter) CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error) {
// Check account lockout first
locked, remaining, err := r.IsAccountLocked(ctx, email)
if err != nil {
r.logger.Error("failed to check account lockout",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
// Fail open on Redis error
return true, false, 0, err
}
if locked {
r.logger.Warn("login attempt on locked account",
zap.String("email_hash", hashEmail(email)),
zap.String("ip", clientIP),
zap.Duration("remaining_lockout", remaining))
return false, true, 0, nil
}
// Check IP-based rate limit
ipKey := r.getIPKey(clientIP)
allowed, err := r.checkIPRateLimit(ctx, ipKey)
if err != nil {
r.logger.Error("failed to check IP rate limit",
zap.String("ip", clientIP),
zap.Error(err))
// Fail open on Redis error
return true, false, 0, err
}
if !allowed {
r.logger.Warn("IP rate limit exceeded",
zap.String("ip", clientIP))
return false, false, 0, nil
}
// Record the attempt for IP
if err := r.recordIPAttempt(ctx, ipKey); err != nil {
r.logger.Error("failed to record IP attempt",
zap.String("ip", clientIP),
zap.Error(err))
}
// Get remaining attempts for account
failedAttempts, err := r.GetFailedAttempts(ctx, email)
if err != nil {
r.logger.Error("failed to get failed attempts",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
}
remainingAttempts := r.config.MaxFailedAttemptsPerAccount - failedAttempts
if remainingAttempts < 0 {
remainingAttempts = 0
}
r.logger.Debug("login attempt check passed",
zap.String("email_hash", hashEmail(email)),
zap.String("ip", clientIP),
zap.Int("remaining_attempts", remainingAttempts))
return true, false, remainingAttempts, nil
}
// RecordFailedAttempt records a failed login attempt
// CWE-307: Tracks failed attempts to enable account lockout
func (r *loginRateLimiter) RecordFailedAttempt(ctx context.Context, email string, clientIP string) error {
accountKey := r.getAccountKey(email)
// Increment failed attempt counter
count, err := r.client.Incr(ctx, accountKey).Result()
if err != nil {
r.logger.Error("failed to increment failed attempts",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
return err
}
// Set expiration on first failed attempt
if count == 1 {
r.client.Expire(ctx, accountKey, r.config.AccountLockoutDuration)
}
// Check if account should be locked
if count >= int64(r.config.MaxFailedAttemptsPerAccount) {
lockKey := r.getLockKey(email)
err := r.client.Set(ctx, lockKey, "locked", r.config.AccountLockoutDuration).Err()
if err != nil {
r.logger.Error("failed to lock account",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
return err
}
r.logger.Warn("account locked due to too many failed attempts",
zap.String("email_hash", hashEmail(email)),
zap.String("ip", clientIP),
zap.Int64("failed_attempts", count),
zap.Duration("lockout_duration", r.config.AccountLockoutDuration))
}
r.logger.Info("failed login attempt recorded",
zap.String("email_hash", hashEmail(email)),
zap.String("ip", clientIP),
zap.Int64("total_failed_attempts", count))
return nil
}
// RecordSuccessfulLogin records a successful login and resets counters
func (r *loginRateLimiter) RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error {
accountKey := r.getAccountKey(email)
lockKey := r.getLockKey(email)
// Delete failed attempt counter
pipe := r.client.Pipeline()
pipe.Del(ctx, accountKey)
pipe.Del(ctx, lockKey)
_, err := pipe.Exec(ctx)
if err != nil {
r.logger.Error("failed to reset login counters",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
return err
}
r.logger.Info("successful login recorded, counters reset",
zap.String("email_hash", hashEmail(email)),
zap.String("ip", clientIP))
return nil
}
// IsAccountLocked checks if an account is locked
func (r *loginRateLimiter) IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error) {
lockKey := r.getLockKey(email)
ttl, err := r.client.TTL(ctx, lockKey).Result()
if err != nil {
return false, 0, err
}
// TTL returns -2 if key doesn't exist, -1 if no expiration
if ttl < 0 {
return false, 0, nil
}
return true, ttl, nil
}
// UnlockAccount manually unlocks an account
func (r *loginRateLimiter) UnlockAccount(ctx context.Context, email string) error {
accountKey := r.getAccountKey(email)
lockKey := r.getLockKey(email)
pipe := r.client.Pipeline()
pipe.Del(ctx, accountKey)
pipe.Del(ctx, lockKey)
_, err := pipe.Exec(ctx)
if err != nil {
r.logger.Error("failed to unlock account",
zap.String("email_hash", hashEmail(email)),
zap.Error(err))
return err
}
r.logger.Info("account unlocked",
zap.String("email_hash", hashEmail(email)))
return nil
}
// GetFailedAttempts returns the number of failed attempts
func (r *loginRateLimiter) GetFailedAttempts(ctx context.Context, email string) (int, error) {
accountKey := r.getAccountKey(email)
count, err := r.client.Get(ctx, accountKey).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
return count, nil
}
// checkIPRateLimit checks if IP has exceeded rate limit
func (r *loginRateLimiter) checkIPRateLimit(ctx context.Context, ipKey string) (bool, error) {
now := time.Now()
windowStart := now.Add(-r.config.IPWindow)
// Remove old entries
r.client.ZRemRangeByScore(ctx, ipKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
// Count current attempts
count, err := r.client.ZCount(ctx, ipKey,
fmt.Sprintf("%d", windowStart.UnixNano()),
"+inf").Result()
if err != nil && err != redis.Nil {
return false, err
}
return count < int64(r.config.MaxAttemptsPerIP), nil
}
// recordIPAttempt records an IP attempt
func (r *loginRateLimiter) recordIPAttempt(ctx context.Context, ipKey string) error {
now := time.Now()
timestamp := now.UnixNano()
pipe := r.client.Pipeline()
pipe.ZAdd(ctx, ipKey, redis.Z{
Score: float64(timestamp),
Member: fmt.Sprintf("%d", timestamp),
})
pipe.Expire(ctx, ipKey, r.config.IPWindow+time.Minute)
_, err := pipe.Exec(ctx)
return err
}
// Key generation helpers
func (r *loginRateLimiter) getIPKey(ip string) string {
return fmt.Sprintf("%s:ip:%s", r.config.KeyPrefix, ip)
}
func (r *loginRateLimiter) getAccountKey(email string) string {
return fmt.Sprintf("%s:account:%s:attempts", r.config.KeyPrefix, hashEmail(email))
}
func (r *loginRateLimiter) getLockKey(email string) string {
return fmt.Sprintf("%s:account:%s:locked", r.config.KeyPrefix, hashEmail(email))
}
// hashEmail creates a consistent hash of an email for use as a key
// CWE-532: Prevents PII in Redis keys
func hashEmail(email string) string {
// Use a simple hash for key generation (not for security)
// In production, consider using SHA-256
hash := 0
for _, c := range email {
hash = (hash * 31) + int(c)
}
return fmt.Sprintf("%x", hash)
}

View file

@ -0,0 +1,45 @@
package ratelimit
import (
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideRateLimiter provides a rate limiter for dependency injection (registration endpoints)
func ProvideRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
rateLimitConfig := Config{
MaxRequests: cfg.RateLimit.RegistrationMaxRequests,
Window: cfg.RateLimit.RegistrationWindow,
KeyPrefix: "ratelimit:registration",
}
return NewRateLimiter(redisClient, rateLimitConfig, logger)
}
// ProvideGenericRateLimiter provides a rate limiter for generic CRUD endpoints (CWE-770)
// This is used for authenticated endpoints like tenant/user/site management, admin endpoints
// Strategy: User-based limiting (authenticated user ID from JWT)
func ProvideGenericRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
rateLimitConfig := Config{
MaxRequests: cfg.RateLimit.GenericMaxRequests,
Window: cfg.RateLimit.GenericWindow,
KeyPrefix: "ratelimit:generic",
}
return NewRateLimiter(redisClient, rateLimitConfig, logger)
}
// ProvidePluginAPIRateLimiter provides a rate limiter for WordPress plugin API endpoints (CWE-770)
// This is used for plugin endpoints that are core business/revenue endpoints
// Strategy: Site-based limiting (API key → site_id)
func ProvidePluginAPIRateLimiter(redisClient *redis.Client, cfg *config.Config, logger *zap.Logger) RateLimiter {
rateLimitConfig := Config{
MaxRequests: cfg.RateLimit.PluginAPIMaxRequests,
Window: cfg.RateLimit.PluginAPIWindow,
KeyPrefix: "ratelimit:plugin",
}
return NewRateLimiter(redisClient, rateLimitConfig, logger)
}

View file

@ -0,0 +1,23 @@
package ratelimit
import (
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideLoginRateLimiter creates a LoginRateLimiter for dependency injection
// CWE-307: Implements rate limiting and account lockout protection against brute force attacks
func ProvideLoginRateLimiter(client *redis.Client, cfg *config.Config, logger *zap.Logger) LoginRateLimiter {
// Use configuration from environment variables
loginConfig := LoginRateLimiterConfig{
MaxAttemptsPerIP: cfg.RateLimit.LoginMaxAttemptsPerIP,
IPWindow: cfg.RateLimit.LoginIPWindow,
MaxFailedAttemptsPerAccount: cfg.RateLimit.LoginMaxFailedAttemptsPerAccount,
AccountLockoutDuration: cfg.RateLimit.LoginAccountLockoutDuration,
KeyPrefix: "login_rl",
}
return NewLoginRateLimiter(client, loginConfig, logger)
}

View file

@ -0,0 +1,172 @@
package ratelimit
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// RateLimiter provides rate limiting functionality using Redis
type RateLimiter interface {
// Allow checks if a request should be allowed based on the key
// Returns true if allowed, false if rate limit exceeded
Allow(ctx context.Context, key string) (bool, error)
// AllowN checks if N requests should be allowed
AllowN(ctx context.Context, key string, n int) (bool, error)
// Reset resets the rate limit for a key
Reset(ctx context.Context, key string) error
// GetRemaining returns the number of remaining requests
GetRemaining(ctx context.Context, key string) (int, error)
}
// Config holds rate limiter configuration
type Config struct {
// MaxRequests is the maximum number of requests allowed
MaxRequests int
// Window is the time window for rate limiting
Window time.Duration
// KeyPrefix is the prefix for Redis keys
KeyPrefix string
}
type rateLimiter struct {
client *redis.Client
config Config
logger *zap.Logger
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(client *redis.Client, config Config, logger *zap.Logger) RateLimiter {
return &rateLimiter{
client: client,
config: config,
logger: logger.Named("rate-limiter"),
}
}
// Allow checks if a request should be allowed
func (r *rateLimiter) Allow(ctx context.Context, key string) (bool, error) {
return r.AllowN(ctx, key, 1)
}
// AllowN checks if N requests should be allowed using sliding window counter
func (r *rateLimiter) AllowN(ctx context.Context, key string, n int) (bool, error) {
redisKey := r.getRedisKey(key)
now := time.Now()
windowStart := now.Add(-r.config.Window)
// Use Redis transaction to ensure atomicity
pipe := r.client.Pipeline()
// Remove old entries outside the window
pipe.ZRemRangeByScore(ctx, redisKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
// Count current requests in window
countCmd := pipe.ZCount(ctx, redisKey, fmt.Sprintf("%d", windowStart.UnixNano()), "+inf")
// Execute pipeline
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
r.logger.Error("failed to check rate limit",
zap.String("key", key),
zap.Error(err))
// Fail open: allow request if Redis is down
return true, err
}
currentCount := countCmd.Val()
// Check if adding N requests would exceed limit
if currentCount+int64(n) > int64(r.config.MaxRequests) {
r.logger.Warn("rate limit exceeded",
zap.String("key", key),
zap.Int64("current_count", currentCount),
zap.Int("max_requests", r.config.MaxRequests))
return false, nil
}
// Add the new request(s) to the sorted set
pipe2 := r.client.Pipeline()
for i := 0; i < n; i++ {
// Use nanosecond timestamp with incremental offset to ensure uniqueness
timestamp := now.Add(time.Duration(i) * time.Nanosecond).UnixNano()
pipe2.ZAdd(ctx, redisKey, redis.Z{
Score: float64(timestamp),
Member: fmt.Sprintf("%d-%d", timestamp, i),
})
}
// Set expiration on the key (window + buffer)
pipe2.Expire(ctx, redisKey, r.config.Window+time.Minute)
// Execute pipeline
_, err = pipe2.Exec(ctx)
if err != nil && err != redis.Nil {
r.logger.Error("failed to record request",
zap.String("key", key),
zap.Error(err))
// Already counted, so return true
return true, err
}
r.logger.Debug("rate limit check passed",
zap.String("key", key),
zap.Int64("current_count", currentCount),
zap.Int("max_requests", r.config.MaxRequests))
return true, nil
}
// Reset resets the rate limit for a key
func (r *rateLimiter) Reset(ctx context.Context, key string) error {
redisKey := r.getRedisKey(key)
err := r.client.Del(ctx, redisKey).Err()
if err != nil {
r.logger.Error("failed to reset rate limit",
zap.String("key", key),
zap.Error(err))
return err
}
r.logger.Info("rate limit reset",
zap.String("key", key))
return nil
}
// GetRemaining returns the number of remaining requests in the current window
func (r *rateLimiter) GetRemaining(ctx context.Context, key string) (int, error) {
redisKey := r.getRedisKey(key)
now := time.Now()
windowStart := now.Add(-r.config.Window)
// Count current requests in window
count, err := r.client.ZCount(ctx, redisKey,
fmt.Sprintf("%d", windowStart.UnixNano()),
"+inf").Result()
if err != nil && err != redis.Nil {
r.logger.Error("failed to get remaining requests",
zap.String("key", key),
zap.Error(err))
return 0, err
}
remaining := r.config.MaxRequests - int(count)
if remaining < 0 {
remaining = 0
}
return remaining, nil
}
// getRedisKey constructs the Redis key with prefix
func (r *rateLimiter) getRedisKey(key string) string {
return fmt.Sprintf("%s:%s", r.config.KeyPrefix, key)
}

View file

@ -0,0 +1,18 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/search/config.go
package search
// Config holds Meilisearch configuration
type Config struct {
Host string
APIKey string
IndexPrefix string // e.g., "maplepress_" or "site_"
}
// NewConfig creates a new Meilisearch configuration
func NewConfig(host, apiKey, indexPrefix string) *Config {
return &Config{
Host: host,
APIKey: apiKey,
IndexPrefix: indexPrefix,
}
}

View file

@ -0,0 +1,216 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/search/index.go
package search
import (
"fmt"
"github.com/meilisearch/meilisearch-go"
)
// PageDocument represents a document in the search index
type PageDocument struct {
ID string `json:"id"` // page_id
SiteID string `json:"site_id"` // for filtering (though each site has its own index)
TenantID string `json:"tenant_id"` // for additional isolation
Title string `json:"title"`
Content string `json:"content"` // HTML stripped
Excerpt string `json:"excerpt"`
URL string `json:"url"`
Status string `json:"status"` // publish, draft, trash
PostType string `json:"post_type"` // page, post
Author string `json:"author"`
PublishedAt int64 `json:"published_at"` // Unix timestamp for sorting
ModifiedAt int64 `json:"modified_at"` // Unix timestamp for sorting
}
// CreateIndex creates a new index for a site
func (c *Client) CreateIndex(siteID string) error {
indexName := c.GetIndexName(siteID)
// Create index with site_id as primary key
_, err := c.client.CreateIndex(&meilisearch.IndexConfig{
Uid: indexName,
PrimaryKey: "id", // page_id is the primary key
})
if err != nil {
return fmt.Errorf("failed to create index %s: %w", indexName, err)
}
// Configure index settings
return c.ConfigureIndex(siteID)
}
// ConfigureIndex configures the index settings
func (c *Client) ConfigureIndex(siteID string) error {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
// Set searchable attributes (in order of priority)
searchableAttributes := []string{
"title",
"excerpt",
"content",
}
_, err := index.UpdateSearchableAttributes(&searchableAttributes)
if err != nil {
return fmt.Errorf("failed to set searchable attributes: %w", err)
}
// Set filterable attributes
filterableAttributes := []interface{}{
"status",
"post_type",
"author",
"published_at",
}
_, err = index.UpdateFilterableAttributes(&filterableAttributes)
if err != nil {
return fmt.Errorf("failed to set filterable attributes: %w", err)
}
// Set ranking rules
rankingRules := []string{
"words",
"typo",
"proximity",
"attribute",
"sort",
"exactness",
}
_, err = index.UpdateRankingRules(&rankingRules)
if err != nil {
return fmt.Errorf("failed to set ranking rules: %w", err)
}
// Set displayed attributes (don't return full content in search results)
displayedAttributes := []string{
"id",
"title",
"excerpt",
"url",
"status",
"post_type",
"author",
"published_at",
"modified_at",
}
_, err = index.UpdateDisplayedAttributes(&displayedAttributes)
if err != nil {
return fmt.Errorf("failed to set displayed attributes: %w", err)
}
return nil
}
// IndexExists checks if an index exists
func (c *Client) IndexExists(siteID string) (bool, error) {
indexName := c.GetIndexName(siteID)
_, err := c.client.GetIndex(indexName)
if err != nil {
// Check if error is "index not found" (status code 404)
if meiliErr, ok := err.(*meilisearch.Error); ok {
if meiliErr.StatusCode == 404 {
return false, nil
}
}
return false, fmt.Errorf("failed to check index existence: %w", err)
}
return true, nil
}
// DeleteIndex deletes an index for a site
func (c *Client) DeleteIndex(siteID string) error {
indexName := c.GetIndexName(siteID)
_, err := c.client.DeleteIndex(indexName)
if err != nil {
return fmt.Errorf("failed to delete index %s: %w", indexName, err)
}
return nil
}
// AddDocuments adds or updates documents in the index
func (c *Client) AddDocuments(siteID string, documents []PageDocument) (*meilisearch.TaskInfo, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
taskInfo, err := index.AddDocuments(documents, nil)
if err != nil {
return nil, fmt.Errorf("failed to add documents to index %s: %w", indexName, err)
}
return taskInfo, nil
}
// UpdateDocuments updates documents in the index
func (c *Client) UpdateDocuments(siteID string, documents []PageDocument) (*meilisearch.TaskInfo, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
taskInfo, err := index.UpdateDocuments(documents, nil)
if err != nil {
return nil, fmt.Errorf("failed to update documents in index %s: %w", indexName, err)
}
return taskInfo, nil
}
// DeleteDocument deletes a single document from the index
func (c *Client) DeleteDocument(siteID string, documentID string) (*meilisearch.TaskInfo, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
taskInfo, err := index.DeleteDocument(documentID)
if err != nil {
return nil, fmt.Errorf("failed to delete document %s from index %s: %w", documentID, indexName, err)
}
return taskInfo, nil
}
// DeleteDocuments deletes multiple documents from the index
func (c *Client) DeleteDocuments(siteID string, documentIDs []string) (*meilisearch.TaskInfo, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
taskInfo, err := index.DeleteDocuments(documentIDs)
if err != nil {
return nil, fmt.Errorf("failed to delete documents from index %s: %w", indexName, err)
}
return taskInfo, nil
}
// DeleteAllDocuments deletes all documents from the index
func (c *Client) DeleteAllDocuments(siteID string) (*meilisearch.TaskInfo, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
taskInfo, err := index.DeleteAllDocuments()
if err != nil {
return nil, fmt.Errorf("failed to delete all documents from index %s: %w", indexName, err)
}
return taskInfo, nil
}
// GetStats returns statistics about an index
func (c *Client) GetStats(siteID string) (*meilisearch.StatsIndex, error) {
indexName := c.GetIndexName(siteID)
index := c.client.Index(indexName)
stats, err := index.GetStats()
if err != nil {
return nil, fmt.Errorf("failed to get stats for index %s: %w", indexName, err)
}
return stats, nil
}

View file

@ -0,0 +1,47 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/search/meilisearch.go
package search
import (
"fmt"
"github.com/meilisearch/meilisearch-go"
"go.uber.org/zap"
)
// Client wraps the Meilisearch client
type Client struct {
client meilisearch.ServiceManager
config *Config
logger *zap.Logger
}
// NewClient creates a new Meilisearch client
func NewClient(config *Config, logger *zap.Logger) (*Client, error) {
if config.Host == "" {
return nil, fmt.Errorf("meilisearch host is required")
}
client := meilisearch.New(config.Host, meilisearch.WithAPIKey(config.APIKey))
return &Client{
client: client,
config: config,
logger: logger.Named("meilisearch"),
}, nil
}
// GetIndexName returns the full index name for a site
func (c *Client) GetIndexName(siteID string) string {
return c.config.IndexPrefix + siteID
}
// Health checks if Meilisearch is healthy
func (c *Client) Health() error {
_, err := c.client.Health()
return err
}
// GetClient returns the underlying Meilisearch client (for advanced operations)
func (c *Client) GetClient() meilisearch.ServiceManager {
return c.client
}

View file

@ -0,0 +1,22 @@
package search
import (
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"go.uber.org/zap"
)
// ProvideClient provides a Meilisearch client
func ProvideClient(cfg *config.Config, logger *zap.Logger) (*Client, error) {
searchConfig := NewConfig(
cfg.Meilisearch.Host,
cfg.Meilisearch.APIKey,
cfg.Meilisearch.IndexPrefix,
)
client, err := NewClient(searchConfig, logger)
if err != nil {
return nil, err
}
return client, nil
}

View file

@ -0,0 +1,155 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/search/search.go
package search
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"go.uber.org/zap"
)
// SearchRequest represents a search request
type SearchRequest struct {
Query string
Limit int64
Offset int64
Filter string // e.g., "status = publish"
}
// SearchResult represents a search result
type SearchResult struct {
Hits []map[string]interface{} `json:"hits"`
Query string `json:"query"`
ProcessingTimeMs int64 `json:"processing_time_ms"`
TotalHits int64 `json:"total_hits"`
Limit int64 `json:"limit"`
Offset int64 `json:"offset"`
}
// Search performs a search query on the index
func (c *Client) Search(siteID string, req SearchRequest) (*SearchResult, error) {
indexName := c.GetIndexName(siteID)
c.logger.Info("initiating search",
zap.String("site_id", siteID),
zap.String("index_name", indexName),
zap.String("query", req.Query),
zap.Int64("limit", req.Limit),
zap.Int64("offset", req.Offset),
zap.String("filter", req.Filter))
// Build search request manually to avoid hybrid field
searchBody := map[string]interface{}{
"q": req.Query,
"limit": req.Limit,
"offset": req.Offset,
"attributesToHighlight": []string{"title", "excerpt", "content"},
}
// Add filter if provided
if req.Filter != "" {
searchBody["filter"] = req.Filter
}
// Marshal to JSON
jsonData, err := json.Marshal(searchBody)
if err != nil {
c.logger.Error("failed to marshal search request", zap.Error(err))
return nil, fmt.Errorf("failed to marshal search request: %w", err)
}
// Uncomment for debugging: shows exact JSON payload sent to Meilisearch
// c.logger.Debug("search request payload", zap.String("json", string(jsonData)))
// Build search URL
searchURL := fmt.Sprintf("%s/indexes/%s/search", c.config.Host, indexName)
// Uncomment for debugging: shows the Meilisearch API endpoint
// c.logger.Debug("search URL", zap.String("url", searchURL))
// Create HTTP request
httpReq, err := http.NewRequest("POST", searchURL, bytes.NewBuffer(jsonData))
if err != nil {
c.logger.Error("failed to create HTTP request", zap.Error(err))
return nil, fmt.Errorf("failed to create search request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.APIKey))
// Execute request
httpClient := &http.Client{}
resp, err := httpClient.Do(httpReq)
if err != nil {
c.logger.Error("failed to execute HTTP request", zap.Error(err))
return nil, fmt.Errorf("failed to execute search request: %w", err)
}
if resp == nil {
c.logger.Error("received nil response from search API")
return nil, fmt.Errorf("received nil response from search API")
}
defer resp.Body.Close()
c.logger.Info("received search response",
zap.Int("status_code", resp.StatusCode),
zap.String("status", resp.Status))
// Read response body for logging
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.logger.Error("failed to read response body", zap.Error(err))
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Uncomment for debugging: shows full Meilisearch response
// c.logger.Debug("search response body", zap.String("body", string(bodyBytes)))
if resp.StatusCode != http.StatusOK {
c.logger.Error("search request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response_body", string(bodyBytes)))
return nil, fmt.Errorf("search request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
// Parse response
var searchResp struct {
Hits []map[string]interface{} `json:"hits"`
Query string `json:"query"`
ProcessingTimeMs int64 `json:"processingTimeMs"`
EstimatedTotalHits int `json:"estimatedTotalHits"`
Limit int64 `json:"limit"`
Offset int64 `json:"offset"`
}
if err := json.Unmarshal(bodyBytes, &searchResp); err != nil {
c.logger.Error("failed to decode search response", zap.Error(err))
return nil, fmt.Errorf("failed to decode search response: %w", err)
}
c.logger.Info("search completed successfully",
zap.Int("hits_count", len(searchResp.Hits)),
zap.Int("total_hits", searchResp.EstimatedTotalHits),
zap.Int64("processing_time_ms", searchResp.ProcessingTimeMs))
// Convert response
result := &SearchResult{
Hits: searchResp.Hits,
Query: searchResp.Query,
ProcessingTimeMs: searchResp.ProcessingTimeMs,
TotalHits: int64(searchResp.EstimatedTotalHits),
Limit: req.Limit,
Offset: req.Offset,
}
return result, nil
}
// SearchWithHighlights performs a search with custom highlighting
// Note: Currently uses same implementation as Search()
func (c *Client) SearchWithHighlights(siteID string, req SearchRequest, highlightTags []string) (*SearchResult, error) {
// For now, just use the regular Search method
// TODO: Implement custom highlight tags if needed
return c.Search(siteID, req)
}

View file

@ -0,0 +1,520 @@
# Security Package
This package provides secure password hashing and memory-safe storage for sensitive data.
## Packages
### Password (`pkg/security/password`)
Provides Argon2id-based password hashing and verification with secure default parameters following OWASP recommendations.
### SecureString (`pkg/security/securestring`)
Memory-safe string storage using `memguard` to protect sensitive data like passwords and API keys from memory dumps and swap files.
### SecureBytes (`pkg/security/securebytes`)
Memory-safe byte slice storage using `memguard` to protect sensitive binary data.
### IPCountryBlocker (`pkg/security/ipcountryblocker`)
GeoIP-based country blocking using MaxMind's GeoLite2 database to block requests from specific countries.
## Installation
The packages are included in the project. Required dependencies:
- `github.com/awnumar/memguard` - For secure memory management
- `golang.org/x/crypto/argon2` - For password hashing
- `github.com/oschwald/geoip2-golang` - For GeoIP lookups
## Usage
### Password Hashing
```go
import (
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/password"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
)
// Create password provider
passwordProvider := password.NewPasswordProvider()
// Hash a password
plainPassword := "mySecurePassword123!"
securePass, err := securestring.NewSecureString(plainPassword)
if err != nil {
// Handle error
}
defer securePass.Wipe() // Always wipe after use
hashedPassword, err := passwordProvider.GenerateHashFromPassword(securePass)
if err != nil {
// Handle error
}
// Verify a password
match, err := passwordProvider.ComparePasswordAndHash(securePass, hashedPassword)
if err != nil {
// Handle error
}
if match {
// Password is correct
}
```
### Secure String Storage
```go
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
// Store sensitive data securely
apiKey := "secret-api-key-12345"
secureKey, err := securestring.NewSecureString(apiKey)
if err != nil {
// Handle error
}
defer secureKey.Wipe() // Always wipe when done
// Use the secure string
keyValue := secureKey.String() // Get the value when needed
// ... use keyValue ...
// The original string should be cleared
apiKey = ""
```
### Secure Bytes Storage
```go
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securebytes"
// Store sensitive binary data
sensitiveData := []byte{0x01, 0x02, 0x03, 0x04}
secureData, err := securebytes.NewSecureBytes(sensitiveData)
if err != nil {
// Handle error
}
defer secureData.Wipe()
// Use the secure bytes
data := secureData.Bytes()
// ... use data ...
// Clear the original slice
for i := range sensitiveData {
sensitiveData[i] = 0
}
```
### Generate Random Values
```go
passwordProvider := password.NewPasswordProvider()
// Generate random bytes
randomBytes, err := passwordProvider.GenerateSecureRandomBytes(32)
// Generate random hex string (length * 2 characters)
randomString, err := passwordProvider.GenerateSecureRandomString(16)
// Returns a 32-character hex string
```
### IP Country Blocking
```go
import (
"context"
"net"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/ipcountryblocker"
)
// Create the blocker (typically done via dependency injection)
cfg, _ := config.Load()
blocker := ipcountryblocker.NewProvider(cfg, logger)
defer blocker.Close()
// Check if an IP is blocked
ip := net.ParseIP("192.0.2.1")
if blocker.IsBlockedIP(context.Background(), ip) {
// Handle blocked IP
return errors.New("access denied: your country is blocked")
}
// Check if a country code is blocked
if blocker.IsBlockedCountry("CN") {
// Country is in the blocked list
}
// Get country code for an IP
countryCode, err := blocker.GetCountryCode(context.Background(), ip)
if err != nil {
// Handle error
}
// countryCode will be ISO 3166-1 alpha-2 code like "US", "CA", "GB"
```
**Configuration**:
```bash
# Environment variables
APP_GEOLITE_DB_PATH=/path/to/GeoLite2-Country.mmdb
APP_BANNED_COUNTRIES=CN,RU,KP # Comma-separated ISO 3166-1 alpha-2 codes
```
## Password Hashing Details
### Algorithm: Argon2id
Argon2id is the recommended password hashing algorithm by OWASP. It combines:
- Argon2i: Resistant to side-channel attacks
- Argon2d: Resistant to GPU cracking attacks
### Default Parameters
```
Memory: 64 MB (65536 KB)
Iterations: 3
Parallelism: 2 threads
Salt Length: 16 bytes
Key Length: 32 bytes
```
These parameters provide strong security while maintaining reasonable performance for authentication systems.
### Hash Format
```
$argon2id$v=19$m=65536,t=3,p=2$<base64-salt>$<base64-hash>
```
Example:
```
$argon2id$v=19$m=65536,t=3,p=2$YWJjZGVmZ2hpamtsbW5vcA$9XJqrJ8fQvVrMz0FqJ7gBGqKvYLvLxC8HzPqKvYLvLxC
```
The hash includes all parameters, so it can be verified even if you change the default parameters later.
## Security Best Practices
### 1. Always Wipe Sensitive Data
```go
securePass, _ := securestring.NewSecureString(password)
defer securePass.Wipe() // Ensures cleanup even on panic
// ... use securePass ...
```
### 2. Clear Original Data
After creating a secure string/bytes, clear the original data:
```go
password := "secret"
securePass, _ := securestring.NewSecureString(password)
password = "" // Clear the original string
// Even better for byte slices:
data := []byte("secret")
secureData, _ := securebytes.NewSecureBytes(data)
for i := range data {
data[i] = 0
}
```
### 3. Minimize Exposure Time
Get values from secure storage only when needed:
```go
// Bad - exposes value for too long
value := secureString.String()
// ... lots of code ...
useValue(value)
// Good - get value right before use
// ... lots of code ...
useValue(secureString.String())
```
### 4. Use Dependency Injection
```go
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/password"
// In your Wire provider set
wire.NewSet(
password.ProvidePasswordProvider,
// ... other providers
)
// Use in your service
type AuthService struct {
passwordProvider password.PasswordProvider
}
func NewAuthService(pp password.PasswordProvider) *AuthService {
return &AuthService{passwordProvider: pp}
}
```
### 5. Handle Errors Properly
```go
securePass, err := securestring.NewSecureString(password)
if err != nil {
return fmt.Errorf("failed to create secure string: %w", err)
}
defer securePass.Wipe()
```
### 6. Clean Up GeoIP Resources
```go
import "codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/ipcountryblocker"
// Always close the provider when done to release database resources
blocker := ipcountryblocker.NewProvider(cfg, logger)
defer blocker.Close()
```
## Memory Safety
### How It Works
The `memguard` library provides:
1. **Locked Memory**: Prevents sensitive data from being swapped to disk
2. **Guarded Heap**: Detects buffer overflows and underflows
3. **Secure Wiping**: Overwrites memory with random data before freeing
4. **Read Protection**: Makes memory pages read-only when not in use
### When to Use
Use secure storage for:
- Passwords and password hashes (during verification)
- API keys and tokens
- Encryption keys
- Private keys
- Database credentials
- OAuth secrets
- JWT signing keys
- Session tokens
- Any sensitive user data
### When NOT to Use
Don't use for:
- Public data
- Non-sensitive configuration
- Data that needs to be logged
- Data that will be stored long-term in memory
## Performance Considerations
### Password Hashing
Argon2id is intentionally slow to prevent brute-force attacks:
- Expected time: ~50-100ms per hash on modern hardware
- This is acceptable for authentication (login) operations
- DO NOT use for high-throughput operations
### Memory Usage
SecureString/SecureBytes use locked memory:
- Each instance locks a page in RAM (typically 4KB minimum)
- Don't create thousands of instances
- Reuse instances when possible
- Always wipe when done
## Examples
### Complete Login Example
```go
func (s *AuthService) Login(ctx context.Context, email, password string) (*User, error) {
// Create secure string from password
securePass, err := securestring.NewSecureString(password)
if err != nil {
return nil, err
}
defer securePass.Wipe()
// Clear the original password
password = ""
// Get user from database
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
return nil, err
}
// Verify password
match, err := s.passwordProvider.ComparePasswordAndHash(securePass, user.PasswordHash)
if err != nil {
return nil, err
}
if !match {
return nil, ErrInvalidCredentials
}
return user, nil
}
```
### Complete Registration Example
```go
func (s *AuthService) Register(ctx context.Context, email, password string) (*User, error) {
// Validate password strength first
if len(password) < 8 {
return nil, ErrWeakPassword
}
// Create secure string from password
securePass, err := securestring.NewSecureString(password)
if err != nil {
return nil, err
}
defer securePass.Wipe()
// Clear the original password
password = ""
// Hash the password
hashedPassword, err := s.passwordProvider.GenerateHashFromPassword(securePass)
if err != nil {
return nil, err
}
// Create user with hashed password
user := &User{
Email: email,
PasswordHash: hashedPassword,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
```
## Troubleshooting
### "failed to create buffer"
**Problem**: memguard couldn't allocate locked memory
**Solutions**:
- Check system limits for locked memory (`ulimit -l`)
- Reduce number of concurrent SecureString/SecureBytes instances
- Ensure proper cleanup with `Wipe()`
### "buffer is not alive"
**Problem**: Trying to use a SecureString/SecureBytes after it was wiped
**Solutions**:
- Don't use secure data after calling `Wipe()`
- Check your defer ordering
- Create new instances if you need the data again
### Slow Performance
**Problem**: Password hashing is too slow
**Solutions**:
- This is by design for security
- Don't hash passwords in high-throughput operations
- Consider caching authentication results (with care)
- Use async operations for registration/password changes
### "failed to open GeoLite2 DB"
**Problem**: Cannot open the GeoIP2 database
**Solutions**:
- Verify APP_GEOLITE_DB_PATH points to a valid .mmdb file
- Download the GeoLite2-Country database from MaxMind
- Check file permissions
- Ensure the database file is not corrupted
### "no country found for IP"
**Problem**: IP address lookup returns no country
**Solutions**:
- This is normal for private IP ranges (10.x.x.x, 192.168.x.x, etc.)
- The IP might not be in the GeoIP2 database
- Update to a newer GeoLite2 database
- By default, unknown IPs are allowed (returns false from IsBlockedIP)
## IP Country Blocking Details
### GeoLite2 Database
The IP country blocker uses MaxMind's GeoLite2-Country database for IP geolocation.
**How to Get the Database**:
1. Create a free account at https://www.maxmind.com/en/geolite2/signup
2. Generate a license key
3. Download GeoLite2-Country database (.mmdb format)
4. Set APP_GEOLITE_DB_PATH to the file location
**Database Updates**:
- MaxMind updates GeoLite2 databases weekly
- Set up automated updates for production systems
- Database file is typically 5-10 MB
### Country Codes
Uses ISO 3166-1 alpha-2 country codes:
- US - United States
- CA - Canada
- GB - United Kingdom
- CN - China
- RU - Russia
- KP - North Korea
- etc.
Full list: https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2
### Blocking Behavior
**Default Behavior**:
- If IP lookup fails → Allow (returns false)
- If country not found → Allow (returns false)
- If country is blocked → Block (returns true)
**To block unknown IPs**, modify IsBlockedIP to return true on error (line 101 in ipcountryblocker.go).
### Thread Safety
The provider is thread-safe:
- Uses sync.RWMutex for concurrent access to blocked countries map
- GeoIP2 Reader is thread-safe by design
- Safe to use in HTTP middleware and concurrent handlers
### Performance
**Lookup Speed**:
- In-memory database lookups are very fast (~microseconds)
- Database is memory-mapped for efficiency
- Suitable for high-traffic applications
**Memory Usage**:
- GeoLite2-Country database: ~5-10 MB in memory
- Blocked countries map: negligible (few KB)
## References
- [Argon2 RFC](https://tools.ietf.org/html/rfc9106)
- [OWASP Password Storage Cheat Sheet](https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html)
- [memguard Documentation](https://github.com/awnumar/memguard)
- [Alex Edwards: How to Hash and Verify Passwords With Argon2 in Go](https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go)
- [MaxMind GeoLite2 Free Geolocation Data](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data)
- [ISO 3166-1 Country Codes](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2)

View file

@ -0,0 +1,96 @@
package apikey
import (
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
)
const (
// PrefixLive is the prefix for production API keys
PrefixLive = "live_sk_"
// PrefixTest is the prefix for test/sandbox API keys
PrefixTest = "test_sk_"
// KeyLength is the length of the random part (40 chars in base64url)
KeyLength = 30 // 30 bytes = 40 base64url chars
)
// Generator generates API keys
type Generator interface {
// Generate creates a new live API key
Generate() (string, error)
// GenerateTest creates a new test API key
GenerateTest() (string, error)
}
type generator struct{}
// NewGenerator creates a new API key generator
func NewGenerator() Generator {
return &generator{}
}
// Generate creates a new live API key
func (g *generator) Generate() (string, error) {
return g.generateWithPrefix(PrefixLive)
}
// GenerateTest creates a new test API key
func (g *generator) GenerateTest() (string, error) {
return g.generateWithPrefix(PrefixTest)
}
func (g *generator) generateWithPrefix(prefix string) (string, error) {
// Generate cryptographically secure random bytes
b := make([]byte, KeyLength)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode to base64url (URL-safe, no padding)
key := base64.RawURLEncoding.EncodeToString(b)
// Remove any special chars and make lowercase for consistency
key = strings.Map(func(r rune) rune {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
return r
}
return -1 // Remove character
}, key)
// Ensure we have at least 40 characters
if len(key) < 40 {
// Pad with additional random bytes if needed
additional := make([]byte, 10)
rand.Read(additional)
extraKey := base64.RawURLEncoding.EncodeToString(additional)
key += extraKey
}
// Trim to exactly 40 characters
key = key[:40]
return prefix + key, nil
}
// ExtractPrefix extracts the prefix from an API key
func ExtractPrefix(apiKey string) string {
if len(apiKey) < 13 {
return ""
}
return apiKey[:13] // "live_sk_a1b2" or "test_sk_a1b2"
}
// ExtractLastFour extracts the last 4 characters from an API key
func ExtractLastFour(apiKey string) string {
if len(apiKey) < 4 {
return ""
}
return apiKey[len(apiKey)-4:]
}
// IsValid checks if an API key has a valid format
func IsValid(apiKey string) bool {
return strings.HasPrefix(apiKey, PrefixLive) || strings.HasPrefix(apiKey, PrefixTest)
}

View file

@ -0,0 +1,35 @@
package apikey
import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
)
// Hasher hashes and verifies API keys using SHA-256
type Hasher interface {
// Hash creates a deterministic SHA-256 hash of the API key
Hash(apiKey string) string
// Verify checks if the API key matches the hash using constant-time comparison
Verify(apiKey string, hash string) bool
}
type hasher struct{}
// NewHasher creates a new API key hasher
func NewHasher() Hasher {
return &hasher{}
}
// Hash creates a deterministic SHA-256 hash of the API key
func (h *hasher) Hash(apiKey string) string {
hash := sha256.Sum256([]byte(apiKey))
return base64.StdEncoding.EncodeToString(hash[:])
}
// Verify checks if the API key matches the hash using constant-time comparison
// This prevents timing attacks
func (h *hasher) Verify(apiKey string, expectedHash string) bool {
actualHash := h.Hash(apiKey)
return subtle.ConstantTimeCompare([]byte(actualHash), []byte(expectedHash)) == 1
}

View file

@ -0,0 +1,11 @@
package apikey
// ProvideGenerator provides an API key generator for dependency injection
func ProvideGenerator() Generator {
return NewGenerator()
}
// ProvideHasher provides an API key hasher for dependency injection
func ProvideHasher() Hasher {
return NewHasher()
}

View file

@ -0,0 +1,168 @@
package clientip
import (
"net"
"net/http"
"strings"
"go.uber.org/zap"
)
// Extractor provides secure client IP address extraction
// CWE-348: Prevents X-Forwarded-For header spoofing by validating trusted proxies
type Extractor struct {
trustedProxies []*net.IPNet
logger *zap.Logger
}
// NewExtractor creates a new IP extractor with trusted proxy configuration
// trustedProxyCIDRs should contain CIDR blocks of trusted reverse proxies
// Example: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
func NewExtractor(trustedProxyCIDRs []string, logger *zap.Logger) (*Extractor, error) {
var trustedProxies []*net.IPNet
for _, cidr := range trustedProxyCIDRs {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
logger.Error("failed to parse trusted proxy CIDR",
zap.String("cidr", cidr),
zap.Error(err))
return nil, err
}
trustedProxies = append(trustedProxies, ipNet)
}
logger.Info("client IP extractor initialized",
zap.Int("trusted_proxy_ranges", len(trustedProxies)))
return &Extractor{
trustedProxies: trustedProxies,
logger: logger.Named("client-ip-extractor"),
}, nil
}
// NewDefaultExtractor creates an extractor with no trusted proxies
// This is safe for direct connections but will ignore X-Forwarded-For headers
func NewDefaultExtractor(logger *zap.Logger) *Extractor {
logger.Warn("client IP extractor initialized with NO trusted proxies - X-Forwarded-For will be ignored")
return &Extractor{
trustedProxies: []*net.IPNet{},
logger: logger.Named("client-ip-extractor"),
}
}
// Extract extracts the real client IP address from the HTTP request
// CWE-348: Secure implementation that prevents header spoofing
func (e *Extractor) Extract(r *http.Request) string {
// Step 1: Get the immediate connection's remote address
remoteAddr := r.RemoteAddr
// Remove port from RemoteAddr (format: "IP:port" or "[IPv6]:port")
remoteIP := e.stripPort(remoteAddr)
// Step 2: Parse the remote IP
parsedRemoteIP := net.ParseIP(remoteIP)
if parsedRemoteIP == nil {
e.logger.Warn("failed to parse remote IP address",
zap.String("remote_addr", remoteAddr))
return remoteIP // Return as-is if we can't parse it
}
// Step 3: Check if the immediate connection is from a trusted proxy
if !e.isTrustedProxy(parsedRemoteIP) {
// NOT from a trusted proxy - do NOT trust X-Forwarded-For header
// This prevents clients from spoofing their IP by setting the header
e.logger.Debug("remote IP is not a trusted proxy, using RemoteAddr",
zap.String("remote_ip", remoteIP))
return remoteIP
}
// Step 4: Remote IP is trusted, check X-Forwarded-For header
// Format: "client, proxy1, proxy2" (leftmost is original client)
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
// No X-Forwarded-For header, use RemoteAddr
e.logger.Debug("no X-Forwarded-For header from trusted proxy",
zap.String("remote_ip", remoteIP))
return remoteIP
}
// Step 5: Parse X-Forwarded-For header
// Take the FIRST IP (leftmost) which should be the original client
ips := strings.Split(xff, ",")
if len(ips) == 0 {
e.logger.Debug("empty X-Forwarded-For header",
zap.String("remote_ip", remoteIP))
return remoteIP
}
// Get the first IP and trim whitespace
clientIP := strings.TrimSpace(ips[0])
// Step 6: Validate the client IP
parsedClientIP := net.ParseIP(clientIP)
if parsedClientIP == nil {
e.logger.Warn("invalid IP in X-Forwarded-For header",
zap.String("xff", xff),
zap.String("client_ip", clientIP))
return remoteIP // Fall back to RemoteAddr
}
e.logger.Debug("extracted client IP from X-Forwarded-For",
zap.String("client_ip", clientIP),
zap.String("remote_proxy", remoteIP),
zap.String("xff_chain", xff))
return clientIP
}
// ExtractOrDefault extracts the client IP or returns a default value
func (e *Extractor) ExtractOrDefault(r *http.Request, defaultIP string) string {
ip := e.Extract(r)
if ip == "" {
return defaultIP
}
return ip
}
// isTrustedProxy checks if an IP is in the trusted proxy list
func (e *Extractor) isTrustedProxy(ip net.IP) bool {
for _, ipNet := range e.trustedProxies {
if ipNet.Contains(ip) {
return true
}
}
return false
}
// stripPort removes the port from an address string
// Handles both IPv4 (192.168.1.1:8080) and IPv6 ([::1]:8080) formats
func (e *Extractor) stripPort(addr string) string {
// For IPv6, check for bracket format [IP]:port
if strings.HasPrefix(addr, "[") {
// IPv6 format: [::1]:8080
if idx := strings.LastIndex(addr, "]:"); idx != -1 {
return addr[1:idx] // Extract IP between [ and ]
}
// Malformed IPv6 address
return addr
}
// For IPv4, split on last colon
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
// No port found
return addr
}
// GetTrustedProxyCount returns the number of configured trusted proxy ranges
func (e *Extractor) GetTrustedProxyCount() int {
return len(e.trustedProxies)
}
// HasTrustedProxies returns true if any trusted proxies are configured
func (e *Extractor) HasTrustedProxies() bool {
return len(e.trustedProxies) > 0
}

View file

@ -0,0 +1,19 @@
package clientip
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideExtractor provides a client IP extractor configured from the application config
func ProvideExtractor(cfg *config.Config, logger *zap.Logger) (*Extractor, error) {
// If no trusted proxies configured, use default (no X-Forwarded-For trust)
if len(cfg.Security.TrustedProxies) == 0 {
logger.Info("no trusted proxies configured - X-Forwarded-For headers will be ignored for security")
return NewDefaultExtractor(logger), nil
}
// Create extractor with trusted proxies
return NewExtractor(cfg.Security.TrustedProxies, logger)
}

View file

@ -0,0 +1,127 @@
package ipcountryblocker
import (
"context"
"fmt"
"net"
"sync"
"github.com/oschwald/geoip2-golang"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// Provider defines the interface for IP-based country blocking operations.
// It provides methods to check if an IP or country is blocked and to retrieve
// country codes for given IP addresses.
type Provider interface {
// IsBlockedCountry checks if a country is in the blocked list.
// isoCode must be an ISO 3166-1 alpha-2 country code.
IsBlockedCountry(isoCode string) bool
// IsBlockedIP determines if an IP address originates from a blocked country.
// Returns false for nil IP addresses or if country lookup fails.
IsBlockedIP(ctx context.Context, ip net.IP) bool
// GetCountryCode returns the ISO 3166-1 alpha-2 country code for an IP address.
// Returns an error if the lookup fails or no country is found.
GetCountryCode(ctx context.Context, ip net.IP) (string, error)
// Close releases resources associated with the provider.
Close() error
}
// provider implements the Provider interface using MaxMind's GeoIP2 database.
type provider struct {
db *geoip2.Reader
blockedCountries map[string]struct{} // Uses empty struct to optimize memory
logger *zap.Logger
mu sync.RWMutex // Protects concurrent access to blockedCountries
}
// NewProvider creates a new IP country blocking provider using the provided configuration.
// It initializes the GeoIP2 database and sets up the blocked countries list.
// Fatally crashes the entire application if the database cannot be opened.
func NewProvider(cfg *config.Config, logger *zap.Logger) Provider {
logger.Info("⏳ Loading GeoIP2 database...",
zap.String("db_path", cfg.App.GeoLiteDBPath))
db, err := geoip2.Open(cfg.App.GeoLiteDBPath)
if err != nil {
logger.Fatal("Failed to open GeoLite2 database",
zap.String("db_path", cfg.App.GeoLiteDBPath),
zap.Error(err))
}
blocked := make(map[string]struct{}, len(cfg.App.BannedCountries))
for _, country := range cfg.App.BannedCountries {
blocked[country] = struct{}{}
}
logger.Info("✓ IP country blocker initialized",
zap.Int("blocked_countries", len(cfg.App.BannedCountries)))
return &provider{
db: db,
blockedCountries: blocked,
logger: logger,
}
}
// IsBlockedCountry checks if a country code exists in the blocked countries map.
// Thread-safe through RLock.
func (p *provider) IsBlockedCountry(isoCode string) bool {
p.mu.RLock()
defer p.mu.RUnlock()
_, exists := p.blockedCountries[isoCode]
return exists
}
// IsBlockedIP performs a country lookup for the IP and checks if it's blocked.
// Returns false for nil IPs or failed lookups to fail safely.
func (p *provider) IsBlockedIP(ctx context.Context, ip net.IP) bool {
if ip == nil {
return false
}
code, err := p.GetCountryCode(ctx, ip)
if err != nil {
// Developers Note:
// Comment this console log as it contributes a `noisy` server log.
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// p.logger.WarnContext(ctx, "failed to get country code",
// zap.Any("ip", ip),
// zap.Any("error", err))
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// Developers Note:
// If the country d.n.e. exist that means we will return with `false`
// indicating this IP address is allowed to access our server. If this
// is concerning then you might set this to `true` to block on all
// IP address which are not categorized by country.
return false
}
return p.IsBlockedCountry(code)
}
// GetCountryCode performs a GeoIP2 database lookup to determine an IP's country.
// Returns an error if the lookup fails or no country is found.
func (p *provider) GetCountryCode(ctx context.Context, ip net.IP) (string, error) {
record, err := p.db.Country(ip)
if err != nil {
return "", fmt.Errorf("lookup country: %w", err)
}
if record == nil || record.Country.IsoCode == "" {
return "", fmt.Errorf("no country found for IP: %v", ip)
}
return record.Country.IsoCode, nil
}
// Close cleanly shuts down the GeoIP2 database connection.
func (p *provider) Close() error {
return p.db.Close()
}

View file

@ -0,0 +1,12 @@
package ipcountryblocker
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideIPCountryBlocker creates a new IP country blocker provider instance.
func ProvideIPCountryBlocker(cfg *config.Config, logger *zap.Logger) Provider {
return NewProvider(cfg, logger)
}

View file

@ -0,0 +1,221 @@
package ipcrypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"time"
"go.uber.org/zap"
)
// IPEncryptor provides secure IP address encryption for GDPR compliance
// Uses AES-GCM (Galois/Counter Mode) for authenticated encryption
// Encrypts IP addresses before storage and provides expiration checking
type IPEncryptor struct {
gcm cipher.AEAD
logger *zap.Logger
}
// NewIPEncryptor creates a new IP encryptor with the given encryption key
// keyHex should be a 32-character hex string (16 bytes for AES-128)
// or 64-character hex string (32 bytes for AES-256)
// Example: "0123456789abcdef0123456789abcdef" (AES-128)
// Recommended: Use AES-256 with 64-character hex key
func NewIPEncryptor(keyHex string, logger *zap.Logger) (*IPEncryptor, error) {
// Decode hex key to bytes
keyBytes, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid hex key: %w", err)
}
// AES requires exactly 16, 24, or 32 bytes
if len(keyBytes) != 16 && len(keyBytes) != 24 && len(keyBytes) != 32 {
return nil, fmt.Errorf("key must be 16, 24, or 32 bytes (32, 48, or 64 hex characters), got %d bytes", len(keyBytes))
}
// Create AES cipher block
block, err := aes.NewCipher(keyBytes)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
// Create GCM (Galois/Counter Mode) for authenticated encryption
// GCM provides both confidentiality and integrity
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
logger.Info("IP encryptor initialized with AES-GCM",
zap.Int("key_length_bytes", len(keyBytes)),
zap.Int("nonce_size", gcm.NonceSize()),
zap.Int("overhead", gcm.Overhead()))
return &IPEncryptor{
gcm: gcm,
logger: logger.Named("ip-encryptor"),
}, nil
}
// Encrypt encrypts an IP address for secure storage using AES-GCM
// Returns base64-encoded encrypted IP address with embedded nonce
// Format: base64(nonce + ciphertext + auth_tag)
// Supports both IPv4 and IPv6 addresses
//
// Security Properties:
// - Semantic security: same IP address produces different ciphertext each time
// - Authentication: tampering with ciphertext is detected
// - Unique nonce per encryption prevents pattern analysis
func (e *IPEncryptor) Encrypt(ipAddress string) (string, error) {
if ipAddress == "" {
return "", nil // Empty string remains empty
}
// Parse IP address to validate format
ip := net.ParseIP(ipAddress)
if ip == nil {
e.logger.Warn("invalid IP address format",
zap.String("ip", ipAddress))
return "", fmt.Errorf("invalid IP address: %s", ipAddress)
}
// Convert to 16-byte representation (IPv4 gets converted to IPv6 format)
ipBytes := ip.To16()
if ipBytes == nil {
return "", fmt.Errorf("failed to convert IP to 16-byte format")
}
// Generate a random nonce (number used once)
// GCM requires a unique nonce for each encryption operation
nonce := make([]byte, e.gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil {
e.logger.Error("failed to generate nonce", zap.Error(err))
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt the IP bytes using AES-GCM
// GCM appends the authentication tag to the ciphertext
// nil additional data means no associated data
ciphertext := e.gcm.Seal(nil, nonce, ipBytes, nil)
// Prepend nonce to ciphertext for storage
// Format: nonce || ciphertext+tag
encryptedData := append(nonce, ciphertext...)
// Encode to base64 for database storage (text-safe)
encryptedBase64 := base64.StdEncoding.EncodeToString(encryptedData)
e.logger.Debug("IP address encrypted with AES-GCM",
zap.Int("plaintext_length", len(ipBytes)),
zap.Int("nonce_length", len(nonce)),
zap.Int("ciphertext_length", len(ciphertext)),
zap.Int("total_encrypted_length", len(encryptedData)),
zap.Int("base64_length", len(encryptedBase64)))
return encryptedBase64, nil
}
// Decrypt decrypts an encrypted IP address
// Takes base64-encoded encrypted IP and returns original IP address string
// Verifies authentication tag to detect tampering
func (e *IPEncryptor) Decrypt(encryptedBase64 string) (string, error) {
if encryptedBase64 == "" {
return "", nil // Empty string remains empty
}
// Decode base64 to bytes
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
if err != nil {
e.logger.Warn("invalid base64-encoded encrypted IP",
zap.String("base64", encryptedBase64),
zap.Error(err))
return "", fmt.Errorf("invalid base64 encoding: %w", err)
}
// Extract nonce from the beginning
nonceSize := e.gcm.NonceSize()
if len(encryptedData) < nonceSize {
return "", fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", nonceSize, len(encryptedData))
}
nonce := encryptedData[:nonceSize]
ciphertext := encryptedData[nonceSize:]
// Decrypt and verify authentication tag using AES-GCM
ipBytes, err := e.gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
e.logger.Warn("failed to decrypt IP address (authentication failed or corrupted data)",
zap.Error(err))
return "", fmt.Errorf("decryption failed: %w", err)
}
// Convert bytes to IP address
ip := net.IP(ipBytes)
if ip == nil {
return "", fmt.Errorf("failed to parse decrypted IP bytes")
}
// Convert to string
ipString := ip.String()
e.logger.Debug("IP address decrypted with AES-GCM",
zap.Int("encrypted_length", len(encryptedData)),
zap.Int("decrypted_length", len(ipBytes)))
return ipString, nil
}
// IsExpired checks if an IP address timestamp has expired (> 90 days old)
// GDPR compliance: IP addresses must be deleted after 90 days
func (e *IPEncryptor) IsExpired(timestamp time.Time) bool {
if timestamp.IsZero() {
return false // No timestamp means not expired (will be cleaned up later)
}
// Calculate age in days
age := time.Since(timestamp)
ageInDays := int(age.Hours() / 24)
expired := ageInDays > 90
if expired {
e.logger.Debug("IP timestamp expired",
zap.Time("timestamp", timestamp),
zap.Int("age_days", ageInDays))
}
return expired
}
// ShouldCleanup checks if an IP address should be cleaned up based on timestamp
// Returns true if timestamp is older than 90 days OR if timestamp is zero (unset)
func (e *IPEncryptor) ShouldCleanup(timestamp time.Time) bool {
// Always cleanup if timestamp is not set (backwards compatibility)
if timestamp.IsZero() {
return false // Don't cleanup unset timestamps immediately
}
return e.IsExpired(timestamp)
}
// ValidateKey validates that a key is properly formatted for IP encryption
// Returns true if key is valid 32-character hex string (AES-128) or 64-character (AES-256)
func ValidateKey(keyHex string) error {
// Check length (must be 16, 24, or 32 bytes = 32, 48, or 64 hex chars)
if len(keyHex) != 32 && len(keyHex) != 48 && len(keyHex) != 64 {
return fmt.Errorf("key must be 32, 48, or 64 hex characters, got %d characters", len(keyHex))
}
// Check if valid hex
_, err := hex.DecodeString(keyHex)
if err != nil {
return fmt.Errorf("key must be valid hex string: %w", err)
}
return nil
}

View file

@ -0,0 +1,13 @@
package ipcrypt
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideIPEncryptor provides an IP encryptor instance
// CWE-359: GDPR compliance for IP address storage
func ProvideIPEncryptor(cfg *config.Config, logger *zap.Logger) (*IPEncryptor, error) {
return NewIPEncryptor(cfg.Security.IPEncryptionKey, logger)
}

View file

@ -0,0 +1,110 @@
package jwt
import (
"fmt"
"log"
"time"
"github.com/golang-jwt/jwt/v5"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/validator"
)
// Provider provides interface for JWT token generation and validation
type Provider interface {
GenerateToken(sessionID string, duration time.Duration) (string, time.Time, error)
GenerateTokenPair(sessionID string, accessDuration time.Duration, refreshDuration time.Duration) (accessToken string, accessExpiry time.Time, refreshToken string, refreshExpiry time.Time, err error)
ValidateToken(tokenString string) (sessionID string, err error)
}
type provider struct {
secret []byte
}
// NewProvider creates a new JWT provider with security validation
func NewProvider(cfg *config.Config) Provider {
// Validate JWT secret security before creating provider
v := validator.NewCredentialValidator()
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
// Log detailed error with remediation steps
log.Printf("[SECURITY ERROR] %s", err.Error())
// In production, this is a fatal error that should prevent startup
if cfg.App.Environment == "production" {
panic(fmt.Sprintf("SECURITY: Invalid JWT secret in production environment: %s", err.Error()))
}
// In development, log warning but allow to continue
log.Printf("[WARNING] Continuing with weak JWT secret in %s environment. This is NOT safe for production!", cfg.App.Environment)
}
return &provider{
secret: []byte(cfg.App.JWTSecret),
}
}
// GenerateToken generates a single JWT token
func (p *provider) GenerateToken(sessionID string, duration time.Duration) (string, time.Time, error) {
expiresAt := time.Now().Add(duration)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"session_id": sessionID,
"exp": expiresAt.Unix(),
})
tokenString, err := token.SignedString(p.secret)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to sign token: %w", err)
}
return tokenString, expiresAt, nil
}
// GenerateTokenPair generates both access token and refresh token
func (p *provider) GenerateTokenPair(sessionID string, accessDuration time.Duration, refreshDuration time.Duration) (string, time.Time, string, time.Time, error) {
// Generate access token
accessToken, accessExpiry, err := p.GenerateToken(sessionID, accessDuration)
if err != nil {
return "", time.Time{}, "", time.Time{}, fmt.Errorf("failed to generate access token: %w", err)
}
// Generate refresh token
refreshToken, refreshExpiry, err := p.GenerateToken(sessionID, refreshDuration)
if err != nil {
return "", time.Time{}, "", time.Time{}, fmt.Errorf("failed to generate refresh token: %w", err)
}
return accessToken, accessExpiry, refreshToken, refreshExpiry, nil
}
// ValidateToken validates a JWT token and returns the session ID
func (p *provider) ValidateToken(tokenString string) (string, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return p.secret, nil
})
if err != nil {
return "", fmt.Errorf("failed to parse token: %w", err)
}
if !token.Valid {
return "", fmt.Errorf("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", fmt.Errorf("invalid token claims")
}
sessionID, ok := claims["session_id"].(string)
if !ok {
return "", fmt.Errorf("session_id not found in token")
}
return sessionID, nil
}

View file

@ -0,0 +1,10 @@
package jwt
import (
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideProvider provides a JWT provider instance for Wire dependency injection
func ProvideProvider(cfg *config.Config) Provider {
return NewProvider(cfg)
}

View file

@ -0,0 +1,149 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/breachcheck.go
package password
import (
"context"
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"go.uber.org/zap"
)
var (
// ErrPasswordBreached indicates the password has been found in known data breaches
ErrPasswordBreached = fmt.Errorf("password has been found in data breaches")
)
// BreachChecker checks if passwords have been compromised in known data breaches
// using the Have I Been Pwned API's k-anonymity model
type BreachChecker interface {
// CheckPassword checks if a password has been breached
// Returns the number of times the password was found in breaches (0 = safe)
CheckPassword(ctx context.Context, password string) (int, error)
// IsPasswordBreached returns true if password has been found in breaches
IsPasswordBreached(ctx context.Context, password string) (bool, error)
}
type breachChecker struct {
httpClient *http.Client
apiURL string
userAgent string
logger *zap.Logger
}
// NewBreachChecker creates a new password breach checker
// CWE-521: Password breach checking using Have I Been Pwned API
// Uses k-anonymity model - only sends first 5 characters of SHA-1 hash
func NewBreachChecker(logger *zap.Logger) BreachChecker {
return &breachChecker{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
apiURL: "https://api.pwnedpasswords.com/range/",
userAgent: "MaplePress-Backend-Password-Checker",
logger: logger.Named("breach-checker"),
}
}
// CheckPassword checks if a password has been breached using HIBP k-anonymity API
// Returns the number of times the password appears in breaches (0 = safe)
// CWE-521: This implements password breach checking without sending the full password
func (bc *breachChecker) CheckPassword(ctx context.Context, password string) (int, error) {
// Step 1: SHA-1 hash the password
hash := sha1.Sum([]byte(password))
hashStr := strings.ToUpper(hex.EncodeToString(hash[:]))
// Step 2: Take first 5 characters (k-anonymity prefix)
prefix := hashStr[:5]
suffix := hashStr[5:]
// Step 3: Query HIBP API with prefix only
url := bc.apiURL + prefix
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
bc.logger.Error("failed to create HIBP request", zap.Error(err))
return 0, fmt.Errorf("failed to create request: %w", err)
}
// Set User-Agent as required by HIBP API
req.Header.Set("User-Agent", bc.userAgent)
req.Header.Set("Add-Padding", "true") // Request padding for additional privacy
bc.logger.Debug("checking password against HIBP",
zap.String("prefix", prefix))
resp, err := bc.httpClient.Do(req)
if err != nil {
bc.logger.Error("failed to query HIBP API", zap.Error(err))
return 0, fmt.Errorf("failed to query breach database: %w", err)
}
if resp == nil {
bc.logger.Error("received nil response from HIBP API")
return 0, fmt.Errorf("received nil response from breach database")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bc.logger.Error("HIBP API returned non-OK status",
zap.Int("status", resp.StatusCode))
return 0, fmt.Errorf("breach database returned status %d", resp.StatusCode)
}
// Step 4: Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
bc.logger.Error("failed to read HIBP response", zap.Error(err))
return 0, fmt.Errorf("failed to read response: %w", err)
}
// Step 5: Parse response and look for our suffix
// Response format: SUFFIX:COUNT\r\n for each hash
lines := strings.Split(string(body), "\r\n")
for _, line := range lines {
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) != 2 {
continue
}
// Check if this is our hash
if parts[0] == suffix {
count, err := strconv.Atoi(parts[1])
if err != nil {
bc.logger.Warn("failed to parse breach count",
zap.String("line", line),
zap.Error(err))
return 0, fmt.Errorf("failed to parse breach count: %w", err)
}
bc.logger.Warn("password found in data breaches",
zap.Int("breach_count", count))
return count, nil
}
}
// Password not found in breaches
bc.logger.Debug("password not found in breaches")
return 0, nil
}
// IsPasswordBreached returns true if password has been found in data breaches
// This is a convenience wrapper around CheckPassword
func (bc *breachChecker) IsPasswordBreached(ctx context.Context, password string) (bool, error) {
count, err := bc.CheckPassword(ctx, password)
if err != nil {
return false, err
}
return count > 0, nil
}

View file

@ -0,0 +1,200 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/password.go
package password
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"strings"
"golang.org/x/crypto/argon2"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
)
var (
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
ErrPasswordTooLong = errors.New("password must not exceed 128 characters")
// Granular password strength errors (CWE-521: Weak Password Requirements)
ErrPasswordNoUppercase = errors.New("password must contain at least one uppercase letter (A-Z)")
ErrPasswordNoLowercase = errors.New("password must contain at least one lowercase letter (a-z)")
ErrPasswordNoNumber = errors.New("password must contain at least one number (0-9)")
ErrPasswordNoSpecialChar = errors.New("password must contain at least one special character (!@#$%^&*()_+-=[]{}; etc.)")
ErrPasswordTooWeak = errors.New("password must contain uppercase, lowercase, number, and special character")
)
// PasswordProvider provides secure password hashing and verification using Argon2id.
type PasswordProvider interface {
GenerateHashFromPassword(password *securestring.SecureString) (string, error)
ComparePasswordAndHash(password *securestring.SecureString, hash string) (bool, error)
AlgorithmName() string
GenerateSecureRandomBytes(length int) ([]byte, error)
GenerateSecureRandomString(length int) (string, error)
}
type passwordProvider struct {
memory uint32
iterations uint32
parallelism uint8
saltLength uint32
keyLength uint32
}
// NewPasswordProvider creates a new password provider with secure default parameters.
// The default parameters are based on OWASP recommendations for Argon2id:
// - Memory: 64 MB
// - Iterations: 3
// - Parallelism: 2
// - Salt length: 16 bytes
// - Key length: 32 bytes
func NewPasswordProvider() PasswordProvider {
// DEVELOPERS NOTE:
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
// Establish the parameters to use for Argon2
return &passwordProvider{
memory: 64 * 1024, // 64 MB
iterations: 3,
parallelism: 2,
saltLength: 16,
keyLength: 32,
}
}
// GenerateHashFromPassword takes a secure string and returns an Argon2id hashed string.
// The returned hash string includes all parameters needed for verification:
// Format: $argon2id$v=19$m=65536,t=3,p=2$<base64-salt>$<base64-hash>
func (p *passwordProvider) GenerateHashFromPassword(password *securestring.SecureString) (string, error) {
salt, err := generateRandomBytes(p.saltLength)
if err != nil {
return "", fmt.Errorf("failed to generate salt: %w", err)
}
passwordBytes := password.Bytes()
// Generate the hash using Argon2id
hash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
// Base64 encode the salt and hashed password
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
// Return a string using the standard encoded hash representation
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
argon2.Version, p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
return encodedHash, nil
}
// ComparePasswordAndHash verifies that a password matches the provided hash.
// It uses constant-time comparison to prevent timing attacks.
// Returns true if the password matches, false otherwise.
func (p *passwordProvider) ComparePasswordAndHash(password *securestring.SecureString, encodedHash string) (match bool, err error) {
// DEVELOPERS NOTE:
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
// Extract the parameters, salt and derived key from the encoded password hash
params, salt, hash, err := decodeHash(encodedHash)
if err != nil {
return false, err
}
// Derive the key from the password using the same parameters
otherHash := argon2.IDKey(password.Bytes(), salt, params.iterations, params.memory, params.parallelism, params.keyLength)
// Check that the contents of the hashed passwords are identical
// Using subtle.ConstantTimeCompare() to help prevent timing attacks
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
return true, nil
}
return false, nil
}
// AlgorithmName returns the name of the hashing algorithm used.
func (p *passwordProvider) AlgorithmName() string {
return "argon2id"
}
// GenerateSecureRandomBytes generates a cryptographically secure random byte slice.
func (p *passwordProvider) GenerateSecureRandomBytes(length int) ([]byte, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return nil, fmt.Errorf("failed to generate secure random bytes: %w", err)
}
return bytes, nil
}
// GenerateSecureRandomString generates a cryptographically secure random hex string.
// The returned string will be twice the length parameter (2 hex chars per byte).
func (p *passwordProvider) GenerateSecureRandomString(length int) (string, error) {
bytes, err := p.GenerateSecureRandomBytes(length)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
// generateRandomBytes generates cryptographically secure random bytes.
func generateRandomBytes(n uint32) ([]byte, error) {
// DEVELOPERS NOTE:
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// decodeHash extracts the parameters, salt, and hash from an encoded hash string.
func decodeHash(encodedHash string) (p *passwordProvider, salt, hash []byte, err error) {
// DEVELOPERS NOTE:
// The following code was adapted from: "How to Hash and Verify Passwords With Argon2 in Go"
// via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
vals := strings.Split(encodedHash, "$")
if len(vals) != 6 {
return nil, nil, nil, ErrInvalidHash
}
var version int
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
if err != nil {
return nil, nil, nil, err
}
if version != argon2.Version {
return nil, nil, nil, ErrIncompatibleVersion
}
p = &passwordProvider{}
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
if err != nil {
return nil, nil, nil, err
}
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
if err != nil {
return nil, nil, nil, err
}
p.saltLength = uint32(len(salt))
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
if err != nil {
return nil, nil, nil, err
}
p.keyLength = uint32(len(hash))
return p, salt, hash, nil
}

View file

@ -0,0 +1,6 @@
package password
// ProvidePasswordProvider creates a new password provider instance.
func ProvidePasswordProvider() PasswordProvider {
return NewPasswordProvider()
}

View file

@ -0,0 +1,44 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/security/password/timing.go
package password
import (
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/securestring"
)
// DummyPasswordHash is a pre-computed valid Argon2id hash used for timing attack mitigation
// This hash is computed with the same parameters as real password hashes
// CWE-208: Observable Timing Discrepancy - Prevents user enumeration via timing attacks
const DummyPasswordHash = "$argon2id$v=19$m=65536,t=3,p=2$c29tZXJhbmRvbXNhbHQxMjM0$kixiIQQ/y8E7dSH0j8p8KPBUlCMUGQOvH2kP7XYPkVs"
// ComparePasswordWithDummy performs password comparison but always uses a dummy hash
// This is used when a user doesn't exist to maintain constant time behavior
// CWE-208: Observable Timing Discrepancy - Mitigates timing-based user enumeration
func (p *passwordProvider) ComparePasswordWithDummy(password *securestring.SecureString) error {
// Perform the same expensive operation (Argon2 hashing) even for non-existent users
// This ensures the timing is constant regardless of whether the user exists
_, _ = p.ComparePasswordAndHash(password, DummyPasswordHash)
// Always return false (user doesn't exist, so authentication always fails)
// The important part is that we spent the same amount of time
return nil
}
// TimingSafeCompare performs a timing-safe password comparison
// It always performs the password hashing operation regardless of whether
// the user exists or the password matches
// CWE-208: Observable Timing Discrepancy - Prevents timing attacks
func TimingSafeCompare(provider PasswordProvider, password *securestring.SecureString, hash string, userExists bool) (bool, error) {
if !userExists {
// User doesn't exist - perform dummy hash comparison to maintain constant time
if pp, ok := provider.(*passwordProvider); ok {
_ = pp.ComparePasswordWithDummy(password)
} else {
// Fallback if type assertion fails
_, _ = provider.ComparePasswordAndHash(password, DummyPasswordHash)
}
return false, nil
}
// User exists - perform real comparison
return provider.ComparePasswordAndHash(password, hash)
}

View file

@ -0,0 +1,90 @@
package password
import (
"regexp"
"unicode"
)
const (
// MinPasswordLength is the minimum required password length
MinPasswordLength = 8
// MaxPasswordLength is the maximum allowed password length
MaxPasswordLength = 128
)
var (
// Special characters allowed in passwords
specialCharRegex = regexp.MustCompile(`[!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]`)
)
// PasswordValidator provides password strength validation
type PasswordValidator interface {
ValidatePasswordStrength(password string) error
}
type passwordValidator struct{}
// NewPasswordValidator creates a new password validator
func NewPasswordValidator() PasswordValidator {
return &passwordValidator{}
}
// ValidatePasswordStrength validates that a password meets strength requirements
// Requirements:
// - At least 8 characters long
// - At most 128 characters long
// - Contains at least one uppercase letter
// - Contains at least one lowercase letter
// - Contains at least one digit
// - Contains at least one special character
//
// CWE-521: Returns granular error messages to help users create strong passwords
func (v *passwordValidator) ValidatePasswordStrength(password string) error {
// Check length first
if len(password) < MinPasswordLength {
return ErrPasswordTooShort
}
if len(password) > MaxPasswordLength {
return ErrPasswordTooLong
}
// Check character type requirements
var (
hasUpper bool
hasLower bool
hasNumber bool
hasSpecial bool
)
for _, char := range password {
switch {
case unicode.IsUpper(char):
hasUpper = true
case unicode.IsLower(char):
hasLower = true
case unicode.IsNumber(char):
hasNumber = true
}
}
// Check for special characters
hasSpecial = specialCharRegex.MatchString(password)
// Return granular error for the first missing requirement
// This provides specific feedback to users about what's missing
if !hasUpper {
return ErrPasswordNoUppercase
}
if !hasLower {
return ErrPasswordNoLowercase
}
if !hasNumber {
return ErrPasswordNoNumber
}
if !hasSpecial {
return ErrPasswordNoSpecialChar
}
return nil
}

View file

@ -0,0 +1,20 @@
package security
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/clientip"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/pkg/security/jwt"
)
// ProvideJWTProvider provides a JWT provider instance
func ProvideJWTProvider(cfg *config.Config) jwt.Provider {
return jwt.NewProvider(cfg)
}
// ProvideClientIPExtractor provides a client IP extractor instance
// CWE-348: Secure IP extraction with X-Forwarded-For validation
func ProvideClientIPExtractor(cfg *config.Config, logger *zap.Logger) (*clientip.Extractor, error) {
return clientip.ProvideExtractor(cfg, logger)
}

View file

@ -0,0 +1,49 @@
package securebytes
import (
"errors"
"github.com/awnumar/memguard"
)
// SecureBytes is used to store a byte slice securely in memory.
// It uses memguard to protect sensitive data from being exposed in memory dumps,
// swap files, or other memory scanning attacks.
type SecureBytes struct {
buffer *memguard.LockedBuffer
}
// NewSecureBytes creates a new SecureBytes instance from the given byte slice.
// The original byte slice should be wiped after creating SecureBytes to ensure
// the sensitive data is only stored in the secure buffer.
func NewSecureBytes(b []byte) (*SecureBytes, error) {
if len(b) == 0 {
return nil, errors.New("byte slice cannot be empty")
}
buffer := memguard.NewBuffer(len(b))
// Check if buffer was created successfully
if buffer == nil {
return nil, errors.New("failed to create buffer")
}
copy(buffer.Bytes(), b)
return &SecureBytes{buffer: buffer}, nil
}
// Bytes returns the securely stored byte slice.
// WARNING: The returned bytes are still protected by memguard, but any copies
// made from this slice will not be protected. Use with caution.
func (sb *SecureBytes) Bytes() []byte {
return sb.buffer.Bytes()
}
// Wipe removes the byte slice from memory and makes it unrecoverable.
// After calling Wipe, the SecureBytes instance should not be used.
func (sb *SecureBytes) Wipe() error {
sb.buffer.Wipe()
sb.buffer = nil
return nil
}

View file

@ -0,0 +1,71 @@
package securestring
import (
"errors"
"github.com/awnumar/memguard"
)
// SecureString is used to store a string securely in memory.
// It uses memguard to protect sensitive data like passwords, API keys, etc.
// from being exposed in memory dumps, swap files, or other memory scanning attacks.
type SecureString struct {
buffer *memguard.LockedBuffer
}
// NewSecureString creates a new SecureString instance from the given string.
// The original string should be cleared/wiped after creating SecureString to ensure
// the sensitive data is only stored in the secure buffer.
func NewSecureString(s string) (*SecureString, error) {
if len(s) == 0 {
return nil, errors.New("string cannot be empty")
}
// Use memguard's built-in method for creating from bytes
buffer := memguard.NewBufferFromBytes([]byte(s))
// Check if buffer was created successfully
if buffer == nil {
return nil, errors.New("failed to create buffer")
}
return &SecureString{buffer: buffer}, nil
}
// String returns the securely stored string.
// WARNING: The returned string is a copy and will not be protected by memguard.
// Use this method carefully and wipe the string after use if possible.
func (ss *SecureString) String() string {
if ss.buffer == nil {
return ""
}
if !ss.buffer.IsAlive() {
return ""
}
return ss.buffer.String()
}
// Bytes returns the byte representation of the securely stored string.
// WARNING: The returned bytes are still protected by memguard, but any copies
// made from this slice will not be protected. Use with caution.
func (ss *SecureString) Bytes() []byte {
if ss.buffer == nil {
return nil
}
if !ss.buffer.IsAlive() {
return nil
}
return ss.buffer.Bytes()
}
// Wipe removes the string from memory and makes it unrecoverable.
// After calling Wipe, the SecureString instance should not be used.
func (ss *SecureString) Wipe() error {
if ss.buffer != nil {
if ss.buffer.IsAlive() {
ss.buffer.Destroy()
}
}
ss.buffer = nil
return nil
}

View file

@ -0,0 +1,435 @@
package validator
import (
"fmt"
"math"
"strings"
"unicode"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
const (
// MinJWTSecretLength is the minimum required length for JWT secrets (256 bits)
MinJWTSecretLength = 32
// RecommendedJWTSecretLength is the recommended length for JWT secrets (512 bits)
RecommendedJWTSecretLength = 64
// MinEntropyBits is the minimum Shannon entropy in bits per character
// For reference: random base64 has ~6 bits/char, we require minimum 4.0
MinEntropyBits = 4.0
// MinProductionEntropyBits is the minimum entropy required for production
MinProductionEntropyBits = 4.5
// MaxRepeatingCharacters is the maximum allowed consecutive repeating characters
MaxRepeatingCharacters = 3
)
// WeakSecrets contains common weak/default secrets that should never be used
var WeakSecrets = []string{
"secret",
"password",
"changeme",
"change-me",
"change_me",
"12345",
"123456",
"1234567",
"12345678",
"123456789",
"1234567890",
"default",
"test",
"testing",
"admin",
"administrator",
"root",
"qwerty",
"qwertyuiop",
"letmein",
"welcome",
"monkey",
"dragon",
"master",
"sunshine",
"princess",
"football",
"starwars",
"baseball",
"superman",
"iloveyou",
"trustno1",
"hello",
"abc123",
"password123",
"admin123",
"guest",
"user",
"demo",
"sample",
"example",
}
// DangerousPatterns contains patterns that indicate a secret should be changed
var DangerousPatterns = []string{
"change",
"replace",
"update",
"modify",
"sample",
"example",
"todo",
"fixme",
"temp",
"temporary",
}
// CredentialValidator validates credentials and secrets for security issues
type CredentialValidator interface {
ValidateJWTSecret(secret string, environment string) error
ValidateAllCredentials(cfg *config.Config) error
}
type credentialValidator struct{}
// NewCredentialValidator creates a new credential validator
func NewCredentialValidator() CredentialValidator {
return &credentialValidator{}
}
// ValidateJWTSecret validates JWT secret strength and security
// CWE-798: Comprehensive validation to prevent hard-coded/weak credentials
func (v *credentialValidator) ValidateJWTSecret(secret string, environment string) error {
// Check minimum length
if len(secret) < MinJWTSecretLength {
return fmt.Errorf(
"JWT secret is too short (%d characters). Minimum required: %d characters (256 bits). "+
"Generate a secure secret with: openssl rand -base64 64",
len(secret),
MinJWTSecretLength,
)
}
// Check for common weak secrets (case-insensitive)
secretLower := strings.ToLower(secret)
for _, weak := range WeakSecrets {
if secretLower == weak || strings.Contains(secretLower, weak) {
return fmt.Errorf(
"JWT secret cannot contain common weak value: '%s'. "+
"Generate a secure secret with: openssl rand -base64 64",
weak,
)
}
}
// Check for dangerous patterns indicating default/placeholder values
for _, pattern := range DangerousPatterns {
if strings.Contains(secretLower, pattern) {
return fmt.Errorf(
"JWT secret contains suspicious pattern '%s' which suggests it's a placeholder. "+
"Generate a secure secret with: openssl rand -base64 64",
pattern,
)
}
}
// Check for repeating character patterns (e.g., "aaaa", "1111")
if err := checkRepeatingPatterns(secret); err != nil {
return fmt.Errorf(
"JWT secret validation failed: %s. "+
"Generate a secure secret with: openssl rand -base64 64",
err.Error(),
)
}
// Check for sequential patterns (e.g., "abcd", "1234")
if hasSequentialPattern(secret) {
return fmt.Errorf(
"JWT secret contains sequential patterns (e.g., 'abcd', '1234') which reduces entropy. "+
"Generate a secure secret with: openssl rand -base64 64",
)
}
// Calculate Shannon entropy
entropy := calculateShannonEntropy(secret)
minEntropy := MinEntropyBits
if environment == "production" {
minEntropy = MinProductionEntropyBits
}
if entropy < minEntropy {
return fmt.Errorf(
"JWT secret has insufficient entropy: %.2f bits/char (minimum: %.1f bits/char for %s). "+
"The secret appears to have low randomness. "+
"Generate a secure secret with: openssl rand -base64 64",
entropy,
minEntropy,
environment,
)
}
// In production, enforce stricter requirements
if environment == "production" {
// Check recommended length for production
if len(secret) < RecommendedJWTSecretLength {
return fmt.Errorf(
"JWT secret is too short for production environment (%d characters). "+
"Recommended: %d characters (512 bits). "+
"Generate a secure secret with: openssl rand -base64 64",
len(secret),
RecommendedJWTSecretLength,
)
}
// Check for sufficient character complexity
if !hasSufficientComplexity(secret) {
return fmt.Errorf(
"JWT secret has insufficient complexity for production. It should contain a mix of uppercase, lowercase, " +
"digits, and special characters (at least 3 types). Generate a secure secret with: openssl rand -base64 64",
)
}
// Validate base64-like characteristics (recommended generation method)
if !looksLikeBase64(secret) {
return fmt.Errorf(
"JWT secret does not appear to be randomly generated (expected base64-like characteristics). "+
"Generate a secure secret with: openssl rand -base64 64",
)
}
}
return nil
}
// ValidateAllCredentials validates all credentials in the configuration
func (v *credentialValidator) ValidateAllCredentials(cfg *config.Config) error {
var errors []string
// Validate JWT Secret
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
errors = append(errors, fmt.Sprintf("JWT Secret validation failed: %s", err.Error()))
}
// In production, ensure other critical configs are not using defaults/placeholders
if cfg.App.Environment == "production" {
// Check Meilisearch API key
if cfg.Meilisearch.APIKey == "" {
errors = append(errors, "Meilisearch API key must be set in production")
} else if containsDangerousPattern(cfg.Meilisearch.APIKey) {
errors = append(errors, "Meilisearch API key appears to be a placeholder/default value")
}
// Check database hosts are not using localhost
for _, host := range cfg.Database.Hosts {
if strings.Contains(strings.ToLower(host), "localhost") || host == "127.0.0.1" {
errors = append(errors, "Database hosts should not use localhost in production")
break
}
}
// Check cache host is not localhost
if strings.Contains(strings.ToLower(cfg.Cache.Host), "localhost") || cfg.Cache.Host == "127.0.0.1" {
errors = append(errors, "Cache host should not use localhost in production")
}
}
if len(errors) > 0 {
return fmt.Errorf("credential validation failed:\n - %s", strings.Join(errors, "\n - "))
}
return nil
}
// calculateShannonEntropy calculates the Shannon entropy of a string in bits per character
// Shannon entropy measures the randomness/unpredictability of data
// Formula: H(X) = -Σ(p(x) * log2(p(x))) where p(x) is the probability of character x
func calculateShannonEntropy(s string) float64 {
if len(s) == 0 {
return 0
}
// Count character frequencies
frequencies := make(map[rune]int)
for _, char := range s {
frequencies[char]++
}
// Calculate entropy
var entropy float64
length := float64(len(s))
for _, count := range frequencies {
probability := float64(count) / length
entropy -= probability * math.Log2(probability)
}
return entropy
}
// hasSufficientComplexity checks if the secret has a good mix of character types
// Requires at least 3 out of 4 character types for production
func hasSufficientComplexity(secret string) bool {
var (
hasUpper bool
hasLower bool
hasDigit bool
hasSpecial bool
)
for _, char := range secret {
switch {
case unicode.IsUpper(char):
hasUpper = true
case unicode.IsLower(char):
hasLower = true
case unicode.IsDigit(char):
hasDigit = true
default:
hasSpecial = true
}
}
// Require at least 3 out of 4 character types
count := 0
if hasUpper {
count++
}
if hasLower {
count++
}
if hasDigit {
count++
}
if hasSpecial {
count++
}
return count >= 3
}
// checkRepeatingPatterns checks for excessive repeating characters
func checkRepeatingPatterns(s string) error {
if len(s) < 2 {
return nil
}
repeatCount := 1
lastChar := rune(s[0])
for _, char := range s[1:] {
if char == lastChar {
repeatCount++
if repeatCount > MaxRepeatingCharacters {
return fmt.Errorf(
"contains %d consecutive repeating characters ('%c'), maximum allowed: %d",
repeatCount,
lastChar,
MaxRepeatingCharacters,
)
}
} else {
repeatCount = 1
lastChar = char
}
}
return nil
}
// hasSequentialPattern detects common sequential patterns
func hasSequentialPattern(s string) bool {
if len(s) < 4 {
return false
}
// Check for at least 4 consecutive sequential characters
for i := 0; i < len(s)-3; i++ {
// Check ascending sequence (e.g., "abcd", "1234")
if s[i+1] == s[i]+1 && s[i+2] == s[i]+2 && s[i+3] == s[i]+3 {
return true
}
// Check descending sequence (e.g., "dcba", "4321")
if s[i+1] == s[i]-1 && s[i+2] == s[i]-2 && s[i+3] == s[i]-3 {
return true
}
}
return false
}
// looksLikeBase64 checks if the string has base64-like characteristics
// Base64 uses: A-Z, a-z, 0-9, +, /, and = for padding
func looksLikeBase64(s string) bool {
if len(s) < MinJWTSecretLength {
return false
}
var (
hasUpper bool
hasLower bool
hasDigit bool
validChars int
)
// Base64 valid characters
for _, char := range s {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
validChars++
case char >= 'a' && char <= 'z':
hasLower = true
validChars++
case char >= '0' && char <= '9':
hasDigit = true
validChars++
case char == '+' || char == '/' || char == '=' || char == '-' || char == '_':
validChars++
default:
// Invalid character for base64
return false
}
}
// Should have good mix of character types typical of base64
charTypesCount := 0
if hasUpper {
charTypesCount++
}
if hasLower {
charTypesCount++
}
if hasDigit {
charTypesCount++
}
// Base64 typically has at least uppercase, lowercase, and digits
// Also check that it doesn't look like a repeated pattern
if charTypesCount < 3 {
return false
}
// Check for repeated patterns (e.g., "AbCd12!@" repeated)
// If the string has low unique character count relative to its length, it's probably not random
uniqueChars := make(map[rune]bool)
for _, char := range s {
uniqueChars[char] = true
}
// Random base64 should have at least 50% unique characters for strings over 32 chars
uniqueRatio := float64(len(uniqueChars)) / float64(len(s))
return uniqueRatio >= 0.4 // At least 40% unique characters
}
// containsDangerousPattern checks if a string contains any dangerous patterns
func containsDangerousPattern(value string) bool {
valueLower := strings.ToLower(value)
for _, pattern := range DangerousPatterns {
if strings.Contains(valueLower, pattern) {
return true
}
}
return false
}

View file

@ -0,0 +1,113 @@
package validator
import (
"testing"
)
// Simplified comprehensive test for JWT secret validation
func TestJWTSecretValidation(t *testing.T) {
validator := NewCredentialValidator()
// Good secrets - these should pass
goodSecrets := []struct {
name string
secret string
env string
}{
{
name: "Good 32-char for dev",
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
env: "development",
},
{
name: "Good 64-char for prod",
secret: "1WDduocStecRuIv+Us1t/RnYDoW1ZcEEbU+H+WykJG+IT5WnijzBb8uUPzGKju+D",
env: "production",
},
}
for _, tt := range goodSecrets {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateJWTSecret(tt.secret, tt.env)
if err != nil {
t.Errorf("Expected no error for valid secret, got: %v", err)
}
})
}
// Bad secrets - these should fail
badSecrets := []struct {
name string
secret string
env string
mustContain string
}{
{
name: "Too short",
secret: "short",
env: "development",
mustContain: "too short",
},
{
name: "Common weak - password",
secret: "password-is-not-secure-but-32char",
env: "development",
mustContain: "common weak value",
},
{
name: "Dangerous pattern",
secret: "please-change-this-ima7xR+9nT0Yz",
env: "development",
mustContain: "suspicious pattern",
},
{
name: "Repeating characters",
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
env: "development",
mustContain: "consecutive repeating characters",
},
{
name: "Sequential pattern",
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
env: "development",
mustContain: "sequential patterns",
},
{
name: "Low entropy",
secret: "abababababababababababababababab",
env: "development",
mustContain: "insufficient entropy",
},
{
name: "Prod too short",
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
env: "production",
mustContain: "too short for production",
},
}
for _, tt := range badSecrets {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateJWTSecret(tt.secret, tt.env)
if err == nil {
t.Errorf("Expected error containing '%s', got no error", tt.mustContain)
} else if !contains(err.Error(), tt.mustContain) {
t.Errorf("Expected error containing '%s', got: %v", tt.mustContain, err)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View file

@ -0,0 +1,535 @@
package validator
import (
"strings"
"testing"
)
func TestCalculateShannonEntropy(t *testing.T) {
tests := []struct {
name string
input string
minBits float64
maxBits float64
expected string
}{
{
name: "Empty string",
input: "",
minBits: 0,
maxBits: 0,
expected: "should have 0 entropy",
},
{
name: "All same character",
input: "aaaaaaaaaa",
minBits: 0,
maxBits: 0,
expected: "should have very low entropy",
},
{
name: "Low entropy - repeated pattern",
input: "abcabcabcabc",
minBits: 1.5,
maxBits: 2.0,
expected: "should have low entropy",
},
{
name: "Medium entropy - simple password",
input: "Password123",
minBits: 3.0,
maxBits: 4.5,
expected: "should have medium entropy",
},
{
name: "High entropy - random base64",
input: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
minBits: 4.0,
maxBits: 6.0,
expected: "should have high entropy",
},
{
name: "Very high entropy - long random base64",
input: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
minBits: 4.5,
maxBits: 6.5,
expected: "should have very high entropy",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
entropy := calculateShannonEntropy(tt.input)
if entropy < tt.minBits || entropy > tt.maxBits {
t.Errorf("%s: got %.2f bits/char, expected between %.1f and %.1f", tt.expected, entropy, tt.minBits, tt.maxBits)
}
})
}
}
func TestHasSufficientComplexity(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Empty string",
input: "",
expected: false,
},
{
name: "Only lowercase",
input: "abcdefghijklmnop",
expected: false,
},
{
name: "Only uppercase",
input: "ABCDEFGHIJKLMNOP",
expected: false,
},
{
name: "Only digits",
input: "1234567890",
expected: false,
},
{
name: "Lowercase + uppercase",
input: "AbCdEfGhIjKl",
expected: false,
},
{
name: "Lowercase + digits",
input: "abc123def456",
expected: false,
},
{
name: "Uppercase + digits",
input: "ABC123DEF456",
expected: false,
},
{
name: "Lowercase + uppercase + digits",
input: "Abc123Def456",
expected: true,
},
{
name: "Lowercase + uppercase + special",
input: "AbC+DeF/GhI=",
expected: true,
},
{
name: "Lowercase + digits + special",
input: "abc123+def456/",
expected: true,
},
{
name: "All four types",
input: "Abc123+Def456/",
expected: true,
},
{
name: "Base64 string",
input: "K8vN2mP9sQ4tR7wY3zA6b+xK8vN2mP9sQ4tR7wY3zA6b=",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasSufficientComplexity(tt.input)
if result != tt.expected {
t.Errorf("hasSufficientComplexity(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestCheckRepeatingPatterns(t *testing.T) {
tests := []struct {
name string
input string
shouldErr bool
}{
{
name: "Empty string",
input: "",
shouldErr: false,
},
{
name: "Single character",
input: "a",
shouldErr: false,
},
{
name: "No repeating",
input: "abcdefgh",
shouldErr: false,
},
{
name: "Two repeating (ok)",
input: "aabcdeef",
shouldErr: false,
},
{
name: "Three repeating (ok)",
input: "aaabcdeee",
shouldErr: false,
},
{
name: "Four repeating (error)",
input: "aaaabcde",
shouldErr: true,
},
{
name: "Five repeating (error)",
input: "aaaaabcde",
shouldErr: true,
},
{
name: "Multiple groups of three (ok)",
input: "aaabbbccc",
shouldErr: false,
},
{
name: "Repeating in middle (error)",
input: "abcdddddef",
shouldErr: true,
},
{
name: "Repeating at end (error)",
input: "abcdefgggg",
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkRepeatingPatterns(tt.input)
if (err != nil) != tt.shouldErr {
t.Errorf("checkRepeatingPatterns(%q) error = %v, shouldErr = %v", tt.input, err, tt.shouldErr)
}
})
}
}
func TestHasSequentialPattern(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Empty string",
input: "",
expected: false,
},
{
name: "Too short",
input: "abc",
expected: false,
},
{
name: "No sequential",
input: "acegikmo",
expected: false,
},
{
name: "Ascending sequence - abcd",
input: "xyzabcdefg",
expected: true,
},
{
name: "Descending sequence - dcba",
input: "xyzdcbafg",
expected: true,
},
{
name: "Ascending digits - 1234",
input: "abc1234def",
expected: true,
},
{
name: "Descending digits - 4321",
input: "abc4321def",
expected: true,
},
{
name: "Random characters",
input: "xK8vN2mP9sQ4",
expected: false,
},
{
name: "Base64-like",
input: "K8vN2mP9sQ4tR7wY3zA6b",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := hasSequentialPattern(tt.input)
if result != tt.expected {
t.Errorf("hasSequentialPattern(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestLooksLikeBase64(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "Empty string",
input: "",
expected: false,
},
{
name: "Too short",
input: "abc",
expected: false,
},
{
name: "Only lowercase",
input: "abcdefghijklmnopqrstuvwxyzabcdef",
expected: false,
},
{
name: "Real base64",
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b=",
expected: true,
},
{
name: "Base64 without padding",
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
expected: true,
},
{
name: "Base64 with URL-safe chars",
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b-_",
expected: true,
},
{
name: "Generated secret",
input: "xK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
expected: true,
},
{
name: "Simple password",
input: "Password123!Password123!Password123!",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := looksLikeBase64(tt.input)
if result != tt.expected {
t.Errorf("looksLikeBase64(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestValidateJWTSecret(t *testing.T) {
validator := NewCredentialValidator()
tests := []struct {
name string
secret string
environment string
shouldErr bool
errContains string
}{
{
name: "Too short - 20 chars",
secret: "12345678901234567890",
environment: "development",
shouldErr: true,
errContains: "too short",
},
{
name: "Minimum length - 32 chars (acceptable for dev)",
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
environment: "development",
shouldErr: false,
},
{
name: "Common weak secret - contains password",
secret: "my-password-is-secure-123456789012",
environment: "development",
shouldErr: true,
errContains: "common weak value",
},
{
name: "Common weak secret - secret",
secret: "secretsecretsecretsecretsecretsec",
environment: "development",
shouldErr: true,
errContains: "common weak value",
},
{
name: "Common weak secret - contains 12345",
secret: "abcd12345efghijklmnopqrstuvwxyz",
environment: "development",
shouldErr: true,
errContains: "common weak value",
},
{
name: "Dangerous pattern - change",
secret: "please-change-this-j8EJm9ZKnuTYxcVK",
environment: "development",
shouldErr: true,
errContains: "suspicious pattern",
},
{
name: "Dangerous pattern - sample",
secret: "sample-secret-j8EJm9ZKnuTYxcVKQ",
environment: "development",
shouldErr: true,
errContains: "suspicious pattern",
},
{
name: "Repeating characters",
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
environment: "development",
shouldErr: true,
errContains: "consecutive repeating characters",
},
{
name: "Sequential pattern - abcd",
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
environment: "development",
shouldErr: true,
errContains: "sequential patterns",
},
{
name: "Sequential pattern - 1234",
secret: "12345678901234567890123456789012",
environment: "development",
shouldErr: true,
errContains: "sequential patterns",
},
{
name: "Low entropy secret",
secret: "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpP",
environment: "development",
shouldErr: true,
errContains: "insufficient entropy",
},
{
name: "Good secret - base64 style (dev)",
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
environment: "development",
shouldErr: false,
},
{
name: "Good secret - longer (dev)",
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
environment: "development",
shouldErr: false,
},
{
name: "Production - too short (32 chars)",
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
environment: "production",
shouldErr: true,
errContains: "too short for production",
},
{
name: "Production - insufficient complexity",
secret: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01",
environment: "production",
shouldErr: true,
errContains: "insufficient complexity",
},
{
name: "Production - low entropy pattern",
secret: strings.Repeat("AbCd12!@", 8), // 64 chars but repetitive
environment: "production",
shouldErr: true,
errContains: "insufficient entropy",
},
{
name: "Production - good secret",
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
environment: "production",
shouldErr: false,
},
{
name: "Production - excellent secret with padding",
secret: "7mK2nP8sR4wT6xZ3bA5cxK7mN1oQ9uS4vY2zA6bxK7mN1oQ9uS4vY2zA6b+W0E=",
environment: "production",
shouldErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateJWTSecret(tt.secret, tt.environment)
if tt.shouldErr {
if err == nil {
t.Errorf("ValidateJWTSecret() expected error containing %q, got no error", tt.errContains)
} else if !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("ValidateJWTSecret() error = %q, should contain %q", err.Error(), tt.errContains)
}
} else {
if err != nil {
t.Errorf("ValidateJWTSecret() unexpected error: %v", err)
}
}
})
}
}
func TestValidateJWTSecret_EdgeCases(t *testing.T) {
validator := NewCredentialValidator()
t.Run("Secret with mixed weak patterns", func(t *testing.T) {
secret := "password123admin" // Contains multiple weak patterns
err := validator.ValidateJWTSecret(secret, "development")
if err == nil {
t.Error("Expected error for secret containing weak patterns, got nil")
}
})
t.Run("Secret exactly at minimum length", func(t *testing.T) {
// 32 characters exactly
secret := "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx"
err := validator.ValidateJWTSecret(secret, "development")
if err != nil {
t.Errorf("Expected no error for 32-char secret in development, got: %v", err)
}
})
t.Run("Secret exactly at recommended length", func(t *testing.T) {
// 64 characters exactly - using real random base64
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFir"
err := validator.ValidateJWTSecret(secret, "production")
if err != nil {
t.Errorf("Expected no error for 64-char secret in production, got: %v", err)
}
})
}
// Benchmark tests to ensure validation is performant
func BenchmarkCalculateShannonEntropy(b *testing.B) {
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
b.ResetTimer()
for i := 0; i < b.N; i++ {
calculateShannonEntropy(secret)
}
}
func BenchmarkValidateJWTSecret(b *testing.B) {
validator := NewCredentialValidator()
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = validator.ValidateJWTSecret(secret, "production")
}
}

View file

@ -0,0 +1,6 @@
package validator
// ProvideCredentialValidator provides a credential validator for dependency injection
func ProvideCredentialValidator() CredentialValidator {
return NewCredentialValidator()
}

View file

@ -0,0 +1,33 @@
package cache
import (
"context"
"fmt"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// ProvideRedisClient creates a new Redis client
func ProvideRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
logger.Info("connecting to Redis",
zap.String("host", cfg.Cache.Host),
zap.Int("port", cfg.Cache.Port))
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Host, cfg.Cache.Port),
Password: cfg.Cache.Password,
DB: cfg.Cache.DB,
})
// Test connection
ctx := context.Background()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
logger.Info("successfully connected to Redis")
return client, nil
}

View file

@ -0,0 +1,121 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/storage/database/cassandra/cassandra.go
package database
import (
"fmt"
"strings"
"time"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
"github.com/gocql/gocql"
"go.uber.org/zap"
)
// gocqlLogger wraps zap logger to filter out noisy gocql warnings
type gocqlLogger struct {
logger *zap.Logger
}
// Print implements gocql's Logger interface
func (l *gocqlLogger) Print(v ...interface{}) {
msg := fmt.Sprint(v...)
// Filter out noisy "invalid peer" warnings from Cassandra gossip
// These are harmless and occur due to Docker networking
if strings.Contains(msg, "Found invalid peer") {
return
}
// Log other messages at debug level
l.logger.Debug(msg)
}
// Printf implements gocql's Logger interface
func (l *gocqlLogger) Printf(format string, v ...interface{}) {
msg := fmt.Sprintf(format, v...)
// Filter out noisy "invalid peer" warnings from Cassandra gossip
if strings.Contains(msg, "Found invalid peer") {
return
}
// Log other messages at debug level
l.logger.Debug(msg)
}
// Println implements gocql's Logger interface
func (l *gocqlLogger) Println(v ...interface{}) {
msg := fmt.Sprintln(v...)
// Filter out noisy "invalid peer" warnings from Cassandra gossip
if strings.Contains(msg, "Found invalid peer") {
return
}
// Log other messages at debug level
l.logger.Debug(msg)
}
// ProvideCassandraSession creates a new Cassandra session
func ProvideCassandraSession(cfg *config.Config, logger *zap.Logger) (*gocql.Session, error) {
logger.Info("⏳ Connecting to Cassandra...",
zap.Strings("hosts", cfg.Database.Hosts),
zap.String("keyspace", cfg.Database.Keyspace))
// Create cluster configuration
cluster := gocql.NewCluster(cfg.Database.Hosts...)
cluster.Keyspace = cfg.Database.Keyspace
cluster.Consistency = parseConsistency(cfg.Database.Consistency)
cluster.ProtoVersion = 4
cluster.ConnectTimeout = 10 * time.Second
cluster.Timeout = 10 * time.Second
cluster.NumConns = 2
// Set custom logger to filter out noisy warnings
cluster.Logger = &gocqlLogger{logger: logger.Named("gocql")}
// Retry policy
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
NumRetries: 3,
Min: 1 * time.Second,
Max: 10 * time.Second,
}
// Create session
session, err := cluster.CreateSession()
if err != nil {
return nil, fmt.Errorf("failed to connect to Cassandra: %w", err)
}
logger.Info("✓ Cassandra connected",
zap.String("consistency", cfg.Database.Consistency),
zap.Int("connections", cluster.NumConns))
return session, nil
}
// parseConsistency converts string consistency level to gocql.Consistency
func parseConsistency(consistency string) gocql.Consistency {
switch consistency {
case "ANY":
return gocql.Any
case "ONE":
return gocql.One
case "TWO":
return gocql.Two
case "THREE":
return gocql.Three
case "QUORUM":
return gocql.Quorum
case "ALL":
return gocql.All
case "LOCAL_QUORUM":
return gocql.LocalQuorum
case "EACH_QUORUM":
return gocql.EachQuorum
case "LOCAL_ONE":
return gocql.LocalOne
default:
return gocql.Quorum // Default to QUORUM
}
}

View file

@ -0,0 +1,199 @@
package database
import (
"fmt"
"github.com/gocql/gocql"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/cassandra"
_ "github.com/golang-migrate/migrate/v4/source/file"
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// silentGocqlLogger filters out noisy "invalid peer" warnings from gocql
type silentGocqlLogger struct{}
func (l *silentGocqlLogger) Print(v ...interface{}) {
// Silently discard all gocql logs including "invalid peer" warnings
}
func (l *silentGocqlLogger) Printf(format string, v ...interface{}) {
// Silently discard all gocql logs including "invalid peer" warnings
}
func (l *silentGocqlLogger) Println(v ...interface{}) {
// Silently discard all gocql logs including "invalid peer" warnings
}
// Migrator handles database schema migrations
// This encapsulates all migration logic and makes it testable
type Migrator struct {
config *config.Config
logger *zap.Logger
}
// NewMigrator creates a new migration manager
func NewMigrator(cfg *config.Config, logger *zap.Logger) *Migrator {
if logger == nil {
// Create a no-op logger if none provided (for backward compatibility)
logger = zap.NewNop()
}
return &Migrator{
config: cfg,
logger: logger,
}
}
// Up runs all pending migrations with dirty state recovery
func (m *Migrator) Up() error {
// Ensure keyspace exists before running migrations
m.logger.Debug("Ensuring keyspace exists...")
if err := m.ensureKeyspaceExists(); err != nil {
return fmt.Errorf("failed to ensure keyspace exists: %w", err)
}
m.logger.Debug("Creating migrator...")
migrateInstance, err := m.createMigrate()
if err != nil {
return fmt.Errorf("failed to create migrator: %w", err)
}
defer migrateInstance.Close()
m.logger.Debug("Checking migration version...")
version, dirty, err := migrateInstance.Version()
if err != nil && err != migrate.ErrNilVersion {
return fmt.Errorf("failed to get migration version: %w", err)
}
if dirty {
m.logger.Warn("Database is in dirty state, attempting to force clean state",
zap.Uint("version", uint(version)))
if err := migrateInstance.Force(int(version)); err != nil {
return fmt.Errorf("failed to force clean migration state: %w", err)
}
}
// Run migrations
if err := migrateInstance.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to run migrations: %w", err)
}
// Get final version
finalVersion, _, err := migrateInstance.Version()
if err != nil && err != migrate.ErrNilVersion {
m.logger.Warn("Could not get final migration version", zap.Error(err))
} else if err != migrate.ErrNilVersion {
m.logger.Debug("Database migrations completed successfully",
zap.Uint("version", uint(finalVersion)))
} else {
m.logger.Debug("Database migrations completed successfully (no migrations applied)")
}
return nil
}
// Down rolls back the last migration
// Useful for development and rollback scenarios
func (m *Migrator) Down() error {
migrateInstance, err := m.createMigrate()
if err != nil {
return fmt.Errorf("failed to create migrator: %w", err)
}
defer migrateInstance.Close()
if err := migrateInstance.Steps(-1); err != nil {
return fmt.Errorf("failed to rollback migration: %w", err)
}
return nil
}
// Version returns the current migration version
func (m *Migrator) Version() (uint, bool, error) {
migrateInstance, err := m.createMigrate()
if err != nil {
return 0, false, fmt.Errorf("failed to create migrator: %w", err)
}
defer migrateInstance.Close()
return migrateInstance.Version()
}
// ForceVersion forces the migration version (useful for fixing dirty states)
func (m *Migrator) ForceVersion(version int) error {
migrateInstance, err := m.createMigrate()
if err != nil {
return fmt.Errorf("failed to create migrator: %w", err)
}
defer migrateInstance.Close()
if err := migrateInstance.Force(version); err != nil {
return fmt.Errorf("failed to force version %d: %w", version, err)
}
m.logger.Info("Successfully forced migration version", zap.Int("version", version))
return nil
}
// createMigrate creates a migrate instance with proper configuration
func (m *Migrator) createMigrate() (*migrate.Migrate, error) {
// Set global gocql logger to suppress "invalid peer" warnings
// This affects the internal gocql connections used by golang-migrate
gocql.Logger = &silentGocqlLogger{}
// Build Cassandra connection string
// Format: cassandra://host:port/keyspace?consistency=level
databaseURL := fmt.Sprintf("cassandra://%s/%s?consistency=%s",
m.config.Database.Hosts[0], // Use first host for migrations
m.config.Database.Keyspace,
m.config.Database.Consistency,
)
// Create migrate instance
migrateInstance, err := migrate.New(m.config.Database.MigrationsPath, databaseURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize migrate: %w", err)
}
return migrateInstance, nil
}
// ensureKeyspaceExists creates the keyspace if it doesn't exist
// This must be done before running migrations since golang-migrate requires the keyspace to exist
func (m *Migrator) ensureKeyspaceExists() error {
// Create cluster configuration without keyspace
cluster := gocql.NewCluster(m.config.Database.Hosts...)
cluster.Port = 9042
cluster.Consistency = gocql.Quorum
cluster.ProtoVersion = 4
// Suppress noisy "invalid peer" warnings from gocql
// Use a minimal logger that discards these harmless Docker networking warnings
cluster.Logger = &silentGocqlLogger{}
// Create session to system keyspace
session, err := cluster.CreateSession()
if err != nil {
return fmt.Errorf("failed to connect to Cassandra: %w", err)
}
defer session.Close()
// Create keyspace if it doesn't exist
replicationFactor := m.config.Database.Replication
createKeyspaceQuery := fmt.Sprintf(`
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = {'class': 'NetworkTopologyStrategy', 'datacenter1': %d}
AND durable_writes = true
`, m.config.Database.Keyspace, replicationFactor)
m.logger.Debug("Creating keyspace if it doesn't exist",
zap.String("keyspace", m.config.Database.Keyspace))
if err := session.Query(createKeyspaceQuery).Exec(); err != nil {
return fmt.Errorf("failed to create keyspace: %w", err)
}
m.logger.Debug("Keyspace is ready", zap.String("keyspace", m.config.Database.Keyspace))
return nil
}

View file

@ -0,0 +1,54 @@
package s3
type S3ObjectStorageConfigurationProvider interface {
GetAccessKey() string
GetSecretKey() string
GetEndpoint() string
GetRegion() string
GetBucketName() string
GetIsPublicBucket() bool
}
type s3ObjectStorageConfigurationProviderImpl struct {
accessKey string
secretKey string
endpoint string
region string
bucketName string
isPublicBucket bool
}
func NewS3ObjectStorageConfigurationProvider(accessKey, secretKey, endpoint, region, bucketName string, isPublicBucket bool) S3ObjectStorageConfigurationProvider {
return &s3ObjectStorageConfigurationProviderImpl{
accessKey: accessKey,
secretKey: secretKey,
endpoint: endpoint,
region: region,
bucketName: bucketName,
isPublicBucket: isPublicBucket,
}
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetAccessKey() string {
return s.accessKey
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetSecretKey() string {
return s.secretKey
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetEndpoint() string {
return s.endpoint
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetRegion() string {
return s.region
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetBucketName() string {
return s.bucketName
}
func (s *s3ObjectStorageConfigurationProviderImpl) GetIsPublicBucket() bool {
return s.isPublicBucket
}

View file

@ -0,0 +1,23 @@
package s3
import (
"go.uber.org/zap"
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
)
// ProvideS3ObjectStorage provides an S3 object storage instance
func ProvideS3ObjectStorage(cfg *config.Config, logger *zap.Logger) S3ObjectStorage {
// Create configuration provider
configProvider := NewS3ObjectStorageConfigurationProvider(
cfg.AWS.AccessKey,
cfg.AWS.SecretKey,
cfg.AWS.Endpoint,
cfg.AWS.Region,
cfg.AWS.BucketName,
false, // Default to private bucket
)
// Return new S3 storage instance
return NewObjectStorage(configProvider, logger)
}

View file

@ -0,0 +1,508 @@
// monorepo/cloud/maplefileapps-backend/pkg/storage/object/s3/s3.go
package s3
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
"go.uber.org/zap"
)
// ACL constants for public and private objects
const (
ACLPrivate = "private"
ACLPublicRead = "public-read"
)
type S3ObjectStorage interface {
UploadContent(ctx context.Context, objectKey string, content []byte) error
UploadContentWithVisibility(ctx context.Context, objectKey string, content []byte, isPublic bool) error
UploadContentFromMulipart(ctx context.Context, objectKey string, file multipart.File) error
UploadContentFromMulipartWithVisibility(ctx context.Context, objectKey string, file multipart.File, isPublic bool) error
BucketExists(ctx context.Context, bucketName string) (bool, error)
DeleteByKeys(ctx context.Context, key []string) error
Cut(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error
CutWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error
Copy(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error
CopyWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error
GetBinaryData(ctx context.Context, objectKey string) (io.ReadCloser, error)
DownloadToLocalfile(ctx context.Context, objectKey string, filePath string) (string, error)
ListAllObjects(ctx context.Context) (*s3.ListObjectsOutput, error)
FindMatchingObjectKey(s3Objects *s3.ListObjectsOutput, partialKey string) string
IsPublicBucket() bool
// GeneratePresignedUploadURL creates a presigned URL for uploading objects
GeneratePresignedUploadURL(ctx context.Context, key string, duration time.Duration) (string, error)
GetDownloadablePresignedURL(ctx context.Context, key string, duration time.Duration) (string, error)
ObjectExists(ctx context.Context, key string) (bool, error)
GetObjectSize(ctx context.Context, key string) (int64, error)
}
type s3ObjectStorage struct {
S3Client *s3.Client
PresignClient *s3.PresignClient
Logger *zap.Logger
BucketName string
IsPublic bool
}
// NewObjectStorage connects to a specific S3 bucket instance and returns a connected
// instance structure.
func NewObjectStorage(s3Config S3ObjectStorageConfigurationProvider, logger *zap.Logger) S3ObjectStorage {
logger = logger.Named("s3-object-storage")
// DEVELOPERS NOTE:
// How can I use the AWS SDK v2 for Go with DigitalOcean Spaces? via https://stackoverflow.com/a/74284205
logger.Info("⏳ Connecting to S3-compatible storage...",
zap.String("endpoint", s3Config.GetEndpoint()),
zap.String("bucket", s3Config.GetBucketName()),
zap.String("region", s3Config.GetRegion()))
// STEP 1: initialize the custom `endpoint` we will connect to.
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...any) (aws.Endpoint, error) {
return aws.Endpoint{
URL: s3Config.GetEndpoint(),
}, nil
})
// STEP 2: Configure.
sdkConfig, err := config.LoadDefaultConfig(
context.TODO(), config.WithRegion(s3Config.GetRegion()),
config.WithEndpointResolverWithOptions(customResolver),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(s3Config.GetAccessKey(), s3Config.GetSecretKey(), "")),
)
if err != nil {
logger.Fatal("S3ObjectStorage failed loading default config", zap.Error(err)) // We need to crash the program at start to satisfy google wire requirement of having no errors.
}
// STEP 3\: Load up s3 instance.
s3Client := s3.NewFromConfig(sdkConfig)
// Create our storage handler.
s3Storage := &s3ObjectStorage{
S3Client: s3Client,
PresignClient: s3.NewPresignClient(s3Client),
Logger: logger,
BucketName: s3Config.GetBucketName(),
IsPublic: s3Config.GetIsPublicBucket(),
}
logger.Debug("Verifying bucket exists...")
// STEP 4: Connect to the s3 bucket instance and confirm that bucket exists.
doesExist, err := s3Storage.BucketExists(context.TODO(), s3Config.GetBucketName())
if err != nil {
logger.Fatal("S3ObjectStorage failed checking if bucket exists",
zap.String("bucket", s3Config.GetBucketName()),
zap.Error(err)) // We need to crash the program at start to satisfy google wire requirement of having no errors.
}
if !doesExist {
logger.Fatal("S3ObjectStorage failed - bucket does not exist",
zap.String("bucket", s3Config.GetBucketName())) // We need to crash the program at start to satisfy google wire requirement of having no errors.
}
logger.Info("✓ S3-compatible storage connected",
zap.String("bucket", s3Config.GetBucketName()),
zap.Bool("public", s3Config.GetIsPublicBucket()))
// Return our s3 storage handler.
return s3Storage
}
// IsPublicBucket returns whether the bucket is configured as public by default
func (s *s3ObjectStorage) IsPublicBucket() bool {
return s.IsPublic
}
// UploadContent uploads content using the default bucket visibility setting
func (s *s3ObjectStorage) UploadContent(ctx context.Context, objectKey string, content []byte) error {
return s.UploadContentWithVisibility(ctx, objectKey, content, s.IsPublic)
}
// UploadContentWithVisibility uploads content with specified visibility (public or private)
func (s *s3ObjectStorage) UploadContentWithVisibility(ctx context.Context, objectKey string, content []byte, isPublic bool) error {
acl := ACLPrivate
if isPublic {
acl = ACLPublicRead
}
s.Logger.Debug("Uploading content with visibility",
zap.String("objectKey", objectKey),
zap.Bool("isPublic", isPublic),
zap.String("acl", acl))
_, err := s.S3Client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(objectKey),
Body: bytes.NewReader(content),
ACL: types.ObjectCannedACL(acl),
})
if err != nil {
s.Logger.Error("Failed to upload content",
zap.String("objectKey", objectKey),
zap.Bool("isPublic", isPublic),
zap.Any("error", err))
return err
}
return nil
}
// UploadContentFromMulipart uploads file using the default bucket visibility setting
func (s *s3ObjectStorage) UploadContentFromMulipart(ctx context.Context, objectKey string, file multipart.File) error {
return s.UploadContentFromMulipartWithVisibility(ctx, objectKey, file, s.IsPublic)
}
// UploadContentFromMulipartWithVisibility uploads a multipart file with specified visibility
func (s *s3ObjectStorage) UploadContentFromMulipartWithVisibility(ctx context.Context, objectKey string, file multipart.File, isPublic bool) error {
acl := ACLPrivate
if isPublic {
acl = ACLPublicRead
}
s.Logger.Debug("Uploading multipart file with visibility",
zap.String("objectKey", objectKey),
zap.Bool("isPublic", isPublic),
zap.String("acl", acl))
// Create the S3 upload input parameters
params := &s3.PutObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(objectKey),
Body: file,
ACL: types.ObjectCannedACL(acl),
}
// Perform the file upload to S3
_, err := s.S3Client.PutObject(ctx, params)
if err != nil {
s.Logger.Error("Failed to upload multipart file",
zap.String("objectKey", objectKey),
zap.Bool("isPublic", isPublic),
zap.Any("error", err))
return err
}
return nil
}
func (s *s3ObjectStorage) BucketExists(ctx context.Context, bucketName string) (bool, error) {
// Note: https://docs.aws.amazon.com/code-library/latest/ug/go_2_s3_code_examples.html#actions
_, err := s.S3Client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(bucketName),
})
exists := true
if err != nil {
var apiError smithy.APIError
if errors.As(err, &apiError) {
switch apiError.(type) {
case *types.NotFound:
s.Logger.Debug("Bucket is available", zap.String("bucket", bucketName))
exists = false
err = nil
default:
s.Logger.Error("Either you don't have access to bucket or another error occurred",
zap.String("bucket", bucketName),
zap.Error(err))
}
}
}
return exists, err
}
func (s *s3ObjectStorage) GetDownloadablePresignedURL(ctx context.Context, key string, duration time.Duration) (string, error) {
// DEVELOPERS NOTE:
// AWS S3 Bucket — presigned URL APIs with Go (2022) via https://ronen-niv.medium.com/aws-s3-handling-presigned-urls-2718ab247d57
presignedUrl, err := s.PresignClient.PresignGetObject(context.Background(),
&s3.GetObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(key),
ResponseContentDisposition: aws.String("attachment"), // This field allows the file to download it directly from your browser
},
s3.WithPresignExpires(duration))
if err != nil {
return "", err
}
return presignedUrl.URL, nil
}
func (s *s3ObjectStorage) DeleteByKeys(ctx context.Context, objectKeys []string) error {
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
var objectIds []types.ObjectIdentifier
for _, key := range objectKeys {
objectIds = append(objectIds, types.ObjectIdentifier{Key: aws.String(key)})
}
_, err := s.S3Client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
Bucket: aws.String(s.BucketName),
Delete: &types.Delete{Objects: objectIds},
})
if err != nil {
s.Logger.Error("Couldn't delete objects from bucket",
zap.String("bucket", s.BucketName),
zap.Error(err))
}
return err
}
// Cut moves a file using the default bucket visibility setting
func (s *s3ObjectStorage) Cut(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error {
return s.CutWithVisibility(ctx, sourceObjectKey, destinationObjectKey, s.IsPublic)
}
// CutWithVisibility moves a file with specified visibility
func (s *s3ObjectStorage) CutWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error {
ctx, cancel := context.WithTimeout(ctx, 60*time.Second) // Increase timout so it runs longer then usual to handle this unique case.
defer cancel()
// First copy the object with the desired visibility
if err := s.CopyWithVisibility(ctx, sourceObjectKey, destinationObjectKey, isPublic); err != nil {
return err
}
// Delete the original object
_, deleteErr := s.S3Client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(sourceObjectKey),
})
if deleteErr != nil {
s.Logger.Error("Failed to delete original object:", zap.Any("deleteErr", deleteErr))
return deleteErr
}
s.Logger.Debug("Original object deleted.")
return nil
}
// Copy copies a file using the default bucket visibility setting
func (s *s3ObjectStorage) Copy(ctx context.Context, sourceObjectKey string, destinationObjectKey string) error {
return s.CopyWithVisibility(ctx, sourceObjectKey, destinationObjectKey, s.IsPublic)
}
// CopyWithVisibility copies a file with specified visibility
func (s *s3ObjectStorage) CopyWithVisibility(ctx context.Context, sourceObjectKey string, destinationObjectKey string, isPublic bool) error {
ctx, cancel := context.WithTimeout(ctx, 60*time.Second) // Increase timout so it runs longer then usual to handle this unique case.
defer cancel()
acl := ACLPrivate
if isPublic {
acl = ACLPublicRead
}
s.Logger.Debug("Copying object with visibility",
zap.String("sourceKey", sourceObjectKey),
zap.String("destinationKey", destinationObjectKey),
zap.Bool("isPublic", isPublic),
zap.String("acl", acl))
_, copyErr := s.S3Client.CopyObject(ctx, &s3.CopyObjectInput{
Bucket: aws.String(s.BucketName),
CopySource: aws.String(s.BucketName + "/" + sourceObjectKey),
Key: aws.String(destinationObjectKey),
ACL: types.ObjectCannedACL(acl),
})
if copyErr != nil {
s.Logger.Error("Failed to copy object:",
zap.String("sourceKey", sourceObjectKey),
zap.String("destinationKey", destinationObjectKey),
zap.Bool("isPublic", isPublic),
zap.Any("copyErr", copyErr))
return copyErr
}
s.Logger.Debug("Object copied successfully.")
return nil
}
// GetBinaryData function will return the binary data for the particular key.
func (s *s3ObjectStorage) GetBinaryData(ctx context.Context, objectKey string) (io.ReadCloser, error) {
input := &s3.GetObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(objectKey),
}
s3object, err := s.S3Client.GetObject(ctx, input)
if err != nil {
return nil, err
}
return s3object.Body, nil
}
func (s *s3ObjectStorage) DownloadToLocalfile(ctx context.Context, objectKey string, filePath string) (string, error) {
responseBin, err := s.GetBinaryData(ctx, objectKey)
if err != nil {
return filePath, err
}
out, err := os.Create(filePath)
if err != nil {
return filePath, err
}
defer out.Close()
_, err = io.Copy(out, responseBin)
if err != nil {
return "", err
}
return filePath, err
}
func (s *s3ObjectStorage) ListAllObjects(ctx context.Context) (*s3.ListObjectsOutput, error) {
input := &s3.ListObjectsInput{
Bucket: aws.String(s.BucketName),
}
objects, err := s.S3Client.ListObjects(ctx, input)
if err != nil {
return nil, err
}
return objects, nil
}
// Function will iterate over all the s3 objects to match the partial key with
// the actual key found in the S3 bucket.
func (s *s3ObjectStorage) FindMatchingObjectKey(s3Objects *s3.ListObjectsOutput, partialKey string) string {
for _, obj := range s3Objects.Contents {
match := strings.Contains(*obj.Key, partialKey)
// If a match happens then it means we have found the ACTUAL KEY in the
// s3 objects inside the bucket.
if match == true {
return *obj.Key
}
}
return ""
}
// GeneratePresignedUploadURL creates a presigned URL for uploading objects to S3
func (s *s3ObjectStorage) GeneratePresignedUploadURL(ctx context.Context, key string, duration time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Create PutObjectInput without ACL to avoid requiring x-amz-acl header
putObjectInput := &s3.PutObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(key),
// Removed ACL field - files inherit bucket's default privacy settings.
}
presignedUrl, err := s.PresignClient.PresignPutObject(ctx, putObjectInput, s3.WithPresignExpires(duration))
if err != nil {
s.Logger.Error("Failed to generate presigned upload URL",
zap.String("key", key),
zap.Duration("duration", duration),
zap.Error(err))
return "", err
}
s.Logger.Debug("Generated presigned upload URL",
zap.String("key", key),
zap.Duration("duration", duration))
return presignedUrl.URL, nil
}
// ObjectExists checks if an object exists at the given key using HeadObject
func (s *s3ObjectStorage) ObjectExists(ctx context.Context, key string) (bool, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
_, err := s.S3Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(key),
})
if err != nil {
var apiError smithy.APIError
if errors.As(err, &apiError) {
switch apiError.(type) {
case *types.NotFound:
// Object doesn't exist
s.Logger.Debug("Object does not exist",
zap.String("key", key))
return false, nil
case *types.NoSuchKey:
// Object doesn't exist
s.Logger.Debug("Object does not exist (NoSuchKey)",
zap.String("key", key))
return false, nil
default:
// Some other error occurred
s.Logger.Error("Error checking object existence",
zap.String("key", key),
zap.Error(err))
return false, err
}
}
// Non-API error
s.Logger.Error("Error checking object existence",
zap.String("key", key),
zap.Error(err))
return false, err
}
s.Logger.Debug("Object exists",
zap.String("key", key))
return true, nil
}
// GetObjectSize returns the size of an object at the given key using HeadObject
func (s *s3ObjectStorage) GetObjectSize(ctx context.Context, key string) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
result, err := s.S3Client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(s.BucketName),
Key: aws.String(key),
})
if err != nil {
var apiError smithy.APIError
if errors.As(err, &apiError) {
switch apiError.(type) {
case *types.NotFound:
s.Logger.Debug("Object not found when getting size",
zap.String("key", key))
return 0, errors.New("object not found")
case *types.NoSuchKey:
s.Logger.Debug("Object not found when getting size (NoSuchKey)",
zap.String("key", key))
return 0, errors.New("object not found")
default:
s.Logger.Error("Error getting object size",
zap.String("key", key),
zap.Error(err))
return 0, err
}
}
s.Logger.Error("Error getting object size",
zap.String("key", key),
zap.Error(err))
return 0, err
}
// Let's use aws.ToInt64 which handles both pointer and non-pointer cases
size := aws.ToInt64(result.ContentLength)
s.Logger.Debug("Retrieved object size",
zap.String("key", key),
zap.Int64("size", size))
return size, nil
}

View file

@ -0,0 +1,516 @@
package transaction
import (
"context"
"go.uber.org/zap"
)
// Package transaction provides a SAGA pattern implementation for managing distributed transactions.
//
// # What is SAGA Pattern?
//
// SAGA is a pattern for managing distributed transactions through a sequence of local transactions,
// each with a corresponding compensating transaction that undoes its effects if a later step fails.
//
// # When to Use SAGA
//
// Use SAGA when you have multiple database operations that need to succeed or fail together,
// but you can't use traditional ACID transactions (e.g., with Cassandra, distributed services,
// or operations across multiple bounded contexts).
//
// # Key Concepts
//
// - Forward Transaction: A database write operation (e.g., CreateTenant)
// - Compensating Transaction: An undo operation (e.g., DeleteTenant)
// - LIFO Execution: Compensations execute in reverse order (Last In, First Out)
//
// # Example Usage: User Registration Flow
//
// Problem: When registering a user, we create a tenant, then create a user.
// If user creation fails, the tenant becomes orphaned in the database.
//
// Solution: Use SAGA to automatically delete the tenant if user creation fails.
//
// func (s *RegisterService) Register(ctx context.Context, input *RegisterInput) (*RegisterResponse, error) {
// // Step 1: Create SAGA instance
// saga := transaction.NewSaga("user-registration", s.logger)
//
// // Step 2: Validate input (no DB writes, no compensation needed)
// if err := s.validateInputUC.Execute(input); err != nil {
// return nil, err
// }
//
// // Step 3: Create tenant (FIRST DB WRITE - register compensation)
// tenantOutput, err := s.createTenantUC.Execute(ctx, input)
// if err != nil {
// return nil, err // No rollback needed - tenant creation failed
// }
//
// // Register compensation: if anything fails later, delete this tenant
// saga.AddCompensation(func(ctx context.Context) error {
// s.logger.Warn("compensating: deleting tenant",
// zap.String("tenant_id", tenantOutput.ID))
// return s.deleteTenantUC.Execute(ctx, tenantOutput.ID)
// })
//
// // Step 4: Create user (SECOND DB WRITE)
// userOutput, err := s.createUserUC.Execute(ctx, tenantOutput.ID, input)
// if err != nil {
// s.logger.Error("user creation failed - rolling back tenant",
// zap.Error(err))
//
// // Execute SAGA rollback - this will delete the tenant
// saga.Rollback(ctx)
//
// return nil, err
// }
//
// // Success! Both tenant and user created, no rollback needed
// return &RegisterResponse{
// TenantID: tenantOutput.ID,
// UserID: userOutput.ID,
// }, nil
// }
//
// # Example Usage: Multi-Step Saga
//
// For operations with many steps, register multiple compensations:
//
// func (uc *ComplexOperationUseCase) Execute(ctx context.Context) error {
// saga := transaction.NewSaga("complex-operation", uc.logger)
//
// // Step 1: Create resource A
// resourceA, err := uc.createResourceA(ctx)
// if err != nil {
// return err
// }
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteResourceA(ctx, resourceA.ID)
// })
//
// // Step 2: Create resource B
// resourceB, err := uc.createResourceB(ctx)
// if err != nil {
// saga.Rollback(ctx) // Deletes A
// return err
// }
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteResourceB(ctx, resourceB.ID)
// })
//
// // Step 3: Create resource C
// resourceC, err := uc.createResourceC(ctx)
// if err != nil {
// saga.Rollback(ctx) // Deletes B, then A (LIFO order)
// return err
// }
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteResourceC(ctx, resourceC.ID)
// })
//
// // All steps succeeded - no rollback needed
// return nil
// }
//
// # Important Notes for Junior Developers
//
// 1. LIFO Order: Compensations execute in REVERSE order of registration
// If you create: Tenant → User → Email
// Rollback deletes: Email → User → Tenant
//
// 2. Idempotency: Compensating operations should be idempotent (safe to call multiple times)
// Your DeleteTenant should not error if tenant is already deleted
//
// 3. Failures Continue: If one compensation fails, others still execute
// This ensures maximum cleanup even if some operations fail
//
// 4. Logging: All operations are logged with emoji icons (🔴 for errors, 🟡 for warnings)
// Monitor logs for "saga rollback had failures" - indicates manual intervention needed
//
// 5. When NOT to Use SAGA:
// - Single database operation (no need for compensation)
// - Read-only operations (no state changes to rollback)
// - Operations where compensation isn't possible (e.g., sending an email can't be unsent)
//
// 6. Testing: Always test your rollback scenarios!
// Mock the second operation to fail and verify the first is rolled back
//
// # Common Pitfalls to Avoid
//
// - DON'T register compensations before the operation succeeds
// - DON'T forget to call saga.Rollback(ctx) when an operation fails
// - DON'T assume compensations will always succeed (they might fail too)
// - DON'T use SAGA for operations that can use database transactions
// - DO make your compensating operations idempotent
// - DO log all compensation failures for investigation
//
// # See Also
//
// For real-world examples, see:
// - internal/service/gateway/register.go (user registration with SAGA)
// - internal/usecase/tenant/delete.go (compensating transaction example)
// - internal/usecase/user/delete.go (compensating transaction example)
// Compensator defines a function that undoes a previously executed operation.
//
// A compensator is the "undo" function for a database write operation.
// For example:
// - Forward operation: CreateTenant
// - Compensator: DeleteTenant
//
// Compensators must:
// - Accept a context (for cancellation/timeouts)
// - Return an error if compensation fails
// - Be idempotent (safe to call multiple times)
// - Clean up the exact resources created by the forward operation
//
// Example:
//
// // Forward operation: Create tenant
// tenantID := "tenant-123"
// err := tenantRepo.Create(ctx, tenant)
//
// // Compensator: Delete tenant
// compensator := func(ctx context.Context) error {
// return tenantRepo.Delete(ctx, tenantID)
// }
//
// saga.AddCompensation(compensator)
type Compensator func(ctx context.Context) error
// Saga manages a sequence of operations with compensating transactions.
//
// A Saga coordinates a multi-step workflow where each step that performs a database
// write registers a compensating transaction. If any step fails, all registered
// compensations are executed in reverse order (LIFO) to undo previous changes.
//
// # How it Works
//
// 1. Create a Saga instance with NewSaga()
// 2. Execute your operations in sequence
// 3. After each successful write, call AddCompensation() with the undo operation
// 4. If any operation fails, call Rollback() to undo all previous changes
// 5. If all operations succeed, no action needed (compensations are never called)
//
// # Thread Safety
//
// Saga is NOT thread-safe. Do not share a single Saga instance across goroutines.
// Each workflow execution should create its own Saga instance.
//
// # Fields
//
// - name: Human-readable name for logging (e.g., "user-registration")
// - compensators: Stack of undo functions, executed in LIFO order
// - logger: Structured logger for tracking saga execution and failures
type Saga struct {
name string // Name of the saga (for logging)
compensators []Compensator // Stack of compensating transactions (LIFO)
logger *zap.Logger // Logger for tracking saga execution
}
// NewSaga creates a new SAGA instance with the given name.
//
// The name parameter should be a descriptive identifier for the workflow
// (e.g., "user-registration", "order-processing", "account-setup").
// This name appears in all log messages for easy tracking and debugging.
//
// # Parameters
//
// - name: A descriptive name for this saga workflow (used in logging)
// - logger: A zap logger instance (will be enhanced with saga-specific fields)
//
// # Returns
//
// A new Saga instance ready to coordinate multi-step operations.
//
// # Example
//
// // In your use case
// func (uc *RegisterUseCase) Execute(ctx context.Context, input *Input) error {
// // Create a new saga for this registration workflow
// saga := transaction.NewSaga("user-registration", uc.logger)
//
// // ... use saga for your operations ...
// }
//
// # Important
//
// Each workflow execution should create its own Saga instance.
// Do NOT reuse a Saga instance across multiple workflow executions.
func NewSaga(name string, logger *zap.Logger) *Saga {
return &Saga{
name: name,
compensators: make([]Compensator, 0),
logger: logger.Named("saga").With(zap.String("saga_name", name)),
}
}
// AddCompensation registers a compensating transaction for rollback.
//
// Call this method IMMEDIATELY AFTER a successful database write operation
// to register the corresponding undo operation.
//
// # Execution Order: LIFO (Last In, First Out)
//
// Compensations are executed in REVERSE order of registration during rollback.
// This ensures proper cleanup order:
// - If you create: Tenant → User → Subscription
// - Rollback deletes: Subscription → User → Tenant
//
// # Parameters
//
// - compensate: A function that undoes the operation (e.g., DeleteTenant)
//
// # When to Call
//
// // ✅ CORRECT: Register compensation AFTER operation succeeds
// tenantOutput, err := uc.createTenantUC.Execute(ctx, input)
// if err != nil {
// return nil, err // Operation failed - no compensation needed
// }
// // Operation succeeded - NOW register the undo operation
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteTenantUC.Execute(ctx, tenantOutput.ID)
// })
//
// // ❌ WRONG: Don't register compensation BEFORE operation
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteTenantUC.Execute(ctx, tenantOutput.ID)
// })
// tenantOutput, err := uc.createTenantUC.Execute(ctx, input) // Might fail!
//
// # Example: Basic Usage
//
// // Step 1: Create tenant
// tenant, err := uc.createTenantUC.Execute(ctx, input)
// if err != nil {
// return nil, err
// }
//
// // Step 2: Register compensation for tenant
// saga.AddCompensation(func(ctx context.Context) error {
// uc.logger.Warn("rolling back: deleting tenant",
// zap.String("tenant_id", tenant.ID))
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
// })
//
// # Example: Capturing Variables in Closure
//
// // Be careful with variable scope in closures!
// for _, item := range items {
// created, err := uc.createItem(ctx, item)
// if err != nil {
// saga.Rollback(ctx)
// return err
// }
//
// // ✅ CORRECT: Capture the variable value
// itemID := created.ID // Capture in local variable
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteItem(ctx, itemID) // Use captured value
// })
//
// // ❌ WRONG: Variable will have wrong value at rollback time
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteItem(ctx, created.ID) // 'created' may change!
// })
// }
//
// # Tips for Writing Good Compensators
//
// 1. Make them idempotent (safe to call multiple times)
// 2. Log what you're compensating for easier debugging
// 3. Capture all necessary IDs before the closure
// 4. Handle "not found" errors gracefully (resource may already be deleted)
// 5. Return errors if compensation truly fails (logged but doesn't stop other compensations)
func (s *Saga) AddCompensation(compensate Compensator) {
s.compensators = append(s.compensators, compensate)
s.logger.Debug("compensation registered",
zap.Int("total_compensations", len(s.compensators)))
}
// Rollback executes all registered compensating transactions in reverse order (LIFO).
//
// Call this method when any operation in your workflow fails AFTER you've started
// registering compensations. This will undo all previously successful operations
// by executing their compensating transactions in reverse order.
//
// # When to Call
//
// tenant, err := uc.createTenantUC.Execute(ctx, input)
// if err != nil {
// return nil, err // No compensations registered yet - no rollback needed
// }
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
// })
//
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
// if err != nil {
// // Compensations ARE registered - MUST call rollback!
// saga.Rollback(ctx)
// return nil, err
// }
//
// # Execution Behavior
//
// 1. LIFO Order: Compensations execute in REVERSE order of registration
// - If you registered: [DeleteTenant, DeleteUser, DeleteSubscription]
// - Rollback executes: DeleteSubscription → DeleteUser → DeleteTenant
//
// 2. Best Effort: If a compensation fails, it's logged but others still execute
// - This maximizes cleanup even if some operations fail
// - Failed compensations are logged with 🔴 emoji for investigation
//
// 3. No Panic: Rollback never panics, even if all compensations fail
// - Failures are logged for manual intervention
// - Returns without error (compensation failures are logged, not returned)
//
// # Example: Basic Rollback
//
// func (uc *RegisterUseCase) Execute(ctx context.Context, input *Input) error {
// saga := transaction.NewSaga("user-registration", uc.logger)
//
// // Step 1: Create tenant
// tenant, err := uc.createTenantUC.Execute(ctx, input)
// if err != nil {
// return err // No rollback needed
// }
// saga.AddCompensation(func(ctx context.Context) error {
// return uc.deleteTenantUC.Execute(ctx, tenant.ID)
// })
//
// // Step 2: Create user
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
// if err != nil {
// uc.logger.Error("user creation failed", zap.Error(err))
// saga.Rollback(ctx) // ← Deletes tenant
// return err
// }
//
// // Both operations succeeded - no rollback needed
// return nil
// }
//
// # Log Output Example
//
// Successful rollback:
//
// WARN 🟡 executing saga rollback {"saga_name": "user-registration", "compensation_count": 1}
// INFO executing compensation {"step": 1, "index": 0}
// INFO deleting tenant {"tenant_id": "tenant-123"}
// INFO tenant deleted successfully {"tenant_id": "tenant-123"}
// INFO compensation succeeded {"step": 1}
// WARN 🟡 saga rollback completed {"total_compensations": 1, "successes": 1, "failures": 0}
//
// Failed compensation:
//
// WARN 🟡 executing saga rollback
// INFO executing compensation
// ERROR 🔴 failed to delete tenant {"error": "connection lost"}
// ERROR 🔴 compensation failed {"step": 1, "error": "..."}
// WARN 🟡 saga rollback completed {"successes": 0, "failures": 1}
// ERROR 🔴 saga rollback had failures - manual intervention may be required
//
// # Important Notes
//
// 1. Always call Rollback if you've registered ANY compensations and a later step fails
// 2. Don't call Rollback if no compensations have been registered yet
// 3. Rollback is safe to call multiple times (idempotent) but wasteful
// 4. Monitor logs for "saga rollback had failures" - indicates manual cleanup needed
// 5. Context cancellation is respected - compensations will see cancelled context
//
// # Parameters
//
// - ctx: Context for cancellation/timeout (passed to each compensating function)
//
// # What Gets Logged
//
// - Start of rollback (warning level with 🟡 emoji)
// - Each compensation execution attempt
// - Success or failure of each compensation
// - Summary of rollback results
// - Alert if any compensations failed (error level with 🔴 emoji)
func (s *Saga) Rollback(ctx context.Context) {
if len(s.compensators) == 0 {
s.logger.Info("no compensations to execute")
return
}
s.logger.Warn("executing saga rollback",
zap.Int("compensation_count", len(s.compensators)))
successCount := 0
failureCount := 0
// Execute in reverse order (LIFO - Last In, First Out)
for i := len(s.compensators) - 1; i >= 0; i-- {
compensationStep := len(s.compensators) - i
s.logger.Info("executing compensation",
zap.Int("step", compensationStep),
zap.Int("index", i))
if err := s.compensators[i](ctx); err != nil {
failureCount++
// Log with error level (automatically adds emoji)
s.logger.Error("compensation failed",
zap.Int("step", compensationStep),
zap.Int("index", i),
zap.Error(err))
// Continue with other compensations even if one fails
} else {
successCount++
s.logger.Info("compensation succeeded",
zap.Int("step", compensationStep),
zap.Int("index", i))
}
}
s.logger.Warn("saga rollback completed",
zap.Int("total_compensations", len(s.compensators)),
zap.Int("successes", successCount),
zap.Int("failures", failureCount))
// If any compensations failed, this indicates a serious issue
// The operations team should be alerted to investigate
if failureCount > 0 {
s.logger.Error("saga rollback had failures - manual intervention may be required",
zap.Int("failed_compensations", failureCount))
}
}
// MustRollback is a convenience method that executes rollback.
//
// This method currently has the same behavior as Rollback() - it executes
// all compensating transactions but does NOT panic on failure.
//
// # When to Use
//
// Use this method when you want to make it explicit in your code that rollback
// is critical and must be executed, even though the actual behavior is the same
// as Rollback().
//
// # Example
//
// user, err := uc.createUserUC.Execute(ctx, tenant.ID, input)
// if err != nil {
// // Make it explicit that rollback is critical
// saga.MustRollback(ctx)
// return nil, err
// }
//
// # Note for Junior Developers
//
// Despite the name "MustRollback", this method does NOT panic if compensations fail.
// Compensation failures are logged for manual intervention, but the method returns normally.
//
// The name "Must" indicates that YOU must call this method if compensations are registered,
// not that the rollback itself must succeed.
//
// If you need actual panic behavior on compensation failure, you would need to check
// logs or implement custom panic logic.
func (s *Saga) MustRollback(ctx context.Context) {
s.Rollback(ctx)
}

View file

@ -0,0 +1,275 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/validation/email.go
package validation
import (
"fmt"
"strings"
)
// EmailValidator provides comprehensive email validation and normalization
// CWE-20: Improper Input Validation - Ensures email addresses are properly validated
type EmailValidator struct {
validator *Validator
}
// NewEmailValidator creates a new email validator
func NewEmailValidator() *EmailValidator {
return &EmailValidator{
validator: NewValidator(),
}
}
// ValidateAndNormalize validates and normalizes an email address
// Returns the normalized email and any validation error
func (ev *EmailValidator) ValidateAndNormalize(email, fieldName string) (string, error) {
// Step 1: Basic validation using existing validator
if err := ev.validator.ValidateEmail(email, fieldName); err != nil {
return "", err
}
// Step 2: Normalize the email
normalized := ev.Normalize(email)
// Step 3: Additional security checks
if err := ev.ValidateSecurityConstraints(normalized, fieldName); err != nil {
return "", err
}
return normalized, nil
}
// Normalize normalizes an email address for consistent storage and comparison
// CWE-180: Incorrect Behavior Order: Validate Before Canonicalize
func (ev *EmailValidator) Normalize(email string) string {
// Trim whitespace
email = strings.TrimSpace(email)
// Convert to lowercase (email local parts are case-sensitive per RFC 5321,
// but most providers treat them as case-insensitive for better UX)
email = strings.ToLower(email)
// Remove any null bytes
email = strings.ReplaceAll(email, "\x00", "")
// Gmail-specific normalization (optional - commented out by default)
// This removes dots and plus-aliases from Gmail addresses
// Uncomment if you want to prevent abuse via Gmail aliases
// email = ev.normalizeGmail(email)
return email
}
// ValidateSecurityConstraints performs additional security validation
func (ev *EmailValidator) ValidateSecurityConstraints(email, fieldName string) error {
// Check for suspicious patterns
// 1. Detect emails with excessive special characters (potential obfuscation)
specialCharCount := 0
for _, ch := range email {
if ch == '+' || ch == '.' || ch == '_' || ch == '-' || ch == '%' {
specialCharCount++
}
}
if specialCharCount > 10 {
return fmt.Errorf("%s: contains too many special characters", fieldName)
}
// 2. Detect potentially disposable email patterns
if ev.isLikelyDisposable(email) {
// Note: This is a warning-level check. In production, you might want to
// either reject these or flag them for review.
// For now, we'll allow them but this can be configured.
}
// 3. Check for common typos in popular domains
if typo := ev.detectCommonDomainTypo(email); typo != "" {
return fmt.Errorf("%s: possible typo detected, did you mean %s?", fieldName, typo)
}
// 4. Prevent IP-based email addresses
if ev.hasIPAddress(email) {
return fmt.Errorf("%s: IP-based email addresses are not allowed", fieldName)
}
return nil
}
// isLikelyDisposable checks if email is from a known disposable email provider
// This is a basic implementation - in production, use a service like:
// - https://github.com/disposable/disposable-email-domains
// - or an API service
func (ev *EmailValidator) isLikelyDisposable(email string) bool {
// Extract domain
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := strings.ToLower(parts[1])
// Common disposable email patterns
disposablePatterns := []string{
"temp",
"disposable",
"throwaway",
"guerrilla",
"mailinator",
"10minute",
"trashmail",
"yopmail",
"fakeinbox",
}
for _, pattern := range disposablePatterns {
if strings.Contains(domain, pattern) {
return true
}
}
// Known disposable domains (small sample - expand as needed)
disposableDomains := map[string]bool{
"mailinator.com": true,
"guerrillamail.com": true,
"10minutemail.com": true,
"tempmailaddress.com": true,
"yopmail.com": true,
"fakeinbox.com": true,
"trashmail.com": true,
"throwaway.email": true,
}
return disposableDomains[domain]
}
// detectCommonDomainTypo checks for common typos in popular email domains
func (ev *EmailValidator) detectCommonDomainTypo(email string) string {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return ""
}
localPart := parts[0]
domain := strings.ToLower(parts[1])
// Common typos map: typo -> correct
typos := map[string]string{
"gmial.com": "gmail.com",
"gmai.com": "gmail.com",
"gmil.com": "gmail.com",
"yahooo.com": "yahoo.com",
"yaho.com": "yahoo.com",
"hotmial.com": "hotmail.com",
"hotmal.com": "hotmail.com",
"outlok.com": "outlook.com",
"outloo.com": "outlook.com",
"iclodu.com": "icloud.com",
"iclod.com": "icloud.com",
"protonmai.com": "protonmail.com",
"protonmal.com": "protonmail.com",
}
if correct, found := typos[domain]; found {
return localPart + "@" + correct
}
return ""
}
// hasIPAddress checks if email domain is an IP address
func (ev *EmailValidator) hasIPAddress(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := parts[1]
// Check for IPv4 pattern: [192.168.1.1]
if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") {
return true
}
// Check for unbracketed IP patterns (less common but possible)
// Simple heuristic: contains only digits and dots
hasOnlyDigitsAndDots := true
for _, ch := range domain {
if ch != '.' && (ch < '0' || ch > '9') {
hasOnlyDigitsAndDots = false
break
}
}
return hasOnlyDigitsAndDots && strings.Count(domain, ".") >= 3
}
// normalizeGmail normalizes Gmail addresses by removing dots and plus-aliases
// Gmail ignores dots in the local part and treats everything after + as an alias
// Example: john.doe+test@gmail.com -> johndoe@gmail.com
func (ev *EmailValidator) normalizeGmail(email string) string {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return email
}
localPart := parts[0]
domain := strings.ToLower(parts[1])
// Only normalize for Gmail and Googlemail
if domain != "gmail.com" && domain != "googlemail.com" {
return email
}
// Remove dots from local part
localPart = strings.ReplaceAll(localPart, ".", "")
// Remove everything after + (plus-alias)
if plusIndex := strings.Index(localPart, "+"); plusIndex != -1 {
localPart = localPart[:plusIndex]
}
return localPart + "@" + domain
}
// ValidateEmailList validates a list of email addresses
// Returns the first error encountered, or nil if all are valid
func (ev *EmailValidator) ValidateEmailList(emails []string, fieldName string) ([]string, error) {
normalized := make([]string, 0, len(emails))
for i, email := range emails {
norm, err := ev.ValidateAndNormalize(email, fmt.Sprintf("%s[%d]", fieldName, i))
if err != nil {
return nil, err
}
normalized = append(normalized, norm)
}
return normalized, nil
}
// IsValidEmailDomain checks if a domain is likely valid (has proper structure)
// This is a lightweight check - for production, consider DNS MX record validation
func (ev *EmailValidator) IsValidEmailDomain(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := strings.ToLower(parts[1])
// Must have at least one dot
if !strings.Contains(domain, ".") {
return false
}
// TLD must be at least 2 characters
tldParts := strings.Split(domain, ".")
if len(tldParts) < 2 {
return false
}
tld := tldParts[len(tldParts)-1]
if len(tld) < 2 {
return false
}
return true
}

View file

@ -0,0 +1,120 @@
package validation
import (
"fmt"
"net/http"
"strconv"
)
// ValidatePathUUID validates a UUID path parameter
// CWE-20: Improper Input Validation
func ValidatePathUUID(r *http.Request, paramName string) (string, error) {
value := r.PathValue(paramName)
if value == "" {
return "", fmt.Errorf("%s is required", paramName)
}
validator := NewValidator()
if err := validator.ValidateUUID(value, paramName); err != nil {
return "", err
}
return value, nil
}
// ValidatePathSlug validates a slug path parameter
// CWE-20: Improper Input Validation
func ValidatePathSlug(r *http.Request, paramName string) (string, error) {
value := r.PathValue(paramName)
if value == "" {
return "", fmt.Errorf("%s is required", paramName)
}
validator := NewValidator()
if err := validator.ValidateSlug(value, paramName); err != nil {
return "", err
}
return value, nil
}
// ValidatePathInt validates an integer path parameter
// CWE-20: Improper Input Validation
func ValidatePathInt(r *http.Request, paramName string) (int64, error) {
valueStr := r.PathValue(paramName)
if valueStr == "" {
return 0, fmt.Errorf("%s is required", paramName)
}
value, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("%s must be a valid integer", paramName)
}
if value <= 0 {
return 0, fmt.Errorf("%s must be greater than 0", paramName)
}
return value, nil
}
// ValidatePagination validates pagination query parameters
// Returns limit and offset with defaults and bounds checking
func ValidatePagination(r *http.Request, defaultLimit int) (limit int, offset int, err error) {
limit = defaultLimit
offset = 0
// Validate limit
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
parsedLimit, err := strconv.Atoi(limitStr)
if err != nil || parsedLimit <= 0 || parsedLimit > 100 {
return 0, 0, fmt.Errorf("limit must be between 1 and 100")
}
limit = parsedLimit
}
// Validate offset
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
parsedOffset, err := strconv.Atoi(offsetStr)
if err != nil || parsedOffset < 0 {
return 0, 0, fmt.Errorf("offset must be >= 0")
}
offset = parsedOffset
}
return limit, offset, nil
}
// ValidateSortField validates sort field against whitelist
// CWE-89: SQL Injection prevention via whitelist
func ValidateSortField(r *http.Request, allowedFields []string) (string, error) {
sortBy := r.URL.Query().Get("sort_by")
if sortBy == "" {
return "", nil // Optional field
}
for _, allowed := range allowedFields {
if sortBy == allowed {
return sortBy, nil
}
}
return "", fmt.Errorf("invalid sort_by field (allowed: %v)", allowedFields)
}
// ValidateQueryEmail validates an email query parameter
// CWE-20: Improper Input Validation
func ValidateQueryEmail(r *http.Request, paramName string) (string, error) {
email := r.URL.Query().Get(paramName)
if email == "" {
return "", fmt.Errorf("%s is required", paramName)
}
emailValidator := NewEmailValidator()
normalizedEmail, err := emailValidator.ValidateAndNormalize(email, paramName)
if err != nil {
return "", err
}
return normalizedEmail, nil
}

View file

@ -0,0 +1,6 @@
package validation
// ProvideValidator provides a Validator instance
func ProvideValidator() *Validator {
return NewValidator()
}

View file

@ -0,0 +1,498 @@
package validation
import (
"fmt"
"net/mail"
"net/url"
"regexp"
"strings"
"time"
"unicode"
)
// Common validation errors
var (
ErrRequired = fmt.Errorf("field is required")
ErrInvalidEmail = fmt.Errorf("invalid email format")
ErrInvalidURL = fmt.Errorf("invalid URL format")
ErrInvalidDomain = fmt.Errorf("invalid domain format")
ErrTooShort = fmt.Errorf("value is too short")
ErrTooLong = fmt.Errorf("value is too long")
ErrInvalidCharacters = fmt.Errorf("contains invalid characters")
ErrInvalidFormat = fmt.Errorf("invalid format")
ErrInvalidValue = fmt.Errorf("invalid value")
ErrWhitespaceOnly = fmt.Errorf("cannot contain only whitespace")
ErrContainsHTML = fmt.Errorf("cannot contain HTML tags")
ErrInvalidSlug = fmt.Errorf("invalid slug format")
)
// Regex patterns for validation
var (
// Email validation: RFC 5322 compliant
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._%+\-]*[a-zA-Z0-9]@[a-zA-Z0-9][a-zA-Z0-9.\-]*[a-zA-Z0-9]\.[a-zA-Z]{2,}$`)
// Domain validation: alphanumeric with dots and hyphens
domainRegex = regexp.MustCompile(`^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
// Slug validation: lowercase alphanumeric with hyphens
slugRegex = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
// HTML tag detection
htmlTagRegex = regexp.MustCompile(`<[^>]+>`)
// UUID validation (version 4)
uuidRegex = regexp.MustCompile(`^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$`)
// Alphanumeric only
alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
)
// Reserved slugs that cannot be used for tenant names
var ReservedSlugs = map[string]bool{
"api": true,
"admin": true,
"www": true,
"mail": true,
"email": true,
"health": true,
"status": true,
"metrics": true,
"static": true,
"cdn": true,
"assets": true,
"blog": true,
"docs": true,
"help": true,
"support": true,
"login": true,
"logout": true,
"signup": true,
"register": true,
"app": true,
"dashboard": true,
"settings": true,
"account": true,
"profile": true,
"root": true,
"system": true,
"public": true,
"private": true,
}
// Validator provides input validation utilities
type Validator struct{}
// NewValidator creates a new validator instance
func NewValidator() *Validator {
return &Validator{}
}
// ==================== String Validation ====================
// ValidateRequired checks if a string is not empty
func (v *Validator) ValidateRequired(value, fieldName string) error {
if strings.TrimSpace(value) == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
return nil
}
// ValidateLength checks if string length is within range
func (v *Validator) ValidateLength(value, fieldName string, min, max int) error {
length := len(strings.TrimSpace(value))
if length < min {
return fmt.Errorf("%s: %w (minimum %d characters)", fieldName, ErrTooShort, min)
}
if max > 0 && length > max {
return fmt.Errorf("%s: %w (maximum %d characters)", fieldName, ErrTooLong, max)
}
return nil
}
// ValidateNotWhitespaceOnly ensures the string contains non-whitespace characters
func (v *Validator) ValidateNotWhitespaceOnly(value, fieldName string) error {
if len(strings.TrimSpace(value)) == 0 && len(value) > 0 {
return fmt.Errorf("%s: %w", fieldName, ErrWhitespaceOnly)
}
return nil
}
// ValidateNoHTML checks that the string doesn't contain HTML tags
func (v *Validator) ValidateNoHTML(value, fieldName string) error {
if htmlTagRegex.MatchString(value) {
return fmt.Errorf("%s: %w", fieldName, ErrContainsHTML)
}
return nil
}
// ValidateAlphanumeric checks if string contains only alphanumeric characters
func (v *Validator) ValidateAlphanumeric(value, fieldName string) error {
if !alphanumericRegex.MatchString(value) {
return fmt.Errorf("%s: %w (only letters and numbers allowed)", fieldName, ErrInvalidCharacters)
}
return nil
}
// ValidatePrintable ensures string contains only printable characters
func (v *Validator) ValidatePrintable(value, fieldName string) error {
for _, r := range value {
if !unicode.IsPrint(r) && !unicode.IsSpace(r) {
return fmt.Errorf("%s: %w (contains non-printable characters)", fieldName, ErrInvalidCharacters)
}
}
return nil
}
// ==================== Email Validation ====================
// ValidateEmail validates email format using RFC 5322 compliant regex
func (v *Validator) ValidateEmail(email, fieldName string) error {
email = strings.TrimSpace(email)
// Check required
if email == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (RFC 5321: max 320 chars)
if len(email) > 320 {
return fmt.Errorf("%s: %w (maximum 320 characters)", fieldName, ErrTooLong)
}
// Validate using regex
if !emailRegex.MatchString(email) {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
}
// Additional validation using net/mail package
_, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
}
// Check for consecutive dots
if strings.Contains(email, "..") {
return fmt.Errorf("%s: %w (consecutive dots not allowed)", fieldName, ErrInvalidEmail)
}
// Check for leading/trailing dots in local part
parts := strings.Split(email, "@")
if len(parts) == 2 {
if strings.HasPrefix(parts[0], ".") || strings.HasSuffix(parts[0], ".") {
return fmt.Errorf("%s: %w (local part cannot start or end with dot)", fieldName, ErrInvalidEmail)
}
}
return nil
}
// ==================== URL Validation ====================
// ValidateURL validates URL format and ensures it has a valid scheme
func (v *Validator) ValidateURL(urlStr, fieldName string) error {
urlStr = strings.TrimSpace(urlStr)
// Check required
if urlStr == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (max 2048 chars for URL)
if len(urlStr) > 2048 {
return fmt.Errorf("%s: %w (maximum 2048 characters)", fieldName, ErrTooLong)
}
// Parse URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidURL)
}
// Ensure scheme is present and valid
if parsedURL.Scheme == "" {
return fmt.Errorf("%s: %w (missing scheme)", fieldName, ErrInvalidURL)
}
// Only allow http and https
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("%s: %w (only http and https schemes allowed)", fieldName, ErrInvalidURL)
}
// Ensure host is present
if parsedURL.Host == "" {
return fmt.Errorf("%s: %w (missing host)", fieldName, ErrInvalidURL)
}
return nil
}
// ValidateHTTPSURL validates URL and ensures it uses HTTPS
func (v *Validator) ValidateHTTPSURL(urlStr, fieldName string) error {
if err := v.ValidateURL(urlStr, fieldName); err != nil {
return err
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("%s: invalid URL format", fieldName)
}
if parsedURL.Scheme != "https" {
return fmt.Errorf("%s: must use HTTPS protocol", fieldName)
}
return nil
}
// ==================== Domain Validation ====================
// ValidateDomain validates domain name format
// Supports standard domains (example.com) and localhost with ports (localhost:8081) for development
func (v *Validator) ValidateDomain(domain, fieldName string) error {
domain = strings.TrimSpace(strings.ToLower(domain))
// Check required
if domain == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (max 253 chars per RFC 1035)
if len(domain) > 253 {
return fmt.Errorf("%s: %w (maximum 253 characters)", fieldName, ErrTooLong)
}
// Check minimum length
if len(domain) < 4 {
return fmt.Errorf("%s: %w (minimum 4 characters)", fieldName, ErrTooShort)
}
// Allow localhost with optional port for development
// Examples: localhost, localhost:8080, localhost:3000
if strings.HasPrefix(domain, "localhost") {
// If it has a port, validate the port format
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
if len(parts) != 2 {
return fmt.Errorf("%s: %w (invalid localhost format)", fieldName, ErrInvalidDomain)
}
// Port should be numeric
if parts[1] == "" {
return fmt.Errorf("%s: %w (missing port number)", fieldName, ErrInvalidDomain)
}
// Basic port validation (could be more strict)
for _, c := range parts[1] {
if c < '0' || c > '9' {
return fmt.Errorf("%s: %w (port must be numeric)", fieldName, ErrInvalidDomain)
}
}
}
return nil
}
// Allow 127.0.0.1 and other local IPs with optional port for development
if strings.HasPrefix(domain, "127.") || strings.HasPrefix(domain, "192.168.") || strings.HasPrefix(domain, "10.") {
// If it has a port, just verify format (IP:port)
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
if len(parts) != 2 {
return fmt.Errorf("%s: %w (invalid IP format)", fieldName, ErrInvalidDomain)
}
}
return nil
}
// Validate standard domain format (example.com)
if !domainRegex.MatchString(domain) {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidDomain)
}
// Check each label length (max 63 chars per RFC 1035)
labels := strings.Split(domain, ".")
for _, label := range labels {
if len(label) > 63 {
return fmt.Errorf("%s: %w (label exceeds 63 characters)", fieldName, ErrInvalidDomain)
}
}
return nil
}
// ==================== Slug Validation ====================
// ValidateSlug validates slug format (lowercase alphanumeric with hyphens)
func (v *Validator) ValidateSlug(slug, fieldName string) error {
slug = strings.TrimSpace(strings.ToLower(slug))
// Check required
if slug == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (3-63 chars)
if len(slug) < 3 {
return fmt.Errorf("%s: %w (minimum 3 characters)", fieldName, ErrTooShort)
}
if len(slug) > 63 {
return fmt.Errorf("%s: %w (maximum 63 characters)", fieldName, ErrTooLong)
}
// Validate format
if !slugRegex.MatchString(slug) {
return fmt.Errorf("%s: %w (only lowercase letters, numbers, and hyphens allowed)", fieldName, ErrInvalidSlug)
}
// Check for reserved slugs
if ReservedSlugs[slug] {
return fmt.Errorf("%s: '%s' is a reserved slug and cannot be used", fieldName, slug)
}
return nil
}
// GenerateSlug generates a URL-friendly slug from a name
// Converts to lowercase, replaces spaces and special chars with hyphens
// Ensures the slug matches the slug validation regex
func (v *Validator) GenerateSlug(name string) string {
// Convert to lowercase and trim spaces
slug := strings.TrimSpace(strings.ToLower(name))
// Replace any non-alphanumeric characters (except hyphens) with hyphens
var result strings.Builder
prevWasHyphen := false
for _, char := range slug {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') {
result.WriteRune(char)
prevWasHyphen = false
} else if !prevWasHyphen {
// Replace any non-alphanumeric character with a hyphen
// But don't add consecutive hyphens
result.WriteRune('-')
prevWasHyphen = true
}
}
slug = result.String()
// Remove leading and trailing hyphens
slug = strings.Trim(slug, "-")
// Enforce length constraints (3-63 chars)
if len(slug) < 3 {
// If too short, pad with random suffix
slug = slug + "-" + strings.ToLower(fmt.Sprintf("%d", time.Now().UnixNano()%10000))
}
if len(slug) > 63 {
// Truncate to 63 chars
slug = slug[:63]
// Remove trailing hyphen if any
slug = strings.TrimRight(slug, "-")
}
return slug
}
// ==================== UUID Validation ====================
// ValidateUUID validates UUID format (version 4)
func (v *Validator) ValidateUUID(id, fieldName string) error {
id = strings.TrimSpace(strings.ToLower(id))
// Check required
if id == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Validate format
if !uuidRegex.MatchString(id) {
return fmt.Errorf("%s: %w (must be a valid UUID v4)", fieldName, ErrInvalidFormat)
}
return nil
}
// ==================== Enum Validation ====================
// ValidateEnum checks if value is in the allowed list (whitelist validation)
func (v *Validator) ValidateEnum(value, fieldName string, allowedValues []string) error {
value = strings.TrimSpace(value)
// Check required
if value == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check if value is in allowed list
for _, allowed := range allowedValues {
if value == allowed {
return nil
}
}
return fmt.Errorf("%s: %w (allowed values: %s)", fieldName, ErrInvalidValue, strings.Join(allowedValues, ", "))
}
// ==================== Number Validation ====================
// ValidateRange checks if a number is within the specified range
func (v *Validator) ValidateRange(value int, fieldName string, min, max int) error {
if value < min {
return fmt.Errorf("%s: value must be at least %d", fieldName, min)
}
if max > 0 && value > max {
return fmt.Errorf("%s: value must be at most %d", fieldName, max)
}
return nil
}
// ==================== Sanitization ====================
// SanitizeString removes potentially dangerous characters and trims whitespace
func (v *Validator) SanitizeString(value string) string {
// Trim whitespace
value = strings.TrimSpace(value)
// Remove null bytes
value = strings.ReplaceAll(value, "\x00", "")
// Normalize Unicode
// Note: For production, consider using golang.org/x/text/unicode/norm
return value
}
// StripHTML removes all HTML tags from a string
func (v *Validator) StripHTML(value string) string {
return htmlTagRegex.ReplaceAllString(value, "")
}
// ==================== Combined Validations ====================
// ValidateAndSanitizeString performs validation and sanitization
func (v *Validator) ValidateAndSanitizeString(value, fieldName string, minLen, maxLen int) (string, error) {
// Sanitize first
value = v.SanitizeString(value)
// Validate required
if err := v.ValidateRequired(value, fieldName); err != nil {
return "", err
}
// Validate length
if err := v.ValidateLength(value, fieldName, minLen, maxLen); err != nil {
return "", err
}
// Validate printable characters
if err := v.ValidatePrintable(value, fieldName); err != nil {
return "", err
}
return value, nil
}

View file

@ -0,0 +1,472 @@
package validation
import (
"strings"
"testing"
)
func TestValidateRequired(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"Valid non-empty string", "test", false},
{"Empty string", "", true},
{"Whitespace only", " ", true},
{"Tab only", "\t", true},
{"Newline only", "\n", true},
{"Valid with spaces", "hello world", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateRequired(tt.value, "test_field")
if (err != nil) != tt.wantError {
t.Errorf("ValidateRequired() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateLength(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
min int
max int
wantError bool
}{
{"Valid length", "hello", 3, 10, false},
{"Too short", "ab", 3, 10, true},
{"Too long", "hello world this is too long", 3, 10, true},
{"Exact minimum", "abc", 3, 10, false},
{"Exact maximum", "0123456789", 3, 10, false},
{"No maximum (0)", "very long string here", 3, 0, false},
{"Whitespace counted correctly", " test ", 4, 10, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateLength(tt.value, "test_field", tt.min, tt.max)
if (err != nil) != tt.wantError {
t.Errorf("ValidateLength() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateEmail(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
email string
wantError bool
}{
// Valid emails
{"Valid email", "user@example.com", false},
{"Valid email with plus", "user+tag@example.com", false},
{"Valid email with dot", "first.last@example.com", false},
{"Valid email with hyphen", "user-name@example-domain.com", false},
{"Valid email with numbers", "user123@example456.com", false},
{"Valid email with subdomain", "user@sub.example.com", false},
// Invalid emails
{"Empty email", "", true},
{"Whitespace only", " ", true},
{"Missing @", "userexample.com", true},
{"Missing domain", "user@", true},
{"Missing local part", "@example.com", true},
{"No TLD", "user@localhost", true},
{"Consecutive dots in local", "user..name@example.com", true},
{"Leading dot in local", ".user@example.com", true},
{"Trailing dot in local", "user.@example.com", true},
{"Double @", "user@@example.com", true},
{"Spaces in email", "user name@example.com", true},
{"Invalid characters", "user<>@example.com", true},
{"Too long", strings.Repeat("a", 320) + "@example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateEmail(tt.email, "email")
if (err != nil) != tt.wantError {
t.Errorf("ValidateEmail() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateURL(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
url string
wantError bool
}{
// Valid URLs
{"Valid HTTP URL", "http://example.com", false},
{"Valid HTTPS URL", "https://example.com", false},
{"Valid URL with path", "https://example.com/path/to/resource", false},
{"Valid URL with query", "https://example.com?param=value", false},
{"Valid URL with port", "https://example.com:8080", false},
{"Valid URL with subdomain", "https://sub.example.com", false},
// Invalid URLs
{"Empty URL", "", true},
{"Whitespace only", " ", true},
{"Missing scheme", "example.com", true},
{"Invalid scheme", "ftp://example.com", true},
{"Missing host", "https://", true},
{"Invalid characters", "https://exam ple.com", true},
{"Too long", "https://" + strings.Repeat("a", 2048) + ".com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateURL(tt.url, "url")
if (err != nil) != tt.wantError {
t.Errorf("ValidateURL() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateHTTPSURL(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
url string
wantError bool
}{
{"Valid HTTPS URL", "https://example.com", false},
{"HTTP URL (should fail)", "http://example.com", true},
{"FTP URL (should fail)", "ftp://example.com", true},
{"Invalid URL", "not-a-url", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateHTTPSURL(tt.url, "url")
if (err != nil) != tt.wantError {
t.Errorf("ValidateHTTPSURL() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateDomain(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
domain string
wantError bool
}{
// Valid domains
{"Valid domain", "example.com", false},
{"Valid subdomain", "sub.example.com", false},
{"Valid deep subdomain", "deep.sub.example.com", false},
{"Valid with hyphen", "my-site.example.com", false},
{"Valid with numbers", "site123.example456.com", false},
// Invalid domains
{"Empty domain", "", true},
{"Whitespace only", " ", true},
{"Too short", "a.b", true},
{"Too long", strings.Repeat("a", 254) + ".com", true},
{"Label too long", strings.Repeat("a", 64) + ".example.com", true},
{"No TLD", "localhost", true},
{"Leading hyphen", "-example.com", true},
{"Trailing hyphen", "example-.com", true},
{"Double dot", "example..com", true},
{"Leading dot", ".example.com", true},
{"Trailing dot", "example.com.", true},
{"Underscore", "my_site.example.com", true},
{"Spaces", "my site.example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateDomain(tt.domain, "domain")
if (err != nil) != tt.wantError {
t.Errorf("ValidateDomain() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateSlug(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
slug string
wantError bool
}{
// Valid slugs
{"Valid slug", "my-company", false},
{"Valid slug with numbers", "company123", false},
{"Valid slug all lowercase", "testcompany", false},
{"Valid slug with multiple hyphens", "my-test-company", false},
// Invalid slugs
{"Empty slug", "", true},
{"Whitespace only", " ", true},
{"Too short", "ab", true},
{"Too long", strings.Repeat("a", 64), true},
{"Uppercase letters", "MyCompany", true},
{"Leading hyphen", "-company", true},
{"Trailing hyphen", "company-", true},
{"Double hyphen", "my--company", true},
{"Underscore", "my_company", true},
{"Spaces", "my company", true},
{"Special characters", "my@company", true},
// Reserved slugs
{"Reserved: api", "api", true},
{"Reserved: admin", "admin", true},
{"Reserved: www", "www", true},
{"Reserved: login", "login", true},
{"Reserved: register", "register", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateSlug(tt.slug, "slug")
if (err != nil) != tt.wantError {
t.Errorf("ValidateSlug() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateUUID(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
uuid string
wantError bool
}{
{"Valid UUID v4", "550e8400-e29b-41d4-a716-446655440000", false},
{"Valid UUID v4 lowercase", "123e4567-e89b-42d3-a456-426614174000", false},
{"Empty UUID", "", true},
{"Invalid format", "not-a-uuid", true},
{"Invalid version", "550e8400-e29b-21d4-a716-446655440000", true},
{"Missing hyphens", "550e8400e29b41d4a716446655440000", true},
{"Too short", "550e8400-e29b-41d4-a716", true},
{"With uppercase", "550E8400-E29B-41D4-A716-446655440000", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateUUID(tt.uuid, "id")
if (err != nil) != tt.wantError {
t.Errorf("ValidateUUID() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateEnum(t *testing.T) {
v := NewValidator()
allowedValues := []string{"free", "basic", "pro", "enterprise"}
tests := []struct {
name string
value string
wantError bool
}{
{"Valid: free", "free", false},
{"Valid: basic", "basic", false},
{"Valid: pro", "pro", false},
{"Valid: enterprise", "enterprise", false},
{"Invalid: premium", "premium", true},
{"Invalid: empty", "", true},
{"Invalid: wrong case", "FREE", true},
{"Invalid: typo", "basi", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateEnum(tt.value, "plan_tier", allowedValues)
if (err != nil) != tt.wantError {
t.Errorf("ValidateEnum() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateRange(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value int
min int
max int
wantError bool
}{
{"Valid within range", 5, 1, 10, false},
{"Valid at minimum", 1, 1, 10, false},
{"Valid at maximum", 10, 1, 10, false},
{"Below minimum", 0, 1, 10, true},
{"Above maximum", 11, 1, 10, true},
{"No maximum (0)", 1000, 1, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateRange(tt.value, "count", tt.min, tt.max)
if (err != nil) != tt.wantError {
t.Errorf("ValidateRange() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateNoHTML(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"Plain text", "Hello world", false},
{"Text with punctuation", "Hello, world!", false},
{"HTML tag <script>", "<script>alert('xss')</script>", true},
{"HTML tag <img>", "<img src='x'>", true},
{"HTML tag <div>", "<div>content</div>", true},
{"HTML tag <a>", "<a href='#'>link</a>", true},
{"Less than symbol", "5 < 10", false},
{"Greater than symbol", "10 > 5", false},
{"Both symbols", "5 < x < 10", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateNoHTML(tt.value, "content")
if (err != nil) != tt.wantError {
t.Errorf("ValidateNoHTML() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestSanitizeString(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
expected string
}{
{"Trim whitespace", " hello ", "hello"},
{"Remove null bytes", "hello\x00world", "helloworld"},
{"Already clean", "hello", "hello"},
{"Empty string", "", ""},
{"Only whitespace", " ", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := v.SanitizeString(tt.input)
if result != tt.expected {
t.Errorf("SanitizeString() = %q, want %q", result, tt.expected)
}
})
}
}
func TestStripHTML(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
expected string
}{
{"Remove script tag", "<script>alert('xss')</script>", "alert('xss')"},
{"Remove div tag", "<div>content</div>", "content"},
{"Remove multiple tags", "<p>Hello <b>world</b></p>", "Hello world"},
{"No tags", "plain text", "plain text"},
{"Empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := v.StripHTML(tt.input)
if result != tt.expected {
t.Errorf("StripHTML() = %q, want %q", result, tt.expected)
}
})
}
}
func TestValidateAndSanitizeString(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
minLen int
maxLen int
wantValue string
wantError bool
}{
{"Valid and clean", "hello", 3, 10, "hello", false},
{"Trim and validate", " hello ", 3, 10, "hello", false},
{"Too short after trim", " a ", 3, 10, "", true},
{"Too long", "hello world this is too long", 3, 10, "", true},
{"Empty after trim", " ", 3, 10, "", true},
{"Valid with null byte removed", "hel\x00lo", 3, 10, "hello", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := v.ValidateAndSanitizeString(tt.input, "test_field", tt.minLen, tt.maxLen)
if (err != nil) != tt.wantError {
t.Errorf("ValidateAndSanitizeString() error = %v, wantError %v", err, tt.wantError)
}
if !tt.wantError && result != tt.wantValue {
t.Errorf("ValidateAndSanitizeString() = %q, want %q", result, tt.wantValue)
}
})
}
}
func TestValidatePrintable(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"All printable", "Hello World 123!", false},
{"With tabs and newlines", "Hello\tWorld\n", false},
{"With control character", "Hello\x01World", true},
{"With bell character", "Hello\x07", true},
{"Empty string", "", false},
{"Unicode printable", "Hello 世界", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidatePrintable(tt.value, "test_field")
if (err != nil) != tt.wantError {
t.Errorf("ValidatePrintable() error = %v, wantError %v", err, tt.wantError)
}
})
}
}