Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
182
cloud/maplefile-backend/pkg/auditlog/auditlog.go
Normal file
182
cloud/maplefile-backend/pkg/auditlog/auditlog.go
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
// Package auditlog provides security audit logging for compliance and security monitoring.
|
||||
// Audit logs are separate from application logs and capture security-relevant events
|
||||
// with consistent structure for analysis and alerting.
|
||||
package auditlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config/constants"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// EventType represents the type of security event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
// Authentication events
|
||||
EventTypeLoginAttempt EventType = "login_attempt"
|
||||
EventTypeLoginSuccess EventType = "login_success"
|
||||
EventTypeLoginFailure EventType = "login_failure"
|
||||
EventTypeLogout EventType = "logout"
|
||||
EventTypeTokenRefresh EventType = "token_refresh"
|
||||
EventTypeTokenRevoked EventType = "token_revoked"
|
||||
|
||||
// Account events
|
||||
EventTypeAccountCreated EventType = "account_created"
|
||||
EventTypeAccountDeleted EventType = "account_deleted"
|
||||
EventTypeAccountLocked EventType = "account_locked"
|
||||
EventTypeAccountUnlocked EventType = "account_unlocked"
|
||||
EventTypeEmailVerified EventType = "email_verified"
|
||||
|
||||
// Recovery events
|
||||
EventTypeRecoveryInitiated EventType = "recovery_initiated"
|
||||
EventTypeRecoveryCompleted EventType = "recovery_completed"
|
||||
EventTypeRecoveryFailed EventType = "recovery_failed"
|
||||
|
||||
// Access control events
|
||||
EventTypeAccessDenied EventType = "access_denied"
|
||||
EventTypePermissionChanged EventType = "permission_changed"
|
||||
|
||||
// Sharing events
|
||||
EventTypeCollectionShared EventType = "collection_shared"
|
||||
EventTypeCollectionUnshared EventType = "collection_unshared"
|
||||
EventTypeSharingBlocked EventType = "sharing_blocked"
|
||||
)
|
||||
|
||||
// Outcome represents the result of the audited action
|
||||
type Outcome string
|
||||
|
||||
const (
|
||||
OutcomeSuccess Outcome = "success"
|
||||
OutcomeFailure Outcome = "failure"
|
||||
OutcomeBlocked Outcome = "blocked"
|
||||
)
|
||||
|
||||
// AuditEvent represents a security audit event
|
||||
type AuditEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
EventType EventType `json:"event_type"`
|
||||
Outcome Outcome `json:"outcome"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Email string `json:"email,omitempty"` // Always masked
|
||||
ClientIP string `json:"client_ip,omitempty"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
Resource string `json:"resource,omitempty"`
|
||||
Action string `json:"action,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
FailReason string `json:"fail_reason,omitempty"`
|
||||
}
|
||||
|
||||
// AuditLogger provides security audit logging functionality
|
||||
type AuditLogger interface {
|
||||
// Log records a security audit event
|
||||
Log(ctx context.Context, event AuditEvent)
|
||||
|
||||
// LogAuth logs an authentication event with common fields
|
||||
LogAuth(ctx context.Context, eventType EventType, outcome Outcome, email string, clientIP string, details map[string]string)
|
||||
|
||||
// LogAccess logs an access control event
|
||||
LogAccess(ctx context.Context, eventType EventType, outcome Outcome, userID string, resource string, action string, details map[string]string)
|
||||
}
|
||||
|
||||
type auditLoggerImpl struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger
|
||||
func NewAuditLogger(logger *zap.Logger) AuditLogger {
|
||||
// Create a named logger specifically for audit events
|
||||
// This allows filtering audit logs separately from application logs
|
||||
auditLogger := logger.Named("AUDIT")
|
||||
|
||||
return &auditLoggerImpl{
|
||||
logger: auditLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// Log records a security audit event
|
||||
func (a *auditLoggerImpl) Log(ctx context.Context, event AuditEvent) {
|
||||
// Set timestamp if not provided
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now().UTC()
|
||||
}
|
||||
|
||||
// Build zap fields
|
||||
fields := []zap.Field{
|
||||
zap.String("audit_event", string(event.EventType)),
|
||||
zap.String("outcome", string(event.Outcome)),
|
||||
zap.Time("event_time", event.Timestamp),
|
||||
}
|
||||
|
||||
if event.UserID != "" {
|
||||
fields = append(fields, zap.String("user_id", event.UserID))
|
||||
}
|
||||
if event.Email != "" {
|
||||
fields = append(fields, zap.String("email", validation.MaskEmail(event.Email))) // Always mask for safety
|
||||
}
|
||||
if event.ClientIP != "" {
|
||||
fields = append(fields, zap.String("client_ip", validation.MaskIP(event.ClientIP))) // Always mask for safety
|
||||
}
|
||||
if event.UserAgent != "" {
|
||||
fields = append(fields, zap.String("user_agent", event.UserAgent))
|
||||
}
|
||||
if event.Resource != "" {
|
||||
fields = append(fields, zap.String("resource", event.Resource))
|
||||
}
|
||||
if event.Action != "" {
|
||||
fields = append(fields, zap.String("action", event.Action))
|
||||
}
|
||||
if event.FailReason != "" {
|
||||
fields = append(fields, zap.String("fail_reason", event.FailReason))
|
||||
}
|
||||
if len(event.Details) > 0 {
|
||||
fields = append(fields, zap.Any("details", event.Details))
|
||||
}
|
||||
|
||||
// Try to get request ID from context
|
||||
if requestID, ok := ctx.Value(constants.SessionID).(string); ok && requestID != "" {
|
||||
fields = append(fields, zap.String("request_id", requestID))
|
||||
}
|
||||
|
||||
// Log at INFO level - audit events are always important
|
||||
a.logger.Info("security_audit", fields...)
|
||||
}
|
||||
|
||||
// LogAuth logs an authentication event with common fields
|
||||
func (a *auditLoggerImpl) LogAuth(ctx context.Context, eventType EventType, outcome Outcome, email string, clientIP string, details map[string]string) {
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC(),
|
||||
EventType: eventType,
|
||||
Outcome: outcome,
|
||||
Email: email, // Should be pre-masked by caller
|
||||
ClientIP: clientIP,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
// Extract user ID from context if available
|
||||
if userID, ok := ctx.Value(constants.SessionUserID).(gocql.UUID); ok {
|
||||
event.UserID = userID.String()
|
||||
}
|
||||
|
||||
a.Log(ctx, event)
|
||||
}
|
||||
|
||||
// LogAccess logs an access control event
|
||||
func (a *auditLoggerImpl) LogAccess(ctx context.Context, eventType EventType, outcome Outcome, userID string, resource string, action string, details map[string]string) {
|
||||
event := AuditEvent{
|
||||
Timestamp: time.Now().UTC(),
|
||||
EventType: eventType,
|
||||
Outcome: outcome,
|
||||
UserID: userID,
|
||||
Resource: resource,
|
||||
Action: action,
|
||||
Details: details,
|
||||
}
|
||||
|
||||
a.Log(ctx, event)
|
||||
}
|
||||
8
cloud/maplefile-backend/pkg/auditlog/provider.go
Normal file
8
cloud/maplefile-backend/pkg/auditlog/provider.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
package auditlog
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
||||
// ProvideAuditLogger provides an audit logger for Wire dependency injection
|
||||
func ProvideAuditLogger(logger *zap.Logger) AuditLogger {
|
||||
return NewAuditLogger(logger)
|
||||
}
|
||||
109
cloud/maplefile-backend/pkg/cache/cassandra.go
vendored
Normal file
109
cloud/maplefile-backend/pkg/cache/cassandra.go
vendored
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CassandraCacher defines the interface for Cassandra cache operations
|
||||
type CassandraCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
type cassandraCache struct {
|
||||
session *gocql.Session
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCassandraCache creates a new Cassandra cache instance
|
||||
func NewCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
logger = logger.Named("cassandra-cache")
|
||||
logger.Info("✓ Cassandra cache layer initialized")
|
||||
return &cassandraCache{
|
||||
session: session,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down Cassandra cache")
|
||||
// Note: Don't close the session here as it's managed by the database layer
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
var value []byte
|
||||
var expiresAt time.Time
|
||||
|
||||
query := `SELECT value, expires_at FROM cache WHERE key = ?`
|
||||
err := c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Scan(&value, &expiresAt)
|
||||
|
||||
if err == gocql.ErrNotFound {
|
||||
// Key doesn't exist - this is not an error
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if expired in application code
|
||||
if time.Now().After(expiresAt) {
|
||||
// Entry is expired, delete it and return nil
|
||||
_ = c.Delete(ctx, key) // Clean up expired entry
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
expiresAt := time.Now().Add(24 * time.Hour) // Default 24 hour expiry
|
||||
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
|
||||
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
expiresAt := time.Now().Add(expiry)
|
||||
query := `INSERT INTO cache (key, expires_at, value) VALUES (?, ?, ?)`
|
||||
return c.session.Query(query, key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) Delete(ctx context.Context, key string) error {
|
||||
query := `DELETE FROM cache WHERE key = ?`
|
||||
return c.session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (c *cassandraCache) PurgeExpired(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Thanks to the index on expires_at, this query is efficient
|
||||
iter := c.session.Query(`SELECT key FROM cache WHERE expires_at < ? ALLOW FILTERING`, now).WithContext(ctx).Iter()
|
||||
|
||||
var expiredKeys []string
|
||||
var key string
|
||||
for iter.Scan(&key) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired keys in batch
|
||||
if len(expiredKeys) > 0 {
|
||||
batch := c.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
|
||||
for _, expiredKey := range expiredKeys {
|
||||
batch.Query(`DELETE FROM cache WHERE key = ?`, expiredKey)
|
||||
}
|
||||
return c.session.ExecuteBatch(batch)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
23
cloud/maplefile-backend/pkg/cache/provider.go
vendored
Normal file
23
cloud/maplefile-backend/pkg/cache/provider.go
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
// ProvideRedisCache provides a Redis cache instance
|
||||
func ProvideRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
|
||||
return NewRedisCache(cfg, logger)
|
||||
}
|
||||
|
||||
// ProvideCassandraCache provides a Cassandra cache instance
|
||||
func ProvideCassandraCache(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
return NewCassandraCache(session, logger)
|
||||
}
|
||||
|
||||
// ProvideTwoTierCache provides a two-tier cache instance
|
||||
func ProvideTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
|
||||
return NewTwoTierCache(redisCache, cassandraCache, logger)
|
||||
}
|
||||
144
cloud/maplefile-backend/pkg/cache/redis.go
vendored
Normal file
144
cloud/maplefile-backend/pkg/cache/redis.go
vendored
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// silentRedisLogger filters out noisy "maintnotifications" warnings from go-redis
|
||||
// This warning occurs when the Redis client tries to use newer Redis 7.2+ features
|
||||
// that may not be fully supported by the current Redis version.
|
||||
// The client automatically falls back to compatible mode, so this is harmless.
|
||||
type silentRedisLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func (l *silentRedisLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Filter out harmless compatibility warnings
|
||||
if strings.Contains(msg, "maintnotifications disabled") ||
|
||||
strings.Contains(msg, "auto mode fallback") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other Redis messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// RedisCacher defines the interface for Redis cache operations
|
||||
type RedisCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
type redisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRedisCache creates a new Redis cache instance
|
||||
func NewRedisCache(cfg *config.Config, logger *zap.Logger) (RedisCacher, error) {
|
||||
logger = logger.Named("redis-cache")
|
||||
|
||||
logger.Info("⏳ Connecting to Redis...",
|
||||
zap.String("host", cfg.Cache.Host),
|
||||
zap.Int("port", cfg.Cache.Port))
|
||||
|
||||
// Build Redis URL from config
|
||||
redisURL := fmt.Sprintf("redis://:%s@%s:%d/%d",
|
||||
cfg.Cache.Password,
|
||||
cfg.Cache.Host,
|
||||
cfg.Cache.Port,
|
||||
cfg.Cache.DB,
|
||||
)
|
||||
|
||||
// If no password, use simpler URL format
|
||||
if cfg.Cache.Password == "" {
|
||||
redisURL = fmt.Sprintf("redis://%s:%d/%d",
|
||||
cfg.Cache.Host,
|
||||
cfg.Cache.Port,
|
||||
cfg.Cache.DB,
|
||||
)
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
|
||||
}
|
||||
|
||||
// Suppress noisy "maintnotifications" warnings from go-redis
|
||||
// Use a custom logger that filters out these harmless compatibility warnings
|
||||
redis.SetLogger(&silentRedisLogger{logger: logger.Named("redis-client")})
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err = client.Ping(ctx).Result(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✓ Redis connected",
|
||||
zap.String("host", cfg.Cache.Host),
|
||||
zap.Int("port", cfg.Cache.Port),
|
||||
zap.Int("db", cfg.Cache.DB))
|
||||
|
||||
return &redisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down Redis cache")
|
||||
if err := c.client.Close(); err != nil {
|
||||
c.logger.Error("error closing Redis connection", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *redisCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.client.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// Key doesn't exist - this is not an error
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis get failed: %w", err)
|
||||
}
|
||||
return []byte(val), nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
if err := c.client.Set(ctx, key, val, 0).Err(); err != nil {
|
||||
return fmt.Errorf("redis set failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *redisCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
if err := c.client.Set(ctx, key, val, expiry).Err(); err != nil {
|
||||
return fmt.Errorf("redis set with expiry failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *redisCache) Delete(ctx context.Context, key string) error {
|
||||
if err := c.client.Del(ctx, key).Err(); err != nil {
|
||||
return fmt.Errorf("redis delete failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
114
cloud/maplefile-backend/pkg/cache/twotier.go
vendored
Normal file
114
cloud/maplefile-backend/pkg/cache/twotier.go
vendored
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/cache/twotier.go
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TwoTierCacher defines the interface for two-tier cache operations
|
||||
type TwoTierCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
// twoTierCache implements a clean 2-layer (read-through write-through) cache
|
||||
//
|
||||
// L1: Redis (fast, in-memory)
|
||||
// L2: Cassandra (persistent)
|
||||
//
|
||||
// On Get: check Redis → then Cassandra → if found in Cassandra → populate Redis
|
||||
// On Set: write to both
|
||||
// On SetWithExpiry: write to both with expiry
|
||||
// On Delete: remove from both
|
||||
type twoTierCache struct {
|
||||
redisCache RedisCacher
|
||||
cassandraCache CassandraCacher
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTwoTierCache creates a new two-tier cache instance
|
||||
func NewTwoTierCache(redisCache RedisCacher, cassandraCache CassandraCacher, logger *zap.Logger) TwoTierCacher {
|
||||
logger = logger.Named("two-tier-cache")
|
||||
logger.Info("✓ Two-tier cache initialized (Redis L1 + Cassandra L2)")
|
||||
return &twoTierCache{
|
||||
redisCache: redisCache,
|
||||
cassandraCache: cassandraCache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
// Try L1 (Redis) first
|
||||
val, err := c.redisCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
c.logger.Debug("cache hit from Redis", zap.String("key", key))
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Not in Redis, try L2 (Cassandra)
|
||||
val, err = c.cassandraCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
// Found in Cassandra, populate Redis for future lookups
|
||||
c.logger.Debug("cache hit from Cassandra, writing back to Redis", zap.String("key", key))
|
||||
_ = c.redisCache.Set(ctx, key, val) // Best effort, don't fail if Redis write fails
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Set(ctx context.Context, key string, val []byte) error {
|
||||
// Write to both layers
|
||||
if err := c.redisCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
// Write to both layers with expiry
|
||||
if err := c.redisCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Delete(ctx context.Context, key string) error {
|
||||
// Remove from both layers
|
||||
if err := c.redisCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.cassandraCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCache) PurgeExpired(ctx context.Context) error {
|
||||
// Only Cassandra needs purging (Redis handles TTL automatically)
|
||||
return c.cassandraCache.PurgeExpired(ctx)
|
||||
}
|
||||
|
||||
func (c *twoTierCache) Shutdown(ctx context.Context) {
|
||||
c.logger.Info("shutting down two-tier cache")
|
||||
c.redisCache.Shutdown(ctx)
|
||||
c.cassandraCache.Shutdown(ctx)
|
||||
c.logger.Info("two-tier cache shutdown complete")
|
||||
}
|
||||
220
cloud/maplefile-backend/pkg/distributedmutex/distributelocker.go
Normal file
220
cloud/maplefile-backend/pkg/distributedmutex/distributelocker.go
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/distributedmutex/distributedmutex.go
|
||||
package distributedmutex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bsm/redislock"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Adapter provides interface for abstracting distributedmutex generation.
|
||||
type Adapter interface {
|
||||
// Blocking acquire - waits until lock is obtained or timeout
|
||||
Acquire(ctx context.Context, key string)
|
||||
Acquiref(ctx context.Context, format string, a ...any)
|
||||
Release(ctx context.Context, key string)
|
||||
Releasef(ctx context.Context, format string, a ...any)
|
||||
|
||||
// Non-blocking operations for leader election
|
||||
// TryAcquire attempts to acquire a lock without blocking
|
||||
// Returns true if lock was acquired, false if already held by someone else
|
||||
TryAcquire(ctx context.Context, key string, ttl time.Duration) (bool, error)
|
||||
|
||||
// Extend renews the TTL of an existing lock
|
||||
// Returns error if the lock is not owned by this instance
|
||||
Extend(ctx context.Context, key string, ttl time.Duration) error
|
||||
|
||||
// IsOwner checks if this instance owns the given lock
|
||||
IsOwner(ctx context.Context, key string) (bool, error)
|
||||
}
|
||||
|
||||
type distributedLockerAdapter struct {
|
||||
Logger *zap.Logger
|
||||
Redis redis.UniversalClient
|
||||
Locker *redislock.Client
|
||||
LockInstances map[string]*redislock.Lock
|
||||
Mutex *sync.Mutex // Add a mutex for synchronization with goroutines
|
||||
}
|
||||
|
||||
// NewAdapter constructor that returns the default DistributedLocker generator.
|
||||
func NewAdapter(loggerp *zap.Logger, redisClient redis.UniversalClient) Adapter {
|
||||
loggerp = loggerp.Named("DistributedMutex")
|
||||
loggerp.Debug("distributed mutex starting and connecting...")
|
||||
|
||||
// Create a new lock client.
|
||||
locker := redislock.New(redisClient)
|
||||
|
||||
loggerp.Debug("distributed mutex initialized")
|
||||
|
||||
return distributedLockerAdapter{
|
||||
Logger: loggerp,
|
||||
Redis: redisClient,
|
||||
Locker: locker,
|
||||
LockInstances: make(map[string]*redislock.Lock, 0),
|
||||
Mutex: &sync.Mutex{}, // Initialize the mutex
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire function blocks the current thread if the lock key is currently locked.
|
||||
func (a distributedLockerAdapter) Acquire(ctx context.Context, k string) {
|
||||
startDT := time.Now()
|
||||
a.Logger.Debug(fmt.Sprintf("locking for key: %v", k))
|
||||
|
||||
// Retry every 250ms, for up-to 20x
|
||||
backoff := redislock.LimitRetry(redislock.LinearBackoff(250*time.Millisecond), 20)
|
||||
|
||||
// Obtain lock with retry
|
||||
lock, err := a.Locker.Obtain(ctx, k, time.Minute, &redislock.Options{
|
||||
RetryStrategy: backoff,
|
||||
})
|
||||
if err == redislock.ErrNotObtained {
|
||||
nowDT := time.Now()
|
||||
diff := nowDT.Sub(startDT)
|
||||
a.Logger.Error("could not obtain lock",
|
||||
zap.String("key", k),
|
||||
zap.Time("start_dt", startDT),
|
||||
zap.Time("now_dt", nowDT),
|
||||
zap.Any("duration_in_minutes", diff.Minutes()))
|
||||
return
|
||||
} else if err != nil {
|
||||
a.Logger.Error("failed obtaining lock",
|
||||
zap.String("key", k),
|
||||
zap.Any("error", err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// DEVELOPERS NOTE:
|
||||
// The `map` datastructure in Golang is not concurrently safe, therefore we
|
||||
// need to use mutex to coordinate access of our `LockInstances` map
|
||||
// resource between all the goroutines.
|
||||
a.Mutex.Lock()
|
||||
defer a.Mutex.Unlock()
|
||||
|
||||
if a.LockInstances != nil { // Defensive code.
|
||||
a.LockInstances[k] = lock
|
||||
}
|
||||
}
|
||||
|
||||
// Acquiref function blocks the current thread if the lock key is currently locked.
|
||||
func (u distributedLockerAdapter) Acquiref(ctx context.Context, format string, a ...any) {
|
||||
k := fmt.Sprintf(format, a...)
|
||||
u.Acquire(ctx, k)
|
||||
return
|
||||
}
|
||||
|
||||
// Release function blocks the current thread if the lock key is currently locked.
|
||||
func (a distributedLockerAdapter) Release(ctx context.Context, k string) {
|
||||
a.Logger.Debug(fmt.Sprintf("unlocking for key: %v", k))
|
||||
|
||||
lockInstance, ok := a.LockInstances[k]
|
||||
if ok {
|
||||
defer lockInstance.Release(ctx)
|
||||
} else {
|
||||
a.Logger.Error("could not obtain to unlock", zap.String("key", k))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Releasef
|
||||
func (u distributedLockerAdapter) Releasef(ctx context.Context, format string, a ...any) {
|
||||
k := fmt.Sprintf(format, a...) //TODO: https://github.com/bsm/redislock/blob/main/README.md
|
||||
u.Release(ctx, k)
|
||||
return
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a lock without blocking.
|
||||
// Returns true if lock was acquired, false if already held by someone else.
|
||||
func (a distributedLockerAdapter) TryAcquire(ctx context.Context, k string, ttl time.Duration) (bool, error) {
|
||||
a.Logger.Debug(fmt.Sprintf("trying to acquire lock for key: %v with ttl: %v", k, ttl))
|
||||
|
||||
// Try to obtain lock without retries (non-blocking)
|
||||
lock, err := a.Locker.Obtain(ctx, k, ttl, &redislock.Options{
|
||||
RetryStrategy: redislock.NoRetry(),
|
||||
})
|
||||
|
||||
if err == redislock.ErrNotObtained {
|
||||
// Lock is held by someone else
|
||||
a.Logger.Debug("lock not obtained, already held by another instance",
|
||||
zap.String("key", k))
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Actual error occurred
|
||||
a.Logger.Error("failed trying to obtain lock",
|
||||
zap.String("key", k),
|
||||
zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Successfully acquired lock
|
||||
a.Mutex.Lock()
|
||||
defer a.Mutex.Unlock()
|
||||
|
||||
if a.LockInstances != nil {
|
||||
a.LockInstances[k] = lock
|
||||
}
|
||||
|
||||
a.Logger.Debug("successfully acquired lock",
|
||||
zap.String("key", k),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Extend renews the TTL of an existing lock.
|
||||
// Returns error if the lock is not owned by this instance.
|
||||
func (a distributedLockerAdapter) Extend(ctx context.Context, k string, ttl time.Duration) error {
|
||||
a.Logger.Debug(fmt.Sprintf("extending lock for key: %v with ttl: %v", k, ttl))
|
||||
|
||||
a.Mutex.Lock()
|
||||
lockInstance, ok := a.LockInstances[k]
|
||||
a.Mutex.Unlock()
|
||||
|
||||
if !ok {
|
||||
err := fmt.Errorf("lock not found in instances map")
|
||||
a.Logger.Error("cannot extend lock, not owned by this instance",
|
||||
zap.String("key", k),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extend the lock TTL
|
||||
err := lockInstance.Refresh(ctx, ttl, nil)
|
||||
if err != nil {
|
||||
a.Logger.Error("failed to extend lock",
|
||||
zap.String("key", k),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
a.Logger.Debug("successfully extended lock",
|
||||
zap.String("key", k),
|
||||
zap.Duration("ttl", ttl))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsOwner checks if this instance owns the given lock.
|
||||
func (a distributedLockerAdapter) IsOwner(ctx context.Context, k string) (bool, error) {
|
||||
a.Mutex.Lock()
|
||||
lockInstance, ok := a.LockInstances[k]
|
||||
a.Mutex.Unlock()
|
||||
|
||||
if !ok {
|
||||
// Not in our instances map
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get the lock metadata to check if we still own it
|
||||
metadata := lockInstance.Metadata()
|
||||
|
||||
// If metadata is empty, we don't own it
|
||||
return metadata != "", nil
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/distributedmutex/distributedmutex_test.go
|
||||
package distributedmutex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// mockRedisClient implements minimal required methods
|
||||
type mockRedisClient struct {
|
||||
redis.UniversalClient
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Get(ctx context.Context, key string) *redis.StringCmd {
|
||||
return redis.NewStringCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd {
|
||||
return redis.NewStatusCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...any) *redis.Cmd {
|
||||
return redis.NewCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) EvalSha(ctx context.Context, sha string, keys []string, args ...any) *redis.Cmd {
|
||||
return redis.NewCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) ScriptExists(ctx context.Context, scripts ...string) *redis.BoolSliceCmd {
|
||||
return redis.NewBoolSliceCmd(ctx)
|
||||
}
|
||||
|
||||
func (m *mockRedisClient) ScriptLoad(ctx context.Context, script string) *redis.StringCmd {
|
||||
return redis.NewStringCmd(ctx)
|
||||
}
|
||||
|
||||
func TestNewAdapter(t *testing.T) {
|
||||
logger, _ := zap.NewDevelopment()
|
||||
adapter := NewAdapter(logger, &mockRedisClient{})
|
||||
if adapter == nil {
|
||||
t.Fatal("expected non-nil adapter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireAndRelease(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger, _ := zap.NewDevelopment()
|
||||
adapter := NewAdapter(logger, &mockRedisClient{})
|
||||
|
||||
adapter.Acquire(ctx, "test-key")
|
||||
adapter.Acquiref(ctx, "test-key-%d", 1)
|
||||
adapter.Release(ctx, "test-key")
|
||||
adapter.Releasef(ctx, "test-key-%d", 1)
|
||||
}
|
||||
23
cloud/maplefile-backend/pkg/distributedmutex/provider.go
Normal file
23
cloud/maplefile-backend/pkg/distributedmutex/provider.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
package distributedmutex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideDistributedMutexAdapter provides a distributed mutex adapter for Wire DI
|
||||
func ProvideDistributedMutexAdapter(cfg *config.Config, logger *zap.Logger) Adapter {
|
||||
// Create Redis client for distributed locking
|
||||
// Note: This is separate from the cache redis client
|
||||
redisClient := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Cache.Host, cfg.Cache.Port),
|
||||
Password: cfg.Cache.Password,
|
||||
DB: cfg.Cache.DB,
|
||||
})
|
||||
|
||||
return NewAdapter(logger, redisClient)
|
||||
}
|
||||
62
cloud/maplefile-backend/pkg/emailer/mailgun/config.go
Normal file
62
cloud/maplefile-backend/pkg/emailer/mailgun/config.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/emailer/mailgun/config.go
|
||||
package mailgun
|
||||
|
||||
type MailgunConfigurationProvider interface {
|
||||
GetSenderEmail() string
|
||||
GetDomainName() string // Deprecated
|
||||
GetBackendDomainName() string
|
||||
GetFrontendDomainName() string
|
||||
GetMaintenanceEmail() string
|
||||
GetAPIKey() string
|
||||
GetAPIBase() string
|
||||
}
|
||||
|
||||
type mailgunConfigurationProviderImpl struct {
|
||||
senderEmail string
|
||||
domain string
|
||||
apiBase string
|
||||
maintenanceEmail string
|
||||
frontendDomain string
|
||||
backendDomain string
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewMailgunConfigurationProvider(senderEmail, domain, apiBase, maintenanceEmail, frontendDomain, backendDomain, apiKey string) MailgunConfigurationProvider {
|
||||
return &mailgunConfigurationProviderImpl{
|
||||
senderEmail: senderEmail,
|
||||
domain: domain,
|
||||
apiBase: apiBase,
|
||||
maintenanceEmail: maintenanceEmail,
|
||||
frontendDomain: frontendDomain,
|
||||
backendDomain: backendDomain,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetDomainName() string {
|
||||
return me.domain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetSenderEmail() string {
|
||||
return me.senderEmail
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetBackendDomainName() string {
|
||||
return me.backendDomain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetFrontendDomainName() string {
|
||||
return me.frontendDomain
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetMaintenanceEmail() string {
|
||||
return me.maintenanceEmail
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetAPIKey() string {
|
||||
return me.apiKey
|
||||
}
|
||||
|
||||
func (me *mailgunConfigurationProviderImpl) GetAPIBase() string {
|
||||
return me.apiBase
|
||||
}
|
||||
13
cloud/maplefile-backend/pkg/emailer/mailgun/interface.go
Normal file
13
cloud/maplefile-backend/pkg/emailer/mailgun/interface.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/emailer/mailgun/interface.go
|
||||
package mailgun
|
||||
|
||||
import "context"
|
||||
|
||||
type Emailer interface {
|
||||
Send(ctx context.Context, sender, subject, recipient, htmlContent string) error
|
||||
GetSenderEmail() string
|
||||
GetDomainName() string // Deprecated
|
||||
GetBackendDomainName() string
|
||||
GetFrontendDomainName() string
|
||||
GetMaintenanceEmail() string
|
||||
}
|
||||
64
cloud/maplefile-backend/pkg/emailer/mailgun/mailgun.go
Normal file
64
cloud/maplefile-backend/pkg/emailer/mailgun/mailgun.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/emailer/mailgun/mailgun.go
|
||||
package mailgun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/mailgun-go/v4"
|
||||
)
|
||||
|
||||
type mailgunEmailer struct {
|
||||
config MailgunConfigurationProvider
|
||||
Mailgun *mailgun.MailgunImpl
|
||||
}
|
||||
|
||||
func NewEmailer(config MailgunConfigurationProvider) Emailer {
|
||||
// Defensive code: Make sure we have access to the file before proceeding any further with the code.
|
||||
mg := mailgun.NewMailgun(config.GetDomainName(), config.GetAPIKey())
|
||||
|
||||
mg.SetAPIBase(config.GetAPIBase()) // Override to support our custom email requirements.
|
||||
|
||||
return &mailgunEmailer{
|
||||
config: config,
|
||||
Mailgun: mg,
|
||||
}
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) Send(ctx context.Context, sender, subject, recipient, body string) error {
|
||||
|
||||
message := me.Mailgun.NewMessage(sender, subject, "", recipient)
|
||||
message.SetHtml(body)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
// Send the message with a 10 second timeout
|
||||
_, _, err := me.Mailgun.Send(ctx, message)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetDomainName() string {
|
||||
return me.config.GetDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetSenderEmail() string {
|
||||
return me.config.GetSenderEmail()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetBackendDomainName() string {
|
||||
return me.config.GetBackendDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetFrontendDomainName() string {
|
||||
return me.config.GetFrontendDomainName()
|
||||
}
|
||||
|
||||
func (me *mailgunEmailer) GetMaintenanceEmail() string {
|
||||
return me.config.GetMaintenanceEmail()
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/emailer/mailgun/maplefilemailgun.go
|
||||
package mailgun
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// NewMapleFileModuleEmailer creates a new emailer for the MapleFile standalone module.
|
||||
func NewMapleFileModuleEmailer(cfg *config.Configuration) Emailer {
|
||||
emailerConfigProvider := NewMailgunConfigurationProvider(
|
||||
cfg.Mailgun.SenderEmail,
|
||||
cfg.Mailgun.Domain,
|
||||
cfg.Mailgun.APIBase,
|
||||
cfg.Mailgun.SenderEmail, // Use sender email as maintenance email
|
||||
cfg.Mailgun.FrontendURL,
|
||||
"", // Backend domain not needed for standalone
|
||||
cfg.Mailgun.APIKey,
|
||||
)
|
||||
|
||||
return NewEmailer(emailerConfigProvider)
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/emailer/mailgun/papercloudmailgun.go
|
||||
package mailgun
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// NewPaperCloudModuleEmailer creates a new emailer for the PaperCloud Property Evaluator module.
|
||||
func NewPaperCloudModuleEmailer(cfg *config.Configuration) Emailer {
|
||||
emailerConfigProvider := NewMailgunConfigurationProvider(
|
||||
cfg.PaperCloudMailgun.SenderEmail,
|
||||
cfg.PaperCloudMailgun.Domain,
|
||||
cfg.PaperCloudMailgun.APIBase,
|
||||
cfg.PaperCloudMailgun.MaintenanceEmail,
|
||||
cfg.PaperCloudMailgun.FrontendDomain,
|
||||
cfg.PaperCloudMailgun.BackendDomain,
|
||||
cfg.PaperCloudMailgun.APIKey,
|
||||
)
|
||||
|
||||
return NewEmailer(emailerConfigProvider)
|
||||
}
|
||||
10
cloud/maplefile-backend/pkg/emailer/mailgun/provider.go
Normal file
10
cloud/maplefile-backend/pkg/emailer/mailgun/provider.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package mailgun
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideMapleFileModuleEmailer provides a Mailgun emailer for Wire DI
|
||||
func ProvideMapleFileModuleEmailer(cfg *config.Config) Emailer {
|
||||
return NewMapleFileModuleEmailer(cfg)
|
||||
}
|
||||
147
cloud/maplefile-backend/pkg/httperror/httperror.go
Normal file
147
cloud/maplefile-backend/pkg/httperror/httperror.go
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/httperror/httperror.go
|
||||
package httperror
|
||||
|
||||
// This package introduces a new `error` type that combines an HTTP status code and a message.
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HTTPError represents an http error that occurred while handling a request
|
||||
type HTTPError struct {
|
||||
Code int `json:"-"` // HTTP Status code. We use `-` to skip json marshaling.
|
||||
Errors *map[string]string `json:"-"` // The original error. Same reason as above.
|
||||
}
|
||||
|
||||
// New creates a new HTTPError instance with a multi-field errors.
|
||||
func New(statusCode int, errorsMap *map[string]string) error {
|
||||
return HTTPError{
|
||||
Code: statusCode,
|
||||
Errors: errorsMap,
|
||||
}
|
||||
}
|
||||
|
||||
// NewForSingleField create a new HTTPError instance for a single field. This is a convinience constructor.
|
||||
func NewForSingleField(statusCode int, field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: statusCode,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForBadRequest create a new HTTPError instance pertaining to 403 bad requests with the multi-errors. This is a convinience constructor.
|
||||
func NewForBadRequest(err *map[string]string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusBadRequest,
|
||||
Errors: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewForBadRequestWithSingleField create a new HTTPError instance pertaining to 403 bad requests for a single field. This is a convinience constructor.
|
||||
func NewForBadRequestWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusBadRequest,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
func NewForInternalServerErrorWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForNotFoundWithSingleField create a new HTTPError instance pertaining to 404 not found for a single field. This is a convinience constructor.
|
||||
func NewForNotFoundWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusNotFound,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForServiceUnavailableWithSingleField create a new HTTPError instance pertaining service unavailable for a single field. This is a convinience constructor.
|
||||
func NewForServiceUnavailableWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusServiceUnavailable,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForLockedWithSingleField create a new HTTPError instance pertaining to 424 locked for a single field. This is a convinience constructor.
|
||||
func NewForLockedWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusLocked,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForForbiddenWithSingleField create a new HTTPError instance pertaining to 403 bad requests for a single field. This is a convinience constructor.
|
||||
func NewForForbiddenWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusForbidden,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForUnauthorizedWithSingleField create a new HTTPError instance pertaining to 401 unauthorized for a single field. This is a convinience constructor.
|
||||
func NewForUnauthorizedWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// NewForGoneWithSingleField create a new HTTPError instance pertaining to 410 gone for a single field. This is a convinience constructor.
|
||||
func NewForGoneWithSingleField(field string, message string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusGone,
|
||||
Errors: &map[string]string{field: message},
|
||||
}
|
||||
}
|
||||
|
||||
// Error function used to implement the `error` interface for returning errors.
|
||||
func (err HTTPError) Error() string {
|
||||
b, e := json.Marshal(err.Errors)
|
||||
if e != nil { // Defensive code
|
||||
return e.Error()
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ResponseError function returns the HTTP error response based on the httpcode used.
|
||||
func ResponseError(rw http.ResponseWriter, err error) {
|
||||
// Copied from:
|
||||
// https://dev.to/tigorlazuardi/go-creating-custom-error-wrapper-and-do-proper-error-equality-check-11k7
|
||||
|
||||
rw.Header().Set("Content-Type", "Application/json")
|
||||
|
||||
//
|
||||
// CASE 1 OF 2: Handle API Errors.
|
||||
//
|
||||
|
||||
var ew HTTPError
|
||||
if errors.As(err, &ew) {
|
||||
rw.WriteHeader(ew.Code)
|
||||
_ = json.NewEncoder(rw).Encode(ew.Errors)
|
||||
return
|
||||
}
|
||||
|
||||
//
|
||||
// CASE 2 OF 2: Handle non ErrorWrapper types.
|
||||
//
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_ = json.NewEncoder(rw).Encode(err.Error())
|
||||
}
|
||||
|
||||
// NewForInternalServerError create a new HTTPError instance pertaining to 500 internal server error with the multi-errors. This is a convinience constructor.
|
||||
func NewForInternalServerError(err string) error {
|
||||
return HTTPError{
|
||||
Code: http.StatusInternalServerError,
|
||||
Errors: &map[string]string{"message": err},
|
||||
}
|
||||
}
|
||||
328
cloud/maplefile-backend/pkg/httperror/httperror_test.go
Normal file
328
cloud/maplefile-backend/pkg/httperror/httperror_test.go
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/httperror/httperror_test.go
|
||||
package httperror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
errors map[string]string
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "basic error",
|
||||
code: http.StatusBadRequest,
|
||||
errors: map[string]string{"field": "error message"},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "empty errors map",
|
||||
code: http.StatusNotFound,
|
||||
errors: map[string]string{},
|
||||
wantCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "multiple errors",
|
||||
code: http.StatusBadRequest,
|
||||
errors: map[string]string{"field1": "error1", "field2": "error2"},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := New(tt.code, &tt.errors)
|
||||
|
||||
httpErr, ok := err.(HTTPError)
|
||||
if !ok {
|
||||
t.Fatal("expected HTTPError type")
|
||||
}
|
||||
if httpErr.Code != tt.wantCode {
|
||||
t.Errorf("Code = %v, want %v", httpErr.Code, tt.wantCode)
|
||||
}
|
||||
for k, v := range tt.errors {
|
||||
if (*httpErr.Errors)[k] != v {
|
||||
t.Errorf("Errors[%s] = %v, want %v", k, (*httpErr.Errors)[k], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForBadRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errors map[string]string
|
||||
}{
|
||||
{
|
||||
name: "single error",
|
||||
errors: map[string]string{"field": "error"},
|
||||
},
|
||||
{
|
||||
name: "multiple errors",
|
||||
errors: map[string]string{"field1": "error1", "field2": "error2"},
|
||||
},
|
||||
{
|
||||
name: "empty errors",
|
||||
errors: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewForBadRequest(&tt.errors)
|
||||
|
||||
httpErr, ok := err.(HTTPError)
|
||||
if !ok {
|
||||
t.Fatal("expected HTTPError type")
|
||||
}
|
||||
if httpErr.Code != http.StatusBadRequest {
|
||||
t.Errorf("Code = %v, want %v", httpErr.Code, http.StatusBadRequest)
|
||||
}
|
||||
for k, v := range tt.errors {
|
||||
if (*httpErr.Errors)[k] != v {
|
||||
t.Errorf("Errors[%s] = %v, want %v", k, (*httpErr.Errors)[k], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForSingleField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
field string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "basic error",
|
||||
code: http.StatusBadRequest,
|
||||
field: "test",
|
||||
message: "error",
|
||||
},
|
||||
{
|
||||
name: "empty field",
|
||||
code: http.StatusNotFound,
|
||||
field: "",
|
||||
message: "error",
|
||||
},
|
||||
{
|
||||
name: "empty message",
|
||||
code: http.StatusBadRequest,
|
||||
field: "field",
|
||||
message: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewForSingleField(tt.code, tt.field, tt.message)
|
||||
|
||||
httpErr, ok := err.(HTTPError)
|
||||
if !ok {
|
||||
t.Fatal("expected HTTPError type")
|
||||
}
|
||||
if httpErr.Code != tt.code {
|
||||
t.Errorf("Code = %v, want %v", httpErr.Code, tt.code)
|
||||
}
|
||||
if (*httpErr.Errors)[tt.field] != tt.message {
|
||||
t.Errorf("Errors[%s] = %v, want %v", tt.field, (*httpErr.Errors)[tt.field], tt.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errors map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid json",
|
||||
errors: map[string]string{"field": "error"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
errors: map[string]string{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := HTTPError{
|
||||
Code: http.StatusBadRequest,
|
||||
Errors: &tt.errors,
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
var jsonMap map[string]string
|
||||
if jsonErr := json.Unmarshal([]byte(errStr), &jsonMap); (jsonErr != nil) != tt.wantErr {
|
||||
t.Errorf("Error() json.Unmarshal error = %v, wantErr %v", jsonErr, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
for k, v := range tt.errors {
|
||||
if jsonMap[k] != v {
|
||||
t.Errorf("Error() jsonMap[%s] = %v, want %v", k, jsonMap[k], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantCode int
|
||||
wantContent string
|
||||
}{
|
||||
{
|
||||
name: "http error",
|
||||
err: NewForBadRequestWithSingleField("field", "invalid"),
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantContent: `{"field":"invalid"}`,
|
||||
},
|
||||
{
|
||||
name: "standard error",
|
||||
err: fmt.Errorf("standard error"),
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantContent: `"standard error"`,
|
||||
},
|
||||
{
|
||||
name: "nil error",
|
||||
err: errors.New("<nil>"),
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantContent: `"\u003cnil\u003e"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
ResponseError(rr, tt.err)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != tt.wantCode {
|
||||
t.Errorf("ResponseError() code = %v, want %v", rr.Code, tt.wantCode)
|
||||
}
|
||||
|
||||
// Check content type
|
||||
if ct := rr.Header().Get("Content-Type"); ct != "Application/json" {
|
||||
t.Errorf("ResponseError() Content-Type = %v, want Application/json", ct)
|
||||
}
|
||||
|
||||
// Trim newline from response for comparison
|
||||
got := rr.Body.String()
|
||||
got = got[:len(got)-1] // Remove trailing newline added by json.Encoder
|
||||
if got != tt.wantContent {
|
||||
t.Errorf("ResponseError() content = %v, want %v", got, tt.wantContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapping(t *testing.T) {
|
||||
originalErr := errors.New("original error")
|
||||
wrappedErr := fmt.Errorf("wrapped: %w", originalErr)
|
||||
httpErr := NewForBadRequestWithSingleField("field", wrappedErr.Error())
|
||||
|
||||
// Test error unwrapping
|
||||
if !errors.Is(httpErr, httpErr) {
|
||||
t.Error("errors.Is failed for same error")
|
||||
}
|
||||
|
||||
var targetErr HTTPError
|
||||
if !errors.As(httpErr, &targetErr) {
|
||||
t.Error("errors.As failed to get HTTPError")
|
||||
}
|
||||
}
|
||||
|
||||
// Test all convenience constructors
|
||||
func TestConvenienceConstructors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
create func() error
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "NewForBadRequestWithSingleField",
|
||||
create: func() error {
|
||||
return NewForBadRequestWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "NewForNotFoundWithSingleField",
|
||||
create: func() error {
|
||||
return NewForNotFoundWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "NewForServiceUnavailableWithSingleField",
|
||||
create: func() error {
|
||||
return NewForServiceUnavailableWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
name: "NewForLockedWithSingleField",
|
||||
create: func() error {
|
||||
return NewForLockedWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusLocked,
|
||||
},
|
||||
{
|
||||
name: "NewForForbiddenWithSingleField",
|
||||
create: func() error {
|
||||
return NewForForbiddenWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "NewForUnauthorizedWithSingleField",
|
||||
create: func() error {
|
||||
return NewForUnauthorizedWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "NewForGoneWithSingleField",
|
||||
create: func() error {
|
||||
return NewForGoneWithSingleField("field", "message")
|
||||
},
|
||||
wantCode: http.StatusGone,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.create()
|
||||
httpErr, ok := err.(HTTPError)
|
||||
if !ok {
|
||||
t.Fatal("expected HTTPError type")
|
||||
}
|
||||
if httpErr.Code != tt.wantCode {
|
||||
t.Errorf("Code = %v, want %v", httpErr.Code, tt.wantCode)
|
||||
}
|
||||
if (*httpErr.Errors)["field"] != "message" {
|
||||
t.Errorf("Error message = %v, want 'message'", (*httpErr.Errors)["field"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
289
cloud/maplefile-backend/pkg/httperror/rfc9457.go
Normal file
289
cloud/maplefile-backend/pkg/httperror/rfc9457.go
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
// Package httperror provides RFC 9457 compliant error handling for HTTP APIs.
|
||||
// RFC 9457: Problem Details for HTTP APIs
|
||||
// https://www.rfc-editor.org/rfc/rfc9457.html
|
||||
package httperror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ProblemDetail represents an RFC 9457 problem detail response.
|
||||
// It provides a standardized way to carry machine-readable details of errors
|
||||
// in HTTP response content.
|
||||
type ProblemDetail struct {
|
||||
// Standard RFC 9457 fields
|
||||
|
||||
// Type is a URI reference that identifies the problem type.
|
||||
// When dereferenced, it should provide human-readable documentation.
|
||||
// Defaults to "about:blank" if not provided.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Status is the HTTP status code for this occurrence of the problem.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Title is a short, human-readable summary of the problem type.
|
||||
Title string `json:"title"`
|
||||
|
||||
// Detail is a human-readable explanation specific to this occurrence.
|
||||
Detail string `json:"detail,omitempty"`
|
||||
|
||||
// Instance is a URI reference that identifies this specific occurrence.
|
||||
Instance string `json:"instance,omitempty"`
|
||||
|
||||
// MapleFile-specific extensions
|
||||
|
||||
// Errors contains field-specific validation errors.
|
||||
// Key is the field name, value is the error message.
|
||||
Errors map[string]string `json:"errors,omitempty"`
|
||||
|
||||
// Timestamp is the ISO 8601 timestamp when the error occurred.
|
||||
Timestamp string `json:"timestamp"`
|
||||
|
||||
// TraceID is the request trace ID for debugging.
|
||||
TraceID string `json:"trace_id,omitempty"`
|
||||
}
|
||||
|
||||
// Problem type URIs - these identify categories of errors
|
||||
const (
|
||||
TypeValidationError = "https://api.maplefile.com/problems/validation-error"
|
||||
TypeBadRequest = "https://api.maplefile.com/problems/bad-request"
|
||||
TypeUnauthorized = "https://api.maplefile.com/problems/unauthorized"
|
||||
TypeForbidden = "https://api.maplefile.com/problems/forbidden"
|
||||
TypeNotFound = "https://api.maplefile.com/problems/not-found"
|
||||
TypeConflict = "https://api.maplefile.com/problems/conflict"
|
||||
TypeTooManyRequests = "https://api.maplefile.com/problems/too-many-requests"
|
||||
TypeInternalError = "https://api.maplefile.com/problems/internal-error"
|
||||
TypeServiceUnavailable = "https://api.maplefile.com/problems/service-unavailable"
|
||||
)
|
||||
|
||||
// NewProblemDetail creates a new RFC 9457 problem detail.
|
||||
func NewProblemDetail(status int, problemType, title, detail string) *ProblemDetail {
|
||||
return &ProblemDetail{
|
||||
Type: problemType,
|
||||
Status: status,
|
||||
Title: title,
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewValidationError creates a validation error problem detail.
|
||||
// Use this when one or more fields fail validation.
|
||||
func NewValidationError(fieldErrors map[string]string) *ProblemDetail {
|
||||
detail := "One or more fields failed validation. Please check the errors and try again."
|
||||
if len(fieldErrors) == 0 {
|
||||
detail = "Validation failed."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeValidationError,
|
||||
Status: http.StatusBadRequest,
|
||||
Title: "Validation Failed",
|
||||
Detail: detail,
|
||||
Errors: fieldErrors,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBadRequestError creates a generic bad request error.
|
||||
// Use this for malformed requests or invalid input.
|
||||
func NewBadRequestError(detail string) *ProblemDetail {
|
||||
return &ProblemDetail{
|
||||
Type: TypeBadRequest,
|
||||
Status: http.StatusBadRequest,
|
||||
Title: "Bad Request",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnauthorizedError creates an unauthorized error.
|
||||
// Use this when authentication is required but missing or invalid.
|
||||
func NewUnauthorizedError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "Authentication is required to access this resource."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeUnauthorized,
|
||||
Status: http.StatusUnauthorized,
|
||||
Title: "Unauthorized",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewForbiddenError creates a forbidden error.
|
||||
// Use this when the user is authenticated but lacks permission.
|
||||
func NewForbiddenError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "You do not have permission to access this resource."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeForbidden,
|
||||
Status: http.StatusForbidden,
|
||||
Title: "Forbidden",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewNotFoundError creates a not found error.
|
||||
// Use this when a requested resource does not exist.
|
||||
func NewNotFoundError(resourceType string) *ProblemDetail {
|
||||
detail := "The requested resource was not found."
|
||||
if resourceType != "" {
|
||||
detail = resourceType + " not found."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeNotFound,
|
||||
Status: http.StatusNotFound,
|
||||
Title: "Not Found",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewConflictError creates a conflict error.
|
||||
// Use this when the request conflicts with the current state.
|
||||
func NewConflictError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "The request conflicts with the current state of the resource."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeConflict,
|
||||
Status: http.StatusConflict,
|
||||
Title: "Conflict",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTooManyRequestsError creates a rate limit exceeded error.
|
||||
// Use this when the client has exceeded the allowed request rate.
|
||||
// CWE-307: Used to prevent brute force attacks by limiting request frequency.
|
||||
func NewTooManyRequestsError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "Too many requests. Please try again later."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeTooManyRequests,
|
||||
Status: http.StatusTooManyRequests,
|
||||
Title: "Too Many Requests",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewInternalServerError creates an internal server error.
|
||||
// Use this for unexpected errors that are not the client's fault.
|
||||
func NewInternalServerError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "An unexpected error occurred. Please try again later."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeInternalError,
|
||||
Status: http.StatusInternalServerError,
|
||||
Title: "Internal Server Error",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// NewServiceUnavailableError creates a service unavailable error.
|
||||
// Use this when the service is temporarily unavailable.
|
||||
func NewServiceUnavailableError(detail string) *ProblemDetail {
|
||||
if detail == "" {
|
||||
detail = "The service is temporarily unavailable. Please try again later."
|
||||
}
|
||||
|
||||
return &ProblemDetail{
|
||||
Type: TypeServiceUnavailable,
|
||||
Status: http.StatusServiceUnavailable,
|
||||
Title: "Service Unavailable",
|
||||
Detail: detail,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
// WithInstance adds the request path as the instance identifier.
|
||||
func (p *ProblemDetail) WithInstance(instance string) *ProblemDetail {
|
||||
p.Instance = instance
|
||||
return p
|
||||
}
|
||||
|
||||
// WithTraceID adds the request trace ID for debugging.
|
||||
func (p *ProblemDetail) WithTraceID(traceID string) *ProblemDetail {
|
||||
p.TraceID = traceID
|
||||
return p
|
||||
}
|
||||
|
||||
// WithError adds a single field error to the problem detail.
|
||||
func (p *ProblemDetail) WithError(field, message string) *ProblemDetail {
|
||||
if p.Errors == nil {
|
||||
p.Errors = make(map[string]string)
|
||||
}
|
||||
p.Errors[field] = message
|
||||
return p
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (p *ProblemDetail) Error() string {
|
||||
if p.Detail != "" {
|
||||
return p.Detail
|
||||
}
|
||||
return p.Title
|
||||
}
|
||||
|
||||
// ExtractRequestID gets the request ID from the request context or headers.
|
||||
// This uses the existing request ID middleware.
|
||||
func ExtractRequestID(r *http.Request) string {
|
||||
// Try to get from context first (preferred)
|
||||
if requestID := r.Context().Value("request_id"); requestID != nil {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to header
|
||||
if requestID := r.Header.Get("X-Request-ID"); requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
|
||||
// No request ID found
|
||||
return ""
|
||||
}
|
||||
|
||||
// RespondWithProblem writes the RFC 9457 problem detail to the HTTP response.
|
||||
// It sets the appropriate Content-Type header and status code.
|
||||
func RespondWithProblem(w http.ResponseWriter, problem *ProblemDetail) {
|
||||
w.Header().Set("Content-Type", "application/problem+json")
|
||||
w.WriteHeader(problem.Status)
|
||||
json.NewEncoder(w).Encode(problem)
|
||||
}
|
||||
|
||||
// RespondWithError is a convenience function that handles both ProblemDetail
|
||||
// and standard Go errors. If the error is a ProblemDetail, it writes it directly.
|
||||
// Otherwise, it wraps it in an internal server error.
|
||||
func RespondWithError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
requestID := ExtractRequestID(r)
|
||||
|
||||
// Check if error is already a ProblemDetail
|
||||
if problem, ok := err.(*ProblemDetail); ok {
|
||||
problem.WithInstance(r.URL.Path).WithTraceID(requestID)
|
||||
RespondWithProblem(w, problem)
|
||||
return
|
||||
}
|
||||
|
||||
// Wrap standard error in internal server error
|
||||
problem := NewInternalServerError(err.Error())
|
||||
problem.WithInstance(r.URL.Path).WithTraceID(requestID)
|
||||
RespondWithProblem(w, problem)
|
||||
}
|
||||
357
cloud/maplefile-backend/pkg/httperror/rfc9457_test.go
Normal file
357
cloud/maplefile-backend/pkg/httperror/rfc9457_test.go
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
package httperror
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewValidationError(t *testing.T) {
|
||||
fieldErrors := map[string]string{
|
||||
"email": "Email is required",
|
||||
"password": "Password must be at least 8 characters",
|
||||
}
|
||||
|
||||
problem := NewValidationError(fieldErrors)
|
||||
|
||||
if problem.Type != TypeValidationError {
|
||||
t.Errorf("Expected type %s, got %s", TypeValidationError, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, problem.Status)
|
||||
}
|
||||
|
||||
if problem.Title != "Validation Failed" {
|
||||
t.Errorf("Expected title 'Validation Failed', got '%s'", problem.Title)
|
||||
}
|
||||
|
||||
if len(problem.Errors) != 2 {
|
||||
t.Errorf("Expected 2 field errors, got %d", len(problem.Errors))
|
||||
}
|
||||
|
||||
if problem.Errors["email"] != "Email is required" {
|
||||
t.Errorf("Expected email error, got '%s'", problem.Errors["email"])
|
||||
}
|
||||
|
||||
if problem.Timestamp == "" {
|
||||
t.Error("Expected timestamp to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidationError_Empty(t *testing.T) {
|
||||
problem := NewValidationError(map[string]string{})
|
||||
|
||||
if problem.Detail != "Validation failed." {
|
||||
t.Errorf("Expected detail 'Validation failed.', got '%s'", problem.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadRequestError(t *testing.T) {
|
||||
detail := "Invalid request payload"
|
||||
problem := NewBadRequestError(detail)
|
||||
|
||||
if problem.Type != TypeBadRequest {
|
||||
t.Errorf("Expected type %s, got %s", TypeBadRequest, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, problem.Status)
|
||||
}
|
||||
|
||||
if problem.Detail != detail {
|
||||
t.Errorf("Expected detail '%s', got '%s'", detail, problem.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError(t *testing.T) {
|
||||
detail := "Invalid token"
|
||||
problem := NewUnauthorizedError(detail)
|
||||
|
||||
if problem.Type != TypeUnauthorized {
|
||||
t.Errorf("Expected type %s, got %s", TypeUnauthorized, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusUnauthorized {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, problem.Status)
|
||||
}
|
||||
|
||||
if problem.Detail != detail {
|
||||
t.Errorf("Expected detail '%s', got '%s'", detail, problem.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUnauthorizedError_DefaultMessage(t *testing.T) {
|
||||
problem := NewUnauthorizedError("")
|
||||
|
||||
if problem.Detail != "Authentication is required to access this resource." {
|
||||
t.Errorf("Expected default detail message, got '%s'", problem.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForbiddenError(t *testing.T) {
|
||||
detail := "Insufficient permissions"
|
||||
problem := NewForbiddenError(detail)
|
||||
|
||||
if problem.Type != TypeForbidden {
|
||||
t.Errorf("Expected type %s, got %s", TypeForbidden, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, problem.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundError(t *testing.T) {
|
||||
problem := NewNotFoundError("User")
|
||||
|
||||
if problem.Type != TypeNotFound {
|
||||
t.Errorf("Expected type %s, got %s", TypeNotFound, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusNotFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusNotFound, problem.Status)
|
||||
}
|
||||
|
||||
if problem.Detail != "User not found." {
|
||||
t.Errorf("Expected detail 'User not found.', got '%s'", problem.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConflictError(t *testing.T) {
|
||||
detail := "Email already exists"
|
||||
problem := NewConflictError(detail)
|
||||
|
||||
if problem.Type != TypeConflict {
|
||||
t.Errorf("Expected type %s, got %s", TypeConflict, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusConflict {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusConflict, problem.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInternalServerError(t *testing.T) {
|
||||
detail := "Database connection failed"
|
||||
problem := NewInternalServerError(detail)
|
||||
|
||||
if problem.Type != TypeInternalError {
|
||||
t.Errorf("Expected type %s, got %s", TypeInternalError, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, problem.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewServiceUnavailableError(t *testing.T) {
|
||||
problem := NewServiceUnavailableError("")
|
||||
|
||||
if problem.Type != TypeServiceUnavailable {
|
||||
t.Errorf("Expected type %s, got %s", TypeServiceUnavailable, problem.Type)
|
||||
}
|
||||
|
||||
if problem.Status != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusServiceUnavailable, problem.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithInstance(t *testing.T) {
|
||||
problem := NewBadRequestError("Test")
|
||||
instance := "/api/v1/test"
|
||||
|
||||
problem.WithInstance(instance)
|
||||
|
||||
if problem.Instance != instance {
|
||||
t.Errorf("Expected instance '%s', got '%s'", instance, problem.Instance)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTraceID(t *testing.T) {
|
||||
problem := NewBadRequestError("Test")
|
||||
traceID := "trace-123"
|
||||
|
||||
problem.WithTraceID(traceID)
|
||||
|
||||
if problem.TraceID != traceID {
|
||||
t.Errorf("Expected traceID '%s', got '%s'", traceID, problem.TraceID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithError(t *testing.T) {
|
||||
problem := NewBadRequestError("Test")
|
||||
|
||||
problem.WithError("email", "Email is required")
|
||||
problem.WithError("password", "Password is required")
|
||||
|
||||
if len(problem.Errors) != 2 {
|
||||
t.Errorf("Expected 2 errors, got %d", len(problem.Errors))
|
||||
}
|
||||
|
||||
if problem.Errors["email"] != "Email is required" {
|
||||
t.Errorf("Expected email error, got '%s'", problem.Errors["email"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProblemDetailError(t *testing.T) {
|
||||
detail := "Test detail"
|
||||
problem := NewBadRequestError(detail)
|
||||
|
||||
if problem.Error() != detail {
|
||||
t.Errorf("Expected Error() to return detail, got '%s'", problem.Error())
|
||||
}
|
||||
|
||||
// Test with no detail
|
||||
problem2 := &ProblemDetail{
|
||||
Title: "Test Title",
|
||||
}
|
||||
|
||||
if problem2.Error() != "Test Title" {
|
||||
t.Errorf("Expected Error() to return title, got '%s'", problem2.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestID_FromContext(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Note: In real code, request ID would be set by middleware
|
||||
// For testing, we'll test the empty case
|
||||
requestID := ExtractRequestID(req)
|
||||
|
||||
// Should return empty string when no request ID is present
|
||||
if requestID != "" {
|
||||
t.Errorf("Expected empty string, got '%s'", requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestID_FromHeader(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-ID", "test-request-123")
|
||||
|
||||
requestID := ExtractRequestID(req)
|
||||
|
||||
if requestID != "test-request-123" {
|
||||
t.Errorf("Expected 'test-request-123', got '%s'", requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithProblem(t *testing.T) {
|
||||
problem := NewValidationError(map[string]string{
|
||||
"email": "Email is required",
|
||||
})
|
||||
problem.WithInstance("/api/v1/test")
|
||||
problem.WithTraceID("trace-123")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
RespondWithProblem(w, problem)
|
||||
|
||||
// Check status code
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// Check content type
|
||||
contentType := w.Header().Get("Content-Type")
|
||||
if contentType != "application/problem+json" {
|
||||
t.Errorf("Expected Content-Type 'application/problem+json', got '%s'", contentType)
|
||||
}
|
||||
|
||||
// Check JSON response
|
||||
var response ProblemDetail
|
||||
if err := json.NewDecoder(w.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if response.Type != TypeValidationError {
|
||||
t.Errorf("Expected type %s, got %s", TypeValidationError, response.Type)
|
||||
}
|
||||
|
||||
if response.Instance != "/api/v1/test" {
|
||||
t.Errorf("Expected instance '/api/v1/test', got '%s'", response.Instance)
|
||||
}
|
||||
|
||||
if response.TraceID != "trace-123" {
|
||||
t.Errorf("Expected traceID 'trace-123', got '%s'", response.TraceID)
|
||||
}
|
||||
|
||||
if len(response.Errors) != 1 {
|
||||
t.Errorf("Expected 1 error, got %d", len(response.Errors))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithError_ProblemDetail(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/v1/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
problem := NewBadRequestError("Test error")
|
||||
RespondWithError(w, req, problem)
|
||||
|
||||
// Check status code
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// Check that instance was set
|
||||
var response ProblemDetail
|
||||
json.NewDecoder(w.Body).Decode(&response)
|
||||
|
||||
if response.Instance != "/api/v1/test" {
|
||||
t.Errorf("Expected instance to be set automatically, got '%s'", response.Instance)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondWithError_StandardError(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/v1/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
err := &customError{message: "Custom error"}
|
||||
RespondWithError(w, req, err)
|
||||
|
||||
// Check status code (should be 500 for standard errors)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
|
||||
// Check that it was wrapped in a ProblemDetail
|
||||
var response ProblemDetail
|
||||
json.NewDecoder(w.Body).Decode(&response)
|
||||
|
||||
if response.Type != TypeInternalError {
|
||||
t.Errorf("Expected type %s, got %s", TypeInternalError, response.Type)
|
||||
}
|
||||
|
||||
if response.Detail != "Custom error" {
|
||||
t.Errorf("Expected detail 'Custom error', got '%s'", response.Detail)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper type for testing standard error handling
|
||||
type customError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *customError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func TestChaining(t *testing.T) {
|
||||
// Test method chaining
|
||||
problem := NewBadRequestError("Test").
|
||||
WithInstance("/api/v1/test").
|
||||
WithTraceID("trace-123").
|
||||
WithError("field1", "error1").
|
||||
WithError("field2", "error2")
|
||||
|
||||
if problem.Instance != "/api/v1/test" {
|
||||
t.Error("Instance not set correctly through chaining")
|
||||
}
|
||||
|
||||
if problem.TraceID != "trace-123" {
|
||||
t.Error("TraceID not set correctly through chaining")
|
||||
}
|
||||
|
||||
if len(problem.Errors) != 2 {
|
||||
t.Error("Errors not set correctly through chaining")
|
||||
}
|
||||
}
|
||||
375
cloud/maplefile-backend/pkg/leaderelection/EXAMPLE.md
Normal file
375
cloud/maplefile-backend/pkg/leaderelection/EXAMPLE.md
Normal file
|
|
@ -0,0 +1,375 @@
|
|||
# Leader Election Integration Example
|
||||
|
||||
## Quick Integration into MapleFile Backend
|
||||
|
||||
### Step 1: Add to Wire Providers (app/wire.go)
|
||||
|
||||
```go
|
||||
// In app/wire.go, add to wire.Build():
|
||||
|
||||
wire.Build(
|
||||
// ... existing providers ...
|
||||
|
||||
// Leader Election
|
||||
leaderelection.ProvideLeaderElection,
|
||||
|
||||
// ... rest of providers ...
|
||||
)
|
||||
```
|
||||
|
||||
### Step 2: Update Application Struct (app/app.go)
|
||||
|
||||
```go
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/leaderelection"
|
||||
)
|
||||
|
||||
type Application struct {
|
||||
config *config.Config
|
||||
httpServer *http.WireServer
|
||||
logger *zap.Logger
|
||||
migrator *cassandradb.Migrator
|
||||
leaderElection leaderelection.LeaderElection // ADD THIS
|
||||
}
|
||||
|
||||
func ProvideApplication(
|
||||
cfg *config.Config,
|
||||
httpServer *http.WireServer,
|
||||
logger *zap.Logger,
|
||||
migrator *cassandradb.Migrator,
|
||||
leaderElection leaderelection.LeaderElection, // ADD THIS
|
||||
) *Application {
|
||||
return &Application{
|
||||
config: cfg,
|
||||
httpServer: httpServer,
|
||||
logger: logger,
|
||||
migrator: migrator,
|
||||
leaderElection: leaderElection, // ADD THIS
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Start Leader Election in Application (app/app.go)
|
||||
|
||||
```go
|
||||
func (app *Application) Start() error {
|
||||
app.logger.Info("🚀 MapleFile Backend Starting (Wire DI)",
|
||||
zap.String("version", app.config.App.Version),
|
||||
zap.String("environment", app.config.App.Environment),
|
||||
zap.String("di_framework", "Google Wire"))
|
||||
|
||||
// Start leader election if enabled
|
||||
if app.config.LeaderElection.Enabled {
|
||||
app.logger.Info("Starting leader election")
|
||||
|
||||
// Register callbacks
|
||||
app.setupLeaderCallbacks()
|
||||
|
||||
// Start election in background
|
||||
go func() {
|
||||
ctx := context.Background()
|
||||
if err := app.leaderElection.Start(ctx); err != nil {
|
||||
app.logger.Error("Leader election failed", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Give it a moment to complete first election
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if app.leaderElection.IsLeader() {
|
||||
app.logger.Info("👑 This instance is the LEADER",
|
||||
zap.String("instance_id", app.leaderElection.GetInstanceID()))
|
||||
} else {
|
||||
app.logger.Info("👥 This instance is a FOLLOWER",
|
||||
zap.String("instance_id", app.leaderElection.GetInstanceID()))
|
||||
}
|
||||
}
|
||||
|
||||
// Run database migrations (only leader should do this)
|
||||
if app.config.LeaderElection.Enabled {
|
||||
if app.leaderElection.IsLeader() {
|
||||
app.logger.Info("Running database migrations as leader...")
|
||||
if err := app.migrator.Up(); err != nil {
|
||||
app.logger.Error("Failed to run database migrations", zap.Error(err))
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
app.logger.Info("✅ Database migrations completed successfully")
|
||||
} else {
|
||||
app.logger.Info("Skipping migrations - not the leader")
|
||||
}
|
||||
} else {
|
||||
// If leader election disabled, always run migrations
|
||||
app.logger.Info("Running database migrations...")
|
||||
if err := app.migrator.Up(); err != nil {
|
||||
app.logger.Error("Failed to run database migrations", zap.Error(err))
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
app.logger.Info("✅ Database migrations completed successfully")
|
||||
}
|
||||
|
||||
// Start HTTP server in goroutine
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := app.httpServer.Start(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for interrupt signal or server error
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
app.logger.Error("HTTP server failed", zap.Error(err))
|
||||
return fmt.Errorf("server startup failed: %w", err)
|
||||
case sig := <-quit:
|
||||
app.logger.Info("Received shutdown signal", zap.String("signal", sig.String()))
|
||||
}
|
||||
|
||||
app.logger.Info("👋 MapleFile Backend Shutting Down")
|
||||
|
||||
// Stop leader election
|
||||
if app.config.LeaderElection.Enabled {
|
||||
if err := app.leaderElection.Stop(); err != nil {
|
||||
app.logger.Error("Failed to stop leader election", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Graceful shutdown with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := app.httpServer.Shutdown(ctx); err != nil {
|
||||
app.logger.Error("Server shutdown error", zap.Error(err))
|
||||
return fmt.Errorf("server shutdown failed: %w", err)
|
||||
}
|
||||
|
||||
app.logger.Info("✅ MapleFile Backend Stopped Successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupLeaderCallbacks configures callbacks for leader election events
|
||||
func (app *Application) setupLeaderCallbacks() {
|
||||
app.leaderElection.OnBecomeLeader(func() {
|
||||
app.logger.Info("🎉 BECAME LEADER - Starting leader-only tasks",
|
||||
zap.String("instance_id", app.leaderElection.GetInstanceID()))
|
||||
|
||||
// Start leader-only background tasks here
|
||||
// For example:
|
||||
// - Scheduled cleanup jobs
|
||||
// - Metrics aggregation
|
||||
// - Cache warming
|
||||
// - Periodic health checks
|
||||
})
|
||||
|
||||
app.leaderElection.OnLoseLeadership(func() {
|
||||
app.logger.Warn("😢 LOST LEADERSHIP - Stopping leader-only tasks",
|
||||
zap.String("instance_id", app.leaderElection.GetInstanceID()))
|
||||
|
||||
// Stop leader-only tasks here
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Environment Variables (.env)
|
||||
|
||||
Add to your `.env` file:
|
||||
|
||||
```bash
|
||||
# Leader Election Configuration
|
||||
LEADER_ELECTION_ENABLED=true
|
||||
LEADER_ELECTION_LOCK_TTL=10s
|
||||
LEADER_ELECTION_HEARTBEAT_INTERVAL=3s
|
||||
LEADER_ELECTION_RETRY_INTERVAL=2s
|
||||
LEADER_ELECTION_INSTANCE_ID= # Leave empty for auto-generation
|
||||
LEADER_ELECTION_HOSTNAME= # Leave empty for auto-detection
|
||||
```
|
||||
|
||||
### Step 5: Update .env.sample
|
||||
|
||||
```bash
|
||||
# Leader Election
|
||||
LEADER_ELECTION_ENABLED=true
|
||||
LEADER_ELECTION_LOCK_TTL=10s
|
||||
LEADER_ELECTION_HEARTBEAT_INTERVAL=3s
|
||||
LEADER_ELECTION_RETRY_INTERVAL=2s
|
||||
LEADER_ELECTION_INSTANCE_ID=
|
||||
LEADER_ELECTION_HOSTNAME=
|
||||
```
|
||||
|
||||
### Step 6: Test Multiple Instances
|
||||
|
||||
#### Terminal 1
|
||||
```bash
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend
|
||||
# Output: 👑 This instance is the LEADER
|
||||
```
|
||||
|
||||
#### Terminal 2
|
||||
```bash
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-2 ./maplefile-backend
|
||||
# Output: 👥 This instance is a FOLLOWER
|
||||
```
|
||||
|
||||
#### Terminal 3
|
||||
```bash
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-3 ./maplefile-backend
|
||||
# Output: 👥 This instance is a FOLLOWER
|
||||
```
|
||||
|
||||
#### Test Failover
|
||||
Stop Terminal 1 (kill the leader):
|
||||
```
|
||||
# Watch Terminal 2 or 3 logs
|
||||
# One will show: 🎉 BECAME LEADER
|
||||
```
|
||||
|
||||
## Optional: Add Health Check Endpoint
|
||||
|
||||
Add to your HTTP handlers to expose leader election status:
|
||||
|
||||
```go
|
||||
// In internal/interface/http/server.go
|
||||
|
||||
func (s *Server) leaderElectionHealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if s.leaderElection == nil {
|
||||
http.Error(w, "Leader election not enabled", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.leaderElection.GetLeaderInfo()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get leader info", zap.Error(err))
|
||||
http.Error(w, "Failed to get leader info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"is_leader": s.leaderElection.IsLeader(),
|
||||
"instance_id": s.leaderElection.GetInstanceID(),
|
||||
"leader_info": info,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// Register in registerRoutes():
|
||||
s.mux.HandleFunc("GET /api/v1/leader-status", s.leaderElectionHealthHandler)
|
||||
```
|
||||
|
||||
Test the endpoint:
|
||||
```bash
|
||||
curl http://localhost:8000/api/v1/leader-status
|
||||
|
||||
# Response:
|
||||
{
|
||||
"is_leader": true,
|
||||
"instance_id": "instance-1",
|
||||
"leader_info": {
|
||||
"instance_id": "instance-1",
|
||||
"hostname": "macbook-pro.local",
|
||||
"started_at": "2025-01-12T10:30:00Z",
|
||||
"last_heartbeat": "2025-01-12T10:35:23Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### Docker Compose
|
||||
|
||||
When deploying with docker-compose, ensure each instance has a unique ID:
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
services:
|
||||
backend-1:
|
||||
image: maplefile-backend:latest
|
||||
environment:
|
||||
- LEADER_ELECTION_ENABLED=true
|
||||
- LEADER_ELECTION_INSTANCE_ID=backend-1
|
||||
# ... other config
|
||||
|
||||
backend-2:
|
||||
image: maplefile-backend:latest
|
||||
environment:
|
||||
- LEADER_ELECTION_ENABLED=true
|
||||
- LEADER_ELECTION_INSTANCE_ID=backend-2
|
||||
# ... other config
|
||||
|
||||
backend-3:
|
||||
image: maplefile-backend:latest
|
||||
environment:
|
||||
- LEADER_ELECTION_ENABLED=true
|
||||
- LEADER_ELECTION_INSTANCE_ID=backend-3
|
||||
# ... other config
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
For Kubernetes, the instance ID can be auto-generated from the pod name:
|
||||
|
||||
```yaml
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: maplefile-backend
|
||||
spec:
|
||||
replicas: 3
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: backend
|
||||
image: maplefile-backend:latest
|
||||
env:
|
||||
- name: LEADER_ELECTION_ENABLED
|
||||
value: "true"
|
||||
- name: LEADER_ELECTION_INSTANCE_ID
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: metadata.name
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
Check logs for leader election events:
|
||||
|
||||
```bash
|
||||
# Grep for leader election events
|
||||
docker logs maplefile-backend | grep "LEADER\|election"
|
||||
|
||||
# Example output:
|
||||
# 2025-01-12T10:30:00.000Z INFO Starting leader election instance_id=instance-1
|
||||
# 2025-01-12T10:30:00.123Z INFO 🎉 Became the leader! instance_id=instance-1
|
||||
# 2025-01-12T10:30:03.456Z DEBUG Heartbeat sent instance_id=instance-1
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Leader keeps changing
|
||||
Increase `LEADER_ELECTION_LOCK_TTL`:
|
||||
```bash
|
||||
LEADER_ELECTION_LOCK_TTL=30s
|
||||
```
|
||||
|
||||
### No leader elected
|
||||
Check Redis connectivity:
|
||||
```bash
|
||||
redis-cli
|
||||
> GET maplefile:leader:lock
|
||||
```
|
||||
|
||||
### Multiple leaders
|
||||
This shouldn't happen, but if it does:
|
||||
1. Check system clock sync across instances
|
||||
2. Check Redis is working properly
|
||||
3. Check network connectivity
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Implement leader-only background jobs
|
||||
2. Add metrics for leader election events
|
||||
3. Create alerting for frequent leadership changes
|
||||
4. Add dashboards to monitor leader status
|
||||
461
cloud/maplefile-backend/pkg/leaderelection/FAILOVER_TEST.md
Normal file
461
cloud/maplefile-backend/pkg/leaderelection/FAILOVER_TEST.md
Normal file
|
|
@ -0,0 +1,461 @@
|
|||
# Leader Election Failover Testing Guide
|
||||
|
||||
This guide helps you verify that leader election handles cascading failures correctly.
|
||||
|
||||
## Test Scenarios
|
||||
|
||||
### Test 1: Graceful Shutdown Failover
|
||||
|
||||
**Objective:** Verify new leader is elected when current leader shuts down gracefully.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 3 instances:
|
||||
```bash
|
||||
# Terminal 1
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend
|
||||
|
||||
# Terminal 2
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-2 ./maplefile-backend
|
||||
|
||||
# Terminal 3
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-3 ./maplefile-backend
|
||||
```
|
||||
|
||||
2. Identify the leader:
|
||||
```bash
|
||||
# Look for this in logs:
|
||||
# "🎉 Became the leader!" instance_id=instance-1
|
||||
```
|
||||
|
||||
3. Gracefully stop the leader (Ctrl+C in Terminal 1)
|
||||
|
||||
4. Watch the other terminals:
|
||||
```bash
|
||||
# Within ~2 seconds, you should see:
|
||||
# "🎉 Became the leader!" instance_id=instance-2 or instance-3
|
||||
```
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ New leader elected within 2 seconds
|
||||
- ✅ Only ONE instance becomes leader (not both)
|
||||
- ✅ Scheduler tasks continue executing on new leader
|
||||
|
||||
---
|
||||
|
||||
### Test 2: Hard Crash Failover
|
||||
|
||||
**Objective:** Verify new leader is elected when current leader crashes.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 3 instances (same as Test 1)
|
||||
|
||||
2. Identify the leader
|
||||
|
||||
3. **Hard kill** the leader process:
|
||||
```bash
|
||||
# Find the process ID
|
||||
ps aux | grep maplefile-backend
|
||||
|
||||
# Kill it (simulates crash)
|
||||
kill -9 <PID>
|
||||
```
|
||||
|
||||
4. Watch the other terminals
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ Lock expires after 10 seconds (LockTTL)
|
||||
- ✅ New leader elected within ~12 seconds total
|
||||
- ✅ Only ONE instance becomes leader
|
||||
|
||||
---
|
||||
|
||||
### Test 3: Cascading Failures
|
||||
|
||||
**Objective:** Verify system handles multiple leaders shutting down in sequence.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 4 instances:
|
||||
```bash
|
||||
# Terminal 1
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend
|
||||
|
||||
# Terminal 2
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-2 ./maplefile-backend
|
||||
|
||||
# Terminal 3
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-3 ./maplefile-backend
|
||||
|
||||
# Terminal 4
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-4 ./maplefile-backend
|
||||
```
|
||||
|
||||
2. Identify first leader (e.g., instance-1)
|
||||
|
||||
3. Stop instance-1 (Ctrl+C)
|
||||
- Watch: instance-2, instance-3, or instance-4 becomes leader
|
||||
|
||||
4. Stop the new leader (Ctrl+C)
|
||||
- Watch: Another instance becomes leader
|
||||
|
||||
5. Stop that leader (Ctrl+C)
|
||||
- Watch: Last remaining instance becomes leader
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ After each shutdown, a new leader is elected
|
||||
- ✅ System continues operating with 1 instance
|
||||
- ✅ Scheduler tasks never stop (always running on current leader)
|
||||
|
||||
---
|
||||
|
||||
### Test 4: Leader Re-joins After Failover
|
||||
|
||||
**Objective:** Verify old leader doesn't reclaim leadership when it comes back.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 3 instances (instance-1, instance-2, instance-3)
|
||||
|
||||
2. instance-1 is the leader
|
||||
|
||||
3. Stop instance-1 (Ctrl+C)
|
||||
|
||||
4. instance-2 becomes the new leader
|
||||
|
||||
5. **Restart instance-1**:
|
||||
```bash
|
||||
# Terminal 1
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend
|
||||
```
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ instance-1 starts as a FOLLOWER (not leader)
|
||||
- ✅ instance-2 remains the leader
|
||||
- ✅ instance-1 logs show: "Another instance is the leader"
|
||||
|
||||
---
|
||||
|
||||
### Test 5: Network Partition Simulation
|
||||
|
||||
**Objective:** Verify behavior when leader loses Redis connectivity.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 3 instances
|
||||
|
||||
2. Identify the leader
|
||||
|
||||
3. **Block Redis access** for the leader instance:
|
||||
```bash
|
||||
# Option 1: Stop Redis temporarily
|
||||
docker stop redis
|
||||
|
||||
# Option 2: Use iptables to block Redis port
|
||||
sudo iptables -A OUTPUT -p tcp --dport 6379 -j DROP
|
||||
```
|
||||
|
||||
4. Watch the logs
|
||||
|
||||
5. **Restore Redis access**:
|
||||
```bash
|
||||
# Option 1: Start Redis
|
||||
docker start redis
|
||||
|
||||
# Option 2: Remove iptables rule
|
||||
sudo iptables -D OUTPUT -p tcp --dport 6379 -j DROP
|
||||
```
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ Leader fails to send heartbeat
|
||||
- ✅ Leader loses leadership (callback fired)
|
||||
- ✅ New leader elected from remaining instances
|
||||
- ✅ When Redis restored, old leader becomes a follower
|
||||
|
||||
---
|
||||
|
||||
### Test 6: Simultaneous Crash of All But One Instance
|
||||
|
||||
**Objective:** Verify last instance standing becomes leader.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 3 instances
|
||||
|
||||
2. Identify the leader (e.g., instance-1)
|
||||
|
||||
3. **Simultaneously kill** instance-1 and instance-2:
|
||||
```bash
|
||||
# Kill both at the same time
|
||||
kill -9 <PID1> <PID2>
|
||||
```
|
||||
|
||||
4. Watch instance-3
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ instance-3 becomes leader within ~12 seconds
|
||||
- ✅ Scheduler tasks continue on instance-3
|
||||
- ✅ System fully operational with 1 instance
|
||||
|
||||
---
|
||||
|
||||
### Test 7: Rapid Leader Changes (Chaos Test)
|
||||
|
||||
**Objective:** Stress test the election mechanism.
|
||||
|
||||
**Steps:**
|
||||
|
||||
1. Start 5 instances
|
||||
|
||||
2. Create a script to randomly kill and restart instances:
|
||||
```bash
|
||||
#!/bin/bash
|
||||
while true; do
|
||||
# Kill random instance
|
||||
RAND=$((RANDOM % 5 + 1))
|
||||
pkill -f "instance-$RAND"
|
||||
|
||||
# Wait a bit
|
||||
sleep $((RANDOM % 10 + 5))
|
||||
|
||||
# Restart it
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-$RAND ./maplefile-backend &
|
||||
|
||||
sleep $((RANDOM % 10 + 5))
|
||||
done
|
||||
```
|
||||
|
||||
3. Run for 5 minutes
|
||||
|
||||
**Expected Result:**
|
||||
- ✅ Always exactly ONE leader at any time
|
||||
- ✅ Smooth leadership transitions
|
||||
- ✅ No errors or race conditions
|
||||
- ✅ Scheduler tasks execute correctly throughout
|
||||
|
||||
---
|
||||
|
||||
## Monitoring During Tests
|
||||
|
||||
### Check Current Leader
|
||||
|
||||
```bash
|
||||
# Query Redis directly
|
||||
redis-cli GET maplefile:leader:lock
|
||||
# Output: instance-2
|
||||
|
||||
# Get leader info
|
||||
redis-cli GET maplefile:leader:info
|
||||
# Output: {"instance_id":"instance-2","hostname":"server-01",...}
|
||||
```
|
||||
|
||||
### Watch Leader Changes in Logs
|
||||
|
||||
```bash
|
||||
# Terminal 1: Watch for "Became the leader"
|
||||
tail -f logs/app.log | grep "Became the leader"
|
||||
|
||||
# Terminal 2: Watch for "lost leadership"
|
||||
tail -f logs/app.log | grep "lost leadership"
|
||||
|
||||
# Terminal 3: Watch for scheduler task execution
|
||||
tail -f logs/app.log | grep "Leader executing"
|
||||
```
|
||||
|
||||
### Monitor Redis Lock
|
||||
|
||||
```bash
|
||||
# Watch the lock key in real-time
|
||||
redis-cli --bigkeys
|
||||
|
||||
# Watch TTL countdown
|
||||
watch -n 1 'redis-cli TTL maplefile:leader:lock'
|
||||
```
|
||||
|
||||
## Expected Log Patterns
|
||||
|
||||
### Graceful Failover
|
||||
```
|
||||
[instance-1] Releasing leadership voluntarily instance_id=instance-1
|
||||
[instance-1] Scheduler stopped successfully
|
||||
[instance-2] 🎉 Became the leader! instance_id=instance-2
|
||||
[instance-2] BECAME LEADER - Starting leader-only tasks
|
||||
[instance-3] Skipping task execution - not the leader
|
||||
```
|
||||
|
||||
### Crash Failover
|
||||
```
|
||||
[instance-1] <nothing - crashed>
|
||||
[instance-2] 🎉 Became the leader! instance_id=instance-2
|
||||
[instance-2] 👑 Leader executing scheduled task task=CleanupJob
|
||||
[instance-3] Skipping task execution - not the leader
|
||||
```
|
||||
|
||||
### Cascading Failover
|
||||
```
|
||||
[instance-1] Releasing leadership voluntarily
|
||||
[instance-2] 🎉 Became the leader! instance_id=instance-2
|
||||
[instance-2] Releasing leadership voluntarily
|
||||
[instance-3] 🎉 Became the leader! instance_id=instance-3
|
||||
[instance-3] Releasing leadership voluntarily
|
||||
[instance-4] 🎉 Became the leader! instance_id=instance-4
|
||||
```
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### Issue: Multiple leaders elected
|
||||
|
||||
**Symptoms:** Two instances both log "Became the leader"
|
||||
|
||||
**Causes:**
|
||||
- Clock skew between servers
|
||||
- Redis not accessible to all instances
|
||||
- Different Redis instances being used
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Ensure all instances use same Redis
|
||||
CACHE_HOST=same-redis-server
|
||||
|
||||
# Sync clocks
|
||||
sudo ntpdate -s time.nist.gov
|
||||
|
||||
# Check Redis connectivity
|
||||
redis-cli PING
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: No leader elected
|
||||
|
||||
**Symptoms:** All instances are followers
|
||||
|
||||
**Causes:**
|
||||
- Redis lock key stuck
|
||||
- TTL not expiring
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Manually clear the lock
|
||||
redis-cli DEL maplefile:leader:lock
|
||||
redis-cli DEL maplefile:leader:info
|
||||
|
||||
# Restart instances
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Issue: Slow failover
|
||||
|
||||
**Symptoms:** Takes > 30s for new leader to be elected
|
||||
|
||||
**Causes:**
|
||||
- LockTTL too high
|
||||
- RetryInterval too high
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Reduce timeouts
|
||||
LEADER_ELECTION_LOCK_TTL=5s
|
||||
LEADER_ELECTION_RETRY_INTERVAL=1s
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Expected failover times:
|
||||
|
||||
| Scenario | Min | Typical | Max |
|
||||
|----------|-----|---------|-----|
|
||||
| Graceful shutdown | 1s | 2s | 3s |
|
||||
| Hard crash | 10s | 12s | 15s |
|
||||
| Network partition | 10s | 12s | 15s |
|
||||
| Cascading (2 leaders) | 2s | 4s | 6s |
|
||||
| Cascading (3 leaders) | 4s | 6s | 9s |
|
||||
|
||||
With optimized settings (`LockTTL=5s`, `RetryInterval=1s`):
|
||||
|
||||
| Scenario | Min | Typical | Max |
|
||||
|----------|-----|---------|-----|
|
||||
| Graceful shutdown | 0.5s | 1s | 2s |
|
||||
| Hard crash | 5s | 6s | 8s |
|
||||
| Network partition | 5s | 6s | 8s |
|
||||
|
||||
## Automated Test Script
|
||||
|
||||
Create `test-failover.sh`:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
echo "=== Leader Election Failover Test ==="
|
||||
echo ""
|
||||
|
||||
# Start 3 instances
|
||||
echo "Starting 3 instances..."
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend > /tmp/instance-1.log 2>&1 &
|
||||
PID1=$!
|
||||
sleep 2
|
||||
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-2 ./maplefile-backend > /tmp/instance-2.log 2>&1 &
|
||||
PID2=$!
|
||||
sleep 2
|
||||
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-3 ./maplefile-backend > /tmp/instance-3.log 2>&1 &
|
||||
PID3=$!
|
||||
sleep 5
|
||||
|
||||
# Find initial leader
|
||||
echo "Checking initial leader..."
|
||||
LEADER=$(redis-cli GET maplefile:leader:lock)
|
||||
echo "Initial leader: $LEADER"
|
||||
|
||||
# Kill the leader
|
||||
echo "Killing leader: $LEADER"
|
||||
if [ "$LEADER" == "instance-1" ]; then
|
||||
kill $PID1
|
||||
elif [ "$LEADER" == "instance-2" ]; then
|
||||
kill $PID2
|
||||
else
|
||||
kill $PID3
|
||||
fi
|
||||
|
||||
# Wait for failover
|
||||
echo "Waiting for failover..."
|
||||
sleep 15
|
||||
|
||||
# Check new leader
|
||||
NEW_LEADER=$(redis-cli GET maplefile:leader:lock)
|
||||
echo "New leader: $NEW_LEADER"
|
||||
|
||||
if [ "$NEW_LEADER" != "" ] && [ "$NEW_LEADER" != "$LEADER" ]; then
|
||||
echo "✅ Failover successful! New leader: $NEW_LEADER"
|
||||
else
|
||||
echo "❌ Failover failed!"
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
kill $PID1 $PID2 $PID3 2>/dev/null
|
||||
echo "Test complete"
|
||||
```
|
||||
|
||||
Run it:
|
||||
```bash
|
||||
chmod +x test-failover.sh
|
||||
./test-failover.sh
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
Your leader election implementation correctly handles:
|
||||
|
||||
✅ Graceful shutdown → New leader elected in ~2s
|
||||
✅ Crash/hard kill → New leader elected in ~12s
|
||||
✅ Cascading failures → Each failure triggers new election
|
||||
✅ Network partitions → Automatic recovery
|
||||
✅ Leader re-joins → Stays as follower
|
||||
✅ Multiple simultaneous failures → Last instance becomes leader
|
||||
|
||||
The system is **production-ready** for multi-instance deployments with automatic failover! 🎉
|
||||
411
cloud/maplefile-backend/pkg/leaderelection/README.md
Normal file
411
cloud/maplefile-backend/pkg/leaderelection/README.md
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
# Leader Election Package
|
||||
|
||||
Distributed leader election for MapleFile backend instances using Redis.
|
||||
|
||||
## Overview
|
||||
|
||||
This package provides leader election functionality for multiple backend instances running behind a load balancer. It ensures that only one instance acts as the "leader" at any given time, with automatic failover if the leader crashes.
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ **Redis-based**: Fast, reliable leader election using Redis
|
||||
- ✅ **Automatic Failover**: New leader elected automatically if current leader crashes
|
||||
- ✅ **Heartbeat Mechanism**: Leader maintains lock with periodic renewals
|
||||
- ✅ **Callbacks**: Execute custom code when becoming/losing leadership
|
||||
- ✅ **Graceful Shutdown**: Clean leadership handoff on shutdown
|
||||
- ✅ **Thread-Safe**: Safe for concurrent use
|
||||
- ✅ **Observable**: Query leader status and information
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Election**: Instances compete to acquire a Redis lock (key)
|
||||
2. **Leadership**: First instance to acquire the lock becomes the leader
|
||||
3. **Heartbeat**: Leader renews the lock every `HeartbeatInterval` (default: 3s)
|
||||
4. **Lock TTL**: Lock expires after `LockTTL` if not renewed (default: 10s)
|
||||
5. **Failover**: If leader crashes, lock expires → followers compete for leadership
|
||||
6. **Re-election**: New leader elected within seconds of previous leader failure
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Instance 1 │ │ Instance 2 │ │ Instance 3 │
|
||||
│ (Leader) │ │ (Follower) │ │ (Follower) │
|
||||
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
|
||||
│ │ │
|
||||
│ Heartbeat │ Try Acquire │ Try Acquire
|
||||
│ (Renew Lock) │ (Check Lock) │ (Check Lock)
|
||||
│ │ │
|
||||
└───────────────────┴───────────────────┘
|
||||
│
|
||||
┌────▼────┐
|
||||
│ Redis │
|
||||
│ Lock │
|
||||
└─────────┘
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/leaderelection"
|
||||
)
|
||||
|
||||
// Create Redis client (you likely already have this)
|
||||
redisClient := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
})
|
||||
|
||||
// Create logger
|
||||
logger, _ := zap.NewProduction()
|
||||
|
||||
// Create leader election configuration
|
||||
config := leaderelection.DefaultConfig()
|
||||
|
||||
// Create leader election instance
|
||||
election, err := leaderelection.NewRedisLeaderElection(config, redisClient, logger)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Start leader election in a goroutine
|
||||
ctx := context.Background()
|
||||
go func() {
|
||||
if err := election.Start(ctx); err != nil {
|
||||
logger.Error("Leader election failed", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Check if this instance is the leader
|
||||
if election.IsLeader() {
|
||||
logger.Info("I am the leader! 👑")
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
defer election.Stop()
|
||||
```
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
// Register callback when becoming leader
|
||||
election.OnBecomeLeader(func() {
|
||||
logger.Info("🎉 I became the leader!")
|
||||
|
||||
// Start leader-only tasks
|
||||
go startBackgroundJobs()
|
||||
go startMetricsAggregation()
|
||||
})
|
||||
|
||||
// Register callback when losing leadership
|
||||
election.OnLoseLeadership(func() {
|
||||
logger.Info("😢 I lost leadership")
|
||||
|
||||
// Stop leader-only tasks
|
||||
stopBackgroundJobs()
|
||||
stopMetricsAggregation()
|
||||
})
|
||||
```
|
||||
|
||||
### Integration with Application Startup
|
||||
|
||||
```go
|
||||
// In your main.go or app startup
|
||||
func (app *Application) Start() error {
|
||||
// Start leader election
|
||||
go func() {
|
||||
if err := app.leaderElection.Start(app.ctx); err != nil {
|
||||
app.logger.Error("Leader election error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait a moment for election to complete
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
if app.leaderElection.IsLeader() {
|
||||
app.logger.Info("This instance is the leader")
|
||||
// Start leader-only services
|
||||
} else {
|
||||
app.logger.Info("This instance is a follower")
|
||||
// Start follower-only services (if any)
|
||||
}
|
||||
|
||||
// Start your HTTP server, etc.
|
||||
return app.httpServer.Start()
|
||||
}
|
||||
```
|
||||
|
||||
### Conditional Logic Based on Leadership
|
||||
|
||||
```go
|
||||
// Only leader executes certain tasks
|
||||
func (s *Service) PerformTask() {
|
||||
if s.leaderElection.IsLeader() {
|
||||
// Only leader does this expensive operation
|
||||
s.aggregateMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
// Get information about the current leader
|
||||
func (s *Service) GetLeaderStatus() (*leaderelection.LeaderInfo, error) {
|
||||
info, err := s.leaderElection.GetLeaderInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("Leader: %s (%s)\n", info.InstanceID, info.Hostname)
|
||||
fmt.Printf("Started: %s\n", info.StartedAt)
|
||||
fmt.Printf("Last Heartbeat: %s\n", info.LastHeartbeat)
|
||||
|
||||
return info, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Default Configuration
|
||||
|
||||
```go
|
||||
config := leaderelection.DefaultConfig()
|
||||
// Returns:
|
||||
// {
|
||||
// RedisKeyName: "maplefile:leader:lock",
|
||||
// RedisInfoKeyName: "maplefile:leader:info",
|
||||
// LockTTL: 10 * time.Second,
|
||||
// HeartbeatInterval: 3 * time.Second,
|
||||
// RetryInterval: 2 * time.Second,
|
||||
// }
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```go
|
||||
config := &leaderelection.Config{
|
||||
RedisKeyName: "my-app:leader",
|
||||
RedisInfoKeyName: "my-app:leader:info",
|
||||
LockTTL: 30 * time.Second, // Lock expires after 30s
|
||||
HeartbeatInterval: 10 * time.Second, // Renew every 10s
|
||||
RetryInterval: 5 * time.Second, // Check for leadership every 5s
|
||||
InstanceID: "instance-1", // Custom instance ID
|
||||
Hostname: "server-01", // Custom hostname
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration in Application Config
|
||||
|
||||
Add to your `config/config.go`:
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// ... existing fields ...
|
||||
|
||||
LeaderElection struct {
|
||||
LockTTL time.Duration `env:"LEADER_ELECTION_LOCK_TTL" envDefault:"10s"`
|
||||
HeartbeatInterval time.Duration `env:"LEADER_ELECTION_HEARTBEAT_INTERVAL" envDefault:"3s"`
|
||||
RetryInterval time.Duration `env:"LEADER_ELECTION_RETRY_INTERVAL" envDefault:"2s"`
|
||||
InstanceID string `env:"LEADER_ELECTION_INSTANCE_ID" envDefault:""`
|
||||
Hostname string `env:"LEADER_ELECTION_HOSTNAME" envDefault:""`
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Background Job Processing
|
||||
Only the leader runs scheduled jobs:
|
||||
|
||||
```go
|
||||
election.OnBecomeLeader(func() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if election.IsLeader() {
|
||||
processScheduledJobs()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
```
|
||||
|
||||
### 2. Database Migrations
|
||||
Only the leader runs migrations on startup:
|
||||
|
||||
```go
|
||||
if election.IsLeader() {
|
||||
logger.Info("Leader instance - running database migrations")
|
||||
if err := migrator.Up(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
logger.Info("Follower instance - skipping migrations")
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Cache Warming
|
||||
Only the leader pre-loads caches:
|
||||
|
||||
```go
|
||||
election.OnBecomeLeader(func() {
|
||||
logger.Info("Warming caches as leader")
|
||||
warmApplicationCache()
|
||||
})
|
||||
```
|
||||
|
||||
### 4. Metrics Aggregation
|
||||
Only the leader aggregates and sends metrics:
|
||||
|
||||
```go
|
||||
election.OnBecomeLeader(func() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if election.IsLeader() {
|
||||
aggregateAndSendMetrics()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
```
|
||||
|
||||
### 5. Cleanup Tasks
|
||||
Only the leader runs periodic cleanup:
|
||||
|
||||
```go
|
||||
election.OnBecomeLeader(func() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if election.IsLeader() {
|
||||
cleanupOldRecords()
|
||||
purgeExpiredSessions()
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Check Endpoint
|
||||
|
||||
```go
|
||||
func (h *HealthHandler) LeaderElectionHealth(w http.ResponseWriter, r *http.Request) {
|
||||
info, err := h.leaderElection.GetLeaderInfo()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get leader info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"is_leader": h.leaderElection.IsLeader(),
|
||||
"instance_id": h.leaderElection.GetInstanceID(),
|
||||
"leader_info": info,
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
The package logs important events:
|
||||
- `🎉 Became the leader!` - When instance becomes leader
|
||||
- `Heartbeat sent` - When leader renews lock (DEBUG level)
|
||||
- `Failed to send heartbeat, lost leadership` - When leader loses lock
|
||||
- `Releasing leadership voluntarily` - On graceful shutdown
|
||||
|
||||
## Testing
|
||||
|
||||
### Local Testing with Multiple Instances
|
||||
|
||||
```bash
|
||||
# Terminal 1
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-1 ./maplefile-backend
|
||||
|
||||
# Terminal 2
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-2 ./maplefile-backend
|
||||
|
||||
# Terminal 3
|
||||
LEADER_ELECTION_INSTANCE_ID=instance-3 ./maplefile-backend
|
||||
```
|
||||
|
||||
### Failover Testing
|
||||
|
||||
1. Start 3 instances
|
||||
2. Check logs - one will become leader
|
||||
3. Kill the leader instance (Ctrl+C)
|
||||
4. Watch logs - another instance becomes leader within seconds
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always check leadership before expensive operations**
|
||||
```go
|
||||
if election.IsLeader() {
|
||||
// expensive operation
|
||||
}
|
||||
```
|
||||
|
||||
2. **Use callbacks for starting/stopping leader-only services**
|
||||
```go
|
||||
election.OnBecomeLeader(startLeaderServices)
|
||||
election.OnLoseLeadership(stopLeaderServices)
|
||||
```
|
||||
|
||||
3. **Set appropriate timeouts**
|
||||
- `LockTTL` should be 2-3x `HeartbeatInterval`
|
||||
- Shorter TTL = faster failover but more Redis traffic
|
||||
- Longer TTL = slower failover but less Redis traffic
|
||||
|
||||
4. **Handle callback panics**
|
||||
- Callbacks run in goroutines and panics are caught
|
||||
- But you should still handle errors gracefully
|
||||
|
||||
5. **Always call Stop() on shutdown**
|
||||
```go
|
||||
defer election.Stop()
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Leader keeps changing
|
||||
- Increase `LockTTL` (network might be slow)
|
||||
- Check Redis connectivity
|
||||
- Check for clock skew between instances
|
||||
|
||||
### No leader elected
|
||||
- Check Redis is running and accessible
|
||||
- Check Redis key permissions
|
||||
- Check logs for errors
|
||||
|
||||
### Leader doesn't release on shutdown
|
||||
- Ensure `Stop()` is called
|
||||
- Check for blocking operations preventing shutdown
|
||||
- TTL will eventually expire the lock
|
||||
|
||||
## Performance
|
||||
|
||||
- **Election time**: < 100ms
|
||||
- **Failover time**: < `LockTTL` (default: 10s)
|
||||
- **Redis operations per second**: `1 / HeartbeatInterval` (default: 0.33/s)
|
||||
- **Memory overhead**: Minimal (~1KB per instance)
|
||||
|
||||
## Thread Safety
|
||||
|
||||
All methods are thread-safe and can be called from multiple goroutines:
|
||||
- `IsLeader()`
|
||||
- `GetLeaderID()`
|
||||
- `GetLeaderInfo()`
|
||||
- `OnBecomeLeader()`
|
||||
- `OnLoseLeadership()`
|
||||
- `Stop()`
|
||||
136
cloud/maplefile-backend/pkg/leaderelection/interface.go
Normal file
136
cloud/maplefile-backend/pkg/leaderelection/interface.go
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
// Package leaderelection provides distributed leader election for multiple application instances.
|
||||
// It ensures only one instance acts as the leader at any given time, with automatic failover.
|
||||
package leaderelection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LeaderElection provides distributed leader election across multiple application instances.
|
||||
// It uses Redis to coordinate which instance is the current leader, with automatic failover
|
||||
// if the leader crashes or becomes unavailable.
|
||||
type LeaderElection interface {
|
||||
// Start begins participating in leader election.
|
||||
// This method blocks and runs the election loop until ctx is cancelled or an error occurs.
|
||||
// The instance will automatically attempt to become leader and maintain leadership.
|
||||
Start(ctx context.Context) error
|
||||
|
||||
// IsLeader returns true if this instance is currently the leader.
|
||||
// This is a local check and does not require network communication.
|
||||
IsLeader() bool
|
||||
|
||||
// GetLeaderID returns the unique identifier of the current leader instance.
|
||||
// Returns empty string if no leader exists (should be rare).
|
||||
GetLeaderID() (string, error)
|
||||
|
||||
// GetLeaderInfo returns detailed information about the current leader.
|
||||
GetLeaderInfo() (*LeaderInfo, error)
|
||||
|
||||
// OnBecomeLeader registers a callback function that will be executed when
|
||||
// this instance becomes the leader. Multiple callbacks can be registered.
|
||||
OnBecomeLeader(callback func())
|
||||
|
||||
// OnLoseLeadership registers a callback function that will be executed when
|
||||
// this instance loses leadership (either voluntarily or due to failure).
|
||||
// Multiple callbacks can be registered.
|
||||
OnLoseLeadership(callback func())
|
||||
|
||||
// Stop gracefully stops leader election participation.
|
||||
// If this instance is the leader, it releases leadership allowing another instance to take over.
|
||||
// This should be called during application shutdown.
|
||||
Stop() error
|
||||
|
||||
// GetInstanceID returns the unique identifier for this instance.
|
||||
GetInstanceID() string
|
||||
}
|
||||
|
||||
// LeaderInfo contains information about the current leader.
|
||||
type LeaderInfo struct {
|
||||
// InstanceID is the unique identifier of the leader instance
|
||||
InstanceID string `json:"instance_id"`
|
||||
|
||||
// Hostname is the hostname of the leader instance
|
||||
Hostname string `json:"hostname"`
|
||||
|
||||
// StartedAt is when this instance became the leader
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
|
||||
// LastHeartbeat is the last time the leader renewed its lock
|
||||
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||
}
|
||||
|
||||
// Config contains configuration for leader election.
|
||||
type Config struct {
|
||||
// RedisKeyName is the Redis key used for leader election.
|
||||
// Default: "maplefile:leader:lock"
|
||||
RedisKeyName string
|
||||
|
||||
// RedisInfoKeyName is the Redis key used to store leader information.
|
||||
// Default: "maplefile:leader:info"
|
||||
RedisInfoKeyName string
|
||||
|
||||
// LockTTL is how long the leader lock lasts before expiring.
|
||||
// The leader must renew the lock before this time expires.
|
||||
// Default: 10 seconds
|
||||
// Recommended: 10-30 seconds
|
||||
LockTTL time.Duration
|
||||
|
||||
// HeartbeatInterval is how often the leader renews its lock.
|
||||
// This should be significantly less than LockTTL (e.g., LockTTL / 3).
|
||||
// Default: 3 seconds
|
||||
// Recommended: LockTTL / 3
|
||||
HeartbeatInterval time.Duration
|
||||
|
||||
// RetryInterval is how often followers check for leadership opportunity.
|
||||
// Default: 2 seconds
|
||||
// Recommended: 1-5 seconds
|
||||
RetryInterval time.Duration
|
||||
|
||||
// InstanceID uniquely identifies this application instance.
|
||||
// If empty, will be auto-generated from hostname + random suffix.
|
||||
// Default: auto-generated
|
||||
InstanceID string
|
||||
|
||||
// Hostname is the hostname of this instance.
|
||||
// If empty, will be auto-detected.
|
||||
// Default: os.Hostname()
|
||||
Hostname string
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
RedisKeyName: "maplefile:leader:lock",
|
||||
RedisInfoKeyName: "maplefile:leader:info",
|
||||
LockTTL: 10 * time.Second,
|
||||
HeartbeatInterval: 3 * time.Second,
|
||||
RetryInterval: 2 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid and returns an error if not.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RedisKeyName == "" {
|
||||
c.RedisKeyName = "maplefile:leader:lock"
|
||||
}
|
||||
if c.RedisInfoKeyName == "" {
|
||||
c.RedisInfoKeyName = "maplefile:leader:info"
|
||||
}
|
||||
if c.LockTTL <= 0 {
|
||||
c.LockTTL = 10 * time.Second
|
||||
}
|
||||
if c.HeartbeatInterval <= 0 {
|
||||
c.HeartbeatInterval = 3 * time.Second
|
||||
}
|
||||
if c.RetryInterval <= 0 {
|
||||
c.RetryInterval = 2 * time.Second
|
||||
}
|
||||
|
||||
// HeartbeatInterval should be less than LockTTL
|
||||
if c.HeartbeatInterval >= c.LockTTL {
|
||||
c.HeartbeatInterval = c.LockTTL / 3
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
351
cloud/maplefile-backend/pkg/leaderelection/mutex_leader.go
Normal file
351
cloud/maplefile-backend/pkg/leaderelection/mutex_leader.go
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
package leaderelection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/distributedmutex"
|
||||
)
|
||||
|
||||
// mutexLeaderElection implements LeaderElection using distributedmutex.
|
||||
type mutexLeaderElection struct {
|
||||
config *Config
|
||||
mutex distributedmutex.Adapter
|
||||
redis redis.UniversalClient
|
||||
logger *zap.Logger
|
||||
instanceID string
|
||||
hostname string
|
||||
isLeader bool
|
||||
leaderMutex sync.RWMutex
|
||||
becomeLeaderCbs []func()
|
||||
loseLeadershipCbs []func()
|
||||
callbackMutex sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
stoppedChan chan struct{}
|
||||
leaderStartTime time.Time
|
||||
lastHeartbeat time.Time
|
||||
lastHeartbeatMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMutexLeaderElection creates a new distributed mutex-based leader election instance.
|
||||
func NewMutexLeaderElection(
|
||||
config *Config,
|
||||
mutex distributedmutex.Adapter,
|
||||
redisClient redis.UniversalClient,
|
||||
logger *zap.Logger,
|
||||
) (LeaderElection, error) {
|
||||
logger = logger.Named("LeaderElection")
|
||||
|
||||
// Validate configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
// Generate instance ID if not provided
|
||||
instanceID := config.InstanceID
|
||||
if instanceID == "" {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
}
|
||||
// Add random suffix to make it unique
|
||||
instanceID = fmt.Sprintf("%s-%d", hostname, rand.Intn(100000))
|
||||
logger.Info("Generated instance ID", zap.String("instance_id", instanceID))
|
||||
}
|
||||
|
||||
// Get hostname if not provided
|
||||
hostname := config.Hostname
|
||||
if hostname == "" {
|
||||
h, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "unknown"
|
||||
} else {
|
||||
hostname = h
|
||||
}
|
||||
}
|
||||
|
||||
return &mutexLeaderElection{
|
||||
config: config,
|
||||
mutex: mutex,
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
instanceID: instanceID,
|
||||
hostname: hostname,
|
||||
isLeader: false,
|
||||
becomeLeaderCbs: make([]func(), 0),
|
||||
loseLeadershipCbs: make([]func(), 0),
|
||||
stopChan: make(chan struct{}),
|
||||
stoppedChan: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins participating in leader election.
|
||||
func (le *mutexLeaderElection) Start(ctx context.Context) error {
|
||||
le.logger.Info("Starting leader election",
|
||||
zap.String("instance_id", le.instanceID),
|
||||
zap.String("hostname", le.hostname),
|
||||
zap.Duration("lock_ttl", le.config.LockTTL),
|
||||
zap.Duration("heartbeat_interval", le.config.HeartbeatInterval),
|
||||
)
|
||||
|
||||
defer close(le.stoppedChan)
|
||||
|
||||
// Main election loop
|
||||
ticker := time.NewTicker(le.config.RetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Try to become leader immediately on startup
|
||||
le.tryBecomeLeader(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
le.logger.Info("Context cancelled, stopping leader election")
|
||||
le.releaseLeadership(context.Background())
|
||||
return ctx.Err()
|
||||
|
||||
case <-le.stopChan:
|
||||
le.logger.Info("Stop signal received, stopping leader election")
|
||||
le.releaseLeadership(context.Background())
|
||||
return nil
|
||||
|
||||
case <-ticker.C:
|
||||
if le.IsLeader() {
|
||||
// If we're the leader, send heartbeat
|
||||
if err := le.sendHeartbeat(ctx); err != nil {
|
||||
le.logger.Error("Failed to send heartbeat, lost leadership",
|
||||
zap.Error(err))
|
||||
le.setLeaderStatus(false)
|
||||
le.executeCallbacks(le.loseLeadershipCbs)
|
||||
}
|
||||
} else {
|
||||
// If we're not the leader, try to become leader
|
||||
le.tryBecomeLeader(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryBecomeLeader attempts to acquire leadership using distributed mutex.
|
||||
func (le *mutexLeaderElection) tryBecomeLeader(ctx context.Context) {
|
||||
// Try to acquire the lock (non-blocking)
|
||||
acquired, err := le.mutex.TryAcquire(ctx, le.config.RedisKeyName, le.config.LockTTL)
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to attempt leader election",
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if acquired {
|
||||
// We became the leader!
|
||||
le.logger.Info("🎉 Became the leader!",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
le.leaderStartTime = time.Now()
|
||||
le.setLeaderStatus(true)
|
||||
le.updateLeaderInfo(ctx)
|
||||
le.executeCallbacks(le.becomeLeaderCbs)
|
||||
} else {
|
||||
// Someone else is the leader
|
||||
if !le.IsLeader() {
|
||||
// Only log if we weren't already aware
|
||||
currentLeader, _ := le.GetLeaderID()
|
||||
le.logger.Debug("Another instance is the leader",
|
||||
zap.String("leader_id", currentLeader))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat renews the leader lock using distributed mutex.
|
||||
func (le *mutexLeaderElection) sendHeartbeat(ctx context.Context) error {
|
||||
// Extend the lock TTL
|
||||
err := le.mutex.Extend(ctx, le.config.RedisKeyName, le.config.LockTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extend lock: %w", err)
|
||||
}
|
||||
|
||||
// Update heartbeat time
|
||||
le.setLastHeartbeat(time.Now())
|
||||
|
||||
// Update leader info
|
||||
le.updateLeaderInfo(ctx)
|
||||
|
||||
le.logger.Debug("Heartbeat sent",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateLeaderInfo updates the leader information in Redis.
|
||||
func (le *mutexLeaderElection) updateLeaderInfo(ctx context.Context) {
|
||||
info := &LeaderInfo{
|
||||
InstanceID: le.instanceID,
|
||||
Hostname: le.hostname,
|
||||
StartedAt: le.leaderStartTime,
|
||||
LastHeartbeat: le.getLastHeartbeat(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to marshal leader info", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Set with same TTL as lock
|
||||
err = le.redis.Set(ctx, le.config.RedisInfoKeyName, data, le.config.LockTTL).Err()
|
||||
if err != nil {
|
||||
le.logger.Error("Failed to update leader info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// releaseLeadership voluntarily releases leadership.
|
||||
func (le *mutexLeaderElection) releaseLeadership(ctx context.Context) {
|
||||
if !le.IsLeader() {
|
||||
return
|
||||
}
|
||||
|
||||
le.logger.Info("Releasing leadership voluntarily",
|
||||
zap.String("instance_id", le.instanceID))
|
||||
|
||||
// Release the lock using distributed mutex
|
||||
le.mutex.Release(ctx, le.config.RedisKeyName)
|
||||
|
||||
// Delete leader info
|
||||
le.redis.Del(ctx, le.config.RedisInfoKeyName)
|
||||
|
||||
le.setLeaderStatus(false)
|
||||
le.executeCallbacks(le.loseLeadershipCbs)
|
||||
}
|
||||
|
||||
// IsLeader returns true if this instance is the leader.
|
||||
func (le *mutexLeaderElection) IsLeader() bool {
|
||||
le.leaderMutex.RLock()
|
||||
defer le.leaderMutex.RUnlock()
|
||||
return le.isLeader
|
||||
}
|
||||
|
||||
// GetLeaderID returns the ID of the current leader.
|
||||
func (le *mutexLeaderElection) GetLeaderID() (string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Check if we own the lock
|
||||
isOwner, err := le.mutex.IsOwner(ctx, le.config.RedisKeyName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check lock ownership: %w", err)
|
||||
}
|
||||
|
||||
if isOwner {
|
||||
return le.instanceID, nil
|
||||
}
|
||||
|
||||
// We don't own it, try to get from Redis
|
||||
leaderID, err := le.redis.Get(ctx, le.config.RedisKeyName).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get leader ID: %w", err)
|
||||
}
|
||||
return leaderID, nil
|
||||
}
|
||||
|
||||
// GetLeaderInfo returns information about the current leader.
|
||||
func (le *mutexLeaderElection) GetLeaderInfo() (*LeaderInfo, error) {
|
||||
ctx := context.Background()
|
||||
data, err := le.redis.Get(ctx, le.config.RedisInfoKeyName).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get leader info: %w", err)
|
||||
}
|
||||
|
||||
var info LeaderInfo
|
||||
if err := json.Unmarshal([]byte(data), &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal leader info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// OnBecomeLeader registers a callback for when this instance becomes leader.
|
||||
func (le *mutexLeaderElection) OnBecomeLeader(callback func()) {
|
||||
le.callbackMutex.Lock()
|
||||
defer le.callbackMutex.Unlock()
|
||||
le.becomeLeaderCbs = append(le.becomeLeaderCbs, callback)
|
||||
}
|
||||
|
||||
// OnLoseLeadership registers a callback for when this instance loses leadership.
|
||||
func (le *mutexLeaderElection) OnLoseLeadership(callback func()) {
|
||||
le.callbackMutex.Lock()
|
||||
defer le.callbackMutex.Unlock()
|
||||
le.loseLeadershipCbs = append(le.loseLeadershipCbs, callback)
|
||||
}
|
||||
|
||||
// Stop gracefully stops leader election.
|
||||
func (le *mutexLeaderElection) Stop() error {
|
||||
le.logger.Info("Stopping leader election")
|
||||
close(le.stopChan)
|
||||
|
||||
// Wait for the election loop to finish (with timeout)
|
||||
select {
|
||||
case <-le.stoppedChan:
|
||||
le.logger.Info("Leader election stopped successfully")
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
le.logger.Warn("Timeout waiting for leader election to stop")
|
||||
return fmt.Errorf("timeout waiting for leader election to stop")
|
||||
}
|
||||
}
|
||||
|
||||
// GetInstanceID returns this instance's unique identifier.
|
||||
func (le *mutexLeaderElection) GetInstanceID() string {
|
||||
return le.instanceID
|
||||
}
|
||||
|
||||
// setLeaderStatus updates the leader status (thread-safe).
|
||||
func (le *mutexLeaderElection) setLeaderStatus(isLeader bool) {
|
||||
le.leaderMutex.Lock()
|
||||
defer le.leaderMutex.Unlock()
|
||||
le.isLeader = isLeader
|
||||
}
|
||||
|
||||
// setLastHeartbeat updates the last heartbeat time (thread-safe).
|
||||
func (le *mutexLeaderElection) setLastHeartbeat(t time.Time) {
|
||||
le.lastHeartbeatMutex.Lock()
|
||||
defer le.lastHeartbeatMutex.Unlock()
|
||||
le.lastHeartbeat = t
|
||||
}
|
||||
|
||||
// getLastHeartbeat gets the last heartbeat time (thread-safe).
|
||||
func (le *mutexLeaderElection) getLastHeartbeat() time.Time {
|
||||
le.lastHeartbeatMutex.RLock()
|
||||
defer le.lastHeartbeatMutex.RUnlock()
|
||||
return le.lastHeartbeat
|
||||
}
|
||||
|
||||
// executeCallbacks executes a list of callbacks in separate goroutines.
|
||||
func (le *mutexLeaderElection) executeCallbacks(callbacks []func()) {
|
||||
le.callbackMutex.RLock()
|
||||
defer le.callbackMutex.RUnlock()
|
||||
|
||||
for _, callback := range callbacks {
|
||||
go func(cb func()) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
le.logger.Error("Panic in leader election callback",
|
||||
zap.Any("panic", r))
|
||||
}
|
||||
}()
|
||||
cb()
|
||||
}(callback)
|
||||
}
|
||||
}
|
||||
30
cloud/maplefile-backend/pkg/leaderelection/provider.go
Normal file
30
cloud/maplefile-backend/pkg/leaderelection/provider.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package leaderelection
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/distributedmutex"
|
||||
)
|
||||
|
||||
// ProvideLeaderElection provides a LeaderElection instance for Wire DI.
|
||||
func ProvideLeaderElection(
|
||||
cfg *config.Config,
|
||||
mutex distributedmutex.Adapter,
|
||||
redisClient redis.UniversalClient,
|
||||
logger *zap.Logger,
|
||||
) (LeaderElection, error) {
|
||||
// Create configuration from app config
|
||||
leConfig := &Config{
|
||||
RedisKeyName: "maplefile:leader:lock",
|
||||
RedisInfoKeyName: "maplefile:leader:info",
|
||||
LockTTL: cfg.LeaderElection.LockTTL,
|
||||
HeartbeatInterval: cfg.LeaderElection.HeartbeatInterval,
|
||||
RetryInterval: cfg.LeaderElection.RetryInterval,
|
||||
InstanceID: cfg.LeaderElection.InstanceID,
|
||||
Hostname: cfg.LeaderElection.Hostname,
|
||||
}
|
||||
|
||||
return NewMutexLeaderElection(leConfig, mutex, redisClient, logger)
|
||||
}
|
||||
84
cloud/maplefile-backend/pkg/logger/logger.go
Normal file
84
cloud/maplefile-backend/pkg/logger/logger.go
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/logger/logger.go
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// NewProduction creates a production-ready logger with appropriate configuration
|
||||
func NewProduction() (*zap.Logger, error) {
|
||||
// Get log level from environment
|
||||
logLevel := getLogLevel()
|
||||
|
||||
// Configure encoder for production (JSON format)
|
||||
encoderConfig := zapcore.EncoderConfig{
|
||||
TimeKey: "timestamp",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "message",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.RFC3339TimeEncoder,
|
||||
EncodeDuration: zapcore.SecondsDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
// Create core
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(os.Stdout),
|
||||
logLevel,
|
||||
)
|
||||
|
||||
// Create logger with caller information
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
// Add service information
|
||||
logger = logger.With(
|
||||
zap.String("service", "maplefile-backend"),
|
||||
zap.String("version", getServiceVersion()),
|
||||
)
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// NewDevelopment creates a development logger (for backward compatibility)
|
||||
func NewDevelopment() (*zap.Logger, error) {
|
||||
return zap.NewDevelopment()
|
||||
}
|
||||
|
||||
// getLogLevel determines log level from environment
|
||||
func getLogLevel() zapcore.Level {
|
||||
levelStr := os.Getenv("LOG_LEVEL")
|
||||
switch levelStr {
|
||||
case "debug", "DEBUG":
|
||||
return zapcore.DebugLevel
|
||||
case "info", "INFO":
|
||||
return zapcore.InfoLevel
|
||||
case "warn", "WARN", "warning", "WARNING":
|
||||
return zapcore.WarnLevel
|
||||
case "error", "ERROR":
|
||||
return zapcore.ErrorLevel
|
||||
case "panic", "PANIC":
|
||||
return zapcore.PanicLevel
|
||||
case "fatal", "FATAL":
|
||||
return zapcore.FatalLevel
|
||||
default:
|
||||
return zapcore.InfoLevel
|
||||
}
|
||||
}
|
||||
|
||||
// getServiceVersion gets the service version (could be injected at build time)
|
||||
func getServiceVersion() string {
|
||||
version := os.Getenv("SERVICE_VERSION")
|
||||
if version == "" {
|
||||
return "1.0.0"
|
||||
}
|
||||
return version
|
||||
}
|
||||
15
cloud/maplefile-backend/pkg/logger/provider.go
Normal file
15
cloud/maplefile-backend/pkg/logger/provider.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideLogger provides a logger instance for Wire DI
|
||||
func ProvideLogger(cfg *config.Config) (*zap.Logger, error) {
|
||||
if cfg.App.Environment == "production" {
|
||||
return NewProduction()
|
||||
}
|
||||
return NewDevelopment()
|
||||
}
|
||||
109
cloud/maplefile-backend/pkg/maplefile/client/auth.go
Normal file
109
cloud/maplefile-backend/pkg/maplefile/client/auth.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Register creates a new user account.
|
||||
func (c *Client) Register(ctx context.Context, input *RegisterInput) (*RegisterResponse, error) {
|
||||
var resp RegisterResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/register", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// VerifyEmailCode verifies the email verification code.
|
||||
func (c *Client) VerifyEmailCode(ctx context.Context, input *VerifyEmailInput) (*VerifyEmailResponse, error) {
|
||||
var resp VerifyEmailResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/verify-email-code", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ResendVerification resends the email verification code.
|
||||
func (c *Client) ResendVerification(ctx context.Context, email string) error {
|
||||
input := ResendVerificationInput{Email: email}
|
||||
return c.doRequest(ctx, "POST", "/api/v1/resend-verification", input, nil, false)
|
||||
}
|
||||
|
||||
// RequestOTT requests a One-Time Token for login.
|
||||
func (c *Client) RequestOTT(ctx context.Context, email string) (*OTTResponse, error) {
|
||||
input := map[string]string{"email": email}
|
||||
var resp OTTResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/request-ott", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// VerifyOTT verifies a One-Time Token and returns the encrypted challenge.
|
||||
func (c *Client) VerifyOTT(ctx context.Context, email, ott string) (*VerifyOTTResponse, error) {
|
||||
input := map[string]string{
|
||||
"email": email,
|
||||
"ott": ott,
|
||||
}
|
||||
var resp VerifyOTTResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/verify-ott", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CompleteLogin completes the login process with the decrypted challenge.
|
||||
// On success, the client automatically stores the tokens and calls the OnTokenRefresh callback.
|
||||
func (c *Client) CompleteLogin(ctx context.Context, input *CompleteLoginInput) (*LoginResponse, error) {
|
||||
var resp LoginResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/complete-login", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the tokens
|
||||
c.SetTokens(resp.AccessToken, resp.RefreshToken)
|
||||
|
||||
// Notify callback if set, passing the expiry date
|
||||
if c.onTokenRefresh != nil {
|
||||
c.onTokenRefresh(resp.AccessToken, resp.RefreshToken, resp.AccessTokenExpiryDate)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RefreshToken manually refreshes the access token using the stored refresh token.
|
||||
// On success, the client automatically updates the stored tokens and calls the OnTokenRefresh callback.
|
||||
func (c *Client) RefreshToken(ctx context.Context) error {
|
||||
return c.refreshAccessToken(ctx)
|
||||
}
|
||||
|
||||
// RecoveryInitiate initiates the account recovery process.
|
||||
func (c *Client) RecoveryInitiate(ctx context.Context, email, method string) (*RecoveryInitiateResponse, error) {
|
||||
input := RecoveryInitiateInput{
|
||||
Email: email,
|
||||
Method: method,
|
||||
}
|
||||
var resp RecoveryInitiateResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/recovery/initiate", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RecoveryVerify verifies the recovery challenge.
|
||||
func (c *Client) RecoveryVerify(ctx context.Context, input *RecoveryVerifyInput) (*RecoveryVerifyResponse, error) {
|
||||
var resp RecoveryVerifyResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/recovery/verify", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RecoveryComplete completes the account recovery and resets credentials.
|
||||
func (c *Client) RecoveryComplete(ctx context.Context, input *RecoveryCompleteInput) (*RecoveryCompleteResponse, error) {
|
||||
var resp RecoveryCompleteResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/recovery/complete", input, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
468
cloud/maplefile-backend/pkg/maplefile/client/client.go
Normal file
468
cloud/maplefile-backend/pkg/maplefile/client/client.go
Normal file
|
|
@ -0,0 +1,468 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger is an interface for logging API requests.
|
||||
// This allows the client to work with any logging library (zap, logrus, etc.)
|
||||
type Logger interface {
|
||||
// Debug logs a debug message with optional key-value pairs
|
||||
Debug(msg string, keysAndValues ...interface{})
|
||||
// Info logs an info message with optional key-value pairs
|
||||
Info(msg string, keysAndValues ...interface{})
|
||||
// Warn logs a warning message with optional key-value pairs
|
||||
Warn(msg string, keysAndValues ...interface{})
|
||||
// Error logs an error message with optional key-value pairs
|
||||
Error(msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
// Client is the MapleFile API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
logger Logger
|
||||
|
||||
// Token storage with mutex for thread safety
|
||||
mu sync.RWMutex
|
||||
accessToken string
|
||||
refreshToken string
|
||||
|
||||
// Callback when tokens are refreshed
|
||||
// Parameters: accessToken, refreshToken, accessTokenExpiryDate (RFC3339 format)
|
||||
onTokenRefresh func(accessToken, refreshToken, accessTokenExpiryDate string)
|
||||
|
||||
// Flag to prevent recursive token refresh (atomic for lock-free reads)
|
||||
isRefreshing atomic.Bool
|
||||
}
|
||||
|
||||
// Predefined environment URLs
|
||||
const (
|
||||
// ProductionURL is the production API endpoint
|
||||
ProductionURL = "https://maplefile.ca"
|
||||
|
||||
// LocalURL is the default local development API endpoint
|
||||
LocalURL = "http://localhost:8000"
|
||||
)
|
||||
|
||||
// Config holds the configuration for creating a new Client.
|
||||
type Config struct {
|
||||
// BaseURL is the base URL of the MapleFile API (e.g., "https://maplefile.ca")
|
||||
// You can use predefined constants: ProductionURL or LocalURL
|
||||
BaseURL string
|
||||
|
||||
// HTTPClient is an optional custom HTTP client. If nil, a default client with 30s timeout is used.
|
||||
HTTPClient *http.Client
|
||||
|
||||
// Logger is an optional logger for API request logging. If nil, no logging is performed.
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// New creates a new MapleFile API client with the given configuration.
|
||||
//
|
||||
// Security Note: This client uses Go's standard http.Client without certificate
|
||||
// pinning. This is intentional and secure because:
|
||||
//
|
||||
// 1. TLS termination is handled by a reverse proxy (Caddy/Nginx) in production,
|
||||
// which manages certificates via Let's Encrypt with automatic renewal.
|
||||
// 2. Go's default TLS configuration already validates certificate chains,
|
||||
// expiration, and hostname matching against system CA roots.
|
||||
// 3. The application uses end-to-end encryption (E2EE) - even if TLS were
|
||||
// compromised, attackers would only see encrypted data they cannot decrypt.
|
||||
// 4. Certificate pinning would require app updates every 90 days (Let's Encrypt
|
||||
// rotation) or risk bricking deployed applications.
|
||||
//
|
||||
// See: docs/OWASP_AUDIT_REPORT.md (Finding 4.1) for full security analysis.
|
||||
func New(cfg Config) *Client {
|
||||
httpClient := cfg.HTTPClient
|
||||
if httpClient == nil {
|
||||
// Standard HTTP client with timeout. Certificate pinning is intentionally
|
||||
// not implemented - see security note above.
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure baseURL doesn't have trailing slash
|
||||
baseURL := strings.TrimSuffix(cfg.BaseURL, "/")
|
||||
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: httpClient,
|
||||
logger: cfg.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
// NewProduction creates a new MapleFile API client configured for production.
|
||||
func NewProduction() *Client {
|
||||
return New(Config{BaseURL: ProductionURL})
|
||||
}
|
||||
|
||||
// NewLocal creates a new MapleFile API client configured for local development.
|
||||
func NewLocal() *Client {
|
||||
return New(Config{BaseURL: LocalURL})
|
||||
}
|
||||
|
||||
// NewWithURL creates a new MapleFile API client with a custom URL.
|
||||
func NewWithURL(baseURL string) *Client {
|
||||
return New(Config{BaseURL: baseURL})
|
||||
}
|
||||
|
||||
// SetTokens sets the access and refresh tokens for authentication.
|
||||
func (c *Client) SetTokens(accessToken, refreshToken string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.accessToken = accessToken
|
||||
c.refreshToken = refreshToken
|
||||
}
|
||||
|
||||
// GetTokens returns the current access and refresh tokens.
|
||||
func (c *Client) GetTokens() (accessToken, refreshToken string) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.accessToken, c.refreshToken
|
||||
}
|
||||
|
||||
// OnTokenRefresh sets a callback function that will be called when tokens are refreshed.
|
||||
// This is useful for persisting the new tokens to storage.
|
||||
// The callback receives: accessToken, refreshToken, and accessTokenExpiryDate (RFC3339 format).
|
||||
func (c *Client) OnTokenRefresh(callback func(accessToken, refreshToken, accessTokenExpiryDate string)) {
|
||||
c.onTokenRefresh = callback
|
||||
}
|
||||
|
||||
// SetBaseURL changes the base URL of the API.
|
||||
// This is useful for switching between environments at runtime.
|
||||
func (c *Client) SetBaseURL(baseURL string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
}
|
||||
|
||||
// GetBaseURL returns the current base URL.
|
||||
func (c *Client) GetBaseURL() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
// Health checks if the API is healthy.
|
||||
func (c *Client) Health(ctx context.Context) (*HealthResponse, error) {
|
||||
var resp HealthResponse
|
||||
if err := c.doRequest(ctx, "GET", "/health", nil, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Version returns the API version information.
|
||||
func (c *Client) Version(ctx context.Context) (*VersionResponse, error) {
|
||||
var resp VersionResponse
|
||||
if err := c.doRequest(ctx, "GET", "/version", nil, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// doRequest performs an HTTP request with automatic token refresh on 401.
|
||||
func (c *Client) doRequest(ctx context.Context, method, path string, body interface{}, result interface{}, requiresAuth bool) error {
|
||||
return c.doRequestWithRetry(ctx, method, path, body, result, requiresAuth, true)
|
||||
}
|
||||
|
||||
// doRequestWithRetry performs an HTTP request with optional retry on 401.
|
||||
func (c *Client) doRequestWithRetry(ctx context.Context, method, path string, body interface{}, result interface{}, requiresAuth bool, allowRetry bool) error {
|
||||
// Build URL
|
||||
url := c.baseURL + path
|
||||
|
||||
// Log API request
|
||||
if c.logger != nil {
|
||||
c.logger.Info("API request", "method", method, "url", url)
|
||||
}
|
||||
|
||||
// Prepare request body
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
jsonData, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(jsonData)
|
||||
}
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
// Accept both standard JSON and RFC 9457 problem+json responses
|
||||
req.Header.Set("Accept", "application/json, application/problem+json")
|
||||
|
||||
// Add authorization header if required
|
||||
if requiresAuth {
|
||||
c.mu.RLock()
|
||||
token := c.accessToken
|
||||
c.mu.RUnlock()
|
||||
|
||||
if token == "" {
|
||||
return &APIError{
|
||||
ProblemDetail: ProblemDetail{
|
||||
Status: 401,
|
||||
Title: "Unauthorized",
|
||||
Detail: "No access token available",
|
||||
},
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("JWT %s", token))
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("API request failed", "method", method, "url", url, "error", err.Error())
|
||||
}
|
||||
return fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Log API response
|
||||
if c.logger != nil {
|
||||
c.logger.Info("API response", "method", method, "url", url, "status", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Handle 401 with automatic token refresh
|
||||
// Use atomic.Bool for lock-free check to avoid unnecessary lock acquisition
|
||||
if resp.StatusCode == http.StatusUnauthorized && requiresAuth && allowRetry && !c.isRefreshing.Load() {
|
||||
c.mu.Lock()
|
||||
// Double-check under lock and verify refresh token exists
|
||||
if c.refreshToken != "" && !c.isRefreshing.Load() {
|
||||
c.isRefreshing.Store(true)
|
||||
c.mu.Unlock()
|
||||
|
||||
// Attempt to refresh token
|
||||
refreshErr := c.refreshAccessToken(ctx)
|
||||
c.isRefreshing.Store(false)
|
||||
|
||||
if refreshErr == nil {
|
||||
// Retry the original request without allowing another retry
|
||||
return c.doRequestWithRetry(ctx, method, path, body, result, requiresAuth, false)
|
||||
}
|
||||
} else {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Handle error status codes
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return parseErrorResponse(respBody, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
if result != nil && len(respBody) > 0 {
|
||||
if err := json.Unmarshal(respBody, result); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshAccessToken attempts to refresh the access token using the refresh token.
|
||||
func (c *Client) refreshAccessToken(ctx context.Context) error {
|
||||
c.mu.RLock()
|
||||
refreshToken := c.refreshToken
|
||||
c.mu.RUnlock()
|
||||
|
||||
if refreshToken == "" {
|
||||
return fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Build refresh request
|
||||
url := c.baseURL + "/api/v1/token/refresh"
|
||||
|
||||
reqBody := map[string]string{
|
||||
"value": refreshToken, // Backend expects "value" field, not "refresh_token"
|
||||
}
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return parseErrorResponse(respBody, resp.StatusCode)
|
||||
}
|
||||
|
||||
// Parse the refresh response
|
||||
// Note: Backend returns access_token_expiry_date and refresh_token_expiry_date,
|
||||
// but the callback currently only passes tokens. Expiry dates are available
|
||||
// in the LoginResponse type if needed for future enhancements.
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccessTokenExpiryDate string `json:"access_token_expiry_date"`
|
||||
RefreshTokenExpiryDate string `json:"refresh_token_expiry_date"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update stored tokens
|
||||
c.mu.Lock()
|
||||
c.accessToken = tokenResp.AccessToken
|
||||
c.refreshToken = tokenResp.RefreshToken
|
||||
c.mu.Unlock()
|
||||
|
||||
// Notify callback if set, passing the expiry date so callers can track actual expiration
|
||||
if c.onTokenRefresh != nil {
|
||||
c.onTokenRefresh(tokenResp.AccessToken, tokenResp.RefreshToken, tokenResp.AccessTokenExpiryDate)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// doRequestRaw performs an HTTP request and returns the raw response body.
|
||||
// This is useful for endpoints that return non-JSON responses.
|
||||
func (c *Client) doRequestRaw(ctx context.Context, method, path string, body interface{}, requiresAuth bool) ([]byte, error) {
|
||||
return c.doRequestRawWithRetry(ctx, method, path, body, requiresAuth, true)
|
||||
}
|
||||
|
||||
// doRequestRawWithRetry performs an HTTP request with optional retry on 401.
|
||||
func (c *Client) doRequestRawWithRetry(ctx context.Context, method, path string, body interface{}, requiresAuth bool, allowRetry bool) ([]byte, error) {
|
||||
// Build URL
|
||||
url := c.baseURL + path
|
||||
|
||||
// Log API request
|
||||
if c.logger != nil {
|
||||
c.logger.Info("API request", "method", method, "url", url)
|
||||
}
|
||||
|
||||
// Prepare request body - we need to be able to re-read it for retry
|
||||
var bodyData []byte
|
||||
if body != nil {
|
||||
var err error
|
||||
bodyData, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create request
|
||||
var bodyReader io.Reader
|
||||
if bodyData != nil {
|
||||
bodyReader = bytes.NewReader(bodyData)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// Add authorization header if required
|
||||
if requiresAuth {
|
||||
c.mu.RLock()
|
||||
token := c.accessToken
|
||||
c.mu.RUnlock()
|
||||
|
||||
if token == "" {
|
||||
return nil, &APIError{
|
||||
ProblemDetail: ProblemDetail{
|
||||
Status: 401,
|
||||
Title: "Unauthorized",
|
||||
Detail: "No access token available",
|
||||
},
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("JWT %s", token))
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if c.logger != nil {
|
||||
c.logger.Error("API request failed", "method", method, "url", url, "error", err.Error())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Log API response
|
||||
if c.logger != nil {
|
||||
c.logger.Info("API response", "method", method, "url", url, "status", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Handle 401 with automatic token refresh
|
||||
if resp.StatusCode == http.StatusUnauthorized && requiresAuth && allowRetry && !c.isRefreshing.Load() {
|
||||
c.mu.Lock()
|
||||
if c.refreshToken != "" && !c.isRefreshing.Load() {
|
||||
c.isRefreshing.Store(true)
|
||||
c.mu.Unlock()
|
||||
|
||||
// Attempt to refresh token
|
||||
refreshErr := c.refreshAccessToken(ctx)
|
||||
c.isRefreshing.Store(false)
|
||||
|
||||
if refreshErr == nil {
|
||||
// Retry the original request without allowing another retry
|
||||
return c.doRequestRawWithRetry(ctx, method, path, body, requiresAuth, false)
|
||||
}
|
||||
} else {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Handle error status codes
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, parseErrorResponse(respBody, resp.StatusCode)
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
165
cloud/maplefile-backend/pkg/maplefile/client/collections.go
Normal file
165
cloud/maplefile-backend/pkg/maplefile/client/collections.go
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CreateCollection creates a new collection.
|
||||
func (c *Client) CreateCollection(ctx context.Context, input *CreateCollectionInput) (*Collection, error) {
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/collections", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListCollections returns all collections for the current user.
|
||||
func (c *Client) ListCollections(ctx context.Context) ([]*Collection, error) {
|
||||
var resp struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/collections", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Collections, nil
|
||||
}
|
||||
|
||||
// GetCollection returns a single collection by ID.
|
||||
func (c *Client) GetCollection(ctx context.Context, id string) (*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s", id)
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateCollection updates a collection.
|
||||
func (c *Client) UpdateCollection(ctx context.Context, id string, input *UpdateCollectionInput) (*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s", id)
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "PUT", path, input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteCollection soft-deletes a collection.
|
||||
func (c *Client) DeleteCollection(ctx context.Context, id string) error {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s", id)
|
||||
return c.doRequest(ctx, "DELETE", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// GetRootCollections returns all root-level collections (no parent).
|
||||
func (c *Client) GetRootCollections(ctx context.Context) ([]*Collection, error) {
|
||||
var resp struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/collections/root", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Collections, nil
|
||||
}
|
||||
|
||||
// GetCollectionsByParent returns all collections with the specified parent.
|
||||
func (c *Client) GetCollectionsByParent(ctx context.Context, parentID string) ([]*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/parent/%s", parentID)
|
||||
var resp struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Collections, nil
|
||||
}
|
||||
|
||||
// MoveCollection moves a collection to a new parent.
|
||||
func (c *Client) MoveCollection(ctx context.Context, id string, input *MoveCollectionInput) (*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s/move", id)
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "PUT", path, input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ShareCollection shares a collection with another user.
|
||||
func (c *Client) ShareCollection(ctx context.Context, id string, input *ShareCollectionInput) error {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s/share", id)
|
||||
return c.doRequest(ctx, "POST", path, input, nil, true)
|
||||
}
|
||||
|
||||
// RemoveCollectionMember removes a user from a shared collection.
|
||||
func (c *Client) RemoveCollectionMember(ctx context.Context, collectionID, userID string) error {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s/members/%s", collectionID, userID)
|
||||
return c.doRequest(ctx, "DELETE", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// ListSharedCollections returns all collections shared with the current user.
|
||||
func (c *Client) ListSharedCollections(ctx context.Context) ([]*Collection, error) {
|
||||
var resp struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/collections/shared", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Collections, nil
|
||||
}
|
||||
|
||||
// ArchiveCollection archives a collection.
|
||||
func (c *Client) ArchiveCollection(ctx context.Context, id string) (*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s/archive", id)
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "PUT", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RestoreCollection restores an archived collection.
|
||||
func (c *Client) RestoreCollection(ctx context.Context, id string) (*Collection, error) {
|
||||
path := fmt.Sprintf("/api/v1/collections/%s/restore", id)
|
||||
var resp Collection
|
||||
if err := c.doRequest(ctx, "PUT", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetFilteredCollections returns collections matching the specified filter.
|
||||
func (c *Client) GetFilteredCollections(ctx context.Context, filter *CollectionFilter) ([]*Collection, error) {
|
||||
path := "/api/v1/collections/filtered"
|
||||
if filter != nil {
|
||||
params := ""
|
||||
if filter.State != "" {
|
||||
params += fmt.Sprintf("state=%s", filter.State)
|
||||
}
|
||||
if filter.ParentID != "" {
|
||||
if params != "" {
|
||||
params += "&"
|
||||
}
|
||||
params += fmt.Sprintf("parent_id=%s", filter.ParentID)
|
||||
}
|
||||
if params != "" {
|
||||
path += "?" + params
|
||||
}
|
||||
}
|
||||
var resp struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Collections, nil
|
||||
}
|
||||
|
||||
// SyncCollections fetches collection changes since the given cursor.
|
||||
func (c *Client) SyncCollections(ctx context.Context, input *SyncInput) (*CollectionSyncResponse, error) {
|
||||
var resp CollectionSyncResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/collections/sync", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
157
cloud/maplefile-backend/pkg/maplefile/client/errors.go
Normal file
157
cloud/maplefile-backend/pkg/maplefile/client/errors.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ProblemDetail represents an RFC 9457 problem detail response from the API.
|
||||
type ProblemDetail struct {
|
||||
Type string `json:"type"`
|
||||
Status int `json:"status"`
|
||||
Title string `json:"title"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
Instance string `json:"instance,omitempty"`
|
||||
Errors map[string]string `json:"errors,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
TraceID string `json:"trace_id,omitempty"`
|
||||
}
|
||||
|
||||
// APIError wraps ProblemDetail for the error interface.
|
||||
type APIError struct {
|
||||
ProblemDetail
|
||||
}
|
||||
|
||||
// Error returns a formatted error message from the ProblemDetail.
|
||||
func (e *APIError) Error() string {
|
||||
var errMsg strings.Builder
|
||||
|
||||
if e.Detail != "" {
|
||||
errMsg.WriteString(e.Detail)
|
||||
} else {
|
||||
errMsg.WriteString(e.Title)
|
||||
}
|
||||
|
||||
if len(e.Errors) > 0 {
|
||||
errMsg.WriteString("\n\nValidation errors:")
|
||||
for field, message := range e.Errors {
|
||||
errMsg.WriteString(fmt.Sprintf("\n - %s: %s", field, message))
|
||||
}
|
||||
}
|
||||
|
||||
return errMsg.String()
|
||||
}
|
||||
|
||||
// StatusCode returns the HTTP status code from the error.
|
||||
func (e *APIError) StatusCode() int {
|
||||
return e.Status
|
||||
}
|
||||
|
||||
// GetValidationErrors returns the validation errors map.
|
||||
func (e *APIError) GetValidationErrors() map[string]string {
|
||||
return e.Errors
|
||||
}
|
||||
|
||||
// GetFieldError returns the error message for a specific field, or empty string if not found.
|
||||
func (e *APIError) GetFieldError(field string) string {
|
||||
if e.Errors == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Errors[field]
|
||||
}
|
||||
|
||||
// HasFieldError checks if a specific field has a validation error.
|
||||
func (e *APIError) HasFieldError(field string) bool {
|
||||
if e.Errors == nil {
|
||||
return false
|
||||
}
|
||||
_, exists := e.Errors[field]
|
||||
return exists
|
||||
}
|
||||
|
||||
// IsNotFound checks if the error is a 404 Not Found error.
|
||||
func IsNotFound(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return apiErr.Status == 404
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsUnauthorized checks if the error is a 401 Unauthorized error.
|
||||
func IsUnauthorized(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return apiErr.Status == 401
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsForbidden checks if the error is a 403 Forbidden error.
|
||||
func IsForbidden(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return apiErr.Status == 403
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsValidationError checks if the error has validation errors.
|
||||
func IsValidationError(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return len(apiErr.Errors) > 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsConflict checks if the error is a 409 Conflict error.
|
||||
func IsConflict(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return apiErr.Status == 409
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTooManyRequests checks if the error is a 429 Too Many Requests error.
|
||||
func IsTooManyRequests(err error) bool {
|
||||
if apiErr, ok := err.(*APIError); ok {
|
||||
return apiErr.Status == 429
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseErrorResponse attempts to parse an error response body into an APIError.
|
||||
// It tries RFC 9457 format first, then falls back to legacy format.
|
||||
//
|
||||
// Note: RFC 9457 specifies that error responses should use Content-Type: application/problem+json,
|
||||
// but we parse based on the response structure rather than Content-Type for maximum compatibility.
|
||||
func parseErrorResponse(body []byte, statusCode int) error {
|
||||
// Try to parse as RFC 9457 ProblemDetail
|
||||
// The presence of the "type" field distinguishes RFC 9457 from legacy responses
|
||||
var problem ProblemDetail
|
||||
if err := json.Unmarshal(body, &problem); err == nil && problem.Type != "" {
|
||||
return &APIError{ProblemDetail: problem}
|
||||
}
|
||||
|
||||
// Fallback for non-RFC 9457 errors
|
||||
var errorResponse map[string]interface{}
|
||||
if err := json.Unmarshal(body, &errorResponse); err == nil {
|
||||
if errMsg, ok := errorResponse["message"].(string); ok {
|
||||
return &APIError{
|
||||
ProblemDetail: ProblemDetail{
|
||||
Status: statusCode,
|
||||
Title: errMsg,
|
||||
Detail: errMsg,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: return raw body as error
|
||||
return &APIError{
|
||||
ProblemDetail: ProblemDetail{
|
||||
Status: statusCode,
|
||||
Title: fmt.Sprintf("HTTP %d", statusCode),
|
||||
Detail: string(body),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
)
|
||||
|
||||
// Example of handling RFC 9457 errors with validation details
|
||||
func ExampleAPIError_validation() {
|
||||
c := client.NewLocal()
|
||||
|
||||
// Attempt to register with invalid data
|
||||
_, err := c.Register(context.Background(), &client.RegisterInput{
|
||||
Email: "", // Missing required field
|
||||
FirstName: "", // Missing required field
|
||||
// ... other fields
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Check if it's an API error
|
||||
if apiErr, ok := err.(*client.APIError); ok {
|
||||
fmt.Printf("Error Type: %s\n", apiErr.Type)
|
||||
fmt.Printf("Status: %d\n", apiErr.Status)
|
||||
fmt.Printf("Title: %s\n", apiErr.Title)
|
||||
|
||||
// Check for validation errors
|
||||
if client.IsValidationError(err) {
|
||||
fmt.Println("\nValidation Errors:")
|
||||
for field, message := range apiErr.GetValidationErrors() {
|
||||
fmt.Printf(" %s: %s\n", field, message)
|
||||
}
|
||||
|
||||
// Check for specific field error
|
||||
if apiErr.HasFieldError("email") {
|
||||
fmt.Printf("\nEmail error: %s\n", apiErr.GetFieldError("email"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Example of checking specific error types
|
||||
func ExampleAPIError_statusChecks() {
|
||||
c := client.NewProduction()
|
||||
|
||||
user, err := c.GetMe(context.Background())
|
||||
if err != nil {
|
||||
// Use helper functions to check error types
|
||||
switch {
|
||||
case client.IsUnauthorized(err):
|
||||
fmt.Println("Authentication required - please login")
|
||||
// Redirect to login
|
||||
|
||||
case client.IsNotFound(err):
|
||||
fmt.Println("User not found")
|
||||
// Handle not found
|
||||
|
||||
case client.IsForbidden(err):
|
||||
fmt.Println("Access denied")
|
||||
// Show permission error
|
||||
|
||||
case client.IsTooManyRequests(err):
|
||||
fmt.Println("Rate limit exceeded - please try again later")
|
||||
// Implement backoff
|
||||
|
||||
case client.IsValidationError(err):
|
||||
fmt.Println("Validation failed - please check your input")
|
||||
// Show validation errors
|
||||
|
||||
default:
|
||||
fmt.Printf("Unexpected error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Welcome, %s!\n", user.Name)
|
||||
}
|
||||
|
||||
// Example of extracting error details for logging
|
||||
func ExampleAPIError_logging() {
|
||||
c := client.NewProduction()
|
||||
|
||||
_, err := c.CreateCollection(context.Background(), &client.CreateCollectionInput{
|
||||
Name: "Test Collection",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*client.APIError); ok {
|
||||
// Log structured error details
|
||||
fmt.Printf("API Error Details:\n")
|
||||
fmt.Printf(" Type: %s\n", apiErr.Type)
|
||||
fmt.Printf(" Status: %d\n", apiErr.StatusCode())
|
||||
fmt.Printf(" Title: %s\n", apiErr.Title)
|
||||
fmt.Printf(" Detail: %s\n", apiErr.Detail)
|
||||
fmt.Printf(" Instance: %s\n", apiErr.Instance)
|
||||
fmt.Printf(" TraceID: %s\n", apiErr.TraceID)
|
||||
fmt.Printf(" Timestamp: %s\n", apiErr.Timestamp)
|
||||
|
||||
if len(apiErr.Errors) > 0 {
|
||||
fmt.Println(" Field Errors:")
|
||||
for field, msg := range apiErr.Errors {
|
||||
fmt.Printf(" %s: %s\n", field, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Example of handling errors in a form validation context
|
||||
func ExampleAPIError_formValidation() {
|
||||
c := client.NewLocal()
|
||||
|
||||
type FormData struct {
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
Password string
|
||||
}
|
||||
|
||||
form := FormData{
|
||||
Email: "invalid-email",
|
||||
FirstName: "",
|
||||
LastName: "Doe",
|
||||
Password: "weak",
|
||||
}
|
||||
|
||||
_, err := c.Register(context.Background(), &client.RegisterInput{
|
||||
Email: form.Email,
|
||||
FirstName: form.FirstName,
|
||||
LastName: form.LastName,
|
||||
// ... other fields
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*client.APIError); ok {
|
||||
// Build form error messages
|
||||
formErrors := make(map[string]string)
|
||||
|
||||
if apiErr.HasFieldError("email") {
|
||||
formErrors["email"] = apiErr.GetFieldError("email")
|
||||
}
|
||||
if apiErr.HasFieldError("first_name") {
|
||||
formErrors["first_name"] = apiErr.GetFieldError("first_name")
|
||||
}
|
||||
if apiErr.HasFieldError("last_name") {
|
||||
formErrors["last_name"] = apiErr.GetFieldError("last_name")
|
||||
}
|
||||
|
||||
// Display errors to user
|
||||
for field, msg := range formErrors {
|
||||
fmt.Printf("Form field '%s': %s\n", field, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Example of handling conflict errors
|
||||
func ExampleAPIError_conflict() {
|
||||
c := client.NewProduction()
|
||||
|
||||
_, err := c.Register(context.Background(), &client.RegisterInput{
|
||||
Email: "existing@example.com",
|
||||
// ... other fields
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if client.IsConflict(err) {
|
||||
if apiErr, ok := err.(*client.APIError); ok {
|
||||
// The Detail field contains the conflict explanation
|
||||
fmt.Printf("Registration failed: %s\n", apiErr.Detail)
|
||||
// Output: "Registration failed: User with this email already exists"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
191
cloud/maplefile-backend/pkg/maplefile/client/files.go
Normal file
191
cloud/maplefile-backend/pkg/maplefile/client/files.go
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// CreatePendingFile creates a new file in pending state.
|
||||
func (c *Client) CreatePendingFile(ctx context.Context, input *CreateFileInput) (*PendingFile, error) {
|
||||
var resp PendingFile
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/files/pending", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetFile returns a single file by ID.
|
||||
func (c *Client) GetFile(ctx context.Context, id string) (*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s", id)
|
||||
var resp File
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateFile updates a file's metadata.
|
||||
func (c *Client) UpdateFile(ctx context.Context, id string, input *UpdateFileInput) (*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s", id)
|
||||
var resp File
|
||||
if err := c.doRequest(ctx, "PUT", path, input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteFile soft-deletes a file.
|
||||
func (c *Client) DeleteFile(ctx context.Context, id string) error {
|
||||
path := fmt.Sprintf("/api/v1/file/%s", id)
|
||||
return c.doRequest(ctx, "DELETE", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// DeleteMultipleFiles deletes multiple files at once.
|
||||
func (c *Client) DeleteMultipleFiles(ctx context.Context, fileIDs []string) error {
|
||||
input := DeleteMultipleFilesInput{FileIDs: fileIDs}
|
||||
return c.doRequest(ctx, "POST", "/api/v1/files/delete-multiple", input, nil, true)
|
||||
}
|
||||
|
||||
// GetPresignedUploadURL gets a presigned URL for uploading file content.
|
||||
func (c *Client) GetPresignedUploadURL(ctx context.Context, fileID string) (*PresignedURL, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/upload-url", fileID)
|
||||
var resp PresignedURL
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CompleteFileUpload marks the file upload as complete and transitions it to active state.
|
||||
func (c *Client) CompleteFileUpload(ctx context.Context, fileID string, input *CompleteUploadInput) (*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/complete", fileID)
|
||||
var resp File
|
||||
if err := c.doRequest(ctx, "POST", path, input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetPresignedDownloadURL gets a presigned URL for downloading file content.
|
||||
func (c *Client) GetPresignedDownloadURL(ctx context.Context, fileID string) (*PresignedDownloadResponse, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/download-url", fileID)
|
||||
var resp PresignedDownloadResponse
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ReportDownloadCompleted reports that a file download has completed.
|
||||
func (c *Client) ReportDownloadCompleted(ctx context.Context, fileID string) error {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/download-completed", fileID)
|
||||
return c.doRequest(ctx, "POST", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// ArchiveFile archives a file.
|
||||
func (c *Client) ArchiveFile(ctx context.Context, id string) (*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/archive", id)
|
||||
var resp File
|
||||
if err := c.doRequest(ctx, "PUT", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RestoreFile restores an archived file.
|
||||
func (c *Client) RestoreFile(ctx context.Context, id string) (*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/file/%s/restore", id)
|
||||
var resp File
|
||||
if err := c.doRequest(ctx, "PUT", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListFilesByCollection returns all files in a collection.
|
||||
func (c *Client) ListFilesByCollection(ctx context.Context, collectionID string) ([]*File, error) {
|
||||
path := fmt.Sprintf("/api/v1/collection/%s/files", collectionID)
|
||||
var resp struct {
|
||||
Files []*File `json:"files"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Files, nil
|
||||
}
|
||||
|
||||
// ListRecentFiles returns the user's recent files.
|
||||
func (c *Client) ListRecentFiles(ctx context.Context) ([]*File, error) {
|
||||
var resp struct {
|
||||
Files []*File `json:"files"`
|
||||
}
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/files/recent", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Files, nil
|
||||
}
|
||||
|
||||
// SyncFiles fetches file changes since the given cursor.
|
||||
func (c *Client) SyncFiles(ctx context.Context, input *SyncInput) (*FileSyncResponse, error) {
|
||||
var resp FileSyncResponse
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/files/sync", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UploadToPresignedURL uploads data to an S3 presigned URL.
|
||||
// This is a helper method for uploading encrypted file content directly to S3.
|
||||
func (c *Client) UploadToPresignedURL(ctx context.Context, presignedURL string, data []byte, contentType string) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "PUT", presignedURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create upload request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
req.ContentLength = int64(len(data))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload to presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadFromPresignedURL downloads data from an S3 presigned URL.
|
||||
// This is a helper method for downloading encrypted file content directly from S3.
|
||||
func (c *Client) DownloadFromPresignedURL(ctx context.Context, presignedURL string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", presignedURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create download request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download from presigned URL: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("download failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read download response: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
123
cloud/maplefile-backend/pkg/maplefile/client/tags.go
Normal file
123
cloud/maplefile-backend/pkg/maplefile/client/tags.go
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// CreateTag creates a new tag.
|
||||
func (c *Client) CreateTag(ctx context.Context, input *CreateTagInput) (*Tag, error) {
|
||||
var resp Tag
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/tags", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListTags returns all tags for the current user.
|
||||
func (c *Client) ListTags(ctx context.Context) ([]*Tag, error) {
|
||||
var resp ListTagsResponse
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/tags", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Tags, nil
|
||||
}
|
||||
|
||||
// GetTag returns a single tag by ID.
|
||||
func (c *Client) GetTag(ctx context.Context, id string) (*Tag, error) {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s", id)
|
||||
var resp Tag
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateTag updates a tag.
|
||||
func (c *Client) UpdateTag(ctx context.Context, id string, input *UpdateTagInput) (*Tag, error) {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s", id)
|
||||
var resp Tag
|
||||
if err := c.doRequest(ctx, "PUT", path, input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteTag deletes a tag.
|
||||
func (c *Client) DeleteTag(ctx context.Context, id string) error {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s", id)
|
||||
return c.doRequest(ctx, "DELETE", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// AssignTag assigns a tag to a collection or file.
|
||||
func (c *Client) AssignTag(ctx context.Context, input *CreateTagAssignmentInput) (*TagAssignment, error) {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s/assign", input.TagID)
|
||||
|
||||
// Create request body without TagID (since it's in the URL)
|
||||
requestBody := map[string]string{
|
||||
"entity_id": input.EntityID,
|
||||
"entity_type": input.EntityType,
|
||||
}
|
||||
|
||||
var resp TagAssignment
|
||||
if err := c.doRequest(ctx, "POST", path, requestBody, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UnassignTag removes a tag from a collection or file.
|
||||
func (c *Client) UnassignTag(ctx context.Context, tagID, entityID, entityType string) error {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s/entities/%s?entity_type=%s", tagID, entityID, entityType)
|
||||
return c.doRequest(ctx, "DELETE", path, nil, nil, true)
|
||||
}
|
||||
|
||||
// GetTagsForEntity returns all tags assigned to a specific entity (collection or file).
|
||||
func (c *Client) GetTagsForEntity(ctx context.Context, entityID, entityType string) ([]*Tag, error) {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s/%s", entityType, entityID)
|
||||
var resp ListTagsResponse
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Tags, nil
|
||||
}
|
||||
|
||||
// GetTagAssignments returns all assignments for a specific tag.
|
||||
func (c *Client) GetTagAssignments(ctx context.Context, tagID string) ([]*TagAssignment, error) {
|
||||
path := fmt.Sprintf("/api/v1/tags/%s/assignments", tagID)
|
||||
var resp ListTagAssignmentsResponse
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.TagAssignments, nil
|
||||
}
|
||||
|
||||
// SearchByTags searches for collections and files matching ALL the specified tags.
|
||||
// tagIDs: slice of tag UUIDs to search for
|
||||
// limit: maximum number of results (default 50, max 100 on backend)
|
||||
func (c *Client) SearchByTags(ctx context.Context, tagIDs []string, limit int) (*SearchByTagsResponse, error) {
|
||||
if len(tagIDs) == 0 {
|
||||
return nil, fmt.Errorf("at least one tag ID is required")
|
||||
}
|
||||
|
||||
// Build query string with comma-separated tag IDs
|
||||
tags := ""
|
||||
for i, id := range tagIDs {
|
||||
if i > 0 {
|
||||
tags += ","
|
||||
}
|
||||
tags += id
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/api/v1/tags/search?tags=%s", tags)
|
||||
if limit > 0 {
|
||||
path += fmt.Sprintf("&limit=%d", limit)
|
||||
}
|
||||
|
||||
var resp SearchByTagsResponse
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
598
cloud/maplefile-backend/pkg/maplefile/client/types.go
Normal file
598
cloud/maplefile-backend/pkg/maplefile/client/types.go
Normal file
|
|
@ -0,0 +1,598 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import "time"
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Health & Version Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// HealthResponse represents the health check response.
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VersionResponse represents the API version response.
|
||||
type VersionResponse struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Authentication Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// RegisterInput represents the registration request.
|
||||
type RegisterInput struct {
|
||||
BetaAccessCode string `json:"beta_access_code"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Phone string `json:"phone"`
|
||||
Country string `json:"country"`
|
||||
Timezone string `json:"timezone"`
|
||||
PasswordSalt string `json:"salt"`
|
||||
KDFAlgorithm string `json:"kdf_algorithm"`
|
||||
KDFIterations int `json:"kdf_iterations"`
|
||||
KDFMemory int `json:"kdf_memory"`
|
||||
KDFParallelism int `json:"kdf_parallelism"`
|
||||
KDFSaltLength int `json:"kdf_salt_length"`
|
||||
KDFKeyLength int `json:"kdf_key_length"`
|
||||
EncryptedMasterKey string `json:"encryptedMasterKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
EncryptedPrivateKey string `json:"encryptedPrivateKey"`
|
||||
EncryptedRecoveryKey string `json:"encryptedRecoveryKey"`
|
||||
MasterKeyEncryptedWithRecoveryKey string `json:"masterKeyEncryptedWithRecoveryKey"`
|
||||
AgreeTermsOfService bool `json:"agree_terms_of_service"`
|
||||
AgreePromotions bool `json:"agree_promotions"`
|
||||
AgreeToTrackingAcrossThirdPartyAppsAndServices bool `json:"agree_to_tracking_across_third_party_apps_and_services"`
|
||||
}
|
||||
|
||||
// RegisterResponse represents the registration response.
|
||||
type RegisterResponse struct {
|
||||
Message string `json:"message"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// VerifyEmailInput represents the email verification request.
|
||||
type VerifyEmailInput struct {
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// VerifyEmailResponse represents the email verification response.
|
||||
type VerifyEmailResponse struct {
|
||||
Message string `json:"message"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// ResendVerificationInput represents the resend verification request.
|
||||
type ResendVerificationInput struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// OTTResponse represents the OTT request response.
|
||||
type OTTResponse struct {
|
||||
Message string `json:"message"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// VerifyOTTResponse represents the OTT verification response.
|
||||
type VerifyOTTResponse struct {
|
||||
Message string `json:"message"`
|
||||
ChallengeID string `json:"challengeId"`
|
||||
EncryptedChallenge string `json:"encryptedChallenge"`
|
||||
Salt string `json:"salt"`
|
||||
EncryptedMasterKey string `json:"encryptedMasterKey"`
|
||||
EncryptedPrivateKey string `json:"encryptedPrivateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
// KDFAlgorithm specifies which key derivation algorithm to use.
|
||||
// Values: "PBKDF2-SHA256" (web frontend) or "argon2id" (native app legacy)
|
||||
KDFAlgorithm string `json:"kdfAlgorithm"`
|
||||
}
|
||||
|
||||
// CompleteLoginInput represents the complete login request.
|
||||
type CompleteLoginInput struct {
|
||||
Email string `json:"email"`
|
||||
ChallengeID string `json:"challengeId"`
|
||||
DecryptedData string `json:"decryptedData"`
|
||||
}
|
||||
|
||||
// LoginResponse represents the login response (from complete-login or token refresh).
|
||||
type LoginResponse struct {
|
||||
Message string `json:"message"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccessTokenExpiryDate string `json:"access_token_expiry_date"`
|
||||
RefreshTokenExpiryDate string `json:"refresh_token_expiry_date"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// RefreshTokenInput represents the token refresh request.
|
||||
type RefreshTokenInput struct {
|
||||
RefreshToken string `json:"value"`
|
||||
}
|
||||
|
||||
// RecoveryInitiateInput represents the recovery initiation request.
|
||||
type RecoveryInitiateInput struct {
|
||||
Email string `json:"email"`
|
||||
Method string `json:"method"` // "recovery_key"
|
||||
}
|
||||
|
||||
// RecoveryInitiateResponse represents the recovery initiation response.
|
||||
type RecoveryInitiateResponse struct {
|
||||
Message string `json:"message"`
|
||||
SessionID string `json:"session_id"`
|
||||
EncryptedChallenge string `json:"encrypted_challenge"`
|
||||
}
|
||||
|
||||
// RecoveryVerifyInput represents the recovery verification request.
|
||||
type RecoveryVerifyInput struct {
|
||||
SessionID string `json:"session_id"`
|
||||
DecryptedChallenge string `json:"decrypted_challenge"`
|
||||
}
|
||||
|
||||
// RecoveryVerifyResponse represents the recovery verification response.
|
||||
type RecoveryVerifyResponse struct {
|
||||
Message string `json:"message"`
|
||||
RecoveryToken string `json:"recovery_token"`
|
||||
CanResetCredentials bool `json:"can_reset_credentials"`
|
||||
}
|
||||
|
||||
// RecoveryCompleteInput represents the recovery completion request.
|
||||
type RecoveryCompleteInput struct {
|
||||
RecoveryToken string `json:"recovery_token"`
|
||||
NewSalt string `json:"new_salt"`
|
||||
NewPublicKey string `json:"new_public_key"`
|
||||
NewEncryptedMasterKey string `json:"new_encrypted_master_key"`
|
||||
NewEncryptedPrivateKey string `json:"new_encrypted_private_key"`
|
||||
NewEncryptedRecoveryKey string `json:"new_encrypted_recovery_key"`
|
||||
NewMasterKeyEncryptedWithRecoveryKey string `json:"new_master_key_encrypted_with_recovery_key"`
|
||||
}
|
||||
|
||||
// RecoveryCompleteResponse represents the recovery completion response.
|
||||
type RecoveryCompleteResponse struct {
|
||||
Message string `json:"message"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// User/Profile Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// User represents a user profile.
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Name string `json:"name"`
|
||||
LexicalName string `json:"lexical_name"`
|
||||
Role int8 `json:"role"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Timezone string `json:"timezone"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
PostalCode string `json:"postal_code,omitempty"`
|
||||
AddressLine1 string `json:"address_line1,omitempty"`
|
||||
AddressLine2 string `json:"address_line2,omitempty"`
|
||||
AgreePromotions bool `json:"agree_promotions,omitempty"`
|
||||
AgreeToTrackingAcrossThirdPartyAppsAndServices bool `json:"agree_to_tracking_across_third_party_apps_and_services,omitempty"`
|
||||
ShareNotificationsEnabled *bool `json:"share_notifications_enabled,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
Status int8 `json:"status"`
|
||||
ProfileVerificationStatus int8 `json:"profile_verification_status,omitempty"`
|
||||
WebsiteURL string `json:"website_url"`
|
||||
Description string `json:"description"`
|
||||
ComicBookStoreName string `json:"comic_book_store_name,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateUserInput represents the user update request.
|
||||
type UpdateUserInput struct {
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Phone string `json:"phone,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Timezone string `json:"timezone"`
|
||||
AgreePromotions bool `json:"agree_promotions,omitempty"`
|
||||
AgreeToTrackingAcrossThirdPartyAppsAndServices bool `json:"agree_to_tracking_across_third_party_apps_and_services,omitempty"`
|
||||
ShareNotificationsEnabled *bool `json:"share_notifications_enabled,omitempty"`
|
||||
}
|
||||
|
||||
// DeleteUserInput represents the user deletion request.
|
||||
type DeleteUserInput struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// PublicUser represents public user information returned from lookup.
|
||||
type PublicUser struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PublicKeyInBase64 string `json:"public_key_in_base64"`
|
||||
VerificationID string `json:"verification_id"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Blocked Email Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// CreateBlockedEmailInput represents the blocked email creation request.
|
||||
type CreateBlockedEmailInput struct {
|
||||
Email string `json:"email"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// BlockedEmail represents a blocked email entry.
|
||||
type BlockedEmail struct {
|
||||
UserID string `json:"user_id"`
|
||||
BlockedEmail string `json:"blocked_email"`
|
||||
BlockedUserID string `json:"blocked_user_id,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ListBlockedEmailsResponse represents the list of blocked emails response.
|
||||
type ListBlockedEmailsResponse struct {
|
||||
BlockedEmails []*BlockedEmail `json:"blocked_emails"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// DeleteBlockedEmailResponse represents the blocked email deletion response.
|
||||
type DeleteBlockedEmailResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Dashboard Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Dashboard represents dashboard data.
|
||||
type Dashboard struct {
|
||||
Summary DashboardSummary `json:"summary"`
|
||||
StorageUsageTrend StorageUsageTrend `json:"storage_usage_trend"`
|
||||
RecentFiles []RecentFileDashboard `json:"recent_files"`
|
||||
CollectionKeys []DashboardCollectionKey `json:"collection_keys,omitempty"`
|
||||
}
|
||||
|
||||
// DashboardCollectionKey contains the encrypted collection key for client-side decryption
|
||||
// This allows clients to decrypt file metadata without making additional API calls
|
||||
type DashboardCollectionKey struct {
|
||||
CollectionID string `json:"collection_id"`
|
||||
EncryptedCollectionKey string `json:"encrypted_collection_key"`
|
||||
EncryptedCollectionKeyNonce string `json:"encrypted_collection_key_nonce"`
|
||||
}
|
||||
|
||||
// DashboardResponse represents the dashboard response.
|
||||
type DashboardResponse struct {
|
||||
Dashboard *Dashboard `json:"dashboard"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// DashboardSummary represents dashboard summary data.
|
||||
type DashboardSummary struct {
|
||||
TotalFiles int `json:"total_files"`
|
||||
TotalFolders int `json:"total_folders"`
|
||||
StorageUsed StorageAmount `json:"storage_used"`
|
||||
StorageLimit StorageAmount `json:"storage_limit"`
|
||||
StorageUsagePercentage int `json:"storage_usage_percentage"`
|
||||
}
|
||||
|
||||
// StorageAmount represents a storage amount with value and unit.
|
||||
type StorageAmount struct {
|
||||
Value float64 `json:"value"`
|
||||
Unit string `json:"unit"`
|
||||
}
|
||||
|
||||
// StorageUsageTrend represents storage usage trend data.
|
||||
type StorageUsageTrend struct {
|
||||
Period string `json:"period"`
|
||||
DataPoints []DataPoint `json:"data_points"`
|
||||
}
|
||||
|
||||
// DataPoint represents a single data point in the usage trend.
|
||||
type DataPoint struct {
|
||||
Date string `json:"date"`
|
||||
Usage StorageAmount `json:"usage"`
|
||||
}
|
||||
|
||||
// RecentFileDashboard represents a recent file in the dashboard.
|
||||
// Note: File metadata is E2EE encrypted. Clients should use locally cached
|
||||
// decrypted data when available, or show placeholder text for cloud-only files.
|
||||
type RecentFileDashboard struct {
|
||||
ID string `json:"id"`
|
||||
CollectionID string `json:"collection_id"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
EncryptedMetadata string `json:"encrypted_metadata"`
|
||||
EncryptedFileKey EncryptedFileKeyData `json:"encrypted_file_key"`
|
||||
EncryptionVersion string `json:"encryption_version"`
|
||||
EncryptedHash string `json:"encrypted_hash"`
|
||||
EncryptedFileSizeInBytes int64 `json:"encrypted_file_size_in_bytes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Version uint64 `json:"version"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Collection Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// EncryptedKeyData represents an encrypted key with its nonce (used for collections and tags)
|
||||
type EncryptedKeyData struct {
|
||||
Ciphertext string `json:"ciphertext"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// Collection represents a file collection (folder).
|
||||
type Collection struct {
|
||||
ID string `json:"id"`
|
||||
ParentID string `json:"parent_id,omitempty"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EncryptedCollectionKey EncryptedKeyData `json:"encrypted_collection_key"`
|
||||
// CustomIcon is the decrypted custom icon for this collection.
|
||||
// Empty string means use default folder/album icon.
|
||||
// Contains either an emoji character (e.g., "📷") or "icon:<identifier>" for predefined icons.
|
||||
CustomIcon string `json:"custom_icon,omitempty"`
|
||||
// EncryptedCustomIcon is the encrypted version of CustomIcon (for sync operations).
|
||||
EncryptedCustomIcon string `json:"encrypted_custom_icon,omitempty"`
|
||||
TotalFiles int `json:"total_files"`
|
||||
TotalSizeInBytes int64 `json:"total_size_in_bytes"`
|
||||
State string `json:"state"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
SharedWith []Share `json:"shared_with,omitempty"`
|
||||
PermissionLevel string `json:"permission_level,omitempty"`
|
||||
IsOwner bool `json:"is_owner"`
|
||||
OwnerName string `json:"owner_name,omitempty"`
|
||||
OwnerEmail string `json:"owner_email,omitempty"`
|
||||
Tags []EmbeddedTag `json:"tags,omitempty"` // Tags assigned to this collection
|
||||
}
|
||||
|
||||
// Share represents a collection sharing entry.
|
||||
type Share struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PermissionLevel string `json:"permission_level"`
|
||||
SharedAt time.Time `json:"shared_at"`
|
||||
}
|
||||
|
||||
// CreateCollectionInput represents the collection creation request.
|
||||
type CreateCollectionInput struct {
|
||||
ParentID string `json:"parent_id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EncryptedCollectionKey string `json:"encrypted_collection_key"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// UpdateCollectionInput represents the collection update request.
|
||||
type UpdateCollectionInput struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// MoveCollectionInput represents the collection move request.
|
||||
type MoveCollectionInput struct {
|
||||
NewParentID string `json:"new_parent_id"`
|
||||
}
|
||||
|
||||
// ShareCollectionInput represents the collection sharing request.
|
||||
type ShareCollectionInput struct {
|
||||
Email string `json:"email"`
|
||||
PermissionLevel string `json:"permission_level"` // "read_only", "read_write", "admin"
|
||||
EncryptedCollectionKey string `json:"encrypted_collection_key"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// CollectionFilter represents filters for listing collections.
|
||||
type CollectionFilter struct {
|
||||
State string `json:"state,omitempty"` // "active", "archived", "trashed"
|
||||
ParentID string `json:"parent_id,omitempty"`
|
||||
}
|
||||
|
||||
// SyncInput represents the sync request.
|
||||
type SyncInput struct {
|
||||
Cursor string `json:"cursor,omitempty"`
|
||||
Limit int64 `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// CollectionSyncResponse represents the collection sync response.
|
||||
type CollectionSyncResponse struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
NextCursor string `json:"next_cursor,omitempty"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// File Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// EncryptedFileKeyData represents the encrypted file key structure returned by the API.
|
||||
type EncryptedFileKeyData struct {
|
||||
Ciphertext string `json:"ciphertext"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
// File represents a file in a collection.
|
||||
type File struct {
|
||||
ID string `json:"id"`
|
||||
CollectionID string `json:"collection_id"`
|
||||
UserID string `json:"user_id"`
|
||||
EncryptedFileKey EncryptedFileKeyData `json:"encrypted_file_key"`
|
||||
FileKeyNonce string `json:"file_key_nonce"`
|
||||
EncryptedMetadata string `json:"encrypted_metadata"`
|
||||
MetadataNonce string `json:"metadata_nonce"`
|
||||
FileNonce string `json:"file_nonce"`
|
||||
EncryptedSizeInBytes int64 `json:"encrypted_file_size_in_bytes"`
|
||||
DecryptedSizeInBytes int64 `json:"decrypted_size_in_bytes,omitempty"`
|
||||
State string `json:"state"`
|
||||
StorageMode string `json:"storage_mode"`
|
||||
Version int `json:"version"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||
Tags []*EmbeddedTag `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// PendingFile represents a file in pending state (awaiting upload).
|
||||
type PendingFile struct {
|
||||
ID string `json:"id"`
|
||||
CollectionID string `json:"collection_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// CreateFileInput represents the file creation request.
|
||||
type CreateFileInput struct {
|
||||
CollectionID string `json:"collection_id"`
|
||||
EncryptedFileKey string `json:"encrypted_file_key"`
|
||||
FileKeyNonce string `json:"file_key_nonce"`
|
||||
EncryptedMetadata string `json:"encrypted_metadata"`
|
||||
MetadataNonce string `json:"metadata_nonce"`
|
||||
FileNonce string `json:"file_nonce"`
|
||||
EncryptedSizeInBytes int64 `json:"encrypted_size_in_bytes"`
|
||||
}
|
||||
|
||||
// UpdateFileInput represents the file update request.
|
||||
type UpdateFileInput struct {
|
||||
EncryptedMetadata string `json:"encrypted_metadata,omitempty"`
|
||||
MetadataNonce string `json:"metadata_nonce,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteUploadInput represents the file upload completion request.
|
||||
type CompleteUploadInput struct {
|
||||
ActualFileSizeInBytes int64 `json:"actual_file_size_in_bytes"`
|
||||
UploadConfirmed bool `json:"upload_confirmed"`
|
||||
}
|
||||
|
||||
// PresignedURL represents a presigned upload URL response.
|
||||
type PresignedURL struct {
|
||||
URL string `json:"url"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
|
||||
// PresignedDownloadResponse represents a presigned download URL response.
|
||||
type PresignedDownloadResponse struct {
|
||||
FileURL string `json:"file_url"`
|
||||
ThumbnailURL string `json:"thumbnail_url,omitempty"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
|
||||
// DeleteMultipleFilesInput represents the multiple files deletion request.
|
||||
type DeleteMultipleFilesInput struct {
|
||||
FileIDs []string `json:"file_ids"`
|
||||
}
|
||||
|
||||
// FileSyncResponse represents the file sync response.
|
||||
type FileSyncResponse struct {
|
||||
Files []*File `json:"files"`
|
||||
NextCursor string `json:"next_cursor,omitempty"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
// ListFilesResponse represents the list files response.
|
||||
type ListFilesResponse struct {
|
||||
Files []*File `json:"files"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Tag Types
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Tag represents a user-defined label with color that can be assigned to collections or files.
|
||||
// All sensitive data (name, color) is encrypted end-to-end.
|
||||
type Tag struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
EncryptedName string `json:"encrypted_name"`
|
||||
EncryptedColor string `json:"encrypted_color"`
|
||||
EncryptedTagKey *EncryptedTagKey `json:"encrypted_tag_key"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Version uint64 `json:"version"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// EncryptedTagKey represents the encrypted tag key data
|
||||
type EncryptedTagKey struct {
|
||||
Ciphertext string `json:"ciphertext"` // Base64 encoded
|
||||
Nonce string `json:"nonce"` // Base64 encoded
|
||||
KeyVersion int `json:"key_version,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddedTag represents tag data that is embedded in collections and files
|
||||
// This eliminates the need for frontend API lookups to get tag colors
|
||||
type EmbeddedTag struct {
|
||||
ID string `json:"id"`
|
||||
EncryptedName string `json:"encrypted_name"`
|
||||
EncryptedColor string `json:"encrypted_color"`
|
||||
EncryptedTagKey *EncryptedTagKey `json:"encrypted_tag_key"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
}
|
||||
|
||||
// CreateTagInput represents the tag creation request
|
||||
type CreateTagInput struct {
|
||||
ID string `json:"id"`
|
||||
EncryptedName string `json:"encrypted_name"`
|
||||
EncryptedColor string `json:"encrypted_color"`
|
||||
EncryptedTagKey *EncryptedTagKey `json:"encrypted_tag_key"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Version uint64 `json:"version"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// UpdateTagInput represents the tag update request
|
||||
type UpdateTagInput struct {
|
||||
EncryptedName string `json:"encrypted_name,omitempty"`
|
||||
EncryptedColor string `json:"encrypted_color,omitempty"`
|
||||
EncryptedTagKey *EncryptedTagKey `json:"encrypted_tag_key"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Version uint64 `json:"version"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// ListTagsResponse represents the list tags response
|
||||
type ListTagsResponse struct {
|
||||
Tags []*Tag `json:"tags"`
|
||||
}
|
||||
|
||||
// TagAssignment represents the assignment of a tag to a collection or file
|
||||
type TagAssignment struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
TagID string `json:"tag_id"`
|
||||
EntityID string `json:"entity_id"`
|
||||
EntityType string `json:"entity_type"` // "collection" or "file"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CreateTagAssignmentInput represents the tag assignment request
|
||||
type CreateTagAssignmentInput struct {
|
||||
TagID string `json:"tag_id"`
|
||||
EntityID string `json:"entity_id"`
|
||||
EntityType string `json:"entity_type"`
|
||||
}
|
||||
|
||||
// ListTagAssignmentsResponse represents the list tag assignments response
|
||||
type ListTagAssignmentsResponse struct {
|
||||
TagAssignments []*TagAssignment `json:"tag_assignments"`
|
||||
}
|
||||
|
||||
// SearchByTagsResponse represents the unified search by tags response
|
||||
type SearchByTagsResponse struct {
|
||||
Collections []*Collection `json:"collections"`
|
||||
Files []*File `json:"files"`
|
||||
TagCount int `json:"tag_count"`
|
||||
CollectionCount int `json:"collection_count"`
|
||||
FileCount int `json:"file_count"`
|
||||
}
|
||||
84
cloud/maplefile-backend/pkg/maplefile/client/user.go
Normal file
84
cloud/maplefile-backend/pkg/maplefile/client/user.go
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// Package client provides a Go SDK for interacting with the MapleFile API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// GetMe returns the current authenticated user's profile.
|
||||
func (c *Client) GetMe(ctx context.Context) (*User, error) {
|
||||
var resp User
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/me", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateMe updates the current user's profile.
|
||||
func (c *Client) UpdateMe(ctx context.Context, input *UpdateUserInput) (*User, error) {
|
||||
var resp User
|
||||
if err := c.doRequest(ctx, "PUT", "/api/v1/me", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteMe deletes the current user's account.
|
||||
func (c *Client) DeleteMe(ctx context.Context, password string) error {
|
||||
input := DeleteUserInput{Password: password}
|
||||
return c.doRequest(ctx, "DELETE", "/api/v1/me", input, nil, true)
|
||||
}
|
||||
|
||||
// PublicUserLookup looks up a user by email (returns public info only).
|
||||
// This endpoint does not require authentication.
|
||||
func (c *Client) PublicUserLookup(ctx context.Context, email string) (*PublicUser, error) {
|
||||
path := fmt.Sprintf("/iam/api/v1/users/lookup?email=%s", url.QueryEscape(email))
|
||||
var resp PublicUser
|
||||
if err := c.doRequest(ctx, "GET", path, nil, &resp, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateBlockedEmail adds an email to the blocked list.
|
||||
func (c *Client) CreateBlockedEmail(ctx context.Context, email, reason string) (*BlockedEmail, error) {
|
||||
input := CreateBlockedEmailInput{
|
||||
Email: email,
|
||||
Reason: reason,
|
||||
}
|
||||
var resp BlockedEmail
|
||||
if err := c.doRequest(ctx, "POST", "/api/v1/me/blocked-emails", input, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListBlockedEmails returns all blocked emails for the current user.
|
||||
func (c *Client) ListBlockedEmails(ctx context.Context) (*ListBlockedEmailsResponse, error) {
|
||||
var resp ListBlockedEmailsResponse
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/me/blocked-emails", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteBlockedEmail removes an email from the blocked list.
|
||||
func (c *Client) DeleteBlockedEmail(ctx context.Context, email string) (*DeleteBlockedEmailResponse, error) {
|
||||
path := fmt.Sprintf("/api/v1/me/blocked-emails/%s", url.PathEscape(email))
|
||||
var resp DeleteBlockedEmailResponse
|
||||
if err := c.doRequest(ctx, "DELETE", path, nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDashboard returns the user's dashboard data.
|
||||
func (c *Client) GetDashboard(ctx context.Context) (*DashboardResponse, error) {
|
||||
var resp DashboardResponse
|
||||
if err := c.doRequest(ctx, "GET", "/api/v1/dashboard", nil, &resp, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
462
cloud/maplefile-backend/pkg/maplefile/e2ee/crypto.go
Normal file
462
cloud/maplefile-backend/pkg/maplefile/e2ee/crypto.go
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
// Package e2ee provides end-to-end encryption operations for the MapleFile SDK.
|
||||
package e2ee
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// KDF Algorithm identifiers
|
||||
const (
|
||||
Argon2IDAlgorithm = "argon2id"
|
||||
PBKDF2Algorithm = "PBKDF2-SHA256"
|
||||
)
|
||||
|
||||
// Argon2id key derivation parameters
|
||||
const (
|
||||
Argon2MemLimit = 4 * 1024 * 1024 // 4 MB
|
||||
Argon2OpsLimit = 1 // 1 iteration (time cost)
|
||||
Argon2Parallelism = 1 // 1 thread
|
||||
Argon2KeySize = 32 // 256-bit output
|
||||
Argon2SaltSize = 16 // 128-bit salt
|
||||
)
|
||||
|
||||
// PBKDF2 key derivation parameters (matching web frontend)
|
||||
const (
|
||||
PBKDF2Iterations = 100000 // 100,000 iterations (matching web frontend)
|
||||
PBKDF2KeySize = 32 // 256-bit output
|
||||
PBKDF2SaltSize = 16 // 128-bit salt
|
||||
)
|
||||
|
||||
// ChaCha20-Poly1305 constants (IETF variant - 12 byte nonce)
|
||||
const (
|
||||
ChaCha20Poly1305KeySize = 32 // ChaCha20 key size
|
||||
ChaCha20Poly1305NonceSize = 12 // ChaCha20-Poly1305 nonce size
|
||||
ChaCha20Poly1305Overhead = 16 // Poly1305 authentication tag size
|
||||
)
|
||||
|
||||
// XSalsa20-Poly1305 (NaCl secretbox) constants - 24 byte nonce
|
||||
// Used by web frontend (libsodium crypto_secretbox_easy)
|
||||
const (
|
||||
SecretBoxKeySize = 32 // Same as ChaCha20
|
||||
SecretBoxNonceSize = 24 // XSalsa20 uses 24-byte nonce
|
||||
SecretBoxOverhead = secretbox.Overhead // 16 bytes (Poly1305 tag)
|
||||
)
|
||||
|
||||
// Key sizes
|
||||
const (
|
||||
MasterKeySize = 32
|
||||
CollectionKeySize = 32
|
||||
FileKeySize = 32
|
||||
RecoveryKeySize = 32
|
||||
)
|
||||
|
||||
// NaCl Box constants
|
||||
const (
|
||||
BoxPublicKeySize = 32
|
||||
BoxSecretKeySize = 32
|
||||
BoxNonceSize = 24
|
||||
)
|
||||
|
||||
// EncryptedData represents encrypted data with its nonce.
|
||||
type EncryptedData struct {
|
||||
Ciphertext []byte
|
||||
Nonce []byte
|
||||
}
|
||||
|
||||
// DeriveKeyFromPassword derives a key encryption key (KEK) from a password using Argon2id.
|
||||
// This is the legacy function - prefer DeriveKeyFromPasswordWithAlgorithm for new code.
|
||||
func DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) {
|
||||
return DeriveKeyFromPasswordArgon2id(password, salt)
|
||||
}
|
||||
|
||||
// DeriveKeyFromPasswordArgon2id derives a KEK using Argon2id algorithm.
|
||||
// SECURITY: Password bytes are wiped from memory after key derivation.
|
||||
func DeriveKeyFromPasswordArgon2id(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != Argon2SaltSize {
|
||||
return nil, fmt.Errorf("invalid salt size: expected %d, got %d", Argon2SaltSize, len(salt))
|
||||
}
|
||||
|
||||
passwordBytes := []byte(password)
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after use
|
||||
|
||||
key := argon2.IDKey(
|
||||
passwordBytes,
|
||||
salt,
|
||||
Argon2OpsLimit, // time cost = 1
|
||||
Argon2MemLimit, // memory = 4 MB
|
||||
Argon2Parallelism, // parallelism = 1
|
||||
Argon2KeySize, // output size = 32 bytes
|
||||
)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeriveKeyFromPasswordPBKDF2 derives a KEK using PBKDF2-SHA256 algorithm.
|
||||
// This matches the web frontend's implementation.
|
||||
// SECURITY: Password bytes are wiped from memory after key derivation.
|
||||
func DeriveKeyFromPasswordPBKDF2(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != PBKDF2SaltSize {
|
||||
return nil, fmt.Errorf("invalid salt size: expected %d, got %d", PBKDF2SaltSize, len(salt))
|
||||
}
|
||||
|
||||
passwordBytes := []byte(password)
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after use
|
||||
|
||||
key := pbkdf2.Key(
|
||||
passwordBytes,
|
||||
salt,
|
||||
PBKDF2Iterations, // 100,000 iterations
|
||||
PBKDF2KeySize, // 32 bytes output
|
||||
sha256.New, // SHA-256 hash
|
||||
)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeriveKeyFromPasswordWithAlgorithm derives a KEK using the specified algorithm.
|
||||
// algorithm should be one of: Argon2IDAlgorithm, PBKDF2Algorithm
|
||||
func DeriveKeyFromPasswordWithAlgorithm(password string, salt []byte, algorithm string) ([]byte, error) {
|
||||
switch algorithm {
|
||||
case Argon2IDAlgorithm: // "argon2id"
|
||||
return DeriveKeyFromPasswordArgon2id(password, salt)
|
||||
case PBKDF2Algorithm, "pbkdf2", "pbkdf2-sha256":
|
||||
return DeriveKeyFromPasswordPBKDF2(password, salt)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported KDF algorithm: %s", algorithm)
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt encrypts data with a symmetric key using ChaCha20-Poly1305.
|
||||
func Encrypt(data, key []byte) (*EncryptedData, error) {
|
||||
if len(key) != ChaCha20Poly1305KeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", ChaCha20Poly1305KeySize, len(key))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate random nonce (12 bytes for ChaCha20-Poly1305)
|
||||
nonce, err := GenerateRandomBytes(ChaCha20Poly1305NonceSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext := cipher.Seal(nil, nonce, data, nil)
|
||||
|
||||
return &EncryptedData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts data with a symmetric key using ChaCha20-Poly1305.
|
||||
func Decrypt(ciphertext, nonce, key []byte) ([]byte, error) {
|
||||
if len(key) != ChaCha20Poly1305KeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", ChaCha20Poly1305KeySize, len(key))
|
||||
}
|
||||
|
||||
if len(nonce) != ChaCha20Poly1305NonceSize {
|
||||
return nil, fmt.Errorf("invalid nonce size: expected %d, got %d", ChaCha20Poly1305NonceSize, len(nonce))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := cipher.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptWithSecretBox encrypts data with a symmetric key using XSalsa20-Poly1305 (NaCl secretbox).
|
||||
// This is compatible with libsodium's crypto_secretbox_easy used by the web frontend.
|
||||
// SECURITY: Key arrays are wiped from memory after encryption.
|
||||
func EncryptWithSecretBox(data, key []byte) (*EncryptedData, error) {
|
||||
if len(key) != SecretBoxKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", SecretBoxKeySize, len(key))
|
||||
}
|
||||
|
||||
// Generate random nonce (24 bytes for XSalsa20)
|
||||
nonce, err := GenerateRandomBytes(SecretBoxNonceSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Convert to fixed-size arrays for NaCl
|
||||
var keyArray [32]byte
|
||||
var nonceArray [24]byte
|
||||
copy(keyArray[:], key)
|
||||
copy(nonceArray[:], nonce)
|
||||
defer memguard.WipeBytes(keyArray[:]) // SECURITY: Wipe key array
|
||||
|
||||
// Encrypt using secretbox
|
||||
ciphertext := secretbox.Seal(nil, data, &nonceArray, &keyArray)
|
||||
|
||||
return &EncryptedData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecryptWithSecretBox decrypts data with a symmetric key using XSalsa20-Poly1305 (NaCl secretbox).
|
||||
// This is compatible with libsodium's crypto_secretbox_open_easy used by the web frontend.
|
||||
// SECURITY: Key arrays are wiped from memory after decryption.
|
||||
func DecryptWithSecretBox(ciphertext, nonce, key []byte) ([]byte, error) {
|
||||
if len(key) != SecretBoxKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", SecretBoxKeySize, len(key))
|
||||
}
|
||||
|
||||
if len(nonce) != SecretBoxNonceSize {
|
||||
return nil, fmt.Errorf("invalid nonce size: expected %d, got %d", SecretBoxNonceSize, len(nonce))
|
||||
}
|
||||
|
||||
// Convert to fixed-size arrays for NaCl
|
||||
var keyArray [32]byte
|
||||
var nonceArray [24]byte
|
||||
copy(keyArray[:], key)
|
||||
copy(nonceArray[:], nonce)
|
||||
defer memguard.WipeBytes(keyArray[:]) // SECURITY: Wipe key array
|
||||
|
||||
// Decrypt using secretbox
|
||||
plaintext, ok := secretbox.Open(nil, ciphertext, &nonceArray, &keyArray)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to decrypt: invalid key, nonce, or corrupted ciphertext")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// DecryptWithAlgorithm decrypts data using the appropriate cipher based on nonce size.
|
||||
// - 12-byte nonce: ChaCha20-Poly1305 (IETF variant)
|
||||
// - 24-byte nonce: XSalsa20-Poly1305 (NaCl secretbox)
|
||||
func DecryptWithAlgorithm(ciphertext, nonce, key []byte) ([]byte, error) {
|
||||
switch len(nonce) {
|
||||
case ChaCha20Poly1305NonceSize: // 12 bytes
|
||||
return Decrypt(ciphertext, nonce, key)
|
||||
case SecretBoxNonceSize: // 24 bytes
|
||||
return DecryptWithSecretBox(ciphertext, nonce, key)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid nonce size: %d (expected %d for ChaCha20 or %d for XSalsa20)",
|
||||
len(nonce), ChaCha20Poly1305NonceSize, SecretBoxNonceSize)
|
||||
}
|
||||
}
|
||||
|
||||
// EncryptWithBoxSeal encrypts data anonymously using NaCl sealed box.
|
||||
// The result format is: ephemeral_public_key (32) || nonce (24) || ciphertext + auth_tag.
|
||||
func EncryptWithBoxSeal(message []byte, recipientPublicKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != BoxPublicKeySize {
|
||||
return nil, fmt.Errorf("recipient public key must be %d bytes", BoxPublicKeySize)
|
||||
}
|
||||
|
||||
var recipientPubKey [32]byte
|
||||
copy(recipientPubKey[:], recipientPublicKey)
|
||||
|
||||
// Generate ephemeral keypair
|
||||
ephemeralPubKey, ephemeralPrivKey, err := box.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral keypair: %w", err)
|
||||
}
|
||||
|
||||
// Generate random nonce
|
||||
nonce, err := GenerateRandomBytes(BoxNonceSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
var nonceArray [24]byte
|
||||
copy(nonceArray[:], nonce)
|
||||
|
||||
// Encrypt with ephemeral private key
|
||||
ciphertext := box.Seal(nil, message, &nonceArray, &recipientPubKey, ephemeralPrivKey)
|
||||
|
||||
// Result format: ephemeral_public_key || nonce || ciphertext
|
||||
result := make([]byte, BoxPublicKeySize+BoxNonceSize+len(ciphertext))
|
||||
copy(result[:BoxPublicKeySize], ephemeralPubKey[:])
|
||||
copy(result[BoxPublicKeySize:BoxPublicKeySize+BoxNonceSize], nonce)
|
||||
copy(result[BoxPublicKeySize+BoxNonceSize:], ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptWithBoxSeal decrypts data that was encrypted with EncryptWithBoxSeal.
|
||||
// SECURITY: Key arrays are wiped from memory after decryption.
|
||||
func DecryptWithBoxSeal(sealedData []byte, recipientPublicKey, recipientPrivateKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != BoxPublicKeySize {
|
||||
return nil, fmt.Errorf("recipient public key must be %d bytes", BoxPublicKeySize)
|
||||
}
|
||||
if len(recipientPrivateKey) != BoxSecretKeySize {
|
||||
return nil, fmt.Errorf("recipient private key must be %d bytes", BoxSecretKeySize)
|
||||
}
|
||||
if len(sealedData) < BoxPublicKeySize+BoxNonceSize+box.Overhead {
|
||||
return nil, errors.New("sealed data too short")
|
||||
}
|
||||
|
||||
// Extract components
|
||||
ephemeralPublicKey := sealedData[:BoxPublicKeySize]
|
||||
nonce := sealedData[BoxPublicKeySize : BoxPublicKeySize+BoxNonceSize]
|
||||
ciphertext := sealedData[BoxPublicKeySize+BoxNonceSize:]
|
||||
|
||||
// Create fixed-size arrays
|
||||
var ephemeralPubKey [32]byte
|
||||
var recipientPrivKey [32]byte
|
||||
var nonceArray [24]byte
|
||||
copy(ephemeralPubKey[:], ephemeralPublicKey)
|
||||
copy(recipientPrivKey[:], recipientPrivateKey)
|
||||
copy(nonceArray[:], nonce)
|
||||
defer memguard.WipeBytes(recipientPrivKey[:]) // SECURITY: Wipe private key array
|
||||
|
||||
// Decrypt
|
||||
plaintext, ok := box.Open(nil, ciphertext, &nonceArray, &ephemeralPubKey, &recipientPrivKey)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to decrypt sealed box: invalid keys or corrupted ciphertext")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// DecryptAnonymousBox decrypts sealed box data (used in login challenges).
|
||||
// SECURITY: Key arrays are wiped from memory after decryption.
|
||||
func DecryptAnonymousBox(encryptedData []byte, recipientPublicKey, recipientPrivateKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != BoxPublicKeySize {
|
||||
return nil, fmt.Errorf("recipient public key must be %d bytes", BoxPublicKeySize)
|
||||
}
|
||||
if len(recipientPrivateKey) != BoxSecretKeySize {
|
||||
return nil, fmt.Errorf("recipient private key must be %d bytes", BoxSecretKeySize)
|
||||
}
|
||||
|
||||
var pubKeyArray, privKeyArray [32]byte
|
||||
copy(pubKeyArray[:], recipientPublicKey)
|
||||
copy(privKeyArray[:], recipientPrivateKey)
|
||||
defer memguard.WipeBytes(privKeyArray[:]) // SECURITY: Wipe private key array
|
||||
|
||||
decryptedData, ok := box.OpenAnonymous(nil, encryptedData, &pubKeyArray, &privKeyArray)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to decrypt anonymous box: invalid keys or corrupted data")
|
||||
}
|
||||
|
||||
return decryptedData, nil
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates cryptographically secure random bytes.
|
||||
func GenerateRandomBytes(size int) ([]byte, error) {
|
||||
if size <= 0 {
|
||||
return nil, errors.New("size must be positive")
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
_, err := io.ReadFull(rand.Reader, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a NaCl box keypair for asymmetric encryption.
|
||||
func GenerateKeyPair() (publicKey []byte, privateKey []byte, err error) {
|
||||
pubKey, privKey, err := box.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate key pair: %w", err)
|
||||
}
|
||||
|
||||
return pubKey[:], privKey[:], nil
|
||||
}
|
||||
|
||||
// ClearBytes overwrites a byte slice with zeros using memguard for secure wiping.
|
||||
// This should be called on sensitive data like keys when they're no longer needed.
|
||||
// SECURITY: Uses memguard.WipeBytes for secure memory wiping that prevents compiler optimizations.
|
||||
func ClearBytes(b []byte) {
|
||||
memguard.WipeBytes(b)
|
||||
}
|
||||
|
||||
// CombineNonceAndCiphertext combines nonce and ciphertext into a single byte slice.
|
||||
func CombineNonceAndCiphertext(nonce, ciphertext []byte) []byte {
|
||||
combined := make([]byte, len(nonce)+len(ciphertext))
|
||||
copy(combined[:len(nonce)], nonce)
|
||||
copy(combined[len(nonce):], ciphertext)
|
||||
return combined
|
||||
}
|
||||
|
||||
// SplitNonceAndCiphertext splits a combined byte slice into nonce and ciphertext.
|
||||
// This function defaults to ChaCha20-Poly1305 nonce size (12 bytes) for backward compatibility.
|
||||
// For XSalsa20-Poly1305 (24-byte nonce), use SplitNonceAndCiphertextSecretBox.
|
||||
func SplitNonceAndCiphertext(combined []byte) (nonce []byte, ciphertext []byte, err error) {
|
||||
if len(combined) < ChaCha20Poly1305NonceSize {
|
||||
return nil, nil, fmt.Errorf("combined data too short: expected at least %d bytes, got %d", ChaCha20Poly1305NonceSize, len(combined))
|
||||
}
|
||||
|
||||
nonce = combined[:ChaCha20Poly1305NonceSize]
|
||||
ciphertext = combined[ChaCha20Poly1305NonceSize:]
|
||||
return nonce, ciphertext, nil
|
||||
}
|
||||
|
||||
// SplitNonceAndCiphertextSecretBox splits a combined byte slice for XSalsa20-Poly1305 (24-byte nonce).
|
||||
// This is compatible with libsodium's secretbox format: nonce (24) || ciphertext || mac (16).
|
||||
func SplitNonceAndCiphertextSecretBox(combined []byte) (nonce []byte, ciphertext []byte, err error) {
|
||||
if len(combined) < SecretBoxNonceSize {
|
||||
return nil, nil, fmt.Errorf("combined data too short: expected at least %d bytes, got %d", SecretBoxNonceSize, len(combined))
|
||||
}
|
||||
|
||||
nonce = combined[:SecretBoxNonceSize]
|
||||
ciphertext = combined[SecretBoxNonceSize:]
|
||||
return nonce, ciphertext, nil
|
||||
}
|
||||
|
||||
// SplitNonceAndCiphertextAuto automatically detects the nonce size based on data length.
|
||||
// It uses heuristics to determine if data is ChaCha20-Poly1305 (12-byte nonce) or XSalsa20 (24-byte nonce).
|
||||
// This function should be used when the cipher type is unknown.
|
||||
func SplitNonceAndCiphertextAuto(combined []byte) (nonce []byte, ciphertext []byte, err error) {
|
||||
// Web frontend uses XSalsa20-Poly1305 with 24-byte nonce
|
||||
// Native app used to use ChaCha20-Poly1305 with 12-byte nonce
|
||||
//
|
||||
// For encrypted master key data:
|
||||
// - Web frontend: nonce (24) + ciphertext (32 + 16 MAC) = 72 bytes
|
||||
// - Native/old: nonce (12) + ciphertext (32 + 16 MAC) = 60 bytes
|
||||
//
|
||||
// We can distinguish by checking if the data length suggests 24-byte nonce
|
||||
// Data encrypted with 24-byte nonce will be 12 bytes longer than 12-byte nonce version
|
||||
|
||||
if len(combined) < ChaCha20Poly1305NonceSize+ChaCha20Poly1305Overhead {
|
||||
return nil, nil, fmt.Errorf("combined data too short: expected at least %d bytes, got %d",
|
||||
ChaCha20Poly1305NonceSize+ChaCha20Poly1305Overhead, len(combined))
|
||||
}
|
||||
|
||||
// If data length is at least 72 bytes (24 nonce + 32 key + 16 MAC for master key),
|
||||
// try XSalsa20 format first. This is the web frontend format.
|
||||
if len(combined) >= SecretBoxNonceSize+SecretBoxOverhead+1 {
|
||||
return SplitNonceAndCiphertextSecretBox(combined)
|
||||
}
|
||||
|
||||
// Default to ChaCha20-Poly1305 (legacy)
|
||||
return SplitNonceAndCiphertext(combined)
|
||||
}
|
||||
|
||||
// EncodeToBase64 encodes bytes to base64 standard encoding.
|
||||
func EncodeToBase64(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// DecodeFromBase64 decodes a base64 standard encoded string to bytes.
|
||||
func DecodeFromBase64(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
235
cloud/maplefile-backend/pkg/maplefile/e2ee/file.go
Normal file
235
cloud/maplefile-backend/pkg/maplefile/e2ee/file.go
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
// Package e2ee provides end-to-end encryption operations for the MapleFile SDK.
|
||||
package e2ee
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// FileMetadata represents decrypted file metadata.
|
||||
type FileMetadata struct {
|
||||
Name string `json:"name"`
|
||||
MimeType string `json:"mime_type"`
|
||||
Size int64 `json:"size"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// EncryptFile encrypts file content using the file key.
|
||||
// Returns the combined nonce + ciphertext.
|
||||
// NOTE: This uses ChaCha20-Poly1305 (12-byte nonce). For web frontend compatibility,
|
||||
// use EncryptFileSecretBox instead.
|
||||
func EncryptFile(plaintext, fileKey []byte) ([]byte, error) {
|
||||
encryptedData, err := Encrypt(plaintext, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt file: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext for storage
|
||||
combined := CombineNonceAndCiphertext(encryptedData.Nonce, encryptedData.Ciphertext)
|
||||
return combined, nil
|
||||
}
|
||||
|
||||
// EncryptFileSecretBox encrypts file content using XSalsa20-Poly1305 (NaCl secretbox).
|
||||
// Returns the combined nonce (24 bytes) + ciphertext.
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func EncryptFileSecretBox(plaintext, fileKey []byte) ([]byte, error) {
|
||||
encryptedData, err := EncryptWithSecretBox(plaintext, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt file: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext for storage (matching web frontend format)
|
||||
combined := CombineNonceAndCiphertext(encryptedData.Nonce, encryptedData.Ciphertext)
|
||||
return combined, nil
|
||||
}
|
||||
|
||||
// DecryptFile decrypts file content using the file key.
|
||||
// The input should be combined nonce + ciphertext.
|
||||
// Auto-detects the cipher based on nonce size:
|
||||
// - 24-byte nonce: XSalsa20-Poly1305 (web frontend / SecretBox)
|
||||
// - 12-byte nonce: ChaCha20-Poly1305 (legacy native app)
|
||||
func DecryptFile(encryptedData, fileKey []byte) ([]byte, error) {
|
||||
// Split nonce and ciphertext (auto-detect nonce size)
|
||||
nonce, ciphertext, err := SplitNonceAndCiphertextAuto(encryptedData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split encrypted data: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt using appropriate algorithm based on nonce size
|
||||
plaintext, err := DecryptWithAlgorithm(ciphertext, nonce, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptFileWithNonce encrypts file content and returns the ciphertext and nonce separately.
|
||||
func EncryptFileWithNonce(plaintext, fileKey []byte) (ciphertext []byte, nonce []byte, err error) {
|
||||
encryptedData, err := Encrypt(plaintext, fileKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to encrypt file: %w", err)
|
||||
}
|
||||
|
||||
return encryptedData.Ciphertext, encryptedData.Nonce, nil
|
||||
}
|
||||
|
||||
// DecryptFileWithNonce decrypts file content using separate ciphertext and nonce.
|
||||
func DecryptFileWithNonce(ciphertext, nonce, fileKey []byte) ([]byte, error) {
|
||||
plaintext, err := Decrypt(ciphertext, nonce, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptMetadata encrypts file metadata using the file key.
|
||||
// Returns base64-encoded combined nonce + ciphertext.
|
||||
// NOTE: This uses ChaCha20-Poly1305 (12-byte nonce). For web frontend compatibility,
|
||||
// use EncryptMetadataSecretBox instead.
|
||||
func EncryptMetadata(metadata *FileMetadata, fileKey []byte) (string, error) {
|
||||
// Convert metadata to JSON
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt metadata
|
||||
encryptedData, err := Encrypt(metadataBytes, fileKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt metadata: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext
|
||||
combined := CombineNonceAndCiphertext(encryptedData.Nonce, encryptedData.Ciphertext)
|
||||
|
||||
// Encode to base64
|
||||
return EncodeToBase64(combined), nil
|
||||
}
|
||||
|
||||
// EncryptMetadataSecretBox encrypts file metadata using XSalsa20-Poly1305 (NaCl secretbox).
|
||||
// Returns base64-encoded combined nonce + ciphertext.
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func EncryptMetadataSecretBox(metadata *FileMetadata, fileKey []byte) (string, error) {
|
||||
// Convert metadata to JSON
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt metadata using SecretBox
|
||||
encryptedData, err := EncryptWithSecretBox(metadataBytes, fileKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt metadata: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext
|
||||
combined := CombineNonceAndCiphertext(encryptedData.Nonce, encryptedData.Ciphertext)
|
||||
|
||||
// Encode to base64
|
||||
return EncodeToBase64(combined), nil
|
||||
}
|
||||
|
||||
// DecryptMetadata decrypts file metadata using the file key.
|
||||
// The input should be base64-encoded combined nonce + ciphertext.
|
||||
func DecryptMetadata(encryptedMetadata string, fileKey []byte) (*FileMetadata, error) {
|
||||
// Decode from base64
|
||||
combined, err := DecodeFromBase64(encryptedMetadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encrypted metadata: %w", err)
|
||||
}
|
||||
|
||||
// Split nonce and ciphertext
|
||||
nonce, ciphertext, err := SplitNonceAndCiphertext(combined)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split encrypted metadata: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decryptedBytes, err := Decrypt(ciphertext, nonce, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt metadata: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var metadata FileMetadata
|
||||
if err := json.Unmarshal(decryptedBytes, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse decrypted metadata: %w", err)
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// EncryptMetadataWithNonce encrypts file metadata and returns nonce separately.
|
||||
func EncryptMetadataWithNonce(metadata *FileMetadata, fileKey []byte) (ciphertext []byte, nonce []byte, err error) {
|
||||
// Convert metadata to JSON
|
||||
metadataBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt metadata
|
||||
encryptedData, err := Encrypt(metadataBytes, fileKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to encrypt metadata: %w", err)
|
||||
}
|
||||
|
||||
return encryptedData.Ciphertext, encryptedData.Nonce, nil
|
||||
}
|
||||
|
||||
// DecryptMetadataWithNonce decrypts file metadata using separate ciphertext and nonce.
|
||||
func DecryptMetadataWithNonce(ciphertext, nonce, fileKey []byte) (*FileMetadata, error) {
|
||||
// Decrypt
|
||||
decryptedBytes, err := Decrypt(ciphertext, nonce, fileKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt metadata: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
var metadata FileMetadata
|
||||
if err := json.Unmarshal(decryptedBytes, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse decrypted metadata: %w", err)
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// EncryptData encrypts arbitrary data using the provided key.
|
||||
// Returns base64-encoded combined nonce + ciphertext.
|
||||
func EncryptData(data, key []byte) (string, error) {
|
||||
encryptedData, err := Encrypt(data, key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext
|
||||
combined := CombineNonceAndCiphertext(encryptedData.Nonce, encryptedData.Ciphertext)
|
||||
|
||||
// Encode to base64
|
||||
return EncodeToBase64(combined), nil
|
||||
}
|
||||
|
||||
// DecryptData decrypts arbitrary data using the provided key.
|
||||
// The input should be base64-encoded combined nonce + ciphertext.
|
||||
func DecryptData(encryptedData string, key []byte) ([]byte, error) {
|
||||
// Decode from base64
|
||||
combined, err := DecodeFromBase64(encryptedData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encrypted data: %w", err)
|
||||
}
|
||||
|
||||
// Split nonce and ciphertext
|
||||
nonce, ciphertext, err := SplitNonceAndCiphertext(combined)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split encrypted data: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := Decrypt(ciphertext, nonce, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt data: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
401
cloud/maplefile-backend/pkg/maplefile/e2ee/keychain.go
Normal file
401
cloud/maplefile-backend/pkg/maplefile/e2ee/keychain.go
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
// Package e2ee provides end-to-end encryption operations for the MapleFile SDK.
|
||||
package e2ee
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// KeyChain holds the key encryption key derived from the user's password.
|
||||
// It provides methods for decrypting keys in the E2EE chain.
|
||||
type KeyChain struct {
|
||||
kek []byte // Key Encryption Key derived from password
|
||||
salt []byte // Password salt used for key derivation
|
||||
kdfAlgorithm string // KDF algorithm used ("argon2id" or "PBKDF2-SHA256")
|
||||
}
|
||||
|
||||
// EncryptedKey represents a key encrypted with another key.
|
||||
type EncryptedKey struct {
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
}
|
||||
|
||||
// NewKeyChain creates a new KeyChain by deriving the KEK from the password and salt.
|
||||
// This function defaults to Argon2id for backward compatibility.
|
||||
// For cross-platform compatibility, use NewKeyChainWithAlgorithm instead.
|
||||
func NewKeyChain(password string, salt []byte) (*KeyChain, error) {
|
||||
return NewKeyChainWithAlgorithm(password, salt, Argon2IDAlgorithm)
|
||||
}
|
||||
|
||||
// NewKeyChainWithAlgorithm creates a new KeyChain using the specified KDF algorithm.
|
||||
// algorithm should be one of: Argon2IDAlgorithm ("argon2id") or PBKDF2Algorithm ("PBKDF2-SHA256").
|
||||
// The web frontend uses PBKDF2-SHA256, while the native app historically used Argon2id.
|
||||
func NewKeyChainWithAlgorithm(password string, salt []byte, algorithm string) (*KeyChain, error) {
|
||||
// Validate salt size (both algorithms use 16-byte salt)
|
||||
if len(salt) != 16 {
|
||||
return nil, fmt.Errorf("invalid salt size: expected 16, got %d", len(salt))
|
||||
}
|
||||
|
||||
// Derive key encryption key from password using specified algorithm
|
||||
kek, err := DeriveKeyFromPasswordWithAlgorithm(password, salt, algorithm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive key from password: %w", err)
|
||||
}
|
||||
|
||||
return &KeyChain{
|
||||
kek: kek,
|
||||
salt: salt,
|
||||
kdfAlgorithm: algorithm,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Clear securely clears the KeyChain's sensitive data from memory.
|
||||
// This should be called when the KeyChain is no longer needed.
|
||||
func (k *KeyChain) Clear() {
|
||||
if k.kek != nil {
|
||||
ClearBytes(k.kek)
|
||||
k.kek = nil
|
||||
}
|
||||
}
|
||||
|
||||
// DecryptMasterKey decrypts the user's master key using the KEK.
|
||||
// This method auto-detects the cipher based on nonce size:
|
||||
// - 12-byte nonce: ChaCha20-Poly1305 (native app)
|
||||
// - 24-byte nonce: XSalsa20-Poly1305 (web frontend)
|
||||
func (k *KeyChain) DecryptMasterKey(encryptedMasterKey *EncryptedKey) ([]byte, error) {
|
||||
if k.kek == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
// Auto-detect cipher based on nonce size
|
||||
masterKey, err := DecryptWithAlgorithm(encryptedMasterKey.Ciphertext, encryptedMasterKey.Nonce, k.kek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt master key: %w", err)
|
||||
}
|
||||
|
||||
return masterKey, nil
|
||||
}
|
||||
|
||||
// DecryptCollectionKey decrypts a collection key using the master key.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptCollectionKey(encryptedCollectionKey *EncryptedKey, masterKey []byte) ([]byte, error) {
|
||||
collectionKey, err := DecryptWithAlgorithm(encryptedCollectionKey.Ciphertext, encryptedCollectionKey.Nonce, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt collection key: %w", err)
|
||||
}
|
||||
|
||||
return collectionKey, nil
|
||||
}
|
||||
|
||||
// DecryptFileKey decrypts a file key using the collection key.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptFileKey(encryptedFileKey *EncryptedKey, collectionKey []byte) ([]byte, error) {
|
||||
fileKey, err := DecryptWithAlgorithm(encryptedFileKey.Ciphertext, encryptedFileKey.Nonce, collectionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt file key: %w", err)
|
||||
}
|
||||
|
||||
return fileKey, nil
|
||||
}
|
||||
|
||||
// DecryptPrivateKey decrypts the user's private key using the master key.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptPrivateKey(encryptedPrivateKey *EncryptedKey, masterKey []byte) ([]byte, error) {
|
||||
privateKey, err := DecryptWithAlgorithm(encryptedPrivateKey.Ciphertext, encryptedPrivateKey.Nonce, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// DecryptRecoveryKey decrypts the user's recovery key using the master key.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptRecoveryKey(encryptedRecoveryKey *EncryptedKey, masterKey []byte) ([]byte, error) {
|
||||
recoveryKey, err := DecryptWithAlgorithm(encryptedRecoveryKey.Ciphertext, encryptedRecoveryKey.Nonce, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt recovery key: %w", err)
|
||||
}
|
||||
|
||||
return recoveryKey, nil
|
||||
}
|
||||
|
||||
// DecryptMasterKeyWithRecoveryKey decrypts the master key using the recovery key.
|
||||
// This is used during account recovery.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptMasterKeyWithRecoveryKey(encryptedMasterKey *EncryptedKey, recoveryKey []byte) ([]byte, error) {
|
||||
masterKey, err := DecryptWithAlgorithm(encryptedMasterKey.Ciphertext, encryptedMasterKey.Nonce, recoveryKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt master key with recovery key: %w", err)
|
||||
}
|
||||
|
||||
return masterKey, nil
|
||||
}
|
||||
|
||||
// GenerateMasterKey generates a new random master key.
|
||||
func GenerateMasterKey() ([]byte, error) {
|
||||
return GenerateRandomBytes(MasterKeySize)
|
||||
}
|
||||
|
||||
// GenerateCollectionKey generates a new random collection key.
|
||||
func GenerateCollectionKey() ([]byte, error) {
|
||||
return GenerateRandomBytes(CollectionKeySize)
|
||||
}
|
||||
|
||||
// GenerateFileKey generates a new random file key.
|
||||
func GenerateFileKey() ([]byte, error) {
|
||||
return GenerateRandomBytes(FileKeySize)
|
||||
}
|
||||
|
||||
// GenerateRecoveryKey generates a new random recovery key.
|
||||
func GenerateRecoveryKey() ([]byte, error) {
|
||||
return GenerateRandomBytes(RecoveryKeySize)
|
||||
}
|
||||
|
||||
// GenerateSalt generates a new random salt for password derivation.
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
return GenerateRandomBytes(Argon2SaltSize)
|
||||
}
|
||||
|
||||
// EncryptMasterKey encrypts a master key with the KEK.
|
||||
func (k *KeyChain) EncryptMasterKey(masterKey []byte) (*EncryptedKey, error) {
|
||||
if k.kek == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
encrypted, err := Encrypt(masterKey, k.kek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptCollectionKey encrypts a collection key with the master key using ChaCha20-Poly1305.
|
||||
// For web frontend compatibility, use EncryptCollectionKeySecretBox instead.
|
||||
func EncryptCollectionKey(collectionKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(collectionKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt collection key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptCollectionKeySecretBox encrypts a collection key with the master key using XSalsa20-Poly1305.
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func EncryptCollectionKeySecretBox(collectionKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(collectionKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt collection key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptFileKey encrypts a file key with the collection key.
|
||||
// NOTE: This uses ChaCha20-Poly1305 (12-byte nonce). For web frontend compatibility,
|
||||
// use EncryptFileKeySecretBox instead.
|
||||
func EncryptFileKey(fileKey, collectionKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(fileKey, collectionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt file key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptFileKeySecretBox encrypts a file key with the collection key using XSalsa20-Poly1305.
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func EncryptFileKeySecretBox(fileKey, collectionKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(fileKey, collectionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt file key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptPrivateKey encrypts a private key with the master key.
|
||||
func EncryptPrivateKey(privateKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(privateKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt private key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptRecoveryKey encrypts a recovery key with the master key.
|
||||
func EncryptRecoveryKey(recoveryKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(recoveryKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt recovery key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptMasterKeyWithRecoveryKey encrypts a master key with the recovery key.
|
||||
// This is used to enable account recovery.
|
||||
func EncryptMasterKeyWithRecoveryKey(masterKey, recoveryKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(masterKey, recoveryKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key with recovery key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SecretBox (XSalsa20-Poly1305) Encryption Functions
|
||||
// These match the web frontend's libsodium crypto_secretbox_easy implementation
|
||||
// =============================================================================
|
||||
|
||||
// EncryptMasterKeySecretBox encrypts a master key with the KEK using XSalsa20-Poly1305.
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func (k *KeyChain) EncryptMasterKeySecretBox(masterKey []byte) (*EncryptedKey, error) {
|
||||
if k.kek == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
encrypted, err := EncryptWithSecretBox(masterKey, k.kek)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptPrivateKeySecretBox encrypts a private key with the master key using XSalsa20-Poly1305.
|
||||
func EncryptPrivateKeySecretBox(privateKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(privateKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt private key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptRecoveryKeySecretBox encrypts a recovery key with the master key using XSalsa20-Poly1305.
|
||||
func EncryptRecoveryKeySecretBox(recoveryKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(recoveryKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt recovery key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptMasterKeyWithRecoveryKeySecretBox encrypts a master key with the recovery key using XSalsa20-Poly1305.
|
||||
func EncryptMasterKeyWithRecoveryKeySecretBox(masterKey, recoveryKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(masterKey, recoveryKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key with recovery key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptCollectionKeyForSharing encrypts a collection key for a recipient using BoxSeal.
|
||||
// This is used when sharing a collection with another user.
|
||||
func EncryptCollectionKeyForSharing(collectionKey, recipientPublicKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != BoxPublicKeySize {
|
||||
return nil, fmt.Errorf("invalid recipient public key size: expected %d, got %d", BoxPublicKeySize, len(recipientPublicKey))
|
||||
}
|
||||
|
||||
return EncryptWithBoxSeal(collectionKey, recipientPublicKey)
|
||||
}
|
||||
|
||||
// DecryptSharedCollectionKey decrypts a collection key that was shared using BoxSeal.
|
||||
// This is used when accessing a shared collection.
|
||||
func DecryptSharedCollectionKey(encryptedCollectionKey, publicKey, privateKey []byte) ([]byte, error) {
|
||||
return DecryptWithBoxSeal(encryptedCollectionKey, publicKey, privateKey)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tag Key Operations
|
||||
// ============================================================================
|
||||
|
||||
// GenerateTagKey generates a new 32-byte tag key for encrypting tag data.
|
||||
func GenerateTagKey() ([]byte, error) {
|
||||
return GenerateRandomBytes(SecretBoxKeySize)
|
||||
}
|
||||
|
||||
// GenerateKey is an alias for GenerateTagKey (convenience function).
|
||||
func GenerateKey() []byte {
|
||||
key, _ := GenerateTagKey()
|
||||
return key
|
||||
}
|
||||
|
||||
// EncryptTagKey encrypts a tag key with the master key using ChaCha20-Poly1305.
|
||||
func EncryptTagKey(tagKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := Encrypt(tagKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt tag key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptTagKeySecretBox encrypts a tag key with the master key using XSalsa20-Poly1305.
|
||||
func EncryptTagKeySecretBox(tagKey, masterKey []byte) (*EncryptedKey, error) {
|
||||
encrypted, err := EncryptWithSecretBox(tagKey, masterKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt tag key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecryptTagKey decrypts a tag key with the master key.
|
||||
func DecryptTagKey(encryptedTagKey *EncryptedKey, masterKey []byte) ([]byte, error) {
|
||||
// Try XSalsa20-Poly1305 first (based on nonce size)
|
||||
if len(encryptedTagKey.Nonce) == SecretBoxNonceSize {
|
||||
return DecryptWithSecretBox(encryptedTagKey.Ciphertext, encryptedTagKey.Nonce, masterKey)
|
||||
}
|
||||
|
||||
// Fall back to ChaCha20-Poly1305
|
||||
return Decrypt(encryptedTagKey.Ciphertext, encryptedTagKey.Nonce, masterKey)
|
||||
}
|
||||
246
cloud/maplefile-backend/pkg/maplefile/e2ee/secure.go
Normal file
246
cloud/maplefile-backend/pkg/maplefile/e2ee/secure.go
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
// Package e2ee provides end-to-end encryption operations for the MapleFile SDK.
|
||||
// This file contains memguard-protected secure memory operations.
|
||||
package e2ee
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureBuffer wraps memguard.LockedBuffer for type safety
|
||||
type SecureBuffer struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureBuffer creates a new secure buffer from bytes
|
||||
func NewSecureBuffer(data []byte) (*SecureBuffer, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, fmt.Errorf("cannot create secure buffer from empty data")
|
||||
}
|
||||
|
||||
buffer := memguard.NewBufferFromBytes(data)
|
||||
return &SecureBuffer{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// NewSecureBufferRandom creates a new secure buffer with random data
|
||||
func NewSecureBufferRandom(size int) (*SecureBuffer, error) {
|
||||
if size <= 0 {
|
||||
return nil, fmt.Errorf("size must be positive")
|
||||
}
|
||||
|
||||
buffer := memguard.NewBuffer(size)
|
||||
return &SecureBuffer{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// Bytes returns the underlying bytes (caller must handle carefully)
|
||||
func (s *SecureBuffer) Bytes() []byte {
|
||||
if s.buffer == nil {
|
||||
return nil
|
||||
}
|
||||
return s.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Size returns the size of the buffer
|
||||
func (s *SecureBuffer) Size() int {
|
||||
if s.buffer == nil {
|
||||
return 0
|
||||
}
|
||||
return s.buffer.Size()
|
||||
}
|
||||
|
||||
// Destroy securely destroys the buffer
|
||||
func (s *SecureBuffer) Destroy() {
|
||||
if s.buffer != nil {
|
||||
s.buffer.Destroy()
|
||||
s.buffer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Copy creates a new SecureBuffer with a copy of the data
|
||||
func (s *SecureBuffer) Copy() (*SecureBuffer, error) {
|
||||
if s.buffer == nil {
|
||||
return nil, fmt.Errorf("cannot copy destroyed buffer")
|
||||
}
|
||||
|
||||
return NewSecureBuffer(s.buffer.Bytes())
|
||||
}
|
||||
|
||||
// SecureKeyChain is a KeyChain that stores the KEK in protected memory
|
||||
type SecureKeyChain struct {
|
||||
kek *SecureBuffer // Key Encryption Key in protected memory
|
||||
salt []byte // Salt (not sensitive, kept in regular memory)
|
||||
kdfAlgorithm string // KDF algorithm used
|
||||
}
|
||||
|
||||
// NewSecureKeyChain creates a new SecureKeyChain with KEK in protected memory.
|
||||
// This function defaults to Argon2id for backward compatibility.
|
||||
// For cross-platform compatibility, use NewSecureKeyChainWithAlgorithm instead.
|
||||
func NewSecureKeyChain(password string, salt []byte) (*SecureKeyChain, error) {
|
||||
return NewSecureKeyChainWithAlgorithm(password, salt, Argon2IDAlgorithm)
|
||||
}
|
||||
|
||||
// NewSecureKeyChainWithAlgorithm creates a new SecureKeyChain using the specified KDF algorithm.
|
||||
// algorithm should be one of: Argon2IDAlgorithm ("argon2id") or PBKDF2Algorithm ("PBKDF2-SHA256").
|
||||
// The web frontend uses PBKDF2-SHA256, while the native app historically used Argon2id.
|
||||
func NewSecureKeyChainWithAlgorithm(password string, salt []byte, algorithm string) (*SecureKeyChain, error) {
|
||||
// Both algorithms use 16-byte salt
|
||||
if len(salt) != 16 {
|
||||
return nil, fmt.Errorf("invalid salt size: expected 16, got %d", len(salt))
|
||||
}
|
||||
|
||||
// Derive KEK from password using specified algorithm
|
||||
kekBytes, err := DeriveKeyFromPasswordWithAlgorithm(password, salt, algorithm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive key from password: %w", err)
|
||||
}
|
||||
|
||||
// Store KEK in secure memory immediately
|
||||
kek, err := NewSecureBuffer(kekBytes)
|
||||
if err != nil {
|
||||
ClearBytes(kekBytes)
|
||||
return nil, fmt.Errorf("failed to create secure buffer for KEK: %w", err)
|
||||
}
|
||||
|
||||
// Clear the temporary KEK bytes
|
||||
ClearBytes(kekBytes)
|
||||
|
||||
return &SecureKeyChain{
|
||||
kek: kek,
|
||||
salt: salt,
|
||||
kdfAlgorithm: algorithm,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Clear securely clears the SecureKeyChain's sensitive data
|
||||
func (k *SecureKeyChain) Clear() {
|
||||
if k.kek != nil {
|
||||
k.kek.Destroy()
|
||||
k.kek = nil
|
||||
}
|
||||
}
|
||||
|
||||
// DecryptMasterKeySecure decrypts the master key and returns it in a SecureBuffer.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func (k *SecureKeyChain) DecryptMasterKeySecure(encryptedMasterKey *EncryptedKey) (*SecureBuffer, error) {
|
||||
if k.kek == nil || k.kek.buffer == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
// Decrypt using KEK from secure memory (auto-detect cipher based on nonce size)
|
||||
masterKeyBytes, err := DecryptWithAlgorithm(encryptedMasterKey.Ciphertext, encryptedMasterKey.Nonce, k.kek.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt master key: %w", err)
|
||||
}
|
||||
|
||||
// Store decrypted master key in secure memory
|
||||
masterKey, err := NewSecureBuffer(masterKeyBytes)
|
||||
if err != nil {
|
||||
ClearBytes(masterKeyBytes)
|
||||
return nil, fmt.Errorf("failed to create secure buffer for master key: %w", err)
|
||||
}
|
||||
|
||||
// Clear temporary bytes
|
||||
ClearBytes(masterKeyBytes)
|
||||
|
||||
return masterKey, nil
|
||||
}
|
||||
|
||||
// DecryptMasterKey provides backward compatibility by returning []byte.
|
||||
// For new code, prefer DecryptMasterKeySecure.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func (k *SecureKeyChain) DecryptMasterKey(encryptedMasterKey *EncryptedKey) ([]byte, error) {
|
||||
if k.kek == nil || k.kek.buffer == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
// Decrypt using KEK from secure memory (auto-detect cipher)
|
||||
return DecryptWithAlgorithm(encryptedMasterKey.Ciphertext, encryptedMasterKey.Nonce, k.kek.Bytes())
|
||||
}
|
||||
|
||||
// EncryptMasterKey encrypts a master key with the KEK using ChaCha20-Poly1305.
|
||||
// For web frontend compatibility, use EncryptMasterKeySecretBox instead.
|
||||
func (k *SecureKeyChain) EncryptMasterKey(masterKey []byte) (*EncryptedKey, error) {
|
||||
if k.kek == nil || k.kek.buffer == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
encrypted, err := Encrypt(masterKey, k.kek.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EncryptMasterKeySecretBox encrypts a master key with the KEK using XSalsa20-Poly1305 (SecretBox).
|
||||
// This is compatible with the web frontend's libsodium implementation.
|
||||
func (k *SecureKeyChain) EncryptMasterKeySecretBox(masterKey []byte) (*EncryptedKey, error) {
|
||||
if k.kek == nil || k.kek.buffer == nil {
|
||||
return nil, fmt.Errorf("keychain has been cleared")
|
||||
}
|
||||
|
||||
encrypted, err := EncryptWithSecretBox(masterKey, k.kek.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt master key: %w", err)
|
||||
}
|
||||
|
||||
return &EncryptedKey{
|
||||
Ciphertext: encrypted.Ciphertext,
|
||||
Nonce: encrypted.Nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecryptPrivateKeySecure decrypts a private key and returns it in a SecureBuffer.
|
||||
// Auto-detects cipher based on nonce size (12 for ChaCha20, 24 for XSalsa20).
|
||||
func DecryptPrivateKeySecure(encryptedPrivateKey *EncryptedKey, masterKey *SecureBuffer) (*SecureBuffer, error) {
|
||||
if masterKey == nil || masterKey.buffer == nil {
|
||||
return nil, fmt.Errorf("master key is nil or destroyed")
|
||||
}
|
||||
|
||||
// Decrypt private key (auto-detect cipher based on nonce size)
|
||||
privateKeyBytes, err := DecryptWithAlgorithm(encryptedPrivateKey.Ciphertext, encryptedPrivateKey.Nonce, masterKey.Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
|
||||
}
|
||||
|
||||
// Store in secure memory
|
||||
privateKey, err := NewSecureBuffer(privateKeyBytes)
|
||||
if err != nil {
|
||||
ClearBytes(privateKeyBytes)
|
||||
return nil, fmt.Errorf("failed to create secure buffer for private key: %w", err)
|
||||
}
|
||||
|
||||
// Clear temporary bytes
|
||||
ClearBytes(privateKeyBytes)
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// WithSecureBuffer provides a callback pattern for temporary use of secure data
|
||||
// The buffer is automatically destroyed after the callback returns
|
||||
func WithSecureBuffer(data []byte, fn func(*SecureBuffer) error) error {
|
||||
buf, err := NewSecureBuffer(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer buf.Destroy()
|
||||
|
||||
return fn(buf)
|
||||
}
|
||||
|
||||
// CopyToSecure copies regular bytes into a new SecureBuffer and clears the source
|
||||
func CopyToSecure(data []byte) (*SecureBuffer, error) {
|
||||
buf, err := NewSecureBuffer(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Clear the source data
|
||||
ClearBytes(data)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
99
cloud/maplefile-backend/pkg/mocks/mock_distributedmutex.go
Normal file
99
cloud/maplefile-backend/pkg/mocks/mock_distributedmutex.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/distributedmutex/distributelocker.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/distributedmutex/distributelocker.go -destination=pkg/mocks/mock_distributedmutex.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockAdapter is a mock of Adapter interface.
|
||||
type MockAdapter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAdapterMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockAdapterMockRecorder is the mock recorder for MockAdapter.
|
||||
type MockAdapterMockRecorder struct {
|
||||
mock *MockAdapter
|
||||
}
|
||||
|
||||
// NewMockAdapter creates a new mock instance.
|
||||
func NewMockAdapter(ctrl *gomock.Controller) *MockAdapter {
|
||||
mock := &MockAdapter{ctrl: ctrl}
|
||||
mock.recorder = &MockAdapterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockAdapter) EXPECT() *MockAdapterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Acquire mocks base method.
|
||||
func (m *MockAdapter) Acquire(ctx context.Context, key string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Acquire", ctx, key)
|
||||
}
|
||||
|
||||
// Acquire indicates an expected call of Acquire.
|
||||
func (mr *MockAdapterMockRecorder) Acquire(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockAdapter)(nil).Acquire), ctx, key)
|
||||
}
|
||||
|
||||
// Acquiref mocks base method.
|
||||
func (m *MockAdapter) Acquiref(ctx context.Context, format string, a ...any) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, format}
|
||||
for _, a_2 := range a {
|
||||
varargs = append(varargs, a_2)
|
||||
}
|
||||
m.ctrl.Call(m, "Acquiref", varargs...)
|
||||
}
|
||||
|
||||
// Acquiref indicates an expected call of Acquiref.
|
||||
func (mr *MockAdapterMockRecorder) Acquiref(ctx, format any, a ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, format}, a...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquiref", reflect.TypeOf((*MockAdapter)(nil).Acquiref), varargs...)
|
||||
}
|
||||
|
||||
// Release mocks base method.
|
||||
func (m *MockAdapter) Release(ctx context.Context, key string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Release", ctx, key)
|
||||
}
|
||||
|
||||
// Release indicates an expected call of Release.
|
||||
func (mr *MockAdapterMockRecorder) Release(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockAdapter)(nil).Release), ctx, key)
|
||||
}
|
||||
|
||||
// Releasef mocks base method.
|
||||
func (m *MockAdapter) Releasef(ctx context.Context, format string, a ...any) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, format}
|
||||
for _, a_2 := range a {
|
||||
varargs = append(varargs, a_2)
|
||||
}
|
||||
m.ctrl.Call(m, "Releasef", varargs...)
|
||||
}
|
||||
|
||||
// Releasef indicates an expected call of Releasef.
|
||||
func (mr *MockAdapterMockRecorder) Releasef(ctx, format any, a ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{ctx, format}, a...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Releasef", reflect.TypeOf((*MockAdapter)(nil).Releasef), varargs...)
|
||||
}
|
||||
125
cloud/maplefile-backend/pkg/mocks/mock_mailgun.go
Normal file
125
cloud/maplefile-backend/pkg/mocks/mock_mailgun.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/emailer/mailgun/interface.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/emailer/mailgun/interface.go -destination=pkg/mocks/mock_mailgun.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockEmailer is a mock of Emailer interface.
|
||||
type MockEmailer struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEmailerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockEmailerMockRecorder is the mock recorder for MockEmailer.
|
||||
type MockEmailerMockRecorder struct {
|
||||
mock *MockEmailer
|
||||
}
|
||||
|
||||
// NewMockEmailer creates a new mock instance.
|
||||
func NewMockEmailer(ctrl *gomock.Controller) *MockEmailer {
|
||||
mock := &MockEmailer{ctrl: ctrl}
|
||||
mock.recorder = &MockEmailerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEmailer) EXPECT() *MockEmailerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetBackendDomainName mocks base method.
|
||||
func (m *MockEmailer) GetBackendDomainName() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetBackendDomainName")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetBackendDomainName indicates an expected call of GetBackendDomainName.
|
||||
func (mr *MockEmailerMockRecorder) GetBackendDomainName() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackendDomainName", reflect.TypeOf((*MockEmailer)(nil).GetBackendDomainName))
|
||||
}
|
||||
|
||||
// GetDomainName mocks base method.
|
||||
func (m *MockEmailer) GetDomainName() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDomainName")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetDomainName indicates an expected call of GetDomainName.
|
||||
func (mr *MockEmailerMockRecorder) GetDomainName() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomainName", reflect.TypeOf((*MockEmailer)(nil).GetDomainName))
|
||||
}
|
||||
|
||||
// GetFrontendDomainName mocks base method.
|
||||
func (m *MockEmailer) GetFrontendDomainName() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetFrontendDomainName")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetFrontendDomainName indicates an expected call of GetFrontendDomainName.
|
||||
func (mr *MockEmailerMockRecorder) GetFrontendDomainName() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFrontendDomainName", reflect.TypeOf((*MockEmailer)(nil).GetFrontendDomainName))
|
||||
}
|
||||
|
||||
// GetMaintenanceEmail mocks base method.
|
||||
func (m *MockEmailer) GetMaintenanceEmail() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetMaintenanceEmail")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetMaintenanceEmail indicates an expected call of GetMaintenanceEmail.
|
||||
func (mr *MockEmailerMockRecorder) GetMaintenanceEmail() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaintenanceEmail", reflect.TypeOf((*MockEmailer)(nil).GetMaintenanceEmail))
|
||||
}
|
||||
|
||||
// GetSenderEmail mocks base method.
|
||||
func (m *MockEmailer) GetSenderEmail() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetSenderEmail")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetSenderEmail indicates an expected call of GetSenderEmail.
|
||||
func (mr *MockEmailerMockRecorder) GetSenderEmail() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSenderEmail", reflect.TypeOf((*MockEmailer)(nil).GetSenderEmail))
|
||||
}
|
||||
|
||||
// Send mocks base method.
|
||||
func (m *MockEmailer) Send(ctx context.Context, sender, subject, recipient, htmlContent string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Send", ctx, sender, subject, recipient, htmlContent)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Send indicates an expected call of Send.
|
||||
func (mr *MockEmailerMockRecorder) Send(ctx, sender, subject, recipient, htmlContent any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockEmailer)(nil).Send), ctx, sender, subject, recipient, htmlContent)
|
||||
}
|
||||
90
cloud/maplefile-backend/pkg/mocks/mock_security_jwt.go
Normal file
90
cloud/maplefile-backend/pkg/mocks/mock_security_jwt.go
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/security/jwt/jwt.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/security/jwt/jwt.go -destination=pkg/mocks/mock_security_jwt.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockJWTProvider is a mock of JWTProvider interface.
|
||||
type MockJWTProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockJWTProviderMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockJWTProviderMockRecorder is the mock recorder for MockJWTProvider.
|
||||
type MockJWTProviderMockRecorder struct {
|
||||
mock *MockJWTProvider
|
||||
}
|
||||
|
||||
// NewMockJWTProvider creates a new mock instance.
|
||||
func NewMockJWTProvider(ctrl *gomock.Controller) *MockJWTProvider {
|
||||
mock := &MockJWTProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockJWTProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockJWTProvider) EXPECT() *MockJWTProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GenerateJWTToken mocks base method.
|
||||
func (m *MockJWTProvider) GenerateJWTToken(uuid string, ad time.Duration) (string, time.Time, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateJWTToken", uuid, ad)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(time.Time)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GenerateJWTToken indicates an expected call of GenerateJWTToken.
|
||||
func (mr *MockJWTProviderMockRecorder) GenerateJWTToken(uuid, ad any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateJWTToken", reflect.TypeOf((*MockJWTProvider)(nil).GenerateJWTToken), uuid, ad)
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair mocks base method.
|
||||
func (m *MockJWTProvider) GenerateJWTTokenPair(uuid string, ad, rd time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateJWTTokenPair", uuid, ad, rd)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(time.Time)
|
||||
ret2, _ := ret[2].(string)
|
||||
ret3, _ := ret[3].(time.Time)
|
||||
ret4, _ := ret[4].(error)
|
||||
return ret0, ret1, ret2, ret3, ret4
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair indicates an expected call of GenerateJWTTokenPair.
|
||||
func (mr *MockJWTProviderMockRecorder) GenerateJWTTokenPair(uuid, ad, rd any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateJWTTokenPair", reflect.TypeOf((*MockJWTProvider)(nil).GenerateJWTTokenPair), uuid, ad, rd)
|
||||
}
|
||||
|
||||
// ProcessJWTToken mocks base method.
|
||||
func (m *MockJWTProvider) ProcessJWTToken(reqToken string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ProcessJWTToken", reqToken)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ProcessJWTToken indicates an expected call of ProcessJWTToken.
|
||||
func (mr *MockJWTProviderMockRecorder) ProcessJWTToken(reqToken any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ProcessJWTToken", reflect.TypeOf((*MockJWTProvider)(nil).ProcessJWTToken), reqToken)
|
||||
}
|
||||
115
cloud/maplefile-backend/pkg/mocks/mock_security_password.go
Normal file
115
cloud/maplefile-backend/pkg/mocks/mock_security_password.go
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/security/password/password.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/security/password/password.go -destination=pkg/mocks/mock_security_password.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
securestring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockPasswordProvider is a mock of PasswordProvider interface.
|
||||
type MockPasswordProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPasswordProviderMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockPasswordProviderMockRecorder is the mock recorder for MockPasswordProvider.
|
||||
type MockPasswordProviderMockRecorder struct {
|
||||
mock *MockPasswordProvider
|
||||
}
|
||||
|
||||
// NewMockPasswordProvider creates a new mock instance.
|
||||
func NewMockPasswordProvider(ctrl *gomock.Controller) *MockPasswordProvider {
|
||||
mock := &MockPasswordProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockPasswordProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPasswordProvider) EXPECT() *MockPasswordProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AlgorithmName mocks base method.
|
||||
func (m *MockPasswordProvider) AlgorithmName() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AlgorithmName")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AlgorithmName indicates an expected call of AlgorithmName.
|
||||
func (mr *MockPasswordProviderMockRecorder) AlgorithmName() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AlgorithmName", reflect.TypeOf((*MockPasswordProvider)(nil).AlgorithmName))
|
||||
}
|
||||
|
||||
// ComparePasswordAndHash mocks base method.
|
||||
func (m *MockPasswordProvider) ComparePasswordAndHash(password *securestring.SecureString, hash string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ComparePasswordAndHash", password, hash)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ComparePasswordAndHash indicates an expected call of ComparePasswordAndHash.
|
||||
func (mr *MockPasswordProviderMockRecorder) ComparePasswordAndHash(password, hash any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ComparePasswordAndHash", reflect.TypeOf((*MockPasswordProvider)(nil).ComparePasswordAndHash), password, hash)
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword mocks base method.
|
||||
func (m *MockPasswordProvider) GenerateHashFromPassword(password *securestring.SecureString) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateHashFromPassword", password)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword indicates an expected call of GenerateHashFromPassword.
|
||||
func (mr *MockPasswordProviderMockRecorder) GenerateHashFromPassword(password any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateHashFromPassword", reflect.TypeOf((*MockPasswordProvider)(nil).GenerateHashFromPassword), password)
|
||||
}
|
||||
|
||||
// GenerateSecureRandomBytes mocks base method.
|
||||
func (m *MockPasswordProvider) GenerateSecureRandomBytes(length int) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateSecureRandomBytes", length)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GenerateSecureRandomBytes indicates an expected call of GenerateSecureRandomBytes.
|
||||
func (mr *MockPasswordProviderMockRecorder) GenerateSecureRandomBytes(length any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateSecureRandomBytes", reflect.TypeOf((*MockPasswordProvider)(nil).GenerateSecureRandomBytes), length)
|
||||
}
|
||||
|
||||
// GenerateSecureRandomString mocks base method.
|
||||
func (m *MockPasswordProvider) GenerateSecureRandomString(length int) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GenerateSecureRandomString", length)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GenerateSecureRandomString indicates an expected call of GenerateSecureRandomString.
|
||||
func (mr *MockPasswordProviderMockRecorder) GenerateSecureRandomString(length any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateSecureRandomString", reflect.TypeOf((*MockPasswordProvider)(nil).GenerateSecureRandomString), length)
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/cache/cassandracache/cassandracache.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/cache/cassandracache/cassandracache.go -destination=pkg/mocks/mock_storage_cache_cassandracache.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCassandraCacher is a mock of CassandraCacher interface.
|
||||
type MockCassandraCacher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCassandraCacherMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockCassandraCacherMockRecorder is the mock recorder for MockCassandraCacher.
|
||||
type MockCassandraCacherMockRecorder struct {
|
||||
mock *MockCassandraCacher
|
||||
}
|
||||
|
||||
// NewMockCassandraCacher creates a new mock instance.
|
||||
func NewMockCassandraCacher(ctrl *gomock.Controller) *MockCassandraCacher {
|
||||
mock := &MockCassandraCacher{ctrl: ctrl}
|
||||
mock.recorder = &MockCassandraCacherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCassandraCacher) EXPECT() *MockCassandraCacherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockCassandraCacher) Delete(ctx context.Context, key string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, key)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockCassandraCacherMockRecorder) Delete(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockCassandraCacher)(nil).Delete), ctx, key)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockCassandraCacher) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", ctx, key)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockCassandraCacherMockRecorder) Get(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCassandraCacher)(nil).Get), ctx, key)
|
||||
}
|
||||
|
||||
// PurgeExpired mocks base method.
|
||||
func (m *MockCassandraCacher) PurgeExpired(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PurgeExpired", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PurgeExpired indicates an expected call of PurgeExpired.
|
||||
func (mr *MockCassandraCacherMockRecorder) PurgeExpired(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PurgeExpired", reflect.TypeOf((*MockCassandraCacher)(nil).PurgeExpired), ctx)
|
||||
}
|
||||
|
||||
// Set mocks base method.
|
||||
func (m *MockCassandraCacher) Set(ctx context.Context, key string, val []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Set", ctx, key, val)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Set indicates an expected call of Set.
|
||||
func (mr *MockCassandraCacherMockRecorder) Set(ctx, key, val any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCassandraCacher)(nil).Set), ctx, key, val)
|
||||
}
|
||||
|
||||
// SetWithExpiry mocks base method.
|
||||
func (m *MockCassandraCacher) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetWithExpiry", ctx, key, val, expiry)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetWithExpiry indicates an expected call of SetWithExpiry.
|
||||
func (mr *MockCassandraCacherMockRecorder) SetWithExpiry(ctx, key, val, expiry any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWithExpiry", reflect.TypeOf((*MockCassandraCacher)(nil).SetWithExpiry), ctx, key, val, expiry)
|
||||
}
|
||||
|
||||
// Shutdown mocks base method.
|
||||
func (m *MockCassandraCacher) Shutdown() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Shutdown")
|
||||
}
|
||||
|
||||
// Shutdown indicates an expected call of Shutdown.
|
||||
func (mr *MockCassandraCacherMockRecorder) Shutdown() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockCassandraCacher)(nil).Shutdown))
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/cache/twotiercache/twotiercache.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/cache/twotiercache/twotiercache.go -destination=pkg/mocks/mock_storage_cache_twotiercache.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockTwoTierCacher is a mock of TwoTierCacher interface.
|
||||
type MockTwoTierCacher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTwoTierCacherMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockTwoTierCacherMockRecorder is the mock recorder for MockTwoTierCacher.
|
||||
type MockTwoTierCacherMockRecorder struct {
|
||||
mock *MockTwoTierCacher
|
||||
}
|
||||
|
||||
// NewMockTwoTierCacher creates a new mock instance.
|
||||
func NewMockTwoTierCacher(ctrl *gomock.Controller) *MockTwoTierCacher {
|
||||
mock := &MockTwoTierCacher{ctrl: ctrl}
|
||||
mock.recorder = &MockTwoTierCacherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockTwoTierCacher) EXPECT() *MockTwoTierCacherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockTwoTierCacher) Delete(ctx context.Context, key string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, key)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockTwoTierCacherMockRecorder) Delete(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockTwoTierCacher)(nil).Delete), ctx, key)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockTwoTierCacher) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", ctx, key)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockTwoTierCacherMockRecorder) Get(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockTwoTierCacher)(nil).Get), ctx, key)
|
||||
}
|
||||
|
||||
// PurgeExpired mocks base method.
|
||||
func (m *MockTwoTierCacher) PurgeExpired(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PurgeExpired", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PurgeExpired indicates an expected call of PurgeExpired.
|
||||
func (mr *MockTwoTierCacherMockRecorder) PurgeExpired(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PurgeExpired", reflect.TypeOf((*MockTwoTierCacher)(nil).PurgeExpired), ctx)
|
||||
}
|
||||
|
||||
// Set mocks base method.
|
||||
func (m *MockTwoTierCacher) Set(ctx context.Context, key string, val []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Set", ctx, key, val)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Set indicates an expected call of Set.
|
||||
func (mr *MockTwoTierCacherMockRecorder) Set(ctx, key, val any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockTwoTierCacher)(nil).Set), ctx, key, val)
|
||||
}
|
||||
|
||||
// SetWithExpiry mocks base method.
|
||||
func (m *MockTwoTierCacher) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetWithExpiry", ctx, key, val, expiry)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetWithExpiry indicates an expected call of SetWithExpiry.
|
||||
func (mr *MockTwoTierCacherMockRecorder) SetWithExpiry(ctx, key, val, expiry any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWithExpiry", reflect.TypeOf((*MockTwoTierCacher)(nil).SetWithExpiry), ctx, key, val, expiry)
|
||||
}
|
||||
|
||||
// Shutdown mocks base method.
|
||||
func (m *MockTwoTierCacher) Shutdown(ctx context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Shutdown", ctx)
|
||||
}
|
||||
|
||||
// Shutdown indicates an expected call of Shutdown.
|
||||
func (mr *MockTwoTierCacherMockRecorder) Shutdown(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockTwoTierCacher)(nil).Shutdown), ctx)
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/database/cassandradb/cassandradb.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/database/cassandradb/cassandradb.go -destination=pkg/mocks/mock_storage_database_cassandra_db.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/database/cassandradb/migration.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/database/cassandradb/migration.go -destination=pkg/mocks/mock_storage_database_cassandra_migration.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/memory/inmemory/memory.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/memory/inmemory/memory.go -destination=pkg/mocks/mock_storage_memory_inmemory.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
111
cloud/maplefile-backend/pkg/mocks/mock_storage_memory_redis.go
Normal file
111
cloud/maplefile-backend/pkg/mocks/mock_storage_memory_redis.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/memory/redis/redis.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/memory/redis/redis.go -destination=pkg/mocks/mock_storage_memory_redis.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCacher is a mock of Cacher interface.
|
||||
type MockCacher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCacherMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockCacherMockRecorder is the mock recorder for MockCacher.
|
||||
type MockCacherMockRecorder struct {
|
||||
mock *MockCacher
|
||||
}
|
||||
|
||||
// NewMockCacher creates a new mock instance.
|
||||
func NewMockCacher(ctrl *gomock.Controller) *MockCacher {
|
||||
mock := &MockCacher{ctrl: ctrl}
|
||||
mock.recorder = &MockCacherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCacher) EXPECT() *MockCacherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockCacher) Delete(ctx context.Context, key string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", ctx, key)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockCacherMockRecorder) Delete(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockCacher)(nil).Delete), ctx, key)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockCacher) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", ctx, key)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockCacherMockRecorder) Get(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCacher)(nil).Get), ctx, key)
|
||||
}
|
||||
|
||||
// Set mocks base method.
|
||||
func (m *MockCacher) Set(ctx context.Context, key string, val []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Set", ctx, key, val)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Set indicates an expected call of Set.
|
||||
func (mr *MockCacherMockRecorder) Set(ctx, key, val any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCacher)(nil).Set), ctx, key, val)
|
||||
}
|
||||
|
||||
// SetWithExpiry mocks base method.
|
||||
func (m *MockCacher) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetWithExpiry", ctx, key, val, expiry)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SetWithExpiry indicates an expected call of SetWithExpiry.
|
||||
func (mr *MockCacherMockRecorder) SetWithExpiry(ctx, key, val, expiry any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWithExpiry", reflect.TypeOf((*MockCacher)(nil).SetWithExpiry), ctx, key, val, expiry)
|
||||
}
|
||||
|
||||
// Shutdown mocks base method.
|
||||
func (m *MockCacher) Shutdown(ctx context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Shutdown", ctx)
|
||||
}
|
||||
|
||||
// Shutdown indicates an expected call of Shutdown.
|
||||
func (mr *MockCacherMockRecorder) Shutdown(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockCacher)(nil).Shutdown), ctx)
|
||||
}
|
||||
319
cloud/maplefile-backend/pkg/mocks/mock_storage_object_s3.go
Normal file
319
cloud/maplefile-backend/pkg/mocks/mock_storage_object_s3.go
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: pkg/storage/object/s3/s3.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -source=pkg/storage/object/s3/s3.go -destination=pkg/mocks/mock_storage_object_s3.go -package=mocks
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
io "io"
|
||||
multipart "mime/multipart"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
s3 "github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockS3ObjectStorage is a mock of S3ObjectStorage interface.
|
||||
type MockS3ObjectStorage struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockS3ObjectStorageMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockS3ObjectStorageMockRecorder is the mock recorder for MockS3ObjectStorage.
|
||||
type MockS3ObjectStorageMockRecorder struct {
|
||||
mock *MockS3ObjectStorage
|
||||
}
|
||||
|
||||
// NewMockS3ObjectStorage creates a new mock instance.
|
||||
func NewMockS3ObjectStorage(ctrl *gomock.Controller) *MockS3ObjectStorage {
|
||||
mock := &MockS3ObjectStorage{ctrl: ctrl}
|
||||
mock.recorder = &MockS3ObjectStorageMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockS3ObjectStorage) EXPECT() *MockS3ObjectStorageMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BucketExists mocks base method.
|
||||
func (m *MockS3ObjectStorage) BucketExists(ctx context.Context, bucketName string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BucketExists", ctx, bucketName)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BucketExists indicates an expected call of BucketExists.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) BucketExists(ctx, bucketName any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BucketExists", reflect.TypeOf((*MockS3ObjectStorage)(nil).BucketExists), ctx, bucketName)
|
||||
}
|
||||
|
||||
// Copy mocks base method.
|
||||
func (m *MockS3ObjectStorage) Copy(ctx context.Context, sourceObjectKey, destinationObjectKey string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Copy", ctx, sourceObjectKey, destinationObjectKey)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Copy indicates an expected call of Copy.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) Copy(ctx, sourceObjectKey, destinationObjectKey any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockS3ObjectStorage)(nil).Copy), ctx, sourceObjectKey, destinationObjectKey)
|
||||
}
|
||||
|
||||
// CopyWithVisibility mocks base method.
|
||||
func (m *MockS3ObjectStorage) CopyWithVisibility(ctx context.Context, sourceObjectKey, destinationObjectKey string, isPublic bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CopyWithVisibility", ctx, sourceObjectKey, destinationObjectKey, isPublic)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CopyWithVisibility indicates an expected call of CopyWithVisibility.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) CopyWithVisibility(ctx, sourceObjectKey, destinationObjectKey, isPublic any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CopyWithVisibility", reflect.TypeOf((*MockS3ObjectStorage)(nil).CopyWithVisibility), ctx, sourceObjectKey, destinationObjectKey, isPublic)
|
||||
}
|
||||
|
||||
// Cut mocks base method.
|
||||
func (m *MockS3ObjectStorage) Cut(ctx context.Context, sourceObjectKey, destinationObjectKey string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Cut", ctx, sourceObjectKey, destinationObjectKey)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Cut indicates an expected call of Cut.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) Cut(ctx, sourceObjectKey, destinationObjectKey any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cut", reflect.TypeOf((*MockS3ObjectStorage)(nil).Cut), ctx, sourceObjectKey, destinationObjectKey)
|
||||
}
|
||||
|
||||
// CutWithVisibility mocks base method.
|
||||
func (m *MockS3ObjectStorage) CutWithVisibility(ctx context.Context, sourceObjectKey, destinationObjectKey string, isPublic bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CutWithVisibility", ctx, sourceObjectKey, destinationObjectKey, isPublic)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CutWithVisibility indicates an expected call of CutWithVisibility.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) CutWithVisibility(ctx, sourceObjectKey, destinationObjectKey, isPublic any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CutWithVisibility", reflect.TypeOf((*MockS3ObjectStorage)(nil).CutWithVisibility), ctx, sourceObjectKey, destinationObjectKey, isPublic)
|
||||
}
|
||||
|
||||
// DeleteByKeys mocks base method.
|
||||
func (m *MockS3ObjectStorage) DeleteByKeys(ctx context.Context, key []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteByKeys", ctx, key)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteByKeys indicates an expected call of DeleteByKeys.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) DeleteByKeys(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteByKeys", reflect.TypeOf((*MockS3ObjectStorage)(nil).DeleteByKeys), ctx, key)
|
||||
}
|
||||
|
||||
// DownloadToLocalfile mocks base method.
|
||||
func (m *MockS3ObjectStorage) DownloadToLocalfile(ctx context.Context, objectKey, filePath string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DownloadToLocalfile", ctx, objectKey, filePath)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DownloadToLocalfile indicates an expected call of DownloadToLocalfile.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) DownloadToLocalfile(ctx, objectKey, filePath any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadToLocalfile", reflect.TypeOf((*MockS3ObjectStorage)(nil).DownloadToLocalfile), ctx, objectKey, filePath)
|
||||
}
|
||||
|
||||
// FindMatchingObjectKey mocks base method.
|
||||
func (m *MockS3ObjectStorage) FindMatchingObjectKey(s3Objects *s3.ListObjectsOutput, partialKey string) string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindMatchingObjectKey", s3Objects, partialKey)
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindMatchingObjectKey indicates an expected call of FindMatchingObjectKey.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) FindMatchingObjectKey(s3Objects, partialKey any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMatchingObjectKey", reflect.TypeOf((*MockS3ObjectStorage)(nil).FindMatchingObjectKey), s3Objects, partialKey)
|
||||
}
|
||||
|
||||
// GeneratePresignedUploadURL mocks base method.
|
||||
func (m *MockS3ObjectStorage) GeneratePresignedUploadURL(ctx context.Context, key string, duration time.Duration) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GeneratePresignedUploadURL", ctx, key, duration)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GeneratePresignedUploadURL indicates an expected call of GeneratePresignedUploadURL.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) GeneratePresignedUploadURL(ctx, key, duration any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GeneratePresignedUploadURL", reflect.TypeOf((*MockS3ObjectStorage)(nil).GeneratePresignedUploadURL), ctx, key, duration)
|
||||
}
|
||||
|
||||
// GetBinaryData mocks base method.
|
||||
func (m *MockS3ObjectStorage) GetBinaryData(ctx context.Context, objectKey string) (io.ReadCloser, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetBinaryData", ctx, objectKey)
|
||||
ret0, _ := ret[0].(io.ReadCloser)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetBinaryData indicates an expected call of GetBinaryData.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) GetBinaryData(ctx, objectKey any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBinaryData", reflect.TypeOf((*MockS3ObjectStorage)(nil).GetBinaryData), ctx, objectKey)
|
||||
}
|
||||
|
||||
// GetDownloadablePresignedURL mocks base method.
|
||||
func (m *MockS3ObjectStorage) GetDownloadablePresignedURL(ctx context.Context, key string, duration time.Duration) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetDownloadablePresignedURL", ctx, key, duration)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetDownloadablePresignedURL indicates an expected call of GetDownloadablePresignedURL.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) GetDownloadablePresignedURL(ctx, key, duration any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDownloadablePresignedURL", reflect.TypeOf((*MockS3ObjectStorage)(nil).GetDownloadablePresignedURL), ctx, key, duration)
|
||||
}
|
||||
|
||||
// GetObjectSize mocks base method.
|
||||
func (m *MockS3ObjectStorage) GetObjectSize(ctx context.Context, key string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetObjectSize", ctx, key)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetObjectSize indicates an expected call of GetObjectSize.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) GetObjectSize(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectSize", reflect.TypeOf((*MockS3ObjectStorage)(nil).GetObjectSize), ctx, key)
|
||||
}
|
||||
|
||||
// IsPublicBucket mocks base method.
|
||||
func (m *MockS3ObjectStorage) IsPublicBucket() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsPublicBucket")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsPublicBucket indicates an expected call of IsPublicBucket.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) IsPublicBucket() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicBucket", reflect.TypeOf((*MockS3ObjectStorage)(nil).IsPublicBucket))
|
||||
}
|
||||
|
||||
// ListAllObjects mocks base method.
|
||||
func (m *MockS3ObjectStorage) ListAllObjects(ctx context.Context) (*s3.ListObjectsOutput, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAllObjects", ctx)
|
||||
ret0, _ := ret[0].(*s3.ListObjectsOutput)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAllObjects indicates an expected call of ListAllObjects.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) ListAllObjects(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllObjects", reflect.TypeOf((*MockS3ObjectStorage)(nil).ListAllObjects), ctx)
|
||||
}
|
||||
|
||||
// ObjectExists mocks base method.
|
||||
func (m *MockS3ObjectStorage) ObjectExists(ctx context.Context, key string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ObjectExists", ctx, key)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ObjectExists indicates an expected call of ObjectExists.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) ObjectExists(ctx, key any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectExists", reflect.TypeOf((*MockS3ObjectStorage)(nil).ObjectExists), ctx, key)
|
||||
}
|
||||
|
||||
// UploadContent mocks base method.
|
||||
func (m *MockS3ObjectStorage) UploadContent(ctx context.Context, objectKey string, content []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadContent", ctx, objectKey, content)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UploadContent indicates an expected call of UploadContent.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) UploadContent(ctx, objectKey, content any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadContent", reflect.TypeOf((*MockS3ObjectStorage)(nil).UploadContent), ctx, objectKey, content)
|
||||
}
|
||||
|
||||
// UploadContentFromMulipart mocks base method.
|
||||
func (m *MockS3ObjectStorage) UploadContentFromMulipart(ctx context.Context, objectKey string, file multipart.File) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadContentFromMulipart", ctx, objectKey, file)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UploadContentFromMulipart indicates an expected call of UploadContentFromMulipart.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) UploadContentFromMulipart(ctx, objectKey, file any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadContentFromMulipart", reflect.TypeOf((*MockS3ObjectStorage)(nil).UploadContentFromMulipart), ctx, objectKey, file)
|
||||
}
|
||||
|
||||
// UploadContentFromMulipartWithVisibility mocks base method.
|
||||
func (m *MockS3ObjectStorage) UploadContentFromMulipartWithVisibility(ctx context.Context, objectKey string, file multipart.File, isPublic bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadContentFromMulipartWithVisibility", ctx, objectKey, file, isPublic)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UploadContentFromMulipartWithVisibility indicates an expected call of UploadContentFromMulipartWithVisibility.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) UploadContentFromMulipartWithVisibility(ctx, objectKey, file, isPublic any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadContentFromMulipartWithVisibility", reflect.TypeOf((*MockS3ObjectStorage)(nil).UploadContentFromMulipartWithVisibility), ctx, objectKey, file, isPublic)
|
||||
}
|
||||
|
||||
// UploadContentWithVisibility mocks base method.
|
||||
func (m *MockS3ObjectStorage) UploadContentWithVisibility(ctx context.Context, objectKey string, content []byte, isPublic bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UploadContentWithVisibility", ctx, objectKey, content, isPublic)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UploadContentWithVisibility indicates an expected call of UploadContentWithVisibility.
|
||||
func (mr *MockS3ObjectStorageMockRecorder) UploadContentWithVisibility(ctx, objectKey, content, isPublic any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UploadContentWithVisibility", reflect.TypeOf((*MockS3ObjectStorage)(nil).UploadContentWithVisibility), ctx, objectKey, content, isPublic)
|
||||
}
|
||||
453
cloud/maplefile-backend/pkg/observability/health.go
Normal file
453
cloud/maplefile-backend/pkg/observability/health.go
Normal file
|
|
@ -0,0 +1,453 @@
|
|||
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/observability/health.go
|
||||
package observability
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/cache/twotiercache"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/object/s3"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a component
|
||||
type HealthStatus string
|
||||
|
||||
const (
|
||||
HealthStatusHealthy HealthStatus = "healthy"
|
||||
HealthStatusUnhealthy HealthStatus = "unhealthy"
|
||||
HealthStatusDegraded HealthStatus = "degraded"
|
||||
)
|
||||
|
||||
// HealthCheckResult represents the result of a health check
|
||||
type HealthCheckResult struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
Component string `json:"component"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// HealthResponse represents the overall health response
|
||||
type HealthResponse struct {
|
||||
Status HealthStatus `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Services map[string]HealthCheckResult `json:"services"`
|
||||
Version string `json:"version"`
|
||||
Uptime string `json:"uptime"`
|
||||
}
|
||||
|
||||
// HealthChecker manages health checks for various components
|
||||
type HealthChecker struct {
|
||||
checks map[string]HealthCheck
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// HealthCheck represents a health check function
|
||||
type HealthCheck func(ctx context.Context) HealthCheckResult
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(logger *zap.Logger) *HealthChecker {
|
||||
return &HealthChecker{
|
||||
checks: make(map[string]HealthCheck),
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterCheck registers a health check for a service
|
||||
func (hc *HealthChecker) RegisterCheck(name string, check HealthCheck) {
|
||||
hc.mu.Lock()
|
||||
defer hc.mu.Unlock()
|
||||
hc.checks[name] = check
|
||||
}
|
||||
|
||||
// CheckHealth performs all registered health checks
|
||||
func (hc *HealthChecker) CheckHealth(ctx context.Context) HealthResponse {
|
||||
hc.mu.RLock()
|
||||
checks := make(map[string]HealthCheck, len(hc.checks))
|
||||
for name, check := range hc.checks {
|
||||
checks[name] = check
|
||||
}
|
||||
hc.mu.RUnlock()
|
||||
|
||||
results := make(map[string]HealthCheckResult)
|
||||
overallStatus := HealthStatusHealthy
|
||||
|
||||
// Run health checks with timeout
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for name, check := range checks {
|
||||
start := time.Now()
|
||||
result := check(checkCtx)
|
||||
result.Duration = time.Since(start).String()
|
||||
|
||||
results[name] = result
|
||||
|
||||
// Determine overall status
|
||||
if result.Status == HealthStatusUnhealthy {
|
||||
overallStatus = HealthStatusUnhealthy
|
||||
} else if result.Status == HealthStatusDegraded && overallStatus == HealthStatusHealthy {
|
||||
overallStatus = HealthStatusDegraded
|
||||
}
|
||||
}
|
||||
|
||||
return HealthResponse{
|
||||
Status: overallStatus,
|
||||
Timestamp: time.Now(),
|
||||
Services: results,
|
||||
Version: "1.0.0", // Could be injected
|
||||
Uptime: time.Since(hc.startTime).String(),
|
||||
}
|
||||
}
|
||||
|
||||
// HealthHandler creates an HTTP handler for health checks
|
||||
func (hc *HealthChecker) HealthHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
health := hc.CheckHealth(ctx)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// Set appropriate status code
|
||||
switch health.Status {
|
||||
case HealthStatusHealthy:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case HealthStatusDegraded:
|
||||
w.WriteHeader(http.StatusOK) // 200 but degraded
|
||||
case HealthStatusUnhealthy:
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(health); err != nil {
|
||||
hc.logger.Error("Failed to encode health response", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReadinessHandler creates a simple readiness probe
|
||||
func (hc *HealthChecker) ReadinessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
health := hc.CheckHealth(ctx)
|
||||
|
||||
// For readiness, we're more strict - any unhealthy component means not ready
|
||||
if health.Status == HealthStatusUnhealthy {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
w.Write([]byte("NOT READY"))
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("READY"))
|
||||
}
|
||||
}
|
||||
|
||||
// LivenessHandler creates a simple liveness probe
|
||||
func (hc *HealthChecker) LivenessHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// For liveness, we just check if the service can respond
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ALIVE"))
|
||||
}
|
||||
}
|
||||
|
||||
// CassandraHealthCheck creates a health check for Cassandra database connectivity
|
||||
func CassandraHealthCheck(session *gocql.Session, logger *zap.Logger) HealthCheck {
|
||||
return func(ctx context.Context) HealthCheckResult {
|
||||
start := time.Now()
|
||||
|
||||
// Check if session is nil
|
||||
if session == nil {
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cassandra session is nil",
|
||||
Timestamp: time.Now(),
|
||||
Component: "cassandra",
|
||||
Details: map[string]interface{}{"error": "session_nil"},
|
||||
}
|
||||
}
|
||||
|
||||
// Try to execute a simple query with context
|
||||
var result string
|
||||
query := session.Query("SELECT uuid() FROM system.local")
|
||||
|
||||
// Create a channel to handle the query execution
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- query.Scan(&result)
|
||||
}()
|
||||
|
||||
// Wait for either completion or context cancellation
|
||||
select {
|
||||
case err := <-done:
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn("Cassandra health check failed",
|
||||
zap.Error(err),
|
||||
zap.Duration("duration", duration))
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cassandra query failed: " + err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
Component: "cassandra",
|
||||
Details: map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusHealthy,
|
||||
Message: "Cassandra connection healthy",
|
||||
Timestamp: time.Now(),
|
||||
Component: "cassandra",
|
||||
Details: map[string]interface{}{
|
||||
"query_result": result,
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cassandra health check timed out",
|
||||
Timestamp: time.Now(),
|
||||
Component: "cassandra",
|
||||
Details: map[string]interface{}{
|
||||
"error": "timeout",
|
||||
"duration": time.Since(start).String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TwoTierCacheHealthCheck creates a health check for the two-tier cache system
|
||||
func TwoTierCacheHealthCheck(cache twotiercache.TwoTierCacher, logger *zap.Logger) HealthCheck {
|
||||
return func(ctx context.Context) HealthCheckResult {
|
||||
start := time.Now()
|
||||
|
||||
if cache == nil {
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cache instance is nil",
|
||||
Timestamp: time.Now(),
|
||||
Component: "two_tier_cache",
|
||||
Details: map[string]interface{}{"error": "cache_nil"},
|
||||
}
|
||||
}
|
||||
|
||||
// Test cache functionality with a health check key
|
||||
healthKey := "health_check_" + time.Now().Format("20060102150405")
|
||||
testValue := []byte("health_check_value")
|
||||
|
||||
// Test Set operation
|
||||
if err := cache.Set(ctx, healthKey, testValue); err != nil {
|
||||
duration := time.Since(start)
|
||||
logger.Warn("Cache health check SET failed",
|
||||
zap.Error(err),
|
||||
zap.Duration("duration", duration))
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cache SET operation failed: " + err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
Component: "two_tier_cache",
|
||||
Details: map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"operation": "set",
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Test Get operation
|
||||
retrievedValue, err := cache.Get(ctx, healthKey)
|
||||
if err != nil {
|
||||
duration := time.Since(start)
|
||||
logger.Warn("Cache health check GET failed",
|
||||
zap.Error(err),
|
||||
zap.Duration("duration", duration))
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "Cache GET operation failed: " + err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
Component: "two_tier_cache",
|
||||
Details: map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"operation": "get",
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the value
|
||||
if string(retrievedValue) != string(testValue) {
|
||||
duration := time.Since(start)
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusDegraded,
|
||||
Message: "Cache value mismatch",
|
||||
Timestamp: time.Now(),
|
||||
Component: "two_tier_cache",
|
||||
Details: map[string]interface{}{
|
||||
"expected": string(testValue),
|
||||
"actual": string(retrievedValue),
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up test key
|
||||
_ = cache.Delete(ctx, healthKey)
|
||||
|
||||
duration := time.Since(start)
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusHealthy,
|
||||
Message: "Two-tier cache healthy",
|
||||
Timestamp: time.Now(),
|
||||
Component: "two_tier_cache",
|
||||
Details: map[string]interface{}{
|
||||
"operations_tested": []string{"set", "get", "delete"},
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// S3HealthCheck creates a health check for S3 object storage
|
||||
func S3HealthCheck(s3Storage s3.S3ObjectStorage, logger *zap.Logger) HealthCheck {
|
||||
return func(ctx context.Context) HealthCheckResult {
|
||||
start := time.Now()
|
||||
|
||||
if s3Storage == nil {
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "S3 storage instance is nil",
|
||||
Timestamp: time.Now(),
|
||||
Component: "s3_storage",
|
||||
Details: map[string]interface{}{"error": "storage_nil"},
|
||||
}
|
||||
}
|
||||
|
||||
// Test basic S3 connectivity by listing objects (lightweight operation)
|
||||
_, err := s3Storage.ListAllObjects(ctx)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn("S3 health check failed",
|
||||
zap.Error(err),
|
||||
zap.Duration("duration", duration))
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusUnhealthy,
|
||||
Message: "S3 connectivity failed: " + err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
Component: "s3_storage",
|
||||
Details: map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"operation": "list_objects",
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return HealthCheckResult{
|
||||
Status: HealthStatusHealthy,
|
||||
Message: "S3 storage healthy",
|
||||
Timestamp: time.Now(),
|
||||
Component: "s3_storage",
|
||||
Details: map[string]interface{}{
|
||||
"operation": "list_objects",
|
||||
"duration": duration.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRealHealthChecks registers health checks for actual infrastructure components
|
||||
// Note: This function was previously used with Uber FX. It can be called directly
|
||||
// or wired through Google Wire if needed.
|
||||
func RegisterRealHealthChecks(
|
||||
hc *HealthChecker,
|
||||
logger *zap.Logger,
|
||||
cassandraSession *gocql.Session,
|
||||
cache twotiercache.TwoTierCacher,
|
||||
s3Storage s3.S3ObjectStorage,
|
||||
) {
|
||||
// Register Cassandra health check
|
||||
hc.RegisterCheck("cassandra", CassandraHealthCheck(cassandraSession, logger))
|
||||
|
||||
// Register two-tier cache health check
|
||||
hc.RegisterCheck("cache", TwoTierCacheHealthCheck(cache, logger))
|
||||
|
||||
// Register S3 storage health check
|
||||
hc.RegisterCheck("s3_storage", S3HealthCheck(s3Storage, logger))
|
||||
|
||||
logger.Info("Real infrastructure health checks registered",
|
||||
zap.Strings("components", []string{"cassandra", "cache", "s3_storage"}))
|
||||
}
|
||||
|
||||
// StartObservabilityServer starts the observability HTTP server on a separate port
|
||||
// Note: This function was previously integrated with Uber FX lifecycle.
|
||||
// It should now be called manually or integrated with Google Wire if needed.
|
||||
func StartObservabilityServer(
|
||||
hc *HealthChecker,
|
||||
ms *MetricsServer,
|
||||
logger *zap.Logger,
|
||||
) (*http.Server, error) {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Health endpoints
|
||||
mux.HandleFunc("/health", hc.HealthHandler())
|
||||
mux.HandleFunc("/health/ready", hc.ReadinessHandler())
|
||||
mux.HandleFunc("/health/live", hc.LivenessHandler())
|
||||
|
||||
// Metrics endpoint
|
||||
mux.Handle("/metrics", ms.Handler())
|
||||
|
||||
server := &http.Server{
|
||||
Addr: ":8080", // Separate port for observability
|
||||
Handler: mux,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.Info("Starting observability server on :8080")
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Observability server failed", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
return server, nil
|
||||
}
|
||||
89
cloud/maplefile-backend/pkg/observability/metrics.go
Normal file
89
cloud/maplefile-backend/pkg/observability/metrics.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/observability/metrics.go
|
||||
package observability
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MetricsServer provides basic metrics endpoint
|
||||
type MetricsServer struct {
|
||||
logger *zap.Logger
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewMetricsServer creates a new metrics server
|
||||
func NewMetricsServer(logger *zap.Logger) *MetricsServer {
|
||||
return &MetricsServer{
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an HTTP handler that serves basic metrics
|
||||
func (ms *MetricsServer) Handler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
metrics := ms.collectMetrics()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, metric := range metrics {
|
||||
fmt.Fprintf(w, "%s\n", metric)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectMetrics collects basic application metrics
|
||||
func (ms *MetricsServer) collectMetrics() []string {
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
uptime := time.Since(ms.startTime).Seconds()
|
||||
|
||||
metrics := []string{
|
||||
fmt.Sprintf("# HELP mapleopentech_uptime_seconds Total uptime of the service in seconds"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_uptime_seconds counter"),
|
||||
fmt.Sprintf("mapleopentech_uptime_seconds %.2f", uptime),
|
||||
|
||||
fmt.Sprintf("# HELP mapleopentech_memory_alloc_bytes Currently allocated memory in bytes"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_memory_alloc_bytes gauge"),
|
||||
fmt.Sprintf("mapleopentech_memory_alloc_bytes %d", m.Alloc),
|
||||
|
||||
fmt.Sprintf("# HELP mapleopentech_memory_total_alloc_bytes Total allocated memory in bytes"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_memory_total_alloc_bytes counter"),
|
||||
fmt.Sprintf("mapleopentech_memory_total_alloc_bytes %d", m.TotalAlloc),
|
||||
|
||||
fmt.Sprintf("# HELP mapleopentech_memory_sys_bytes Memory obtained from system in bytes"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_memory_sys_bytes gauge"),
|
||||
fmt.Sprintf("mapleopentech_memory_sys_bytes %d", m.Sys),
|
||||
|
||||
fmt.Sprintf("# HELP mapleopentech_gc_runs_total Total number of GC runs"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_gc_runs_total counter"),
|
||||
fmt.Sprintf("mapleopentech_gc_runs_total %d", m.NumGC),
|
||||
|
||||
fmt.Sprintf("# HELP mapleopentech_goroutines Current number of goroutines"),
|
||||
fmt.Sprintf("# TYPE mapleopentech_goroutines gauge"),
|
||||
fmt.Sprintf("mapleopentech_goroutines %d", runtime.NumGoroutine()),
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// RecordMetric records a custom metric (placeholder for future implementation)
|
||||
func (ms *MetricsServer) RecordMetric(name string, value float64, labels map[string]string) {
|
||||
ms.logger.Debug("Recording metric",
|
||||
zap.String("name", name),
|
||||
zap.Float64("value", value),
|
||||
zap.Any("labels", labels),
|
||||
)
|
||||
}
|
||||
6
cloud/maplefile-backend/pkg/observability/module.go
Normal file
6
cloud/maplefile-backend/pkg/observability/module.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/observability/module.go
|
||||
package observability
|
||||
|
||||
// Note: This file previously contained Uber FX module definitions.
|
||||
// The application now uses Google Wire for dependency injection.
|
||||
// Observability components should be wired through Wire providers if needed.
|
||||
92
cloud/maplefile-backend/pkg/observability/routes.go
Normal file
92
cloud/maplefile-backend/pkg/observability/routes.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
// codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/observability/routes.go
|
||||
package observability
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HealthRoute provides detailed health check endpoint
|
||||
type HealthRoute struct {
|
||||
checker *HealthChecker
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewHealthRoute(checker *HealthChecker, logger *zap.Logger) *HealthRoute {
|
||||
return &HealthRoute{
|
||||
checker: checker,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HealthRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.checker.HealthHandler()(w, r)
|
||||
}
|
||||
|
||||
func (h *HealthRoute) Pattern() string {
|
||||
return "/health"
|
||||
}
|
||||
|
||||
// ReadinessRoute provides readiness probe endpoint
|
||||
type ReadinessRoute struct {
|
||||
checker *HealthChecker
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewReadinessRoute(checker *HealthChecker, logger *zap.Logger) *ReadinessRoute {
|
||||
return &ReadinessRoute{
|
||||
checker: checker,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReadinessRoute) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
r.checker.ReadinessHandler()(w, req)
|
||||
}
|
||||
|
||||
func (r *ReadinessRoute) Pattern() string {
|
||||
return "/health/ready"
|
||||
}
|
||||
|
||||
// LivenessRoute provides liveness probe endpoint
|
||||
type LivenessRoute struct {
|
||||
checker *HealthChecker
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewLivenessRoute(checker *HealthChecker, logger *zap.Logger) *LivenessRoute {
|
||||
return &LivenessRoute{
|
||||
checker: checker,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LivenessRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
l.checker.LivenessHandler()(w, r)
|
||||
}
|
||||
|
||||
func (l *LivenessRoute) Pattern() string {
|
||||
return "/health/live"
|
||||
}
|
||||
|
||||
// MetricsRoute provides metrics endpoint
|
||||
type MetricsRoute struct {
|
||||
server *MetricsServer
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewMetricsRoute(server *MetricsServer, logger *zap.Logger) *MetricsRoute {
|
||||
return &MetricsRoute{
|
||||
server: server,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetricsRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.server.Handler()(w, r)
|
||||
}
|
||||
|
||||
func (m *MetricsRoute) Pattern() string {
|
||||
return "/metrics"
|
||||
}
|
||||
21
cloud/maplefile-backend/pkg/random/numbers.go
Normal file
21
cloud/maplefile-backend/pkg/random/numbers.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package random
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
// GenerateSixDigitCode generates a cryptographically secure random 6-digit number
|
||||
func GenerateSixDigitCode() (string, error) {
|
||||
// Generate a random number between 100000 and 999999
|
||||
max := big.NewInt(900000) // 999999 - 100000 + 1
|
||||
n, err := rand.Int(rand.Reader, max)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Add 100000 to ensure 6 digits
|
||||
n.Add(n, big.NewInt(100000))
|
||||
|
||||
return n.String(), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,366 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuthFailureRateLimiter provides specialized rate limiting for authorization failures
|
||||
// to protect against privilege escalation and unauthorized access attempts
|
||||
type AuthFailureRateLimiter interface {
|
||||
// CheckAuthFailure checks if the user has exceeded authorization failure limits
|
||||
// Returns: allowed (bool), remainingAttempts (int), resetTime (time.Time), error
|
||||
CheckAuthFailure(ctx context.Context, userID string, resourceID string, action string) (bool, int, time.Time, error)
|
||||
|
||||
// RecordAuthFailure records an authorization failure
|
||||
RecordAuthFailure(ctx context.Context, userID string, resourceID string, action string, reason string) error
|
||||
|
||||
// RecordAuthSuccess records a successful authorization (optionally resets counters)
|
||||
RecordAuthSuccess(ctx context.Context, userID string, resourceID string, action string) error
|
||||
|
||||
// IsUserBlocked checks if a user is temporarily blocked from authorization attempts
|
||||
IsUserBlocked(ctx context.Context, userID string) (bool, time.Duration, error)
|
||||
|
||||
// GetFailureCount returns the number of authorization failures for a user
|
||||
GetFailureCount(ctx context.Context, userID string) (int, error)
|
||||
|
||||
// GetResourceFailureCount returns failures for a specific resource
|
||||
GetResourceFailureCount(ctx context.Context, userID string, resourceID string) (int, error)
|
||||
|
||||
// ResetUserFailures manually resets failure counters for a user
|
||||
ResetUserFailures(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// AuthFailureRateLimiterConfig holds configuration for authorization failure rate limiting
|
||||
type AuthFailureRateLimiterConfig struct {
|
||||
// MaxFailuresPerUser is the maximum authorization failures per user before blocking
|
||||
MaxFailuresPerUser int
|
||||
// MaxFailuresPerResource is the maximum failures per resource per user
|
||||
MaxFailuresPerResource int
|
||||
// FailureWindow is the time window for tracking failures
|
||||
FailureWindow time.Duration
|
||||
// BlockDuration is how long to block a user after exceeding limits
|
||||
BlockDuration time.Duration
|
||||
// AlertThreshold is the number of failures before alerting (for monitoring)
|
||||
AlertThreshold int
|
||||
// KeyPrefix is the prefix for Redis keys
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
// DefaultAuthFailureRateLimiterConfig returns recommended configuration
|
||||
// Following OWASP guidelines for authorization failure handling
|
||||
func DefaultAuthFailureRateLimiterConfig() AuthFailureRateLimiterConfig {
|
||||
return AuthFailureRateLimiterConfig{
|
||||
MaxFailuresPerUser: 20, // 20 total auth failures per user
|
||||
MaxFailuresPerResource: 5, // 5 failures per specific resource
|
||||
FailureWindow: 15 * time.Minute, // in 15-minute window
|
||||
BlockDuration: 30 * time.Minute, // block for 30 minutes
|
||||
AlertThreshold: 10, // alert after 10 failures
|
||||
KeyPrefix: "auth_fail_rl",
|
||||
}
|
||||
}
|
||||
|
||||
type authFailureRateLimiter struct {
|
||||
client *redis.Client
|
||||
config AuthFailureRateLimiterConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthFailureRateLimiter creates a new authorization failure rate limiter
|
||||
func NewAuthFailureRateLimiter(client *redis.Client, config AuthFailureRateLimiterConfig, logger *zap.Logger) AuthFailureRateLimiter {
|
||||
return &authFailureRateLimiter{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger.Named("auth-failure-rate-limiter"),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAuthFailure checks if the user has exceeded authorization failure limits
|
||||
// CWE-307: Protection against authorization brute force attacks
|
||||
// OWASP A01:2021: Broken Access Control - Rate limiting authorization failures
|
||||
func (r *authFailureRateLimiter) CheckAuthFailure(ctx context.Context, userID string, resourceID string, action string) (bool, int, time.Time, error) {
|
||||
// Check if user is blocked
|
||||
blocked, remaining, err := r.IsUserBlocked(ctx, userID)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check user block status",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error (security vs availability trade-off)
|
||||
return true, 0, time.Time{}, err
|
||||
}
|
||||
|
||||
if blocked {
|
||||
resetTime := time.Now().Add(remaining)
|
||||
r.logger.Warn("blocked user attempted authorization",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.String("action", action),
|
||||
zap.Duration("remaining_block", remaining))
|
||||
return false, 0, resetTime, nil
|
||||
}
|
||||
|
||||
// Check per-user failure count
|
||||
userFailures, err := r.GetFailureCount(ctx, userID)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to get user failure count",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, 0, time.Time{}, err
|
||||
}
|
||||
|
||||
// Check per-resource failure count
|
||||
resourceFailures, err := r.GetResourceFailureCount(ctx, userID, resourceID)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to get resource failure count",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, 0, time.Time{}, err
|
||||
}
|
||||
|
||||
// Check if limits exceeded
|
||||
if userFailures >= r.config.MaxFailuresPerUser {
|
||||
r.blockUser(ctx, userID)
|
||||
resetTime := time.Now().Add(r.config.BlockDuration)
|
||||
r.logger.Warn("user exceeded authorization failure limit",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Int("failures", userFailures))
|
||||
return false, 0, resetTime, nil
|
||||
}
|
||||
|
||||
if resourceFailures >= r.config.MaxFailuresPerResource {
|
||||
resetTime := time.Now().Add(r.config.FailureWindow)
|
||||
r.logger.Warn("user exceeded resource-specific failure limit",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.Int("failures", resourceFailures))
|
||||
return false, r.config.MaxFailuresPerUser - userFailures, resetTime, nil
|
||||
}
|
||||
|
||||
remainingAttempts := r.config.MaxFailuresPerUser - userFailures
|
||||
resetTime := time.Now().Add(r.config.FailureWindow)
|
||||
|
||||
return true, remainingAttempts, resetTime, nil
|
||||
}
|
||||
|
||||
// RecordAuthFailure records an authorization failure
|
||||
// CWE-778: Insufficient Logging of security events
|
||||
func (r *authFailureRateLimiter) RecordAuthFailure(ctx context.Context, userID string, resourceID string, action string, reason string) error {
|
||||
now := time.Now()
|
||||
timestamp := now.UnixNano()
|
||||
|
||||
// Record per-user failure
|
||||
userKey := r.getUserFailureKey(userID)
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
// Add to sorted set with timestamp as score (for windowing)
|
||||
pipe.ZAdd(ctx, userKey, redis.Z{
|
||||
Score: float64(timestamp),
|
||||
Member: fmt.Sprintf("%d:%s:%s", timestamp, resourceID, action),
|
||||
})
|
||||
pipe.Expire(ctx, userKey, r.config.FailureWindow)
|
||||
|
||||
// Record per-resource failure
|
||||
if resourceID != "" {
|
||||
resourceKey := r.getResourceFailureKey(userID, resourceID)
|
||||
pipe.ZAdd(ctx, resourceKey, redis.Z{
|
||||
Score: float64(timestamp),
|
||||
Member: fmt.Sprintf("%d:%s", timestamp, action),
|
||||
})
|
||||
pipe.Expire(ctx, resourceKey, r.config.FailureWindow)
|
||||
}
|
||||
|
||||
// Increment total failure counter for metrics
|
||||
metricsKey := r.getMetricsKey(userID)
|
||||
pipe.Incr(ctx, metricsKey)
|
||||
pipe.Expire(ctx, metricsKey, 24*time.Hour) // Keep metrics for 24 hours
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to record authorization failure",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.String("action", action),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we should alert
|
||||
count, _ := r.GetFailureCount(ctx, userID)
|
||||
if count == r.config.AlertThreshold {
|
||||
r.logger.Error("SECURITY ALERT: User reached authorization failure alert threshold",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.String("action", action),
|
||||
zap.String("reason", reason),
|
||||
zap.Int("failure_count", count))
|
||||
}
|
||||
|
||||
r.logger.Warn("authorization failure recorded",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.String("action", action),
|
||||
zap.String("reason", reason),
|
||||
zap.Int("total_failures", count))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordAuthSuccess records a successful authorization
|
||||
func (r *authFailureRateLimiter) RecordAuthSuccess(ctx context.Context, userID string, resourceID string, action string) error {
|
||||
// Optionally, we could reset or reduce failure counts on success
|
||||
// For now, we just log the success for audit purposes
|
||||
r.logger.Debug("authorization success recorded",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.String("resource_id_hash", hashID(resourceID)),
|
||||
zap.String("action", action))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsUserBlocked checks if a user is temporarily blocked
|
||||
func (r *authFailureRateLimiter) IsUserBlocked(ctx context.Context, userID string) (bool, time.Duration, error) {
|
||||
blockKey := r.getBlockKey(userID)
|
||||
|
||||
ttl, err := r.client.TTL(ctx, blockKey).Result()
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// TTL returns -2 if key doesn't exist, -1 if no expiration
|
||||
if ttl < 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
return true, ttl, nil
|
||||
}
|
||||
|
||||
// GetFailureCount returns the number of authorization failures for a user
|
||||
func (r *authFailureRateLimiter) GetFailureCount(ctx context.Context, userID string) (int, error) {
|
||||
userKey := r.getUserFailureKey(userID)
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.FailureWindow)
|
||||
|
||||
// Remove old entries outside the window
|
||||
r.client.ZRemRangeByScore(ctx, userKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
||||
|
||||
// Count current failures in window
|
||||
count, err := r.client.ZCount(ctx, userKey,
|
||||
fmt.Sprintf("%d", windowStart.UnixNano()),
|
||||
"+inf").Result()
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// GetResourceFailureCount returns failures for a specific resource
|
||||
func (r *authFailureRateLimiter) GetResourceFailureCount(ctx context.Context, userID string, resourceID string) (int, error) {
|
||||
if resourceID == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
resourceKey := r.getResourceFailureKey(userID, resourceID)
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.FailureWindow)
|
||||
|
||||
// Remove old entries
|
||||
r.client.ZRemRangeByScore(ctx, resourceKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
||||
|
||||
// Count current failures
|
||||
count, err := r.client.ZCount(ctx, resourceKey,
|
||||
fmt.Sprintf("%d", windowStart.UnixNano()),
|
||||
"+inf").Result()
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(count), nil
|
||||
}
|
||||
|
||||
// ResetUserFailures manually resets failure counters for a user
|
||||
func (r *authFailureRateLimiter) ResetUserFailures(ctx context.Context, userID string) error {
|
||||
pattern := fmt.Sprintf("%s:user:%s:*", r.config.KeyPrefix, hashID(userID))
|
||||
|
||||
// Find all keys for this user
|
||||
keys, err := r.client.Keys(ctx, pattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
pipe := r.client.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to reset user failures",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("user authorization failures reset",
|
||||
zap.String("user_id_hash", hashID(userID)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockUser blocks a user from further authorization attempts
|
||||
func (r *authFailureRateLimiter) blockUser(ctx context.Context, userID string) error {
|
||||
blockKey := r.getBlockKey(userID)
|
||||
err := r.client.Set(ctx, blockKey, "blocked", r.config.BlockDuration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to block user",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Error("SECURITY: User blocked due to excessive authorization failures",
|
||||
zap.String("user_id_hash", hashID(userID)),
|
||||
zap.Duration("block_duration", r.config.BlockDuration))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Key generation helpers
|
||||
func (r *authFailureRateLimiter) getUserFailureKey(userID string) string {
|
||||
return fmt.Sprintf("%s:user:%s:failures", r.config.KeyPrefix, hashID(userID))
|
||||
}
|
||||
|
||||
func (r *authFailureRateLimiter) getResourceFailureKey(userID string, resourceID string) string {
|
||||
return fmt.Sprintf("%s:user:%s:resource:%s:failures", r.config.KeyPrefix, hashID(userID), hashID(resourceID))
|
||||
}
|
||||
|
||||
func (r *authFailureRateLimiter) getBlockKey(userID string) string {
|
||||
return fmt.Sprintf("%s:user:%s:blocked", r.config.KeyPrefix, hashID(userID))
|
||||
}
|
||||
|
||||
func (r *authFailureRateLimiter) getMetricsKey(userID string) string {
|
||||
return fmt.Sprintf("%s:user:%s:metrics", r.config.KeyPrefix, hashID(userID))
|
||||
}
|
||||
|
||||
// hashID creates a consistent hash of an ID for use as a Redis key component
|
||||
// CWE-532: Prevents sensitive IDs in Redis keys
|
||||
func hashID(id string) string {
|
||||
if id == "" {
|
||||
return "empty"
|
||||
}
|
||||
hash := sha256.Sum256([]byte(id))
|
||||
// Return first 16 bytes of hash as hex (32 chars) for shorter keys
|
||||
return hex.EncodeToString(hash[:16])
|
||||
}
|
||||
332
cloud/maplefile-backend/pkg/ratelimit/login_ratelimiter.go
Normal file
332
cloud/maplefile-backend/pkg/ratelimit/login_ratelimiter.go
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// LoginRateLimiter provides specialized rate limiting for login attempts
|
||||
// with account lockout functionality
|
||||
type LoginRateLimiter interface {
|
||||
// CheckAndRecordAttempt checks if login attempt is allowed and records it
|
||||
// Returns: allowed (bool), isLocked (bool), remainingAttempts (int), error
|
||||
CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error)
|
||||
|
||||
// RecordFailedAttempt records a failed login attempt
|
||||
RecordFailedAttempt(ctx context.Context, email string, clientIP string) error
|
||||
|
||||
// RecordSuccessfulLogin records a successful login and resets counters
|
||||
RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error
|
||||
|
||||
// IsAccountLocked checks if an account is locked due to too many failed attempts
|
||||
IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error)
|
||||
|
||||
// UnlockAccount manually unlocks an account (admin function)
|
||||
UnlockAccount(ctx context.Context, email string) error
|
||||
|
||||
// GetFailedAttempts returns the number of failed attempts for an email
|
||||
GetFailedAttempts(ctx context.Context, email string) (int, error)
|
||||
}
|
||||
|
||||
// LoginRateLimiterConfig holds configuration for login rate limiting
|
||||
type LoginRateLimiterConfig struct {
|
||||
// MaxAttemptsPerIP is the maximum login attempts per IP in the window
|
||||
MaxAttemptsPerIP int
|
||||
// IPWindow is the time window for IP-based rate limiting
|
||||
IPWindow time.Duration
|
||||
|
||||
// MaxFailedAttemptsPerAccount is the maximum failed attempts before account lockout
|
||||
MaxFailedAttemptsPerAccount int
|
||||
// AccountLockoutDuration is how long to lock an account after too many failures
|
||||
AccountLockoutDuration time.Duration
|
||||
|
||||
// KeyPrefix is the prefix for Redis keys
|
||||
KeyPrefix string
|
||||
}
|
||||
|
||||
// DefaultLoginRateLimiterConfig returns recommended configuration
|
||||
func DefaultLoginRateLimiterConfig() LoginRateLimiterConfig {
|
||||
return LoginRateLimiterConfig{
|
||||
MaxAttemptsPerIP: 10, // 10 attempts per IP
|
||||
IPWindow: 15 * time.Minute, // in 15 minutes
|
||||
MaxFailedAttemptsPerAccount: 10, // 10 failed attempts per account
|
||||
AccountLockoutDuration: 30 * time.Minute, // lock for 30 minutes
|
||||
KeyPrefix: "login_rl",
|
||||
}
|
||||
}
|
||||
|
||||
type loginRateLimiter struct {
|
||||
client *redis.Client
|
||||
config LoginRateLimiterConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewLoginRateLimiter creates a new login rate limiter
|
||||
func NewLoginRateLimiter(client *redis.Client, config LoginRateLimiterConfig, logger *zap.Logger) LoginRateLimiter {
|
||||
return &loginRateLimiter{
|
||||
client: client,
|
||||
config: config,
|
||||
logger: logger.Named("login-rate-limiter"),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAndRecordAttempt checks if login attempt is allowed
|
||||
// CWE-307: Implements protection against brute force attacks
|
||||
func (r *loginRateLimiter) CheckAndRecordAttempt(ctx context.Context, email string, clientIP string) (bool, bool, int, error) {
|
||||
// Check account lockout first
|
||||
locked, remaining, err := r.IsAccountLocked(ctx, email)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check account lockout",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, false, 0, err
|
||||
}
|
||||
|
||||
if locked {
|
||||
r.logger.Warn("login attempt on locked account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Duration("remaining_lockout", remaining))
|
||||
return false, true, 0, nil
|
||||
}
|
||||
|
||||
// Check IP-based rate limit
|
||||
ipKey := r.getIPKey(clientIP)
|
||||
allowed, err := r.checkIPRateLimit(ctx, ipKey)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to check IP rate limit",
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Error(err))
|
||||
// Fail open on Redis error
|
||||
return true, false, 0, err
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
r.logger.Warn("IP rate limit exceeded",
|
||||
zap.String("ip", validation.MaskIP(clientIP)))
|
||||
return false, false, 0, nil
|
||||
}
|
||||
|
||||
// Record the attempt for IP
|
||||
if err := r.recordIPAttempt(ctx, ipKey); err != nil {
|
||||
r.logger.Error("failed to record IP attempt",
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// Get remaining attempts for account
|
||||
failedAttempts, err := r.GetFailedAttempts(ctx, email)
|
||||
if err != nil {
|
||||
r.logger.Error("failed to get failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
remainingAttempts := r.config.MaxFailedAttemptsPerAccount - failedAttempts
|
||||
if remainingAttempts < 0 {
|
||||
remainingAttempts = 0
|
||||
}
|
||||
|
||||
r.logger.Debug("login attempt check passed",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Int("remaining_attempts", remainingAttempts))
|
||||
|
||||
return true, false, remainingAttempts, nil
|
||||
}
|
||||
|
||||
// RecordFailedAttempt records a failed login attempt
|
||||
// CWE-307: Tracks failed attempts to enable account lockout
|
||||
func (r *loginRateLimiter) RecordFailedAttempt(ctx context.Context, email string, clientIP string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
|
||||
// Increment failed attempt counter
|
||||
count, err := r.client.Incr(ctx, accountKey).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to increment failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Set expiration on first failed attempt
|
||||
if count == 1 {
|
||||
r.client.Expire(ctx, accountKey, r.config.AccountLockoutDuration)
|
||||
}
|
||||
|
||||
// Check if account should be locked
|
||||
if count >= int64(r.config.MaxFailedAttemptsPerAccount) {
|
||||
lockKey := r.getLockKey(email)
|
||||
err := r.client.Set(ctx, lockKey, "locked", r.config.AccountLockoutDuration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to lock account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Warn("account locked due to too many failed attempts",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Int64("failed_attempts", count),
|
||||
zap.Duration("lockout_duration", r.config.AccountLockoutDuration))
|
||||
}
|
||||
|
||||
r.logger.Info("failed login attempt recorded",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", validation.MaskIP(clientIP)),
|
||||
zap.Int64("total_failed_attempts", count))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecordSuccessfulLogin records a successful login and resets counters
|
||||
func (r *loginRateLimiter) RecordSuccessfulLogin(ctx context.Context, email string, clientIP string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
// Delete failed attempt counter
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, accountKey)
|
||||
pipe.Del(ctx, lockKey)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("failed to reset login counters",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("successful login recorded, counters reset",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.String("ip", validation.MaskIP(clientIP)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAccountLocked checks if an account is locked
|
||||
func (r *loginRateLimiter) IsAccountLocked(ctx context.Context, email string) (bool, time.Duration, error) {
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
ttl, err := r.client.TTL(ctx, lockKey).Result()
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
|
||||
// TTL returns -2 if key doesn't exist, -1 if no expiration
|
||||
if ttl < 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
return true, ttl, nil
|
||||
}
|
||||
|
||||
// UnlockAccount manually unlocks an account
|
||||
func (r *loginRateLimiter) UnlockAccount(ctx context.Context, email string) error {
|
||||
accountKey := r.getAccountKey(email)
|
||||
lockKey := r.getLockKey(email)
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, accountKey)
|
||||
pipe.Del(ctx, lockKey)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("failed to unlock account",
|
||||
zap.String("email_hash", hashEmail(email)),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("account unlocked",
|
||||
zap.String("email_hash", hashEmail(email)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFailedAttempts returns the number of failed attempts
|
||||
func (r *loginRateLimiter) GetFailedAttempts(ctx context.Context, email string) (int, error) {
|
||||
accountKey := r.getAccountKey(email)
|
||||
|
||||
count, err := r.client.Get(ctx, accountKey).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// checkIPRateLimit checks if IP has exceeded rate limit
|
||||
func (r *loginRateLimiter) checkIPRateLimit(ctx context.Context, ipKey string) (bool, error) {
|
||||
now := time.Now()
|
||||
windowStart := now.Add(-r.config.IPWindow)
|
||||
|
||||
// Remove old entries
|
||||
r.client.ZRemRangeByScore(ctx, ipKey, "0", fmt.Sprintf("%d", windowStart.UnixNano()))
|
||||
|
||||
// Count current attempts
|
||||
count, err := r.client.ZCount(ctx, ipKey,
|
||||
fmt.Sprintf("%d", windowStart.UnixNano()),
|
||||
"+inf").Result()
|
||||
|
||||
if err != nil && err != redis.Nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count < int64(r.config.MaxAttemptsPerIP), nil
|
||||
}
|
||||
|
||||
// recordIPAttempt records an IP attempt
|
||||
func (r *loginRateLimiter) recordIPAttempt(ctx context.Context, ipKey string) error {
|
||||
now := time.Now()
|
||||
timestamp := now.UnixNano()
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.ZAdd(ctx, ipKey, redis.Z{
|
||||
Score: float64(timestamp),
|
||||
Member: fmt.Sprintf("%d", timestamp),
|
||||
})
|
||||
pipe.Expire(ctx, ipKey, r.config.IPWindow+time.Minute)
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Key generation helpers
|
||||
func (r *loginRateLimiter) getIPKey(ip string) string {
|
||||
return fmt.Sprintf("%s:ip:%s", r.config.KeyPrefix, ip)
|
||||
}
|
||||
|
||||
func (r *loginRateLimiter) getAccountKey(email string) string {
|
||||
return fmt.Sprintf("%s:account:%s:attempts", r.config.KeyPrefix, hashEmail(email))
|
||||
}
|
||||
|
||||
func (r *loginRateLimiter) getLockKey(email string) string {
|
||||
return fmt.Sprintf("%s:account:%s:locked", r.config.KeyPrefix, hashEmail(email))
|
||||
}
|
||||
|
||||
// hashEmail creates a consistent hash of an email for use as a key
|
||||
// CWE-532: Prevents PII in Redis keys
|
||||
// Uses SHA-256 for cryptographically secure hashing
|
||||
func hashEmail(email string) string {
|
||||
// Normalize email to lowercase for consistent hashing
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
|
||||
// Use SHA-256 for secure, collision-resistant hashing
|
||||
hash := sha256.Sum256([]byte(normalized))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
81
cloud/maplefile-backend/pkg/ratelimit/providers.go
Normal file
81
cloud/maplefile-backend/pkg/ratelimit/providers.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideLoginRateLimiter creates a LoginRateLimiter for dependency injection
|
||||
// CWE-307: Implements rate limiting and account lockout protection against brute force attacks
|
||||
func ProvideLoginRateLimiter(redisClient redis.UniversalClient, cfg *config.Configuration, logger *zap.Logger) LoginRateLimiter {
|
||||
// Start with default config
|
||||
loginConfig := DefaultLoginRateLimiterConfig()
|
||||
|
||||
// Override with configuration values if provided
|
||||
if cfg != nil {
|
||||
if cfg.LoginRateLimit.MaxAttemptsPerIP > 0 {
|
||||
loginConfig.MaxAttemptsPerIP = cfg.LoginRateLimit.MaxAttemptsPerIP
|
||||
}
|
||||
if cfg.LoginRateLimit.IPWindow > 0 {
|
||||
loginConfig.IPWindow = cfg.LoginRateLimit.IPWindow
|
||||
}
|
||||
if cfg.LoginRateLimit.MaxFailedAttemptsPerAccount > 0 {
|
||||
loginConfig.MaxFailedAttemptsPerAccount = cfg.LoginRateLimit.MaxFailedAttemptsPerAccount
|
||||
}
|
||||
if cfg.LoginRateLimit.AccountLockoutDuration > 0 {
|
||||
loginConfig.AccountLockoutDuration = cfg.LoginRateLimit.AccountLockoutDuration
|
||||
}
|
||||
}
|
||||
|
||||
// Type assert to *redis.Client since LoginRateLimiter needs it
|
||||
client, ok := redisClient.(*redis.Client)
|
||||
if !ok {
|
||||
// If it's a cluster client or other type, log warning
|
||||
// This shouldn't happen in our standard setup
|
||||
logger.Warn("Redis client is not a standard client, login rate limiter may not work correctly")
|
||||
return NewLoginRateLimiter(nil, loginConfig, logger)
|
||||
}
|
||||
|
||||
logger.Info("Login rate limiter initialized",
|
||||
zap.Int("max_attempts_per_ip", loginConfig.MaxAttemptsPerIP),
|
||||
zap.Duration("ip_window", loginConfig.IPWindow),
|
||||
zap.Int("max_failed_per_account", loginConfig.MaxFailedAttemptsPerAccount),
|
||||
zap.Duration("lockout_duration", loginConfig.AccountLockoutDuration))
|
||||
|
||||
return NewLoginRateLimiter(client, loginConfig, logger)
|
||||
}
|
||||
|
||||
// ProvideAuthFailureRateLimiter creates an AuthFailureRateLimiter for dependency injection
|
||||
// CWE-307: Implements rate limiting for authorization failures to prevent privilege escalation attempts
|
||||
// OWASP A01:2021: Broken Access Control - Rate limiting authorization failures
|
||||
func ProvideAuthFailureRateLimiter(redisClient redis.UniversalClient, cfg *config.Configuration, logger *zap.Logger) AuthFailureRateLimiter {
|
||||
// Use default config with secure defaults for authorization failure protection
|
||||
authConfig := DefaultAuthFailureRateLimiterConfig()
|
||||
|
||||
// Override defaults with configuration if provided
|
||||
// Allow configuration through environment variables for flexibility
|
||||
if cfg != nil {
|
||||
// These values could be configured via environment variables
|
||||
// For now, we use the secure defaults
|
||||
// TODO: Add auth failure rate limiting configuration to SecurityConfig
|
||||
}
|
||||
|
||||
// Type assert to *redis.Client since AuthFailureRateLimiter needs it
|
||||
client, ok := redisClient.(*redis.Client)
|
||||
if !ok {
|
||||
// If it's a cluster client or other type, log warning
|
||||
logger.Warn("Redis client is not a standard client, auth failure rate limiter may not work correctly")
|
||||
return NewAuthFailureRateLimiter(nil, authConfig, logger)
|
||||
}
|
||||
|
||||
logger.Info("Authorization failure rate limiter initialized",
|
||||
zap.Int("max_failures_per_user", authConfig.MaxFailuresPerUser),
|
||||
zap.Int("max_failures_per_resource", authConfig.MaxFailuresPerResource),
|
||||
zap.Duration("failure_window", authConfig.FailureWindow),
|
||||
zap.Duration("block_duration", authConfig.BlockDuration),
|
||||
zap.Int("alert_threshold", authConfig.AlertThreshold))
|
||||
|
||||
return NewAuthFailureRateLimiter(client, authConfig, logger)
|
||||
}
|
||||
96
cloud/maplefile-backend/pkg/security/apikey/generator.go
Normal file
96
cloud/maplefile-backend/pkg/security/apikey/generator.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrefixLive is the prefix for production API keys
|
||||
PrefixLive = "live_sk_"
|
||||
// PrefixTest is the prefix for test/sandbox API keys
|
||||
PrefixTest = "test_sk_"
|
||||
// KeyLength is the length of the random part (40 chars in base64url)
|
||||
KeyLength = 30 // 30 bytes = 40 base64url chars
|
||||
)
|
||||
|
||||
// Generator generates API keys
|
||||
type Generator interface {
|
||||
// Generate creates a new live API key
|
||||
Generate() (string, error)
|
||||
// GenerateTest creates a new test API key
|
||||
GenerateTest() (string, error)
|
||||
}
|
||||
|
||||
type generator struct{}
|
||||
|
||||
// NewGenerator creates a new API key generator
|
||||
func NewGenerator() Generator {
|
||||
return &generator{}
|
||||
}
|
||||
|
||||
// Generate creates a new live API key
|
||||
func (g *generator) Generate() (string, error) {
|
||||
return g.generateWithPrefix(PrefixLive)
|
||||
}
|
||||
|
||||
// GenerateTest creates a new test API key
|
||||
func (g *generator) GenerateTest() (string, error) {
|
||||
return g.generateWithPrefix(PrefixTest)
|
||||
}
|
||||
|
||||
func (g *generator) generateWithPrefix(prefix string) (string, error) {
|
||||
// Generate cryptographically secure random bytes
|
||||
b := make([]byte, KeyLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode to base64url (URL-safe, no padding)
|
||||
key := base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
// Remove any special chars and make lowercase for consistency
|
||||
key = strings.Map(func(r rune) rune {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
|
||||
return r
|
||||
}
|
||||
return -1 // Remove character
|
||||
}, key)
|
||||
|
||||
// Ensure we have at least 40 characters
|
||||
if len(key) < 40 {
|
||||
// Pad with additional random bytes if needed
|
||||
additional := make([]byte, 10)
|
||||
rand.Read(additional)
|
||||
extraKey := base64.RawURLEncoding.EncodeToString(additional)
|
||||
key += extraKey
|
||||
}
|
||||
|
||||
// Trim to exactly 40 characters
|
||||
key = key[:40]
|
||||
|
||||
return prefix + key, nil
|
||||
}
|
||||
|
||||
// ExtractPrefix extracts the prefix from an API key
|
||||
func ExtractPrefix(apiKey string) string {
|
||||
if len(apiKey) < 13 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[:13] // "live_sk_a1b2" or "test_sk_a1b2"
|
||||
}
|
||||
|
||||
// ExtractLastFour extracts the last 4 characters from an API key
|
||||
func ExtractLastFour(apiKey string) string {
|
||||
if len(apiKey) < 4 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[len(apiKey)-4:]
|
||||
}
|
||||
|
||||
// IsValid checks if an API key has a valid format
|
||||
func IsValid(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, PrefixLive) || strings.HasPrefix(apiKey, PrefixTest)
|
||||
}
|
||||
35
cloud/maplefile-backend/pkg/security/apikey/hasher.go
Normal file
35
cloud/maplefile-backend/pkg/security/apikey/hasher.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Hasher hashes and verifies API keys using SHA-256
|
||||
type Hasher interface {
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
Hash(apiKey string) string
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
Verify(apiKey string, hash string) bool
|
||||
}
|
||||
|
||||
type hasher struct{}
|
||||
|
||||
// NewHasher creates a new API key hasher
|
||||
func NewHasher() Hasher {
|
||||
return &hasher{}
|
||||
}
|
||||
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
func (h *hasher) Hash(apiKey string) string {
|
||||
hash := sha256.Sum256([]byte(apiKey))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
// This prevents timing attacks
|
||||
func (h *hasher) Verify(apiKey string, expectedHash string) bool {
|
||||
actualHash := h.Hash(apiKey)
|
||||
return subtle.ConstantTimeCompare([]byte(actualHash), []byte(expectedHash)) == 1
|
||||
}
|
||||
11
cloud/maplefile-backend/pkg/security/apikey/provider.go
Normal file
11
cloud/maplefile-backend/pkg/security/apikey/provider.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package apikey
|
||||
|
||||
// ProvideGenerator provides an API key generator for dependency injection
|
||||
func ProvideGenerator() Generator {
|
||||
return NewGenerator()
|
||||
}
|
||||
|
||||
// ProvideHasher provides an API key hasher for dependency injection
|
||||
func ProvideHasher() Hasher {
|
||||
return NewHasher()
|
||||
}
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
// Package benchmark provides performance benchmarks for memguard security operations.
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// BenchmarkPlainStringAllocation benchmarks plain string allocation.
|
||||
func BenchmarkPlainStringAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s := "this is a test string with sensitive data"
|
||||
_ = s
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSecureStringAllocation benchmarks SecureString allocation and cleanup.
|
||||
func BenchmarkSecureStringAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s, err := securestring.NewSecureString("this is a test string with sensitive data")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPlainBytesAllocation benchmarks plain byte slice allocation.
|
||||
func BenchmarkPlainBytesAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
_ = data
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSecureBytesAllocation benchmarks SecureBytes allocation and cleanup.
|
||||
func BenchmarkSecureBytesAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
sb, err := securebytes.NewSecureBytes(data)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
sb.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPasswordHashing_Plain benchmarks password hashing without memguard.
|
||||
func BenchmarkPasswordHashing_Plain(b *testing.B) {
|
||||
password := []byte("test_password_12345")
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = argon2.IDKey(password, salt, 3, 64*1024, 4, 32)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPasswordHashing_Secure benchmarks password hashing with memguard wiping.
|
||||
func BenchmarkPasswordHashing_Secure(b *testing.B) {
|
||||
password, err := securestring.NewSecureString("test_password_12345")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer password.Wipe()
|
||||
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
passwordBytes := password.Bytes()
|
||||
hash := argon2.IDKey(passwordBytes, salt, 3, 64*1024, 4, 32)
|
||||
memguard.WipeBytes(hash)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMemguardWipeBytes benchmarks the memguard.WipeBytes operation.
|
||||
func BenchmarkMemguardWipeBytes(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMemguardWipeBytes_Large benchmarks wiping larger byte slices.
|
||||
func BenchmarkMemguardWipeBytes_Large(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 4096)
|
||||
rand.Read(data)
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLockedBuffer_Create benchmarks creating a memguard LockedBuffer.
|
||||
func BenchmarkLockedBuffer_Create(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := memguard.NewBuffer(32)
|
||||
buf.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLockedBuffer_FromBytes benchmarks creating a LockedBuffer from bytes.
|
||||
func BenchmarkLockedBuffer_FromBytes(b *testing.B) {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := memguard.NewBufferFromBytes(data)
|
||||
buf.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkJWTTokenGeneration_Plain simulates JWT token generation without security.
|
||||
func BenchmarkJWTTokenGeneration_Plain(b *testing.B) {
|
||||
secret := make([]byte, 32)
|
||||
rand.Read(secret)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate token signing
|
||||
_ = secret
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkJWTTokenGeneration_Secure simulates JWT token generation with memguard.
|
||||
func BenchmarkJWTTokenGeneration_Secure(b *testing.B) {
|
||||
secret := make([]byte, 32)
|
||||
rand.Read(secret)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
secretCopy := make([]byte, len(secret))
|
||||
copy(secretCopy, secret)
|
||||
// Simulate token signing
|
||||
_ = secretCopy
|
||||
memguard.WipeBytes(secretCopy)
|
||||
}
|
||||
}
|
||||
|
||||
// Run benchmarks with:
|
||||
// go test -bench=. -benchmem ./pkg/security/benchmark/
|
||||
76
cloud/maplefile-backend/pkg/security/blacklist/blacklist.go
Normal file
76
cloud/maplefile-backend/pkg/security/blacklist/blacklist.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package blacklist
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Provider provides an interface for abstracting time.
|
||||
type Provider interface {
|
||||
IsBannedIPAddress(ipAddress string) bool
|
||||
IsBannedURL(url string) bool
|
||||
}
|
||||
|
||||
type blacklistProvider struct {
|
||||
bannedIPAddresses map[string]bool
|
||||
bannedURLs map[string]bool
|
||||
}
|
||||
|
||||
// readBlacklistFileContent reads the contents of the blacklist file and returns
|
||||
// the list of banned items (ex: IP, URLs, etc).
|
||||
func readBlacklistFileContent(filePath string) ([]string, error) {
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("file %s does not exist", filePath)
|
||||
}
|
||||
|
||||
// Read the file contents
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
// Parse the JSON content as a list of IPs
|
||||
var ips []string
|
||||
if err := json.Unmarshal(data, &ips); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// NewProvider Provider contructor that returns the default time provider.
|
||||
func NewProvider() Provider {
|
||||
bannedIPAddresses := make(map[string]bool)
|
||||
bannedIPAddressesFilePath := "static/blacklist/ips.json"
|
||||
ips, err := readBlacklistFileContent(bannedIPAddressesFilePath)
|
||||
if err == nil { // Aka: if the file exists...
|
||||
for _, ip := range ips {
|
||||
bannedIPAddresses[ip] = true
|
||||
}
|
||||
}
|
||||
|
||||
bannedURLs := make(map[string]bool)
|
||||
bannedURLsFilePath := "static/blacklist/urls.json"
|
||||
urls, err := readBlacklistFileContent(bannedURLsFilePath)
|
||||
if err == nil { // Aka: if the file exists...
|
||||
for _, url := range urls {
|
||||
bannedURLs[url] = true
|
||||
}
|
||||
}
|
||||
|
||||
return blacklistProvider{
|
||||
bannedIPAddresses: bannedIPAddresses,
|
||||
bannedURLs: bannedURLs,
|
||||
}
|
||||
}
|
||||
|
||||
func (p blacklistProvider) IsBannedIPAddress(ipAddress string) bool {
|
||||
return p.bannedIPAddresses[ipAddress]
|
||||
}
|
||||
|
||||
func (p blacklistProvider) IsBannedURL(url string) bool {
|
||||
return p.bannedURLs[url]
|
||||
}
|
||||
132
cloud/maplefile-backend/pkg/security/blacklist/blacklist_test.go
Normal file
132
cloud/maplefile-backend/pkg/security/blacklist/blacklist_test.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package blacklist
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createTempFile(t *testing.T, content string) string {
|
||||
tmpfile, err := os.CreateTemp("", "blacklist*.json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte(content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tmpfile.Name()
|
||||
}
|
||||
|
||||
func TestReadBlacklistFileContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantItems []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid json",
|
||||
content: `["192.168.1.1", "10.0.0.1"]`,
|
||||
wantItems: []string{"192.168.1.1", "10.0.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
content: `[]`,
|
||||
wantItems: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
content: `invalid json`,
|
||||
wantItems: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpfile := createTempFile(t, tt.content)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
items, err := readBlacklistFileContent(tmpfile)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, items)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantItems, items)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("nonexistent file", func(t *testing.T) {
|
||||
_, err := readBlacklistFileContent("nonexistent.json")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
// Create temporary blacklist files
|
||||
ipsContent := `["192.168.1.1", "10.0.0.1"]`
|
||||
urlsContent := `["example.com", "malicious.com"]`
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "blacklist")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = os.MkdirAll(filepath.Join(tmpDir, "static/blacklist"), 0755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "static/blacklist/ips.json"), []byte(ipsContent), 0644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "static/blacklist/urls.json"), []byte(urlsContent), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Change working directory temporarily
|
||||
originalWd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
assert.NoError(t, err)
|
||||
defer os.Chdir(originalWd)
|
||||
|
||||
provider := NewProvider()
|
||||
assert.NotNil(t, provider)
|
||||
|
||||
// Test IP blacklist
|
||||
assert.True(t, provider.IsBannedIPAddress("192.168.1.1"))
|
||||
assert.True(t, provider.IsBannedIPAddress("10.0.0.1"))
|
||||
assert.False(t, provider.IsBannedIPAddress("172.16.0.1"))
|
||||
|
||||
// Test URL blacklist
|
||||
assert.True(t, provider.IsBannedURL("example.com"))
|
||||
assert.True(t, provider.IsBannedURL("malicious.com"))
|
||||
assert.False(t, provider.IsBannedURL("safe.com"))
|
||||
}
|
||||
|
||||
func TestIsBannedIPAddress(t *testing.T) {
|
||||
provider := blacklistProvider{
|
||||
bannedIPAddresses: map[string]bool{
|
||||
"192.168.1.1": true,
|
||||
"10.0.0.1": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, provider.IsBannedIPAddress("192.168.1.1"))
|
||||
assert.True(t, provider.IsBannedIPAddress("10.0.0.1"))
|
||||
assert.False(t, provider.IsBannedIPAddress("172.16.0.1"))
|
||||
}
|
||||
|
||||
func TestIsBannedURL(t *testing.T) {
|
||||
provider := blacklistProvider{
|
||||
bannedURLs: map[string]bool{
|
||||
"example.com": true,
|
||||
"malicious.com": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, provider.IsBannedURL("example.com"))
|
||||
assert.True(t, provider.IsBannedURL("malicious.com"))
|
||||
assert.False(t, provider.IsBannedURL("safe.com"))
|
||||
}
|
||||
170
cloud/maplefile-backend/pkg/security/clientip/extractor.go
Normal file
170
cloud/maplefile-backend/pkg/security/clientip/extractor.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// Extractor provides secure client IP address extraction
|
||||
// CWE-348: Prevents X-Forwarded-For header spoofing by validating trusted proxies
|
||||
type Extractor struct {
|
||||
trustedProxies []*net.IPNet
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExtractor creates a new IP extractor with trusted proxy configuration
|
||||
// trustedProxyCIDRs should contain CIDR blocks of trusted reverse proxies
|
||||
// Example: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
|
||||
func NewExtractor(trustedProxyCIDRs []string, logger *zap.Logger) (*Extractor, error) {
|
||||
var trustedProxies []*net.IPNet
|
||||
|
||||
for _, cidr := range trustedProxyCIDRs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse trusted proxy CIDR",
|
||||
zap.String("cidr", cidr),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
trustedProxies = append(trustedProxies, ipNet)
|
||||
}
|
||||
|
||||
logger.Info("client IP extractor initialized",
|
||||
zap.Int("trusted_proxy_ranges", len(trustedProxies)))
|
||||
|
||||
return &Extractor{
|
||||
trustedProxies: trustedProxies,
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewDefaultExtractor creates an extractor with no trusted proxies
|
||||
// This is safe for direct connections but will ignore X-Forwarded-For headers
|
||||
func NewDefaultExtractor(logger *zap.Logger) *Extractor {
|
||||
logger.Warn("client IP extractor initialized with NO trusted proxies - X-Forwarded-For will be ignored")
|
||||
return &Extractor{
|
||||
trustedProxies: []*net.IPNet{},
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extracts the real client IP address from the HTTP request
|
||||
// CWE-348: Secure implementation that prevents header spoofing
|
||||
func (e *Extractor) Extract(r *http.Request) string {
|
||||
// Step 1: Get the immediate connection's remote address
|
||||
remoteAddr := r.RemoteAddr
|
||||
|
||||
// Remove port from RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||
remoteIP := e.stripPort(remoteAddr)
|
||||
|
||||
// Step 2: Parse the remote IP
|
||||
parsedRemoteIP := net.ParseIP(remoteIP)
|
||||
if parsedRemoteIP == nil {
|
||||
e.logger.Warn("failed to parse remote IP address",
|
||||
zap.String("remote_addr", validation.MaskIP(remoteAddr)))
|
||||
return remoteIP // Return as-is if we can't parse it
|
||||
}
|
||||
|
||||
// Step 3: Check if the immediate connection is from a trusted proxy
|
||||
if !e.isTrustedProxy(parsedRemoteIP) {
|
||||
// NOT from a trusted proxy - do NOT trust X-Forwarded-For header
|
||||
// This prevents clients from spoofing their IP by setting the header
|
||||
e.logger.Debug("remote IP is not a trusted proxy, using RemoteAddr",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 4: Remote IP is trusted, check X-Forwarded-For header
|
||||
// Format: "client, proxy1, proxy2" (leftmost is original client)
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
// No X-Forwarded-For header, use RemoteAddr
|
||||
e.logger.Debug("no X-Forwarded-For header from trusted proxy",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 5: Parse X-Forwarded-For header
|
||||
// Take the FIRST IP (leftmost) which should be the original client
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) == 0 {
|
||||
e.logger.Debug("empty X-Forwarded-For header",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Get the first IP and trim whitespace
|
||||
clientIP := strings.TrimSpace(ips[0])
|
||||
|
||||
// Step 6: Validate the client IP
|
||||
parsedClientIP := net.ParseIP(clientIP)
|
||||
if parsedClientIP == nil {
|
||||
e.logger.Warn("invalid IP in X-Forwarded-For header",
|
||||
zap.String("xff", xff),
|
||||
zap.String("client_ip", validation.MaskIP(clientIP)))
|
||||
return remoteIP // Fall back to RemoteAddr
|
||||
}
|
||||
|
||||
e.logger.Debug("extracted client IP from X-Forwarded-For",
|
||||
zap.String("client_ip", validation.MaskIP(clientIP)),
|
||||
zap.String("remote_proxy", validation.MaskIP(remoteIP)),
|
||||
zap.String("xff_chain", xff))
|
||||
|
||||
return clientIP
|
||||
}
|
||||
|
||||
// ExtractOrDefault extracts the client IP or returns a default value
|
||||
func (e *Extractor) ExtractOrDefault(r *http.Request, defaultIP string) string {
|
||||
ip := e.Extract(r)
|
||||
if ip == "" {
|
||||
return defaultIP
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isTrustedProxy checks if an IP is in the trusted proxy list
|
||||
func (e *Extractor) isTrustedProxy(ip net.IP) bool {
|
||||
for _, ipNet := range e.trustedProxies {
|
||||
if ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort removes the port from an address string
|
||||
// Handles both IPv4 (192.168.1.1:8080) and IPv6 ([::1]:8080) formats
|
||||
func (e *Extractor) stripPort(addr string) string {
|
||||
// For IPv6, check for bracket format [IP]:port
|
||||
if strings.HasPrefix(addr, "[") {
|
||||
// IPv6 format: [::1]:8080
|
||||
if idx := strings.LastIndex(addr, "]:"); idx != -1 {
|
||||
return addr[1:idx] // Extract IP between [ and ]
|
||||
}
|
||||
// Malformed IPv6 address
|
||||
return addr
|
||||
}
|
||||
|
||||
// For IPv4, split on last colon
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
}
|
||||
|
||||
// No port found
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetTrustedProxyCount returns the number of configured trusted proxy ranges
|
||||
func (e *Extractor) GetTrustedProxyCount() int {
|
||||
return len(e.trustedProxies)
|
||||
}
|
||||
|
||||
// HasTrustedProxies returns true if any trusted proxies are configured
|
||||
func (e *Extractor) HasTrustedProxies() bool {
|
||||
return len(e.trustedProxies) > 0
|
||||
}
|
||||
19
cloud/maplefile-backend/pkg/security/clientip/provider.go
Normal file
19
cloud/maplefile-backend/pkg/security/clientip/provider.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideExtractor provides a client IP extractor configured from the application config
|
||||
func ProvideExtractor(cfg *config.Config, logger *zap.Logger) (*Extractor, error) {
|
||||
// If no trusted proxies configured, use default (no X-Forwarded-For trust)
|
||||
if len(cfg.Security.TrustedProxies) == 0 {
|
||||
logger.Info("no trusted proxies configured - X-Forwarded-For headers will be ignored for security")
|
||||
return NewDefaultExtractor(logger), nil
|
||||
}
|
||||
|
||||
// Create extractor with trusted proxies
|
||||
return NewExtractor(cfg.Security.TrustedProxies, logger)
|
||||
}
|
||||
32
cloud/maplefile-backend/pkg/security/crypto/constants.go
Normal file
32
cloud/maplefile-backend/pkg/security/crypto/constants.go
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package crypto
|
||||
|
||||
// Constants to ensure compatibility between Go and JavaScript
|
||||
const (
|
||||
// Key sizes
|
||||
MasterKeySize = 32 // 256-bit
|
||||
KeyEncryptionKeySize = 32
|
||||
CollectionKeySize = 32
|
||||
FileKeySize = 32
|
||||
RecoveryKeySize = 32
|
||||
|
||||
// ChaCha20-Poly1305 constants (updated from XSalsa20-Poly1305)
|
||||
NonceSize = 12 // ChaCha20-Poly1305 nonce size (changed from 24)
|
||||
PublicKeySize = 32
|
||||
PrivateKeySize = 32
|
||||
SealedBoxOverhead = 16
|
||||
|
||||
// Legacy naming for backward compatibility
|
||||
SecretBoxNonceSize = NonceSize
|
||||
|
||||
// Argon2 parameters - must match between platforms
|
||||
Argon2IDAlgorithm = "argon2id"
|
||||
Argon2MemLimit = 67108864 // 64 MB
|
||||
Argon2OpsLimit = 4
|
||||
Argon2Parallelism = 1
|
||||
Argon2KeySize = 32
|
||||
Argon2SaltSize = 16
|
||||
|
||||
// Encryption algorithm identifiers
|
||||
ChaCha20Poly1305Algorithm = "chacha20poly1305" // Primary algorithm
|
||||
XSalsa20Poly1305Algorithm = "xsalsa20poly1305" // Legacy algorithm (deprecated)
|
||||
)
|
||||
174
cloud/maplefile-backend/pkg/security/crypto/encrypt.go
Normal file
174
cloud/maplefile-backend/pkg/security/crypto/encrypt.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
// EncryptData represents encrypted data with its nonce
|
||||
type EncryptData struct {
|
||||
Ciphertext []byte
|
||||
Nonce []byte
|
||||
}
|
||||
|
||||
// EncryptWithSecretKey encrypts data with a symmetric key using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretbox_easy() but using ChaCha20-Poly1305
|
||||
func EncryptWithSecretKey(data, key []byte) (*EncryptData, error) {
|
||||
if len(key) != MasterKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", MasterKeySize, len(key))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate nonce
|
||||
nonce, err := GenerateRandomNonce()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext := cipher.Seal(nil, nonce, data, nil)
|
||||
|
||||
return &EncryptData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecryptWithSecretKey decrypts data with a symmetric key using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretbox_open_easy() but using ChaCha20-Poly1305
|
||||
func DecryptWithSecretKey(encryptedData *EncryptData, key []byte) ([]byte, error) {
|
||||
if len(key) != MasterKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", MasterKeySize, len(key))
|
||||
}
|
||||
|
||||
if len(encryptedData.Nonce) != NonceSize {
|
||||
return nil, fmt.Errorf("invalid nonce size: expected %d, got %d", NonceSize, len(encryptedData.Nonce))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := cipher.Open(nil, encryptedData.Nonce, encryptedData.Ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptWithPublicKey encrypts data with a public key using NaCl box (XSalsa20-Poly1305)
|
||||
// Note: Asymmetric encryption still uses NaCl box for compatibility
|
||||
// JavaScript equivalent: sodium.crypto_box_seal()
|
||||
func EncryptWithPublicKey(data, recipientPublicKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size: expected %d, got %d", PublicKeySize, len(recipientPublicKey))
|
||||
}
|
||||
|
||||
// Convert to fixed-size array
|
||||
var pubKeyArray [32]byte
|
||||
copy(pubKeyArray[:], recipientPublicKey)
|
||||
|
||||
// Generate nonce for box encryption (24 bytes for NaCl box)
|
||||
var nonce [24]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// For sealed box, we need to use SealAnonymous
|
||||
sealed, err := box.SealAnonymous(nil, data, &pubKeyArray, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to seal data: %w", err)
|
||||
}
|
||||
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// DecryptWithPrivateKey decrypts data with a private key using NaCl box
|
||||
// Note: Asymmetric encryption still uses NaCl box for compatibility
|
||||
// JavaScript equivalent: sodium.crypto_box_seal_open()
|
||||
// SECURITY: Key arrays are wiped from memory after use to prevent key extraction via memory dumps.
|
||||
func DecryptWithPrivateKey(encryptedData, publicKey, privateKey []byte) ([]byte, error) {
|
||||
if len(privateKey) != PrivateKeySize {
|
||||
return nil, fmt.Errorf("invalid private key size: expected %d, got %d", PrivateKeySize, len(privateKey))
|
||||
}
|
||||
if len(publicKey) != PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size: expected %d, got %d", PublicKeySize, len(publicKey))
|
||||
}
|
||||
|
||||
// Convert to fixed-size arrays
|
||||
var pubKeyArray [32]byte
|
||||
copy(pubKeyArray[:], publicKey)
|
||||
defer memguard.WipeBytes(pubKeyArray[:]) // SECURITY: Wipe public key array
|
||||
|
||||
var privKeyArray [32]byte
|
||||
copy(privKeyArray[:], privateKey)
|
||||
defer memguard.WipeBytes(privKeyArray[:]) // SECURITY: Wipe private key array
|
||||
|
||||
// Decrypt using OpenAnonymous for sealed box
|
||||
plaintext, ok := box.OpenAnonymous(nil, encryptedData, &pubKeyArray, &privKeyArray)
|
||||
if !ok {
|
||||
return nil, errors.New("decryption failed: invalid keys or corrupted data")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptFileChunked encrypts a file in chunks using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretstream_* but using ChaCha20-Poly1305
|
||||
// SECURITY: Plaintext data is wiped from memory after encryption.
|
||||
func EncryptFileChunked(reader io.Reader, key []byte) ([]byte, error) {
|
||||
// This would be a more complex implementation using
|
||||
// chunked encryption. For brevity, we'll use a simpler approach
|
||||
// that reads the entire file into memory first.
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read data: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(data) // SECURITY: Wipe plaintext after encryption
|
||||
|
||||
encData, err := EncryptWithSecretKey(data, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext
|
||||
result := make([]byte, len(encData.Nonce)+len(encData.Ciphertext))
|
||||
copy(result, encData.Nonce)
|
||||
copy(result[len(encData.Nonce):], encData.Ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptFileChunked decrypts a chunked encrypted file using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretstream_* but using ChaCha20-Poly1305
|
||||
func DecryptFileChunked(encryptedData, key []byte) ([]byte, error) {
|
||||
// Split nonce and ciphertext
|
||||
if len(encryptedData) < NonceSize {
|
||||
return nil, fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", NonceSize, len(encryptedData))
|
||||
}
|
||||
|
||||
nonce := encryptedData[:NonceSize]
|
||||
ciphertext := encryptedData[NonceSize:]
|
||||
|
||||
// Decrypt
|
||||
return DecryptWithSecretKey(&EncryptData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, key)
|
||||
}
|
||||
117
cloud/maplefile-backend/pkg/security/crypto/keys.go
Normal file
117
cloud/maplefile-backend/pkg/security/crypto/keys.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/tyler-smith/go-bip39"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
// GenerateRandomKey generates a new random key using crypto_secretbox_keygen
|
||||
// JavaScript equivalent: sodium.randombytes_buf(crypto.MasterKeySize)
|
||||
func GenerateRandomKey(size int) ([]byte, error) {
|
||||
if size <= 0 {
|
||||
return nil, errors.New("key size must be positive")
|
||||
}
|
||||
|
||||
key := make([]byte, size)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a public/private key pair using NaCl box
|
||||
// JavaScript equivalent: sodium.crypto_box_keypair()
|
||||
func GenerateKeyPair() (publicKey, privateKey []byte, verificationID string, err error) {
|
||||
pubKey, privKey, err := box.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to generate key pair: %w", err)
|
||||
}
|
||||
|
||||
// Convert from fixed-size arrays to slices
|
||||
publicKey = pubKey[:]
|
||||
privateKey = privKey[:]
|
||||
|
||||
// Generate deterministic verification ID
|
||||
verificationID, err = GenerateVerificationID(publicKey[:])
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to generate verification ID: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, privateKey, verificationID, nil
|
||||
}
|
||||
|
||||
// DeriveKeyFromPassword derives a key encryption key from a password using Argon2id
|
||||
// JavaScript equivalent: sodium.crypto_pwhash()
|
||||
// SECURITY: Password bytes are wiped from memory after key derivation.
|
||||
func DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != Argon2SaltSize {
|
||||
return nil, fmt.Errorf("invalid salt size: expected %d, got %d", Argon2SaltSize, len(salt))
|
||||
}
|
||||
|
||||
// Convert password to bytes for wiping
|
||||
passwordBytes := []byte(password)
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after use
|
||||
|
||||
// These parameters must match between Go and JavaScript
|
||||
key := argon2.IDKey(
|
||||
passwordBytes,
|
||||
salt,
|
||||
Argon2OpsLimit,
|
||||
Argon2MemLimit,
|
||||
Argon2Parallelism,
|
||||
Argon2KeySize,
|
||||
)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateRandomNonce generates a random nonce for ChaCha20-Poly1305 encryption operations
|
||||
// JavaScript equivalent: sodium.randombytes_buf(crypto.NonceSize)
|
||||
func GenerateRandomNonce() ([]byte, error) {
|
||||
nonce := make([]byte, NonceSize) // NonceSize is now 12 for ChaCha20-Poly1305
|
||||
_, err := io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
// GenerateVerificationID creates a human-readable representation of a public key
|
||||
// JavaScript equivalent: The same BIP39 mnemonic implementation
|
||||
// Generate VerificationID from public key (deterministic)
|
||||
func GenerateVerificationID(publicKey []byte) (string, error) {
|
||||
if len(publicKey) == 0 {
|
||||
return "", errors.New("public key cannot be empty")
|
||||
}
|
||||
|
||||
// 1. Hash the public key with SHA256
|
||||
hash := sha256.Sum256(publicKey)
|
||||
|
||||
// 2. Use the hash as entropy for BIP39
|
||||
mnemonic, err := bip39.NewMnemonic(hash[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate verification ID: %w", err)
|
||||
}
|
||||
|
||||
return mnemonic, nil
|
||||
}
|
||||
|
||||
// VerifyVerificationID checks if a verification ID matches a public key
|
||||
func VerifyVerificationID(publicKey []byte, verificationID string) bool {
|
||||
expectedID, err := GenerateVerificationID(publicKey)
|
||||
if err != nil {
|
||||
log.Printf("pkg.crypto.VerifyVerificationID - Failed to generate verification ID with error: %v\n", err)
|
||||
return false
|
||||
}
|
||||
return expectedID == verificationID
|
||||
}
|
||||
45
cloud/maplefile-backend/pkg/security/hash/hash.go
Normal file
45
cloud/maplefile-backend/pkg/security/hash/hash.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Package hash provides secure hashing utilities for tokens and sensitive data.
|
||||
// These utilities are used to hash tokens before storing them as cache keys,
|
||||
// preventing token leakage through cache key inspection.
|
||||
package hash
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// HashToken creates a SHA-256 hash of a token for use as a cache key.
|
||||
// This prevents token leakage via cache key inspection.
|
||||
// The input token bytes are wiped after hashing.
|
||||
func HashToken(token string) string {
|
||||
tokenBytes := []byte(token)
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
hash := sha256.Sum256(tokenBytes)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// HashBytes creates a SHA-256 hash of byte data.
|
||||
// If wipeInput is true, the input bytes are wiped after hashing.
|
||||
func HashBytes(data []byte, wipeInput bool) string {
|
||||
if wipeInput {
|
||||
defer memguard.WipeBytes(data)
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// HashTokenToBytes creates a SHA-256 hash and returns the raw bytes.
|
||||
// The input token bytes are wiped after hashing.
|
||||
func HashTokenToBytes(token string) []byte {
|
||||
tokenBytes := []byte(token)
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
hash := sha256.Sum256(tokenBytes)
|
||||
result := make([]byte, len(hash))
|
||||
copy(result, hash[:])
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/ipcountryblocker/ipcountryblocker.go
|
||||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// Provider defines the interface for IP-based country blocking operations.
|
||||
// It provides methods to check if an IP or country is blocked and to retrieve
|
||||
// country codes for given IP addresses.
|
||||
type Provider interface {
|
||||
// IsBlockedCountry checks if a country is in the blocked list.
|
||||
// isoCode must be an ISO 3166-1 alpha-2 country code.
|
||||
IsBlockedCountry(isoCode string) bool
|
||||
|
||||
// IsBlockedIP determines if an IP address originates from a blocked country.
|
||||
// Returns false for nil IP addresses or if country lookup fails.
|
||||
IsBlockedIP(ctx context.Context, ip net.IP) bool
|
||||
|
||||
// GetCountryCode returns the ISO 3166-1 alpha-2 country code for an IP address.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
GetCountryCode(ctx context.Context, ip net.IP) (string, error)
|
||||
|
||||
// Close releases resources associated with the provider.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// provider implements the Provider interface using MaxMind's GeoIP2 database.
|
||||
type provider struct {
|
||||
db *geoip2.Reader
|
||||
blockedCountries map[string]struct{} // Uses empty struct to optimize memory
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex // Protects concurrent access to blockedCountries
|
||||
}
|
||||
|
||||
// NewProvider creates a new IP country blocking provider using the provided configuration.
|
||||
// It initializes the GeoIP2 database and sets up the blocked countries list.
|
||||
// Fatally crashes the entire application if the database cannot be opened.
|
||||
func NewProvider(cfg *config.Configuration, logger *zap.Logger) Provider {
|
||||
db, err := geoip2.Open(cfg.Security.GeoLiteDBPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open GeoLite2 DB: %v", err)
|
||||
}
|
||||
|
||||
blocked := make(map[string]struct{}, len(cfg.Security.BannedCountries))
|
||||
for _, country := range cfg.Security.BannedCountries {
|
||||
blocked[country] = struct{}{}
|
||||
}
|
||||
|
||||
logger.Debug("ip blocker initialized",
|
||||
zap.String("db_path", cfg.Security.GeoLiteDBPath),
|
||||
zap.Any("blocked_countries", cfg.Security.BannedCountries))
|
||||
|
||||
return &provider{
|
||||
db: db,
|
||||
blockedCountries: blocked,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedCountry checks if a country code exists in the blocked countries map.
|
||||
// Thread-safe through RLock.
|
||||
func (p *provider) IsBlockedCountry(isoCode string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
_, exists := p.blockedCountries[isoCode]
|
||||
return exists
|
||||
}
|
||||
|
||||
// IsBlockedIP performs a country lookup for the IP and checks if it's blocked.
|
||||
// Returns false for nil IPs or failed lookups to fail safely.
|
||||
func (p *provider) IsBlockedIP(ctx context.Context, ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
code, err := p.GetCountryCode(ctx, ip)
|
||||
if err != nil {
|
||||
// Developers Note:
|
||||
// Comment this console log as it contributes a `noisy` server log.
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
// p.logger.WarnContext(ctx, "failed to get country code",
|
||||
// zap.Any("ip", ip),
|
||||
// zap.Any("error", err))
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
// Developers Note:
|
||||
// If the country d.n.e. exist that means we will return with `false`
|
||||
// indicating this IP address is allowed to access our server. If this
|
||||
// is concerning then you might set this to `true` to block on all
|
||||
// IP address which are not categorized by country.
|
||||
return false
|
||||
}
|
||||
|
||||
return p.IsBlockedCountry(code)
|
||||
}
|
||||
|
||||
// GetCountryCode performs a GeoIP2 database lookup to determine an IP's country.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
func (p *provider) GetCountryCode(ctx context.Context, ip net.IP) (string, error) {
|
||||
record, err := p.db.Country(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup country: %w", err)
|
||||
}
|
||||
|
||||
if record == nil || record.Country.IsoCode == "" {
|
||||
return "", fmt.Errorf("no country found for IP: %s", validation.MaskIP(ip.String()))
|
||||
}
|
||||
|
||||
return record.Country.IsoCode, nil
|
||||
}
|
||||
|
||||
// Close cleanly shuts down the GeoIP2 database connection.
|
||||
func (p *provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/ipcountryblocker/ipcountryblocker_test.go
|
||||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// testProvider is a test-specific wrapper that allows access to internal fields
|
||||
// of the provider struct for verification in tests. This is a common pattern
|
||||
// when you need to test internal state while keeping the production interface clean.
|
||||
type testProvider struct {
|
||||
Provider // Embedded interface for normal operations
|
||||
internal *provider // Access to internal fields for testing
|
||||
}
|
||||
|
||||
// newTestProvider creates a test provider instance with access to internal fields.
|
||||
// This allows us to verify the internal state in our tests while maintaining
|
||||
// encapsulation in production code.
|
||||
func newTestProvider(cfg *config.Configuration, logger *zap.Logger) testProvider {
|
||||
p := NewProvider(cfg, logger)
|
||||
return testProvider{
|
||||
Provider: p,
|
||||
internal: p.(*provider), // Type assertion to get access to internal fields
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewProvider verifies that the provider is properly initialized with all
|
||||
// required components (database connection, blocked countries map, logger).
|
||||
func TestNewProvider(t *testing.T) {
|
||||
// Setup test configuration with path to test database
|
||||
cfg := &config.Configuration{
|
||||
Security: config.SecurityConfig{
|
||||
GeoLiteDBPath: "../../../static/GeoLite2-Country.mmdb",
|
||||
BannedCountries: []string{"US", "CN"},
|
||||
},
|
||||
}
|
||||
// Initialize logger with JSON output for structured test logs
|
||||
logger, _ := zap.NewDevelopment()
|
||||
|
||||
// Create test provider and verify internal components
|
||||
p := newTestProvider(cfg, logger)
|
||||
assert.NotNil(t, p.Provider, "Provider should not be nil")
|
||||
assert.NotEmpty(t, p.internal.blockedCountries, "Blocked countries map should be initialized")
|
||||
assert.NotNil(t, p.internal.logger, "Logger should be initialized")
|
||||
assert.NotNil(t, p.internal.db, "Database connection should be initialized")
|
||||
defer p.Close() // Ensure cleanup after test
|
||||
}
|
||||
|
||||
// TestProvider_IsBlockedCountry tests the country blocking functionality with
|
||||
// various country codes including edge cases like empty and invalid codes.
|
||||
func TestProvider_IsBlockedCountry(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
// Table-driven test cases covering various scenarios
|
||||
tests := []struct {
|
||||
name string
|
||||
country string
|
||||
expected bool
|
||||
}{
|
||||
// Positive test cases - blocked countries
|
||||
{
|
||||
name: "blocked country US",
|
||||
country: "US",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked country CN",
|
||||
country: "CN",
|
||||
expected: true,
|
||||
},
|
||||
// Negative test cases - allowed countries
|
||||
{
|
||||
name: "non-blocked country GB",
|
||||
country: "GB",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-blocked country JP",
|
||||
country: "JP",
|
||||
expected: false,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty country code",
|
||||
country: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid country code",
|
||||
country: "XX",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "lowercase country code", // Tests case sensitivity
|
||||
country: "us",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Run each test case
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := provider.IsBlockedCountry(tt.country)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_IsBlockedIP verifies IP blocking functionality using real-world
|
||||
// IP addresses, including IPv4, IPv6, and various edge cases.
|
||||
func TestProvider_IsBlockedIP(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
expected bool
|
||||
}{
|
||||
// Known IP addresses from blocked countries
|
||||
{
|
||||
name: "blocked IP (US - Google DNS)",
|
||||
ip: net.ParseIP("8.8.8.8"), // Google's primary DNS
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked IP (US - Google DNS 2)",
|
||||
ip: net.ParseIP("8.8.4.4"), // Google's secondary DNS
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked IP (CN - Alibaba)",
|
||||
ip: net.ParseIP("223.5.5.5"), // Alibaba DNS
|
||||
expected: true,
|
||||
},
|
||||
// Non-blocked country IPs
|
||||
{
|
||||
name: "non-blocked IP (GB)",
|
||||
ip: net.ParseIP("178.62.1.1"),
|
||||
expected: false,
|
||||
},
|
||||
// Edge cases and special scenarios
|
||||
{
|
||||
name: "nil IP",
|
||||
ip: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid IP format",
|
||||
ip: net.ParseIP("invalid"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
ip: net.ParseIP("2001:4860:4860::8888"), // Google's IPv6 DNS
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := provider.IsBlockedIP(ctx, tt.ip)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_GetCountryCode verifies the country code lookup functionality
|
||||
// for various IP addresses, including error cases.
|
||||
func TestProvider_GetCountryCode(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
expected string
|
||||
expectError bool
|
||||
}{
|
||||
// Valid IP addresses with known countries
|
||||
{
|
||||
name: "US IP (Google DNS)",
|
||||
ip: net.ParseIP("8.8.8.8"),
|
||||
expected: "US",
|
||||
expectError: false,
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
name: "nil IP",
|
||||
ip: nil,
|
||||
expected: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "private IP", // RFC 1918 address
|
||||
ip: net.ParseIP("192.168.1.1"),
|
||||
expected: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, err := provider.GetCountryCode(ctx, tt.ip)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "Should return error for invalid IP")
|
||||
assert.Empty(t, code, "Should return empty code on error")
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err, "Should not return error for valid IP")
|
||||
assert.Equal(t, tt.expected, code, "Should return correct country code")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_Close verifies that the provider properly closes its resources
|
||||
// and subsequent operations fail as expected.
|
||||
func TestProvider_Close(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
|
||||
// Verify initial close succeeds
|
||||
err := provider.Close()
|
||||
assert.NoError(t, err, "Initial close should succeed")
|
||||
|
||||
// Verify operations fail after close
|
||||
code, err := provider.GetCountryCode(context.Background(), net.ParseIP("8.8.8.8"))
|
||||
assert.Error(t, err, "Operations should fail after close")
|
||||
assert.Empty(t, code, "No data should be returned after close")
|
||||
}
|
||||
|
||||
// setupTestProvider is a helper function that creates a properly configured
|
||||
// provider instance for testing, using the test database path.
|
||||
func setupTestProvider(t *testing.T) Provider {
|
||||
cfg := &config.Configuration{
|
||||
Security: config.SecurityConfig{
|
||||
GeoLiteDBPath: "../../../static/GeoLite2-Country.mmdb",
|
||||
BannedCountries: []string{"US", "CN"},
|
||||
},
|
||||
}
|
||||
logger, _ := zap.NewDevelopment()
|
||||
return NewProvider(cfg, logger)
|
||||
}
|
||||
223
cloud/maplefile-backend/pkg/security/ipcrypt/encryptor.go
Normal file
223
cloud/maplefile-backend/pkg/security/ipcrypt/encryptor.go
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// IPEncryptor provides secure IP address encryption for GDPR compliance
|
||||
// Uses AES-GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// Encrypts IP addresses before storage and provides expiration checking
|
||||
type IPEncryptor struct {
|
||||
gcm cipher.AEAD
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewIPEncryptor creates a new IP encryptor with the given encryption key
|
||||
// keyHex should be a 32-character hex string (16 bytes for AES-128)
|
||||
// or 64-character hex string (32 bytes for AES-256)
|
||||
// Example: "0123456789abcdef0123456789abcdef" (AES-128)
|
||||
// Recommended: Use AES-256 with 64-character hex key
|
||||
func NewIPEncryptor(keyHex string, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
// Decode hex key to bytes
|
||||
keyBytes, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid hex key: %w", err)
|
||||
}
|
||||
|
||||
// AES requires exactly 16, 24, or 32 bytes
|
||||
if len(keyBytes) != 16 && len(keyBytes) != 24 && len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("key must be 16, 24, or 32 bytes (32, 48, or 64 hex characters), got %d bytes", len(keyBytes))
|
||||
}
|
||||
|
||||
// Create AES cipher block
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// GCM provides both confidentiality and integrity
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("IP encryptor initialized with AES-GCM",
|
||||
zap.Int("key_length_bytes", len(keyBytes)),
|
||||
zap.Int("nonce_size", gcm.NonceSize()),
|
||||
zap.Int("overhead", gcm.Overhead()))
|
||||
|
||||
return &IPEncryptor{
|
||||
gcm: gcm,
|
||||
logger: logger.Named("ip-encryptor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts an IP address for secure storage using AES-GCM
|
||||
// Returns base64-encoded encrypted IP address with embedded nonce
|
||||
// Format: base64(nonce + ciphertext + auth_tag)
|
||||
// Supports both IPv4 and IPv6 addresses
|
||||
//
|
||||
// Security Properties:
|
||||
// - Semantic security: same IP address produces different ciphertext each time
|
||||
// - Authentication: tampering with ciphertext is detected
|
||||
// - Unique nonce per encryption prevents pattern analysis
|
||||
func (e *IPEncryptor) Encrypt(ipAddress string) (string, error) {
|
||||
if ipAddress == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Parse IP address to validate format
|
||||
ip := net.ParseIP(ipAddress)
|
||||
if ip == nil {
|
||||
e.logger.Warn("invalid IP address format",
|
||||
zap.String("ip", validation.MaskIP(ipAddress)))
|
||||
return "", fmt.Errorf("invalid IP address: %s", validation.MaskIP(ipAddress))
|
||||
}
|
||||
|
||||
// Convert to 16-byte representation (IPv4 gets converted to IPv6 format)
|
||||
ipBytes := ip.To16()
|
||||
if ipBytes == nil {
|
||||
return "", fmt.Errorf("failed to convert IP to 16-byte format")
|
||||
}
|
||||
|
||||
// Generate a random nonce (number used once)
|
||||
// GCM requires a unique nonce for each encryption operation
|
||||
nonce := make([]byte, e.gcm.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
e.logger.Error("failed to generate nonce", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the IP bytes using AES-GCM
|
||||
// GCM appends the authentication tag to the ciphertext
|
||||
// nil additional data means no associated data
|
||||
ciphertext := e.gcm.Seal(nil, nonce, ipBytes, nil)
|
||||
|
||||
// Prepend nonce to ciphertext for storage
|
||||
// Format: nonce || ciphertext+tag
|
||||
encryptedData := append(nonce, ciphertext...)
|
||||
|
||||
// Encode to base64 for database storage (text-safe)
|
||||
encryptedBase64 := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
e.logger.Debug("IP address encrypted with AES-GCM",
|
||||
zap.Int("plaintext_length", len(ipBytes)),
|
||||
zap.Int("nonce_length", len(nonce)),
|
||||
zap.Int("ciphertext_length", len(ciphertext)),
|
||||
zap.Int("total_encrypted_length", len(encryptedData)),
|
||||
zap.Int("base64_length", len(encryptedBase64)))
|
||||
|
||||
return encryptedBase64, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts an encrypted IP address
|
||||
// Takes base64-encoded encrypted IP and returns original IP address string
|
||||
// Verifies authentication tag to detect tampering
|
||||
func (e *IPEncryptor) Decrypt(encryptedBase64 string) (string, error) {
|
||||
if encryptedBase64 == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Decode base64 to bytes
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
e.logger.Warn("invalid base64-encoded encrypted IP",
|
||||
zap.String("base64", encryptedBase64),
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
// Extract nonce from the beginning
|
||||
nonceSize := e.gcm.NonceSize()
|
||||
if len(encryptedData) < nonceSize {
|
||||
return "", fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", nonceSize, len(encryptedData))
|
||||
}
|
||||
|
||||
nonce := encryptedData[:nonceSize]
|
||||
ciphertext := encryptedData[nonceSize:]
|
||||
|
||||
// Decrypt and verify authentication tag using AES-GCM
|
||||
ipBytes, err := e.gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to decrypt IP address (authentication failed or corrupted data)",
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to IP address
|
||||
ip := net.IP(ipBytes)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("failed to parse decrypted IP bytes")
|
||||
}
|
||||
|
||||
// Convert to string
|
||||
ipString := ip.String()
|
||||
|
||||
e.logger.Debug("IP address decrypted with AES-GCM",
|
||||
zap.Int("encrypted_length", len(encryptedData)),
|
||||
zap.Int("decrypted_length", len(ipBytes)))
|
||||
|
||||
return ipString, nil
|
||||
}
|
||||
|
||||
// IsExpired checks if an IP address timestamp has expired (> 90 days old)
|
||||
// GDPR compliance: IP addresses must be deleted after 90 days
|
||||
func (e *IPEncryptor) IsExpired(timestamp time.Time) bool {
|
||||
if timestamp.IsZero() {
|
||||
return false // No timestamp means not expired (will be cleaned up later)
|
||||
}
|
||||
|
||||
// Calculate age in days
|
||||
age := time.Since(timestamp)
|
||||
ageInDays := int(age.Hours() / 24)
|
||||
|
||||
expired := ageInDays > 90
|
||||
|
||||
if expired {
|
||||
e.logger.Debug("IP timestamp expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Int("age_days", ageInDays))
|
||||
}
|
||||
|
||||
return expired
|
||||
}
|
||||
|
||||
// ShouldCleanup checks if an IP address should be cleaned up based on timestamp
|
||||
// Returns true if timestamp is older than 90 days OR if timestamp is zero (unset)
|
||||
func (e *IPEncryptor) ShouldCleanup(timestamp time.Time) bool {
|
||||
// Always cleanup if timestamp is not set (backwards compatibility)
|
||||
if timestamp.IsZero() {
|
||||
return false // Don't cleanup unset timestamps immediately
|
||||
}
|
||||
|
||||
return e.IsExpired(timestamp)
|
||||
}
|
||||
|
||||
// ValidateKey validates that a key is properly formatted for IP encryption
|
||||
// Returns true if key is valid 32-character hex string (AES-128) or 64-character (AES-256)
|
||||
func ValidateKey(keyHex string) error {
|
||||
// Check length (must be 16, 24, or 32 bytes = 32, 48, or 64 hex chars)
|
||||
if len(keyHex) != 32 && len(keyHex) != 48 && len(keyHex) != 64 {
|
||||
return fmt.Errorf("key must be 32, 48, or 64 hex characters, got %d characters", len(keyHex))
|
||||
}
|
||||
|
||||
// Check if valid hex
|
||||
_, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key must be valid hex string: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
cloud/maplefile-backend/pkg/security/ipcrypt/provider.go
Normal file
13
cloud/maplefile-backend/pkg/security/ipcrypt/provider.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideIPEncryptor provides an IP encryptor instance
|
||||
// CWE-359: GDPR compliance for IP address storage
|
||||
func ProvideIPEncryptor(cfg *config.Config, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
return NewIPEncryptor(cfg.Security.IPEncryptionKey, logger)
|
||||
}
|
||||
47
cloud/maplefile-backend/pkg/security/jwt/jwt.go
Normal file
47
cloud/maplefile-backend/pkg/security/jwt/jwt.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/jwt_utils"
|
||||
sbytes "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
)
|
||||
|
||||
// JWTProvider provides interface for abstracting JWT generation.
|
||||
type JWTProvider interface {
|
||||
GenerateJWTToken(uuid string, ad time.Duration) (string, time.Time, error)
|
||||
GenerateJWTTokenPair(uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error)
|
||||
ProcessJWTToken(reqToken string) (string, error)
|
||||
}
|
||||
|
||||
type jwtProvider struct {
|
||||
hmacSecret *sbytes.SecureBytes
|
||||
}
|
||||
|
||||
// NewProvider Constructor that returns the JWT generator.
|
||||
func NewJWTProvider(cfg *config.Configuration) JWTProvider {
|
||||
// Convert JWT secret string to SecureBytes
|
||||
secret, _ := sbytes.NewSecureBytes([]byte(cfg.JWT.Secret))
|
||||
return jwtProvider{
|
||||
hmacSecret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateJWTToken generates a single JWT token.
|
||||
func (p jwtProvider) GenerateJWTToken(uuid string, ad time.Duration) (string, time.Time, error) {
|
||||
return jwt_utils.GenerateJWTToken(p.hmacSecret.Bytes(), uuid, ad)
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair Generate the `access token` and `refresh token` for the secret key.
|
||||
func (p jwtProvider) GenerateJWTTokenPair(uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
return jwt_utils.GenerateJWTTokenPair(p.hmacSecret.Bytes(), uuid, ad, rd)
|
||||
}
|
||||
|
||||
func (p jwtProvider) ProcessJWTToken(reqToken string) (string, error) {
|
||||
if p.hmacSecret == nil {
|
||||
return "", errors.New("HMAC secret is required")
|
||||
}
|
||||
return jwt_utils.ProcessJWTToken(p.hmacSecret.Bytes(), reqToken)
|
||||
}
|
||||
98
cloud/maplefile-backend/pkg/security/jwt/jwt_test.go
Normal file
98
cloud/maplefile-backend/pkg/security/jwt/jwt_test.go
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
func setupTestProvider(t *testing.T) JWTProvider {
|
||||
cfg := &config.Configuration{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
return NewJWTProvider(cfg)
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
|
||||
func TestGenerateJWTToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
token, expiry, err := provider.GenerateJWTToken(uuid, duration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiry.After(time.Now()))
|
||||
assert.True(t, expiry.Before(time.Now().Add(duration).Add(time.Second)))
|
||||
}
|
||||
|
||||
func TestGenerateJWTTokenPair(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
accessDuration := time.Hour
|
||||
refreshDuration := time.Hour * 24
|
||||
|
||||
accessToken, accessExpiry, refreshToken, refreshExpiry, err := provider.GenerateJWTTokenPair(uuid, accessDuration, refreshDuration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, refreshToken)
|
||||
assert.True(t, accessExpiry.After(time.Now()))
|
||||
assert.True(t, refreshExpiry.After(time.Now()))
|
||||
assert.True(t, accessExpiry.Before(time.Now().Add(accessDuration).Add(time.Second)))
|
||||
assert.True(t, refreshExpiry.Before(time.Now().Add(refreshDuration).Add(time.Second)))
|
||||
}
|
||||
|
||||
func TestProcessJWTToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
// Generate a token first
|
||||
token, _, err := provider.GenerateJWTToken(uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Process the generated token
|
||||
processedUUID, err := provider.ProcessJWTToken(token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedUUID)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_InvalidToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
|
||||
_, err := provider.ProcessJWTToken("invalid-token")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_NilSecret(t *testing.T) {
|
||||
provider := jwtProvider{
|
||||
hmacSecret: nil,
|
||||
}
|
||||
|
||||
_, err := provider.ProcessJWTToken("any-token")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "HMAC secret is required", err.Error())
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_ExpiredToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := -time.Hour // negative duration for expired token
|
||||
|
||||
token, _, err := provider.GenerateJWTToken(uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = provider.ProcessJWTToken(token)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
10
cloud/maplefile-backend/pkg/security/jwt/provider.go
Normal file
10
cloud/maplefile-backend/pkg/security/jwt/provider.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideJWTProvider provides a JWT provider instance for Wire DI
|
||||
func ProvideJWTProvider(cfg *config.Config) JWTProvider {
|
||||
return NewJWTProvider(cfg)
|
||||
}
|
||||
130
cloud/maplefile-backend/pkg/security/jwt_utils/jwt.go
Normal file
130
cloud/maplefile-backend/pkg/security/jwt_utils/jwt.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package jwt_utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// GenerateJWTToken Generate the `access token` for the secret key.
|
||||
// SECURITY: HMAC secret is wiped from memory after signing to prevent memory dump attacks.
|
||||
func GenerateJWTToken(hmacSecret []byte, uuid string, ad time.Duration) (string, time.Time, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
// Note: The original hmacSecret is owned by the caller
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after signing
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
expiresIn := time.Now().Add(ad)
|
||||
|
||||
// CWE-391: Safe type assertion even though we just created the token
|
||||
// Defensive programming to prevent future panics if jwt library changes
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", expiresIn, jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
claims["session_uuid"] = uuid
|
||||
claims["exp"] = expiresIn.Unix()
|
||||
|
||||
tokenString, err := token.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", expiresIn, err
|
||||
}
|
||||
|
||||
return tokenString, expiresIn, nil
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair Generate the `access token` and `refresh token` for the secret key.
|
||||
// SECURITY: HMAC secret is wiped from memory after signing to prevent memory dump attacks.
|
||||
func GenerateJWTTokenPair(hmacSecret []byte, uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after signing
|
||||
|
||||
//
|
||||
// Generate token.
|
||||
//
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
expiresIn := time.Now().Add(ad)
|
||||
|
||||
// CWE-391: Safe type assertion even though we just created the token
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", time.Now(), "", time.Now(), jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
claims["session_uuid"] = uuid
|
||||
claims["exp"] = expiresIn.Unix()
|
||||
|
||||
tokenString, err := token.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", time.Now(), "", time.Now(), err
|
||||
}
|
||||
|
||||
//
|
||||
// Generate refresh token.
|
||||
//
|
||||
refreshToken := jwt.New(jwt.SigningMethodHS256)
|
||||
refreshExpiresIn := time.Now().Add(rd)
|
||||
|
||||
// CWE-391: Safe type assertion for refresh token
|
||||
rtClaims, ok := refreshToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", time.Now(), "", time.Now(), jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
rtClaims["session_uuid"] = uuid
|
||||
rtClaims["exp"] = refreshExpiresIn.Unix()
|
||||
|
||||
refreshTokenString, err := refreshToken.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", time.Now(), "", time.Now(), err
|
||||
}
|
||||
|
||||
return tokenString, expiresIn, refreshTokenString, refreshExpiresIn, nil
|
||||
}
|
||||
|
||||
// ProcessJWTToken validates either the `access token` or `refresh token` and returns either the `uuid` if success or error on failure.
|
||||
// CWE-347: Implements proper algorithm validation to prevent JWT algorithm confusion attacks
|
||||
// OWASP A02:2021: Cryptographic Failures - Prevents token forgery through algorithm switching
|
||||
// SECURITY: HMAC secret copy is wiped from memory after validation.
|
||||
func ProcessJWTToken(hmacSecret []byte, reqToken string) (string, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after validation
|
||||
|
||||
token, err := jwt.Parse(reqToken, func(t *jwt.Token) (any, error) {
|
||||
// CRITICAL SECURITY FIX: Validate signing method to prevent algorithm confusion attacks
|
||||
// Protects against:
|
||||
// 1. "none" algorithm bypass (CVE-2015-9235)
|
||||
// 2. HS256/RS256 algorithm confusion (CVE-2016-5431)
|
||||
// 3. Token forgery through algorithm switching
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
|
||||
// Additional check: Ensure it's specifically HS256
|
||||
if t.Method.Alg() != "HS256" {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
|
||||
return secretCopy, nil
|
||||
})
|
||||
if err == nil && token.Valid {
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Safe type assertion with validation
|
||||
sessionUUID, ok := claims["session_uuid"].(string)
|
||||
if !ok {
|
||||
return "", jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
return sessionUUID, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
194
cloud/maplefile-backend/pkg/security/jwt_utils/jwt_test.go
Normal file
194
cloud/maplefile-backend/pkg/security/jwt_utils/jwt_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
package jwt_utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testSecret = []byte("test-secret-key")
|
||||
|
||||
func TestGenerateJWTToken(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
token, expiry, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiry.After(time.Now()))
|
||||
assert.True(t, expiry.Before(time.Now().Add(duration).Add(time.Second)))
|
||||
|
||||
// Verify token can be processed
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedUUID)
|
||||
}
|
||||
|
||||
func TestGenerateJWTTokenPair(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
accessDuration := time.Hour
|
||||
refreshDuration := time.Hour * 24
|
||||
|
||||
accessToken, accessExpiry, refreshToken, refreshExpiry, err := GenerateJWTTokenPair(
|
||||
testSecret,
|
||||
uuid,
|
||||
accessDuration,
|
||||
refreshDuration,
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, refreshToken)
|
||||
assert.True(t, accessExpiry.After(time.Now()))
|
||||
assert.True(t, refreshExpiry.After(time.Now()))
|
||||
assert.True(t, accessExpiry.Before(time.Now().Add(accessDuration).Add(time.Second)))
|
||||
assert.True(t, refreshExpiry.Before(time.Now().Add(refreshDuration).Add(time.Second)))
|
||||
|
||||
// Verify both tokens can be processed
|
||||
processedAccessUUID, err := ProcessJWTToken(testSecret, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedAccessUUID)
|
||||
|
||||
processedRefreshUUID, err := ProcessJWTToken(testSecret, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedRefreshUUID)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed token",
|
||||
token: "not.a.token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong signature",
|
||||
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJ0ZXN0LXV1aWQiLCJleHAiOjE3MDQwNjc1NTF9.wrong",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uuid, err := ProcessJWTToken(testSecret, tt.token)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, uuid)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, uuid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_Expired(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
duration := -time.Hour // negative duration for expired token
|
||||
|
||||
token, _, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, processedUUID)
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_AlgorithmConfusion tests protection against JWT algorithm confusion attacks
|
||||
// CVE-2015-9235: None algorithm bypass
|
||||
// CVE-2016-5431: HS256/RS256 algorithm confusion
|
||||
// CWE-347: Improper Verification of Cryptographic Signature
|
||||
func TestProcessJWTToken_AlgorithmConfusion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
description string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "none algorithm bypass attempt",
|
||||
// Token with "alg": "none" - should be rejected
|
||||
token: "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.",
|
||||
description: "Attacker tries to bypass signature verification using 'none' algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "RS256 algorithm confusion attempt",
|
||||
// Token with "alg": "RS256" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use RS256 to confuse HMAC validation",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "HS384 algorithm attempt",
|
||||
// Token with "alg": "HS384" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use different HMAC algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "HS512 algorithm attempt",
|
||||
// Token with "alg": "HS512" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use different HMAC algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Logf("Testing: %s", tt.description)
|
||||
uuid, err := ProcessJWTToken(testSecret, tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Expected error for security vulnerability: %s", tt.description)
|
||||
assert.Empty(t, uuid, "UUID should be empty when algorithm validation fails")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, uuid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_ValidHS256Only tests that only valid HS256 tokens are accepted
|
||||
func TestProcessJWTToken_ValidHS256Only(t *testing.T) {
|
||||
uuid := "valid-test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
// Generate a valid HS256 token
|
||||
token, _, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
assert.NoError(t, err, "Should generate valid token")
|
||||
|
||||
// Verify it's accepted
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.NoError(t, err, "Valid HS256 token should be accepted")
|
||||
assert.Equal(t, uuid, processedUUID, "UUID should match")
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_MissingSessionUUID tests protection against missing session_uuid claim
|
||||
func TestProcessJWTToken_MissingSessionUUID(t *testing.T) {
|
||||
// This test verifies the safe type assertion fix for CWE-391
|
||||
// A token without session_uuid claim should return an error, not panic
|
||||
|
||||
// Note: We can't easily create such a token with our GenerateJWTToken function
|
||||
// as it always includes session_uuid. In a real attack scenario, an attacker
|
||||
// would craft such a token manually. This test documents the expected behavior.
|
||||
|
||||
// For now, we verify that a malformed token is properly rejected
|
||||
malformedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.invalid"
|
||||
uuid, err := ProcessJWTToken(testSecret, malformedToken)
|
||||
assert.Error(t, err, "Token without session_uuid should be rejected")
|
||||
assert.Empty(t, uuid, "UUID should be empty for invalid token")
|
||||
}
|
||||
96
cloud/maplefile-backend/pkg/security/memutil/memutil.go
Normal file
96
cloud/maplefile-backend/pkg/security/memutil/memutil.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// Package memutil provides utilities for secure memory handling.
|
||||
// These utilities help prevent sensitive data from remaining in memory
|
||||
// after use, protecting against memory dump attacks.
|
||||
package memutil
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
sbytes "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// WipeString overwrites a string's backing array with zeros and clears the string.
|
||||
// Note: This only works if the string variable is the only reference to the data.
|
||||
// For better security, use SecureString instead of plain strings for sensitive data.
|
||||
func WipeString(s *string) {
|
||||
if s == nil || *s == "" {
|
||||
return
|
||||
}
|
||||
// Convert to byte slice and wipe
|
||||
// Note: This creates a copy, but we wipe what we can
|
||||
bytes := []byte(*s)
|
||||
memguard.WipeBytes(bytes)
|
||||
*s = ""
|
||||
}
|
||||
|
||||
// SecureCompareStrings performs constant-time comparison of two strings.
|
||||
// This prevents timing attacks when comparing secrets.
|
||||
func SecureCompareStrings(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// SecureCompareBytes performs constant-time comparison of two byte slices.
|
||||
// If wipeAfter is true, both slices are wiped after comparison.
|
||||
func SecureCompareBytes(a, b []byte, wipeAfter bool) bool {
|
||||
if wipeAfter {
|
||||
defer memguard.WipeBytes(a)
|
||||
defer memguard.WipeBytes(b)
|
||||
}
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// WithSecureBytes executes a function with secure byte handling.
|
||||
// The bytes are automatically wiped after the function returns.
|
||||
func WithSecureBytes(data []byte, fn func([]byte) error) error {
|
||||
defer memguard.WipeBytes(data)
|
||||
return fn(data)
|
||||
}
|
||||
|
||||
// WithSecureString executes a function with secure string handling.
|
||||
// The SecureString is automatically wiped after the function returns.
|
||||
func WithSecureString(str string, fn func(*sstring.SecureString) error) error {
|
||||
secure, err := sstring.NewSecureString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer secure.Wipe()
|
||||
return fn(secure)
|
||||
}
|
||||
|
||||
// CloneAndWipe creates a copy of data and wipes the original.
|
||||
// Useful when you need to pass data to a function that will store it,
|
||||
// but want to ensure the original is wiped.
|
||||
func CloneAndWipe(data []byte) []byte {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
clone := make([]byte, len(data))
|
||||
copy(clone, data)
|
||||
memguard.WipeBytes(data)
|
||||
return clone
|
||||
}
|
||||
|
||||
// SecureZero overwrites memory with zeros.
|
||||
// This is a convenience wrapper around memguard.WipeBytes.
|
||||
func SecureZero(data []byte) {
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
|
||||
// WipeSecureString wipes a SecureString if it's not nil.
|
||||
// This is a nil-safe convenience wrapper.
|
||||
func WipeSecureString(s *sstring.SecureString) {
|
||||
if s != nil {
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// WipeSecureBytes wipes a SecureBytes if it's not nil.
|
||||
// This is a nil-safe convenience wrapper.
|
||||
func WipeSecureBytes(s *sbytes.SecureBytes) {
|
||||
if s != nil {
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
186
cloud/maplefile-backend/pkg/security/password/password.go
Normal file
186
cloud/maplefile-backend/pkg/security/password/password.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/password/password.go
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
|
||||
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
|
||||
)
|
||||
|
||||
type PasswordProvider interface {
|
||||
GenerateHashFromPassword(password *sstring.SecureString) (string, error)
|
||||
ComparePasswordAndHash(password *sstring.SecureString, hash string) (bool, error)
|
||||
AlgorithmName() string
|
||||
GenerateSecureRandomBytes(length int) ([]byte, error)
|
||||
GenerateSecureRandomString(length int) (string, error)
|
||||
}
|
||||
|
||||
type passwordProvider struct {
|
||||
memory uint32
|
||||
iterations uint32
|
||||
parallelism uint8
|
||||
saltLength uint32
|
||||
keyLength uint32
|
||||
}
|
||||
|
||||
func NewPasswordProvider() PasswordProvider {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Establish the parameters to use for Argon2.
|
||||
return &passwordProvider{
|
||||
memory: 64 * 1024,
|
||||
iterations: 3,
|
||||
parallelism: 2,
|
||||
saltLength: 16,
|
||||
keyLength: 32,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword function takes the plaintext string and returns an Argon2 hashed string.
|
||||
// SECURITY: Password bytes are wiped from memory after hashing to prevent memory dump attacks.
|
||||
func (p *passwordProvider) GenerateHashFromPassword(password *sstring.SecureString) (string, error) {
|
||||
salt, err := generateRandomBytes(p.saltLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer memguard.WipeBytes(salt) // SECURITY: Wipe salt after use
|
||||
|
||||
passwordBytes := password.Bytes()
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after hashing
|
||||
|
||||
hash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
defer memguard.WipeBytes(hash) // SECURITY: Wipe raw hash after encoding
|
||||
|
||||
// Base64 encode the salt and hashed password.
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return a string using the standard encoded hash representation.
|
||||
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
|
||||
|
||||
return encodedHash, nil
|
||||
}
|
||||
|
||||
// CheckPasswordHash function checks the plaintext string and hash string and returns either true
|
||||
// or false depending.
|
||||
// SECURITY: All sensitive bytes (password, salt, hashes) are wiped from memory after comparison.
|
||||
func (p *passwordProvider) ComparePasswordAndHash(password *sstring.SecureString, encodedHash string) (match bool, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Extract the parameters, salt and derived key from the encoded password
|
||||
// hash.
|
||||
p, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer memguard.WipeBytes(salt) // SECURITY: Wipe salt after use
|
||||
defer memguard.WipeBytes(hash) // SECURITY: Wipe stored hash after comparison
|
||||
|
||||
// Get password bytes and ensure they're wiped after use
|
||||
passwordBytes := password.Bytes()
|
||||
defer memguard.WipeBytes(passwordBytes)
|
||||
|
||||
// Derive the key from the other password using the same parameters.
|
||||
otherHash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
defer memguard.WipeBytes(otherHash) // SECURITY: Wipe computed hash after comparison
|
||||
|
||||
// Check that the contents of the hashed passwords are identical. Note
|
||||
// that we are using the subtle.ConstantTimeCompare() function for this
|
||||
// to help prevent timing attacks.
|
||||
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AlgorithmName function returns the algorithm used for hashing.
|
||||
func (p *passwordProvider) AlgorithmName() string {
|
||||
return "argon2id"
|
||||
}
|
||||
|
||||
func generateRandomBytes(n uint32) ([]byte, error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func decodeHash(encodedHash string) (p *passwordProvider, salt, hash []byte, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
vals := strings.Split(encodedHash, "$")
|
||||
if len(vals) != 6 {
|
||||
return nil, nil, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, ErrIncompatibleVersion
|
||||
}
|
||||
|
||||
p = &passwordProvider{}
|
||||
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.saltLength = uint32(len(salt))
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.keyLength = uint32(len(hash))
|
||||
|
||||
return p, salt, hash, nil
|
||||
}
|
||||
|
||||
// GenerateSecureRandomBytes generates a secure random byte slice of the specified length.
|
||||
func (p *passwordProvider) GenerateSecureRandomBytes(length int) ([]byte, error) {
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate secure random bytes: %v", err)
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
// GenerateSecureRandomString generates a secure random string of the specified length.
|
||||
func (p *passwordProvider) GenerateSecureRandomString(length int) (string, error) {
|
||||
bytes, err := p.GenerateSecureRandomBytes(length)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package password
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
t.Log("TestPasswordHashing: Starting")
|
||||
|
||||
provider := NewPasswordProvider()
|
||||
t.Log("TestPasswordHashing: Provider created")
|
||||
|
||||
password, err := sstring.NewSecureString("test-password")
|
||||
require.NoError(t, err)
|
||||
t.Log("TestPasswordHashing: Password SecureString created")
|
||||
fmt.Println("TestPasswordHashing: Password SecureString created")
|
||||
|
||||
// Let's add a timeout to see if we can pinpoint the issue
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
fmt.Println("TestPasswordHashing: Generating hash...")
|
||||
hash, err := provider.GenerateHashFromPassword(password)
|
||||
fmt.Printf("TestPasswordHashing: Hash generated: %v, error: %v\n", hash != "", err)
|
||||
|
||||
if err == nil {
|
||||
fmt.Println("TestPasswordHashing: Comparing password and hash...")
|
||||
match, err := provider.ComparePasswordAndHash(password, hash)
|
||||
fmt.Printf("TestPasswordHashing: Comparison done: match=%v, error=%v\n", match, err)
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
fmt.Println("TestPasswordHashing: Test completed successfully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Test timed out after 10 seconds")
|
||||
}
|
||||
|
||||
fmt.Println("TestPasswordHashing: Cleaning up password...")
|
||||
password.Wipe()
|
||||
fmt.Println("TestPasswordHashing: Done")
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package password
|
||||
|
||||
// ProvidePasswordProvider provides a password provider instance for Wire DI
|
||||
func ProvidePasswordProvider() PasswordProvider {
|
||||
return NewPasswordProvider()
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securebytes.go
|
||||
package securebytes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureBytes is used to store a byte slice securely in memory.
|
||||
type SecureBytes struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureBytes creates a new SecureBytes instance from the given byte slice.
|
||||
func NewSecureBytes(b []byte) (*SecureBytes, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, errors.New("byte slice cannot be empty")
|
||||
}
|
||||
|
||||
buffer := memguard.NewBuffer(len(b))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
copy(buffer.Bytes(), b)
|
||||
|
||||
return &SecureBytes{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// Bytes returns the securely stored byte slice.
|
||||
func (sb *SecureBytes) Bytes() []byte {
|
||||
return sb.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the byte slice from memory and makes it unrecoverable.
|
||||
func (sb *SecureBytes) Wipe() error {
|
||||
sb.buffer.Wipe()
|
||||
sb.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securebytes_test.go
|
||||
package securebytes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSecureBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
input: []byte("test-data"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sb, err := NewSecureBytes(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, sb)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sb)
|
||||
assert.NotNil(t, sb.buffer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureBytes_Bytes(t *testing.T) {
|
||||
input := []byte("test-data")
|
||||
sb, err := NewSecureBytes(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure the SecureBytes object is properly closed after the test
|
||||
defer sb.Wipe()
|
||||
|
||||
output := sb.Bytes()
|
||||
assert.Equal(t, input, output)
|
||||
assert.NotSame(t, &input, &output) // Verify different memory addresses
|
||||
}
|
||||
|
||||
func TestSecureBytes_Wipe(t *testing.T) {
|
||||
sb, err := NewSecureBytes([]byte("test-data"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = sb.Wipe()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// After wiping, the internal buffer should be nil
|
||||
assert.Nil(t, sb.buffer)
|
||||
|
||||
// Attempting to access bytes after wiping might panic or return nil/empty slice
|
||||
// Based on the panic, calling Bytes() on a wiped buffer is unsafe.
|
||||
// We verify the buffer is nil instead of calling Bytes().
|
||||
}
|
||||
|
||||
func TestSecureBytes_DataIsolation(t *testing.T) {
|
||||
original := []byte("test-data")
|
||||
sb, err := NewSecureBytes(original)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure the SecureBytes object is properly closed after the test
|
||||
defer sb.Wipe()
|
||||
|
||||
// Modify original data
|
||||
original[0] = 'x'
|
||||
|
||||
// Verify secure bytes remains unchanged
|
||||
stored := sb.Bytes()
|
||||
assert.NotEqual(t, original, stored)
|
||||
assert.Equal(t, []byte("test-data"), stored)
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
package secureconfig
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideSecureConfigProvider provides a SecureConfigProvider for Wire DI.
|
||||
func ProvideSecureConfigProvider(cfg *config.Config) *SecureConfigProvider {
|
||||
return NewSecureConfigProvider(cfg)
|
||||
}
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
// Package secureconfig provides secure access to configuration secrets.
|
||||
// It wraps sensitive configuration values in memguard-protected buffers
|
||||
// to prevent secret leakage through memory dumps.
|
||||
package secureconfig
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// SecureConfigProvider provides secure access to configuration secrets.
|
||||
// Secrets are stored in memguard LockedBuffers and wiped when no longer needed.
|
||||
type SecureConfigProvider struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Cached secure buffers - created on first access
|
||||
jwtSecret *memguard.LockedBuffer
|
||||
dbPassword *memguard.LockedBuffer
|
||||
cachePassword *memguard.LockedBuffer
|
||||
s3AccessKey *memguard.LockedBuffer
|
||||
s3SecretKey *memguard.LockedBuffer
|
||||
mailgunAPIKey *memguard.LockedBuffer
|
||||
|
||||
// Original config for initial loading
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSecureConfigProvider creates a new secure config provider from the given config.
|
||||
// The original config secrets are copied to secure buffers and should be cleared
|
||||
// from the original config after this call.
|
||||
func NewSecureConfigProvider(cfg *config.Config) *SecureConfigProvider {
|
||||
provider := &SecureConfigProvider{
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
// Pre-load secrets into secure buffers
|
||||
provider.loadSecrets()
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// loadSecrets copies secrets from config into memguard buffers.
|
||||
// SECURITY: Original config strings remain in memory but secure buffers provide
|
||||
// additional protection for long-lived secret access.
|
||||
func (p *SecureConfigProvider) loadSecrets() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// JWT Secret
|
||||
if p.cfg.JWT.Secret != "" {
|
||||
p.jwtSecret = memguard.NewBufferFromBytes([]byte(p.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// Database Password
|
||||
if p.cfg.Database.Password != "" {
|
||||
p.dbPassword = memguard.NewBufferFromBytes([]byte(p.cfg.Database.Password))
|
||||
}
|
||||
|
||||
// Cache Password
|
||||
if p.cfg.Cache.Password != "" {
|
||||
p.cachePassword = memguard.NewBufferFromBytes([]byte(p.cfg.Cache.Password))
|
||||
}
|
||||
|
||||
// S3 Access Key
|
||||
if p.cfg.S3.AccessKey != "" {
|
||||
p.s3AccessKey = memguard.NewBufferFromBytes([]byte(p.cfg.S3.AccessKey))
|
||||
}
|
||||
|
||||
// S3 Secret Key
|
||||
if p.cfg.S3.SecretKey != "" {
|
||||
p.s3SecretKey = memguard.NewBufferFromBytes([]byte(p.cfg.S3.SecretKey))
|
||||
}
|
||||
|
||||
// Mailgun API Key
|
||||
if p.cfg.Mailgun.APIKey != "" {
|
||||
p.mailgunAPIKey = memguard.NewBufferFromBytes([]byte(p.cfg.Mailgun.APIKey))
|
||||
}
|
||||
}
|
||||
|
||||
// JWTSecret returns the JWT secret as a secure byte slice.
|
||||
// The returned bytes should not be stored - use immediately and let GC collect.
|
||||
func (p *SecureConfigProvider) JWTSecret() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.jwtSecret == nil || !p.jwtSecret.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.jwtSecret.Bytes()
|
||||
}
|
||||
|
||||
// DatabasePassword returns the database password as a secure byte slice.
|
||||
func (p *SecureConfigProvider) DatabasePassword() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.dbPassword == nil || !p.dbPassword.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.dbPassword.Bytes()
|
||||
}
|
||||
|
||||
// CachePassword returns the cache password as a secure byte slice.
|
||||
func (p *SecureConfigProvider) CachePassword() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.cachePassword == nil || !p.cachePassword.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.cachePassword.Bytes()
|
||||
}
|
||||
|
||||
// S3AccessKey returns the S3 access key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) S3AccessKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.s3AccessKey == nil || !p.s3AccessKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.s3AccessKey.Bytes()
|
||||
}
|
||||
|
||||
// S3SecretKey returns the S3 secret key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) S3SecretKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.s3SecretKey == nil || !p.s3SecretKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.s3SecretKey.Bytes()
|
||||
}
|
||||
|
||||
// MailgunAPIKey returns the Mailgun API key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) MailgunAPIKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.mailgunAPIKey == nil || !p.mailgunAPIKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.mailgunAPIKey.Bytes()
|
||||
}
|
||||
|
||||
// Destroy securely wipes all cached secrets from memory.
|
||||
// Should be called during application shutdown.
|
||||
func (p *SecureConfigProvider) Destroy() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.jwtSecret != nil && p.jwtSecret.IsAlive() {
|
||||
p.jwtSecret.Destroy()
|
||||
}
|
||||
if p.dbPassword != nil && p.dbPassword.IsAlive() {
|
||||
p.dbPassword.Destroy()
|
||||
}
|
||||
if p.cachePassword != nil && p.cachePassword.IsAlive() {
|
||||
p.cachePassword.Destroy()
|
||||
}
|
||||
if p.s3AccessKey != nil && p.s3AccessKey.IsAlive() {
|
||||
p.s3AccessKey.Destroy()
|
||||
}
|
||||
if p.s3SecretKey != nil && p.s3SecretKey.IsAlive() {
|
||||
p.s3SecretKey.Destroy()
|
||||
}
|
||||
if p.mailgunAPIKey != nil && p.mailgunAPIKey.IsAlive() {
|
||||
p.mailgunAPIKey.Destroy()
|
||||
}
|
||||
|
||||
p.jwtSecret = nil
|
||||
p.dbPassword = nil
|
||||
p.cachePassword = nil
|
||||
p.s3AccessKey = nil
|
||||
p.s3SecretKey = nil
|
||||
p.mailgunAPIKey = nil
|
||||
}
|
||||
|
||||
// Config returns the underlying config for non-secret access.
|
||||
// Prefer using the specific secret accessor methods for sensitive data.
|
||||
func (p *SecureConfigProvider) Config() *config.Config {
|
||||
return p.cfg
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securestring.go
|
||||
package securestring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureString is used to store a string securely in memory.
|
||||
type SecureString struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureString creates a new SecureString instance from the given string.
|
||||
func NewSecureString(s string) (*SecureString, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, errors.New("string cannot be empty")
|
||||
}
|
||||
|
||||
// Use memguard's built-in method for creating from bytes
|
||||
buffer := memguard.NewBufferFromBytes([]byte(s))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
return &SecureString{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// String returns the securely stored string.
|
||||
func (ss *SecureString) String() string {
|
||||
if ss.buffer == nil {
|
||||
fmt.Println("String(): buffer is nil")
|
||||
return ""
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
fmt.Println("String(): buffer is not alive")
|
||||
return ""
|
||||
}
|
||||
return ss.buffer.String()
|
||||
}
|
||||
|
||||
func (ss *SecureString) Bytes() []byte {
|
||||
if ss.buffer == nil {
|
||||
fmt.Println("Bytes(): buffer is nil")
|
||||
return nil
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
fmt.Println("Bytes(): buffer is not alive")
|
||||
return nil
|
||||
}
|
||||
return ss.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the string from memory and makes it unrecoverable.
|
||||
func (ss *SecureString) Wipe() error {
|
||||
|
||||
if ss.buffer != nil {
|
||||
if ss.buffer.IsAlive() {
|
||||
ss.buffer.Destroy()
|
||||
}
|
||||
} else {
|
||||
// fmt.Println("Wipe(): Buffer is nil")
|
||||
}
|
||||
ss.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securestring_test.go
|
||||
package securestring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSecureString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid string",
|
||||
input: "test-string",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ss, err := NewSecureString(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss)
|
||||
assert.NotNil(t, ss.buffer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureString_String(t *testing.T) {
|
||||
input := "test-string"
|
||||
ss, err := NewSecureString(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := ss.String()
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
|
||||
func TestSecureString_Wipe(t *testing.T) {
|
||||
ss, err := NewSecureString("test-string")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ss.Wipe()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ss.buffer)
|
||||
|
||||
// Verify string is wiped
|
||||
output := ss.String()
|
||||
assert.Empty(t, output)
|
||||
}
|
||||
|
||||
func TestSecureString_DataIsolation(t *testing.T) {
|
||||
original := "test-string"
|
||||
ss, err := NewSecureString(original)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Attempt to modify original
|
||||
original = "modified"
|
||||
|
||||
// Verify secure string remains unchanged
|
||||
stored := ss.String()
|
||||
assert.NotEqual(t, original, stored)
|
||||
assert.Equal(t, "test-string", stored)
|
||||
}
|
||||
|
||||
func TestSecureString_StringConsistency(t *testing.T) {
|
||||
input := "test-string"
|
||||
ss, err := NewSecureString(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Multiple calls should return same value
|
||||
assert.Equal(t, ss.String(), ss.String())
|
||||
}
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinJWTSecretLength is the minimum required length for JWT secrets (256 bits)
|
||||
MinJWTSecretLength = 32
|
||||
|
||||
// RecommendedJWTSecretLength is the recommended length for JWT secrets (512 bits)
|
||||
RecommendedJWTSecretLength = 64
|
||||
|
||||
// MinEntropyBits is the minimum Shannon entropy in bits per character
|
||||
// For reference: random base64 has ~6 bits/char, we require minimum 4.0
|
||||
MinEntropyBits = 4.0
|
||||
|
||||
// MinProductionEntropyBits is the minimum entropy required for production
|
||||
MinProductionEntropyBits = 4.5
|
||||
|
||||
// MaxRepeatingCharacters is the maximum allowed consecutive repeating characters
|
||||
MaxRepeatingCharacters = 3
|
||||
)
|
||||
|
||||
// WeakSecrets contains common weak/default secrets that should never be used
|
||||
var WeakSecrets = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"changeme",
|
||||
"change-me",
|
||||
"change_me",
|
||||
"12345",
|
||||
"123456",
|
||||
"1234567",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"default",
|
||||
"test",
|
||||
"testing",
|
||||
"admin",
|
||||
"administrator",
|
||||
"root",
|
||||
"qwerty",
|
||||
"qwertyuiop",
|
||||
"letmein",
|
||||
"welcome",
|
||||
"monkey",
|
||||
"dragon",
|
||||
"master",
|
||||
"sunshine",
|
||||
"princess",
|
||||
"football",
|
||||
"starwars",
|
||||
"baseball",
|
||||
"superman",
|
||||
"iloveyou",
|
||||
"trustno1",
|
||||
"hello",
|
||||
"abc123",
|
||||
"password123",
|
||||
"admin123",
|
||||
"guest",
|
||||
"user",
|
||||
"demo",
|
||||
"sample",
|
||||
"example",
|
||||
}
|
||||
|
||||
// DangerousPatterns contains patterns that indicate a secret should be changed
|
||||
var DangerousPatterns = []string{
|
||||
"change",
|
||||
"replace",
|
||||
"update",
|
||||
"modify",
|
||||
"sample",
|
||||
"example",
|
||||
"todo",
|
||||
"fixme",
|
||||
"temp",
|
||||
"temporary",
|
||||
}
|
||||
|
||||
// CredentialValidator validates credentials and secrets for security issues
|
||||
type CredentialValidator interface {
|
||||
ValidateJWTSecret(secret string, environment string) error
|
||||
ValidateAllCredentials(cfg *config.Config) error
|
||||
}
|
||||
|
||||
type credentialValidator struct{}
|
||||
|
||||
// NewCredentialValidator creates a new credential validator
|
||||
func NewCredentialValidator() CredentialValidator {
|
||||
return &credentialValidator{}
|
||||
}
|
||||
|
||||
// ValidateJWTSecret validates JWT secret strength and security
|
||||
// CWE-798: Comprehensive validation to prevent hard-coded/weak credentials
|
||||
func (v *credentialValidator) ValidateJWTSecret(secret string, environment string) error {
|
||||
// Check minimum length
|
||||
if len(secret) < MinJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short (%d characters). Minimum required: %d characters (256 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
MinJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for common weak secrets (case-insensitive)
|
||||
secretLower := strings.ToLower(secret)
|
||||
for _, weak := range WeakSecrets {
|
||||
if secretLower == weak || strings.Contains(secretLower, weak) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret cannot contain common weak value: '%s'. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
weak,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for dangerous patterns indicating default/placeholder values
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(secretLower, pattern) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains suspicious pattern '%s' which suggests it's a placeholder. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
pattern,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for repeating character patterns (e.g., "aaaa", "1111")
|
||||
if err := checkRepeatingPatterns(secret); err != nil {
|
||||
return fmt.Errorf(
|
||||
"JWT secret validation failed: %s. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sequential patterns (e.g., "abcd", "1234")
|
||||
if hasSequentialPattern(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains sequential patterns (e.g., 'abcd', '1234') which reduces entropy. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Calculate Shannon entropy
|
||||
entropy := calculateShannonEntropy(secret)
|
||||
minEntropy := MinEntropyBits
|
||||
if environment == "production" {
|
||||
minEntropy = MinProductionEntropyBits
|
||||
}
|
||||
|
||||
if entropy < minEntropy {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient entropy: %.2f bits/char (minimum: %.1f bits/char for %s). "+
|
||||
"The secret appears to have low randomness. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
entropy,
|
||||
minEntropy,
|
||||
environment,
|
||||
)
|
||||
}
|
||||
|
||||
// In production, enforce stricter requirements
|
||||
if environment == "production" {
|
||||
// Check recommended length for production
|
||||
if len(secret) < RecommendedJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short for production environment (%d characters). "+
|
||||
"Recommended: %d characters (512 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
RecommendedJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sufficient character complexity
|
||||
if !hasSufficientComplexity(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient complexity for production. It should contain a mix of uppercase, lowercase, " +
|
||||
"digits, and special characters (at least 3 types). Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Validate base64-like characteristics (recommended generation method)
|
||||
if !looksLikeBase64(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret does not appear to be randomly generated (expected base64-like characteristics). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAllCredentials validates all credentials in the configuration
|
||||
func (v *credentialValidator) ValidateAllCredentials(cfg *config.Config) error {
|
||||
var errors []string
|
||||
|
||||
// Validate JWT Secret
|
||||
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("JWT Secret validation failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
// In production, ensure other critical configs are not using defaults/placeholders
|
||||
if cfg.App.Environment == "production" {
|
||||
// Check Meilisearch API key
|
||||
if cfg.Meilisearch.APIKey == "" {
|
||||
errors = append(errors, "Meilisearch API key must be set in production")
|
||||
} else if containsDangerousPattern(cfg.Meilisearch.APIKey) {
|
||||
errors = append(errors, "Meilisearch API key appears to be a placeholder/default value")
|
||||
}
|
||||
|
||||
// Check database hosts are not using localhost
|
||||
for _, host := range cfg.Database.Hosts {
|
||||
if strings.Contains(strings.ToLower(host), "localhost") || host == "127.0.0.1" {
|
||||
errors = append(errors, "Database hosts should not use localhost in production")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache host is not localhost
|
||||
if strings.Contains(strings.ToLower(cfg.Cache.Host), "localhost") || cfg.Cache.Host == "127.0.0.1" {
|
||||
errors = append(errors, "Cache host should not use localhost in production")
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("credential validation failed:\n - %s", strings.Join(errors, "\n - "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShannonEntropy calculates the Shannon entropy of a string in bits per character
|
||||
// Shannon entropy measures the randomness/unpredictability of data
|
||||
// Formula: H(X) = -Σ(p(x) * log2(p(x))) where p(x) is the probability of character x
|
||||
func calculateShannonEntropy(s string) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Count character frequencies
|
||||
frequencies := make(map[rune]int)
|
||||
for _, char := range s {
|
||||
frequencies[char]++
|
||||
}
|
||||
|
||||
// Calculate entropy
|
||||
var entropy float64
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range frequencies {
|
||||
probability := float64(count) / length
|
||||
entropy -= probability * math.Log2(probability)
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// hasSufficientComplexity checks if the secret has a good mix of character types
|
||||
// Requires at least 3 out of 4 character types for production
|
||||
func hasSufficientComplexity(secret string) bool {
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
hasSpecial bool
|
||||
)
|
||||
|
||||
for _, char := range secret {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(char):
|
||||
hasDigit = true
|
||||
default:
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
// Require at least 3 out of 4 character types
|
||||
count := 0
|
||||
if hasUpper {
|
||||
count++
|
||||
}
|
||||
if hasLower {
|
||||
count++
|
||||
}
|
||||
if hasDigit {
|
||||
count++
|
||||
}
|
||||
if hasSpecial {
|
||||
count++
|
||||
}
|
||||
|
||||
return count >= 3
|
||||
}
|
||||
|
||||
// checkRepeatingPatterns checks for excessive repeating characters
|
||||
func checkRepeatingPatterns(s string) error {
|
||||
if len(s) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repeatCount := 1
|
||||
lastChar := rune(s[0])
|
||||
|
||||
for _, char := range s[1:] {
|
||||
if char == lastChar {
|
||||
repeatCount++
|
||||
if repeatCount > MaxRepeatingCharacters {
|
||||
return fmt.Errorf(
|
||||
"contains %d consecutive repeating characters ('%c'), maximum allowed: %d",
|
||||
repeatCount,
|
||||
lastChar,
|
||||
MaxRepeatingCharacters,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
repeatCount = 1
|
||||
lastChar = char
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasSequentialPattern detects common sequential patterns
|
||||
func hasSequentialPattern(s string) bool {
|
||||
if len(s) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for at least 4 consecutive sequential characters
|
||||
for i := 0; i < len(s)-3; i++ {
|
||||
// Check ascending sequence (e.g., "abcd", "1234")
|
||||
if s[i+1] == s[i]+1 && s[i+2] == s[i]+2 && s[i+3] == s[i]+3 {
|
||||
return true
|
||||
}
|
||||
// Check descending sequence (e.g., "dcba", "4321")
|
||||
if s[i+1] == s[i]-1 && s[i+2] == s[i]-2 && s[i+3] == s[i]-3 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// looksLikeBase64 checks if the string has base64-like characteristics
|
||||
// Base64 uses: A-Z, a-z, 0-9, +, /, and = for padding
|
||||
func looksLikeBase64(s string) bool {
|
||||
if len(s) < MinJWTSecretLength {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
validChars int
|
||||
)
|
||||
|
||||
// Base64 valid characters
|
||||
for _, char := range s {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
validChars++
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
validChars++
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
validChars++
|
||||
case char == '+' || char == '/' || char == '=' || char == '-' || char == '_':
|
||||
validChars++
|
||||
default:
|
||||
// Invalid character for base64
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Should have good mix of character types typical of base64
|
||||
charTypesCount := 0
|
||||
if hasUpper {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasLower {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasDigit {
|
||||
charTypesCount++
|
||||
}
|
||||
|
||||
// Base64 typically has at least uppercase, lowercase, and digits
|
||||
// Also check that it doesn't look like a repeated pattern
|
||||
if charTypesCount < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for repeated patterns (e.g., "AbCd12!@" repeated)
|
||||
// If the string has low unique character count relative to its length, it's probably not random
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range s {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
|
||||
// Random base64 should have at least 50% unique characters for strings over 32 chars
|
||||
uniqueRatio := float64(len(uniqueChars)) / float64(len(s))
|
||||
return uniqueRatio >= 0.4 // At least 40% unique characters
|
||||
}
|
||||
|
||||
// containsDangerousPattern checks if a string contains any dangerous patterns
|
||||
func containsDangerousPattern(value string) bool {
|
||||
valueLower := strings.ToLower(value)
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(valueLower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Simplified comprehensive test for JWT secret validation
|
||||
func TestJWTSecretValidation(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
// Good secrets - these should pass
|
||||
goodSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
}{
|
||||
{
|
||||
name: "Good 32-char for dev",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "development",
|
||||
},
|
||||
{
|
||||
name: "Good 64-char for prod",
|
||||
secret: "1WDduocStecRuIv+Us1t/RnYDoW1ZcEEbU+H+WykJG+IT5WnijzBb8uUPzGKju+D",
|
||||
env: "production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range goodSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for valid secret, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Bad secrets - these should fail
|
||||
badSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
mustContain string
|
||||
}{
|
||||
{
|
||||
name: "Too short",
|
||||
secret: "short",
|
||||
env: "development",
|
||||
mustContain: "too short",
|
||||
},
|
||||
{
|
||||
name: "Common weak - password",
|
||||
secret: "password-is-not-secure-but-32char",
|
||||
env: "development",
|
||||
mustContain: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern",
|
||||
secret: "please-change-this-ima7xR+9nT0Yz",
|
||||
env: "development",
|
||||
mustContain: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
env: "development",
|
||||
mustContain: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
env: "development",
|
||||
mustContain: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy",
|
||||
secret: "abababababababababababababababab",
|
||||
env: "development",
|
||||
mustContain: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Prod too short",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "production",
|
||||
mustContain: "too short for production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range badSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got no error", tt.mustContain)
|
||||
} else if !contains(err.Error(), tt.mustContain) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.mustContain, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,535 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCalculateShannonEntropy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
minBits float64
|
||||
maxBits float64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have 0 entropy",
|
||||
},
|
||||
{
|
||||
name: "All same character",
|
||||
input: "aaaaaaaaaa",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have very low entropy",
|
||||
},
|
||||
{
|
||||
name: "Low entropy - repeated pattern",
|
||||
input: "abcabcabcabc",
|
||||
minBits: 1.5,
|
||||
maxBits: 2.0,
|
||||
expected: "should have low entropy",
|
||||
},
|
||||
{
|
||||
name: "Medium entropy - simple password",
|
||||
input: "Password123",
|
||||
minBits: 3.0,
|
||||
maxBits: 4.5,
|
||||
expected: "should have medium entropy",
|
||||
},
|
||||
{
|
||||
name: "High entropy - random base64",
|
||||
input: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
minBits: 4.0,
|
||||
maxBits: 6.0,
|
||||
expected: "should have high entropy",
|
||||
},
|
||||
{
|
||||
name: "Very high entropy - long random base64",
|
||||
input: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
minBits: 4.5,
|
||||
maxBits: 6.5,
|
||||
expected: "should have very high entropy",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
entropy := calculateShannonEntropy(tt.input)
|
||||
if entropy < tt.minBits || entropy > tt.maxBits {
|
||||
t.Errorf("%s: got %.2f bits/char, expected between %.1f and %.1f", tt.expected, entropy, tt.minBits, tt.maxBits)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSufficientComplexity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnop",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only uppercase",
|
||||
input: "ABCDEFGHIJKLMNOP",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only digits",
|
||||
input: "1234567890",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase",
|
||||
input: "AbCdEfGhIjKl",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits",
|
||||
input: "abc123def456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Uppercase + digits",
|
||||
input: "ABC123DEF456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + digits",
|
||||
input: "Abc123Def456",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + special",
|
||||
input: "AbC+DeF/GhI=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits + special",
|
||||
input: "abc123+def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "All four types",
|
||||
input: "Abc123+Def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 string",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b+xK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSufficientComplexity(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSufficientComplexity(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRepeatingPatterns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Single character",
|
||||
input: "a",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "No repeating",
|
||||
input: "abcdefgh",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Two repeating (ok)",
|
||||
input: "aabcdeef",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Three repeating (ok)",
|
||||
input: "aaabcdeee",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Four repeating (error)",
|
||||
input: "aaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Five repeating (error)",
|
||||
input: "aaaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple groups of three (ok)",
|
||||
input: "aaabbbccc",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Repeating in middle (error)",
|
||||
input: "abcdddddef",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Repeating at end (error)",
|
||||
input: "abcdefgggg",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkRepeatingPatterns(tt.input)
|
||||
if (err != nil) != tt.shouldErr {
|
||||
t.Errorf("checkRepeatingPatterns(%q) error = %v, shouldErr = %v", tt.input, err, tt.shouldErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSequentialPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No sequential",
|
||||
input: "acegikmo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Ascending sequence - abcd",
|
||||
input: "xyzabcdefg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending sequence - dcba",
|
||||
input: "xyzdcbafg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Ascending digits - 1234",
|
||||
input: "abc1234def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending digits - 4321",
|
||||
input: "abc4321def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Random characters",
|
||||
input: "xK8vN2mP9sQ4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Base64-like",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSequentialPattern(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSequentialPattern(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeBase64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Real base64",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 without padding",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 with URL-safe chars",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b-_",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Generated secret",
|
||||
input: "xK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Simple password",
|
||||
input: "Password123!Password123!Password123!",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := looksLikeBase64(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("looksLikeBase64(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
secret string
|
||||
environment string
|
||||
shouldErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Too short - 20 chars",
|
||||
secret: "12345678901234567890",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "too short",
|
||||
},
|
||||
{
|
||||
name: "Minimum length - 32 chars (acceptable for dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains password",
|
||||
secret: "my-password-is-secure-123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - secret",
|
||||
secret: "secretsecretsecretsecretsecretsec",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains 12345",
|
||||
secret: "abcd12345efghijklmnopqrstuvwxyz",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - change",
|
||||
secret: "please-change-this-j8EJm9ZKnuTYxcVK",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - sample",
|
||||
secret: "sample-secret-j8EJm9ZKnuTYxcVKQ",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - abcd",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - 1234",
|
||||
secret: "12345678901234567890123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy secret",
|
||||
secret: "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpP",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Good secret - base64 style (dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Good secret - longer (dev)",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - too short (32 chars)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "too short for production",
|
||||
},
|
||||
{
|
||||
name: "Production - insufficient complexity",
|
||||
secret: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient complexity",
|
||||
},
|
||||
{
|
||||
name: "Production - low entropy pattern",
|
||||
secret: strings.Repeat("AbCd12!@", 8), // 64 chars but repetitive
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Production - good secret",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - excellent secret with padding",
|
||||
secret: "7mK2nP8sR4wT6xZ3bA5cxK7mN1oQ9uS4vY2zA6bxK7mN1oQ9uS4vY2zA6b+W0E=",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.environment)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Errorf("ValidateJWTSecret() expected error containing %q, got no error", tt.errContains)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateJWTSecret() error = %q, should contain %q", err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("ValidateJWTSecret() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret_EdgeCases(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
t.Run("Secret with mixed weak patterns", func(t *testing.T) {
|
||||
secret := "password123admin" // Contains multiple weak patterns
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err == nil {
|
||||
t.Error("Expected error for secret containing weak patterns, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at minimum length", func(t *testing.T) {
|
||||
// 32 characters exactly
|
||||
secret := "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx"
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 32-char secret in development, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at recommended length", func(t *testing.T) {
|
||||
// 64 characters exactly - using real random base64
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFir"
|
||||
err := validator.ValidateJWTSecret(secret, "production")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 64-char secret in production, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure validation is performant
|
||||
func BenchmarkCalculateShannonEntropy(b *testing.B) {
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
calculateShannonEntropy(secret)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateJWTSecret(b *testing.B) {
|
||||
validator := NewCredentialValidator()
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateJWTSecret(secret, "production")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package validator
|
||||
|
||||
// ProvideCredentialValidator provides a credential validator for dependency injection
|
||||
func ProvideCredentialValidator() CredentialValidator {
|
||||
return NewCredentialValidator()
|
||||
}
|
||||
108
cloud/maplefile-backend/pkg/storage/cache/cassandracache/cassandracache.go
vendored
Normal file
108
cloud/maplefile-backend/pkg/storage/cache/cassandracache/cassandracache.go
vendored
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// monorepo/cloud/maplefileapps-backend/pkg/storage/cache/cassandracache/cassandaracache.go
|
||||
package cassandracache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type CassandraCacher interface {
|
||||
Shutdown()
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
Session *gocql.Session
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewCassandraCacher(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
logger = logger.Named("CassandraCache")
|
||||
logger.Info("cassandra cache initialized")
|
||||
return &cache{
|
||||
Session: session,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *cache) Shutdown() {
|
||||
s.Logger.Info("cassandra cache shutting down...")
|
||||
s.Session.Close()
|
||||
}
|
||||
|
||||
func (s *cache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
var value []byte
|
||||
var expiresAt time.Time
|
||||
|
||||
query := `SELECT value, expires_at FROM pkg_cache_by_key_with_asc_expire_at WHERE key=?`
|
||||
err := s.Session.Query(query, key).WithContext(ctx).Consistency(gocql.LocalQuorum).Scan(&value, &expiresAt)
|
||||
|
||||
if err == gocql.ErrNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if expired in application code
|
||||
if time.Now().After(expiresAt) {
|
||||
// Entry is expired, delete it and return nil
|
||||
_ = s.Delete(ctx, key) // Clean up expired entry
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *cache) Set(ctx context.Context, key string, val []byte) error {
|
||||
expiresAt := time.Now().Add(24 * time.Hour) // Default 24 hour expiry
|
||||
return s.Session.Query(`INSERT INTO pkg_cache_by_key_with_asc_expire_at (key, expires_at, value) VALUES (?, ?, ?)`,
|
||||
key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (s *cache) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
expiresAt := time.Now().Add(expiry)
|
||||
return s.Session.Query(`INSERT INTO pkg_cache_by_key_with_asc_expire_at (key, expires_at, value) VALUES (?, ?, ?)`,
|
||||
key, expiresAt, val).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (s *cache) Delete(ctx context.Context, key string) error {
|
||||
return s.Session.Query(`DELETE FROM pkg_cache_by_key_with_asc_expire_at WHERE key=?`,
|
||||
key).WithContext(ctx).Consistency(gocql.LocalQuorum).Exec()
|
||||
}
|
||||
|
||||
func (s *cache) PurgeExpired(ctx context.Context) error {
|
||||
now := time.Now()
|
||||
|
||||
// Thanks to the index on expires_at, this query is efficient
|
||||
iter := s.Session.Query(`SELECT key FROM pkg_cache_by_key_with_asc_expire_at WHERE expires_at < ? ALLOW FILTERING`,
|
||||
now).WithContext(ctx).Iter()
|
||||
|
||||
var expiredKeys []string
|
||||
var key string
|
||||
for iter.Scan(&key) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete expired keys in batch
|
||||
if len(expiredKeys) > 0 {
|
||||
batch := s.Session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
|
||||
for _, expiredKey := range expiredKeys {
|
||||
batch.Query(`DELETE FROM pkg_cache_by_key_with_asc_expire_at WHERE key=?`, expiredKey)
|
||||
}
|
||||
return s.Session.ExecuteBatch(batch)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
11
cloud/maplefile-backend/pkg/storage/cache/cassandracache/provider.go
vendored
Normal file
11
cloud/maplefile-backend/pkg/storage/cache/cassandracache/provider.go
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package cassandracache
|
||||
|
||||
import (
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProvideCassandraCacher provides a Cassandra cache instance for Wire DI
|
||||
func ProvideCassandraCacher(session *gocql.Session, logger *zap.Logger) CassandraCacher {
|
||||
return NewCassandraCacher(session, logger)
|
||||
}
|
||||
17
cloud/maplefile-backend/pkg/storage/cache/twotiercache/provider.go
vendored
Normal file
17
cloud/maplefile-backend/pkg/storage/cache/twotiercache/provider.go
vendored
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
package twotiercache
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/cache/cassandracache"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/memory/redis"
|
||||
)
|
||||
|
||||
// ProvideTwoTierCache provides a two-tier cache instance for Wire DI
|
||||
func ProvideTwoTierCache(
|
||||
redisCache redis.Cacher,
|
||||
cassandraCache cassandracache.CassandraCacher,
|
||||
logger *zap.Logger,
|
||||
) TwoTierCacher {
|
||||
return NewTwoTierCache(redisCache, cassandraCache, logger)
|
||||
}
|
||||
106
cloud/maplefile-backend/pkg/storage/cache/twotiercache/twotiercache.go
vendored
Normal file
106
cloud/maplefile-backend/pkg/storage/cache/twotiercache/twotiercache.go
vendored
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
// monorepo/cloud/maplefileapps-backend/pkg/storage/cache/twotiercache/twotiercache.go
|
||||
package twotiercache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/cache/cassandracache"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/storage/memory/redis"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type TwoTierCacher interface {
|
||||
Shutdown(ctx context.Context)
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Set(ctx context.Context, key string, val []byte) error
|
||||
SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
PurgeExpired(ctx context.Context) error
|
||||
}
|
||||
|
||||
// twoTierCacheImpl: clean 2-layer (read-through write-through) cache
|
||||
//
|
||||
// L1: Redis (fast, in-memory)
|
||||
// L2: Cassandra (persistent)
|
||||
//
|
||||
// On Get: check Redis → then Cassandra → if found in Cassandra → populate Redis
|
||||
// On Set: write to both
|
||||
// On SetWithExpiry: write to both with expiry
|
||||
// On Delete: remove from both
|
||||
type twoTierCacheImpl struct {
|
||||
RedisCache redis.Cacher
|
||||
CassandraCache cassandracache.CassandraCacher
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewTwoTierCache(redisCache redis.Cacher, cassandraCache cassandracache.CassandraCacher, logger *zap.Logger) TwoTierCacher {
|
||||
logger = logger.Named("TwoTierCache")
|
||||
return &twoTierCacheImpl{
|
||||
RedisCache: redisCache,
|
||||
CassandraCache: cassandraCache,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.RedisCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
c.Logger.Debug("cache hit from Redis", zap.String("key", key))
|
||||
return val, nil
|
||||
}
|
||||
|
||||
val, err = c.CassandraCache.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val != nil {
|
||||
c.Logger.Debug("cache hit from Cassandra, writing back to Redis", zap.String("key", key))
|
||||
_ = c.RedisCache.Set(ctx, key, val)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) Set(ctx context.Context, key string, val []byte) error {
|
||||
if err := c.RedisCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.CassandraCache.Set(ctx, key, val); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) SetWithExpiry(ctx context.Context, key string, val []byte, expiry time.Duration) error {
|
||||
if err := c.RedisCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.CassandraCache.SetWithExpiry(ctx, key, val, expiry); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
if err := c.RedisCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.CassandraCache.Delete(ctx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) PurgeExpired(ctx context.Context) error {
|
||||
return c.CassandraCache.PurgeExpired(ctx)
|
||||
}
|
||||
|
||||
func (c *twoTierCacheImpl) Shutdown(ctx context.Context) {
|
||||
c.Logger.Info("two-tier cache shutting down...")
|
||||
c.RedisCache.Shutdown(ctx)
|
||||
c.CassandraCache.Shutdown()
|
||||
c.Logger.Info("two-tier cache shutdown complete")
|
||||
}
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/storage/database/cassandradb/cassandradb.go
|
||||
package cassandradb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// CassandraDB wraps the gocql session with additional functionality
|
||||
type CassandraDB struct {
|
||||
Session *gocql.Session
|
||||
config config.DatabaseConfig
|
||||
}
|
||||
|
||||
// gocqlLogger wraps zap logger to filter out noisy gocql warnings
|
||||
type gocqlLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Print implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Print(v ...interface{}) {
|
||||
msg := fmt.Sprint(v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
// These are harmless and occur due to Docker networking
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// Printf implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Printf(format string, v ...interface{}) {
|
||||
msg := fmt.Sprintf(format, v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// Println implements gocql's Logger interface
|
||||
func (l *gocqlLogger) Println(v ...interface{}) {
|
||||
msg := fmt.Sprintln(v...)
|
||||
|
||||
// Filter out noisy "invalid peer" warnings from Cassandra gossip
|
||||
if strings.Contains(msg, "Found invalid peer") {
|
||||
return
|
||||
}
|
||||
|
||||
// Log other messages at debug level
|
||||
l.logger.Debug(msg)
|
||||
}
|
||||
|
||||
// NewCassandraConnection establishes a connection to Cassandra cluster
|
||||
// Uses the simplified approach from MaplePress (working code)
|
||||
func NewCassandraConnection(cfg *config.Config, logger *zap.Logger) (*gocql.Session, error) {
|
||||
dbConfig := cfg.Database
|
||||
|
||||
logger.Info("⏳ Connecting to Cassandra...",
|
||||
zap.Strings("hosts", dbConfig.Hosts),
|
||||
zap.String("keyspace", dbConfig.Keyspace))
|
||||
|
||||
// Create cluster configuration - let gocql handle DNS resolution
|
||||
cluster := gocql.NewCluster(dbConfig.Hosts...)
|
||||
cluster.Keyspace = dbConfig.Keyspace
|
||||
cluster.Consistency = parseConsistency(dbConfig.Consistency)
|
||||
cluster.ProtoVersion = 4
|
||||
cluster.ConnectTimeout = dbConfig.ConnectTimeout
|
||||
cluster.Timeout = dbConfig.RequestTimeout
|
||||
cluster.NumConns = 2
|
||||
|
||||
// Set custom logger to filter out noisy warnings
|
||||
cluster.Logger = &gocqlLogger{logger: logger.Named("gocql")}
|
||||
|
||||
// Retry policy
|
||||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
||||
NumRetries: int(dbConfig.MaxRetryAttempts),
|
||||
Min: dbConfig.RetryDelay,
|
||||
Max: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Enable compression for better network efficiency
|
||||
cluster.Compressor = &gocql.SnappyCompressor{}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Cassandra: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✓ Cassandra connected",
|
||||
zap.String("consistency", dbConfig.Consistency),
|
||||
zap.Int("connections", cluster.NumConns))
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Close terminates the database connection
|
||||
func (db *CassandraDB) Close() {
|
||||
if db.Session != nil {
|
||||
db.Session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Health checks if the database connection is still alive
|
||||
func (db *CassandraDB) Health() error {
|
||||
// Quick health check using a simple query
|
||||
var timestamp time.Time
|
||||
err := db.Session.Query("SELECT now() FROM system.local").Scan(×tamp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
|
||||
// Validate that we got a reasonable timestamp (within last minute)
|
||||
now := time.Now()
|
||||
if timestamp.Before(now.Add(-time.Minute)) || timestamp.After(now.Add(time.Minute)) {
|
||||
return fmt.Errorf("health check returned suspicious timestamp: %v (current: %v)", timestamp, now)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseConsistency converts string consistency level to gocql.Consistency
|
||||
func parseConsistency(consistency string) gocql.Consistency {
|
||||
switch consistency {
|
||||
case "ANY":
|
||||
return gocql.Any
|
||||
case "ONE":
|
||||
return gocql.One
|
||||
case "TWO":
|
||||
return gocql.Two
|
||||
case "THREE":
|
||||
return gocql.Three
|
||||
case "QUORUM":
|
||||
return gocql.Quorum
|
||||
case "ALL":
|
||||
return gocql.All
|
||||
case "LOCAL_QUORUM":
|
||||
return gocql.LocalQuorum
|
||||
case "EACH_QUORUM":
|
||||
return gocql.EachQuorum
|
||||
case "LOCAL_ONE":
|
||||
return gocql.LocalOne
|
||||
default:
|
||||
return gocql.Quorum // Default to QUORUM
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/storage/database/cassandradb/migration.go
|
||||
package cassandradb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/cassandra"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// Migrator handles database schema migrations
|
||||
// This encapsulates all migration logic and makes it testable
|
||||
type Migrator struct {
|
||||
config config.DatabaseConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMigrator creates a new migration manager that works with fx dependency injection
|
||||
func NewMigrator(cfg *config.Configuration, logger *zap.Logger) *Migrator {
|
||||
return &Migrator{
|
||||
config: cfg.Database,
|
||||
logger: logger.Named("Migrator"),
|
||||
}
|
||||
}
|
||||
|
||||
// Up runs all pending migrations with dirty state recovery
|
||||
func (m *Migrator) Up() error {
|
||||
m.logger.Info("Creating migrator")
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
m.logger.Info("Checking migration version")
|
||||
version, dirty, err := migrateInstance.Version()
|
||||
if err != nil && err != migrate.ErrNilVersion {
|
||||
return fmt.Errorf("failed to get migration version: %w", err)
|
||||
}
|
||||
|
||||
if dirty {
|
||||
m.logger.Warn("Database is in dirty state, attempting to force clean state",
|
||||
zap.Uint("version", version))
|
||||
if err := migrateInstance.Force(int(version)); err != nil {
|
||||
return fmt.Errorf("failed to force clean migration state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := migrateInstance.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
// Get final version
|
||||
finalVersion, _, err := migrateInstance.Version()
|
||||
if err != nil && err != migrate.ErrNilVersion {
|
||||
m.logger.Warn("Could not get final migration version",
|
||||
zap.Error(err))
|
||||
} else if err != migrate.ErrNilVersion {
|
||||
m.logger.Info("Database migrations completed successfully",
|
||||
zap.Uint("version", finalVersion))
|
||||
} else {
|
||||
m.logger.Info("Database migrations completed successfully (no migrations applied)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Down rolls back the last migration
|
||||
// Useful for development and rollback scenarios
|
||||
func (m *Migrator) Down() error {
|
||||
migrate, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrate.Close()
|
||||
|
||||
if err := migrate.Steps(-1); err != nil {
|
||||
return fmt.Errorf("failed to rollback migration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Version returns the current migration version
|
||||
func (m *Migrator) Version() (uint, bool, error) {
|
||||
migrate, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return 0, false, fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrate.Close()
|
||||
|
||||
return migrate.Version()
|
||||
}
|
||||
|
||||
// ForceVersion forces the migration version (useful for fixing dirty states)
|
||||
func (m *Migrator) ForceVersion(version int) error {
|
||||
migrateInstance, err := m.createMigrate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
defer migrateInstance.Close()
|
||||
|
||||
if err := migrateInstance.Force(version); err != nil {
|
||||
return fmt.Errorf("failed to force version %d: %w", version, err)
|
||||
}
|
||||
|
||||
m.logger.Info("Successfully forced migration version",
|
||||
zap.Int("version", version))
|
||||
return nil
|
||||
}
|
||||
|
||||
// createMigrate creates a migrate instance with proper configuration
|
||||
func (m *Migrator) createMigrate() (*migrate.Migrate, error) {
|
||||
// Build Cassandra connection string
|
||||
// Format: cassandra://host:port/keyspace?consistency=level
|
||||
databaseURL := fmt.Sprintf("cassandra://%s/%s?consistency=%s",
|
||||
m.config.Hosts[0], // Use first host for migrations
|
||||
m.config.Keyspace,
|
||||
m.config.Consistency,
|
||||
)
|
||||
|
||||
// Add authentication if configured
|
||||
if m.config.Username != "" && m.config.Password != "" {
|
||||
databaseURL = fmt.Sprintf("cassandra://%s:%s@%s/%s?consistency=%s",
|
||||
m.config.Username,
|
||||
m.config.Password,
|
||||
m.config.Hosts[0],
|
||||
m.config.Keyspace,
|
||||
m.config.Consistency,
|
||||
)
|
||||
}
|
||||
|
||||
// Create migrate instance
|
||||
migrate, err := migrate.New(m.config.MigrationsPath, databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize migrate: %w", err)
|
||||
}
|
||||
|
||||
return migrate, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue