Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
96
cloud/maplefile-backend/pkg/security/apikey/generator.go
Normal file
96
cloud/maplefile-backend/pkg/security/apikey/generator.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrefixLive is the prefix for production API keys
|
||||
PrefixLive = "live_sk_"
|
||||
// PrefixTest is the prefix for test/sandbox API keys
|
||||
PrefixTest = "test_sk_"
|
||||
// KeyLength is the length of the random part (40 chars in base64url)
|
||||
KeyLength = 30 // 30 bytes = 40 base64url chars
|
||||
)
|
||||
|
||||
// Generator generates API keys
|
||||
type Generator interface {
|
||||
// Generate creates a new live API key
|
||||
Generate() (string, error)
|
||||
// GenerateTest creates a new test API key
|
||||
GenerateTest() (string, error)
|
||||
}
|
||||
|
||||
type generator struct{}
|
||||
|
||||
// NewGenerator creates a new API key generator
|
||||
func NewGenerator() Generator {
|
||||
return &generator{}
|
||||
}
|
||||
|
||||
// Generate creates a new live API key
|
||||
func (g *generator) Generate() (string, error) {
|
||||
return g.generateWithPrefix(PrefixLive)
|
||||
}
|
||||
|
||||
// GenerateTest creates a new test API key
|
||||
func (g *generator) GenerateTest() (string, error) {
|
||||
return g.generateWithPrefix(PrefixTest)
|
||||
}
|
||||
|
||||
func (g *generator) generateWithPrefix(prefix string) (string, error) {
|
||||
// Generate cryptographically secure random bytes
|
||||
b := make([]byte, KeyLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Encode to base64url (URL-safe, no padding)
|
||||
key := base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
// Remove any special chars and make lowercase for consistency
|
||||
key = strings.Map(func(r rune) rune {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
|
||||
return r
|
||||
}
|
||||
return -1 // Remove character
|
||||
}, key)
|
||||
|
||||
// Ensure we have at least 40 characters
|
||||
if len(key) < 40 {
|
||||
// Pad with additional random bytes if needed
|
||||
additional := make([]byte, 10)
|
||||
rand.Read(additional)
|
||||
extraKey := base64.RawURLEncoding.EncodeToString(additional)
|
||||
key += extraKey
|
||||
}
|
||||
|
||||
// Trim to exactly 40 characters
|
||||
key = key[:40]
|
||||
|
||||
return prefix + key, nil
|
||||
}
|
||||
|
||||
// ExtractPrefix extracts the prefix from an API key
|
||||
func ExtractPrefix(apiKey string) string {
|
||||
if len(apiKey) < 13 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[:13] // "live_sk_a1b2" or "test_sk_a1b2"
|
||||
}
|
||||
|
||||
// ExtractLastFour extracts the last 4 characters from an API key
|
||||
func ExtractLastFour(apiKey string) string {
|
||||
if len(apiKey) < 4 {
|
||||
return ""
|
||||
}
|
||||
return apiKey[len(apiKey)-4:]
|
||||
}
|
||||
|
||||
// IsValid checks if an API key has a valid format
|
||||
func IsValid(apiKey string) bool {
|
||||
return strings.HasPrefix(apiKey, PrefixLive) || strings.HasPrefix(apiKey, PrefixTest)
|
||||
}
|
||||
35
cloud/maplefile-backend/pkg/security/apikey/hasher.go
Normal file
35
cloud/maplefile-backend/pkg/security/apikey/hasher.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Hasher hashes and verifies API keys using SHA-256
|
||||
type Hasher interface {
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
Hash(apiKey string) string
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
Verify(apiKey string, hash string) bool
|
||||
}
|
||||
|
||||
type hasher struct{}
|
||||
|
||||
// NewHasher creates a new API key hasher
|
||||
func NewHasher() Hasher {
|
||||
return &hasher{}
|
||||
}
|
||||
|
||||
// Hash creates a deterministic SHA-256 hash of the API key
|
||||
func (h *hasher) Hash(apiKey string) string {
|
||||
hash := sha256.Sum256([]byte(apiKey))
|
||||
return base64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// Verify checks if the API key matches the hash using constant-time comparison
|
||||
// This prevents timing attacks
|
||||
func (h *hasher) Verify(apiKey string, expectedHash string) bool {
|
||||
actualHash := h.Hash(apiKey)
|
||||
return subtle.ConstantTimeCompare([]byte(actualHash), []byte(expectedHash)) == 1
|
||||
}
|
||||
11
cloud/maplefile-backend/pkg/security/apikey/provider.go
Normal file
11
cloud/maplefile-backend/pkg/security/apikey/provider.go
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
package apikey
|
||||
|
||||
// ProvideGenerator provides an API key generator for dependency injection
|
||||
func ProvideGenerator() Generator {
|
||||
return NewGenerator()
|
||||
}
|
||||
|
||||
// ProvideHasher provides an API key hasher for dependency injection
|
||||
func ProvideHasher() Hasher {
|
||||
return NewHasher()
|
||||
}
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
// Package benchmark provides performance benchmarks for memguard security operations.
|
||||
package benchmark
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// BenchmarkPlainStringAllocation benchmarks plain string allocation.
|
||||
func BenchmarkPlainStringAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s := "this is a test string with sensitive data"
|
||||
_ = s
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSecureStringAllocation benchmarks SecureString allocation and cleanup.
|
||||
func BenchmarkSecureStringAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
s, err := securestring.NewSecureString("this is a test string with sensitive data")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPlainBytesAllocation benchmarks plain byte slice allocation.
|
||||
func BenchmarkPlainBytesAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
_ = data
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSecureBytesAllocation benchmarks SecureBytes allocation and cleanup.
|
||||
func BenchmarkSecureBytesAllocation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
sb, err := securebytes.NewSecureBytes(data)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
sb.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPasswordHashing_Plain benchmarks password hashing without memguard.
|
||||
func BenchmarkPasswordHashing_Plain(b *testing.B) {
|
||||
password := []byte("test_password_12345")
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = argon2.IDKey(password, salt, 3, 64*1024, 4, 32)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPasswordHashing_Secure benchmarks password hashing with memguard wiping.
|
||||
func BenchmarkPasswordHashing_Secure(b *testing.B) {
|
||||
password, err := securestring.NewSecureString("test_password_12345")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer password.Wipe()
|
||||
|
||||
salt := make([]byte, 16)
|
||||
rand.Read(salt)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
passwordBytes := password.Bytes()
|
||||
hash := argon2.IDKey(passwordBytes, salt, 3, 64*1024, 4, 32)
|
||||
memguard.WipeBytes(hash)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMemguardWipeBytes benchmarks the memguard.WipeBytes operation.
|
||||
func BenchmarkMemguardWipeBytes(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMemguardWipeBytes_Large benchmarks wiping larger byte slices.
|
||||
func BenchmarkMemguardWipeBytes_Large(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
data := make([]byte, 4096)
|
||||
rand.Read(data)
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLockedBuffer_Create benchmarks creating a memguard LockedBuffer.
|
||||
func BenchmarkLockedBuffer_Create(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := memguard.NewBuffer(32)
|
||||
buf.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLockedBuffer_FromBytes benchmarks creating a LockedBuffer from bytes.
|
||||
func BenchmarkLockedBuffer_FromBytes(b *testing.B) {
|
||||
data := make([]byte, 32)
|
||||
rand.Read(data)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := memguard.NewBufferFromBytes(data)
|
||||
buf.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkJWTTokenGeneration_Plain simulates JWT token generation without security.
|
||||
func BenchmarkJWTTokenGeneration_Plain(b *testing.B) {
|
||||
secret := make([]byte, 32)
|
||||
rand.Read(secret)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate token signing
|
||||
_ = secret
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkJWTTokenGeneration_Secure simulates JWT token generation with memguard.
|
||||
func BenchmarkJWTTokenGeneration_Secure(b *testing.B) {
|
||||
secret := make([]byte, 32)
|
||||
rand.Read(secret)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
secretCopy := make([]byte, len(secret))
|
||||
copy(secretCopy, secret)
|
||||
// Simulate token signing
|
||||
_ = secretCopy
|
||||
memguard.WipeBytes(secretCopy)
|
||||
}
|
||||
}
|
||||
|
||||
// Run benchmarks with:
|
||||
// go test -bench=. -benchmem ./pkg/security/benchmark/
|
||||
76
cloud/maplefile-backend/pkg/security/blacklist/blacklist.go
Normal file
76
cloud/maplefile-backend/pkg/security/blacklist/blacklist.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package blacklist
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Provider provides an interface for abstracting time.
|
||||
type Provider interface {
|
||||
IsBannedIPAddress(ipAddress string) bool
|
||||
IsBannedURL(url string) bool
|
||||
}
|
||||
|
||||
type blacklistProvider struct {
|
||||
bannedIPAddresses map[string]bool
|
||||
bannedURLs map[string]bool
|
||||
}
|
||||
|
||||
// readBlacklistFileContent reads the contents of the blacklist file and returns
|
||||
// the list of banned items (ex: IP, URLs, etc).
|
||||
func readBlacklistFileContent(filePath string) ([]string, error) {
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("file %s does not exist", filePath)
|
||||
}
|
||||
|
||||
// Read the file contents
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
// Parse the JSON content as a list of IPs
|
||||
var ips []string
|
||||
if err := json.Unmarshal(data, &ips); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON file %s: %v", filePath, err)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// NewProvider Provider contructor that returns the default time provider.
|
||||
func NewProvider() Provider {
|
||||
bannedIPAddresses := make(map[string]bool)
|
||||
bannedIPAddressesFilePath := "static/blacklist/ips.json"
|
||||
ips, err := readBlacklistFileContent(bannedIPAddressesFilePath)
|
||||
if err == nil { // Aka: if the file exists...
|
||||
for _, ip := range ips {
|
||||
bannedIPAddresses[ip] = true
|
||||
}
|
||||
}
|
||||
|
||||
bannedURLs := make(map[string]bool)
|
||||
bannedURLsFilePath := "static/blacklist/urls.json"
|
||||
urls, err := readBlacklistFileContent(bannedURLsFilePath)
|
||||
if err == nil { // Aka: if the file exists...
|
||||
for _, url := range urls {
|
||||
bannedURLs[url] = true
|
||||
}
|
||||
}
|
||||
|
||||
return blacklistProvider{
|
||||
bannedIPAddresses: bannedIPAddresses,
|
||||
bannedURLs: bannedURLs,
|
||||
}
|
||||
}
|
||||
|
||||
func (p blacklistProvider) IsBannedIPAddress(ipAddress string) bool {
|
||||
return p.bannedIPAddresses[ipAddress]
|
||||
}
|
||||
|
||||
func (p blacklistProvider) IsBannedURL(url string) bool {
|
||||
return p.bannedURLs[url]
|
||||
}
|
||||
132
cloud/maplefile-backend/pkg/security/blacklist/blacklist_test.go
Normal file
132
cloud/maplefile-backend/pkg/security/blacklist/blacklist_test.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package blacklist
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func createTempFile(t *testing.T, content string) string {
|
||||
tmpfile, err := os.CreateTemp("", "blacklist*.json")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte(content), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return tmpfile.Name()
|
||||
}
|
||||
|
||||
func TestReadBlacklistFileContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantItems []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid json",
|
||||
content: `["192.168.1.1", "10.0.0.1"]`,
|
||||
wantItems: []string{"192.168.1.1", "10.0.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
content: `[]`,
|
||||
wantItems: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
content: `invalid json`,
|
||||
wantItems: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpfile := createTempFile(t, tt.content)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
items, err := readBlacklistFileContent(tmpfile)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, items)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantItems, items)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("nonexistent file", func(t *testing.T) {
|
||||
_, err := readBlacklistFileContent("nonexistent.json")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
// Create temporary blacklist files
|
||||
ipsContent := `["192.168.1.1", "10.0.0.1"]`
|
||||
urlsContent := `["example.com", "malicious.com"]`
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "blacklist")
|
||||
assert.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = os.MkdirAll(filepath.Join(tmpDir, "static/blacklist"), 0755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "static/blacklist/ips.json"), []byte(ipsContent), 0644)
|
||||
assert.NoError(t, err)
|
||||
err = os.WriteFile(filepath.Join(tmpDir, "static/blacklist/urls.json"), []byte(urlsContent), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Change working directory temporarily
|
||||
originalWd, err := os.Getwd()
|
||||
assert.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
assert.NoError(t, err)
|
||||
defer os.Chdir(originalWd)
|
||||
|
||||
provider := NewProvider()
|
||||
assert.NotNil(t, provider)
|
||||
|
||||
// Test IP blacklist
|
||||
assert.True(t, provider.IsBannedIPAddress("192.168.1.1"))
|
||||
assert.True(t, provider.IsBannedIPAddress("10.0.0.1"))
|
||||
assert.False(t, provider.IsBannedIPAddress("172.16.0.1"))
|
||||
|
||||
// Test URL blacklist
|
||||
assert.True(t, provider.IsBannedURL("example.com"))
|
||||
assert.True(t, provider.IsBannedURL("malicious.com"))
|
||||
assert.False(t, provider.IsBannedURL("safe.com"))
|
||||
}
|
||||
|
||||
func TestIsBannedIPAddress(t *testing.T) {
|
||||
provider := blacklistProvider{
|
||||
bannedIPAddresses: map[string]bool{
|
||||
"192.168.1.1": true,
|
||||
"10.0.0.1": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, provider.IsBannedIPAddress("192.168.1.1"))
|
||||
assert.True(t, provider.IsBannedIPAddress("10.0.0.1"))
|
||||
assert.False(t, provider.IsBannedIPAddress("172.16.0.1"))
|
||||
}
|
||||
|
||||
func TestIsBannedURL(t *testing.T) {
|
||||
provider := blacklistProvider{
|
||||
bannedURLs: map[string]bool{
|
||||
"example.com": true,
|
||||
"malicious.com": true,
|
||||
},
|
||||
}
|
||||
|
||||
assert.True(t, provider.IsBannedURL("example.com"))
|
||||
assert.True(t, provider.IsBannedURL("malicious.com"))
|
||||
assert.False(t, provider.IsBannedURL("safe.com"))
|
||||
}
|
||||
170
cloud/maplefile-backend/pkg/security/clientip/extractor.go
Normal file
170
cloud/maplefile-backend/pkg/security/clientip/extractor.go
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// Extractor provides secure client IP address extraction
|
||||
// CWE-348: Prevents X-Forwarded-For header spoofing by validating trusted proxies
|
||||
type Extractor struct {
|
||||
trustedProxies []*net.IPNet
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewExtractor creates a new IP extractor with trusted proxy configuration
|
||||
// trustedProxyCIDRs should contain CIDR blocks of trusted reverse proxies
|
||||
// Example: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
|
||||
func NewExtractor(trustedProxyCIDRs []string, logger *zap.Logger) (*Extractor, error) {
|
||||
var trustedProxies []*net.IPNet
|
||||
|
||||
for _, cidr := range trustedProxyCIDRs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse trusted proxy CIDR",
|
||||
zap.String("cidr", cidr),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
trustedProxies = append(trustedProxies, ipNet)
|
||||
}
|
||||
|
||||
logger.Info("client IP extractor initialized",
|
||||
zap.Int("trusted_proxy_ranges", len(trustedProxies)))
|
||||
|
||||
return &Extractor{
|
||||
trustedProxies: trustedProxies,
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewDefaultExtractor creates an extractor with no trusted proxies
|
||||
// This is safe for direct connections but will ignore X-Forwarded-For headers
|
||||
func NewDefaultExtractor(logger *zap.Logger) *Extractor {
|
||||
logger.Warn("client IP extractor initialized with NO trusted proxies - X-Forwarded-For will be ignored")
|
||||
return &Extractor{
|
||||
trustedProxies: []*net.IPNet{},
|
||||
logger: logger.Named("client-ip-extractor"),
|
||||
}
|
||||
}
|
||||
|
||||
// Extract extracts the real client IP address from the HTTP request
|
||||
// CWE-348: Secure implementation that prevents header spoofing
|
||||
func (e *Extractor) Extract(r *http.Request) string {
|
||||
// Step 1: Get the immediate connection's remote address
|
||||
remoteAddr := r.RemoteAddr
|
||||
|
||||
// Remove port from RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||
remoteIP := e.stripPort(remoteAddr)
|
||||
|
||||
// Step 2: Parse the remote IP
|
||||
parsedRemoteIP := net.ParseIP(remoteIP)
|
||||
if parsedRemoteIP == nil {
|
||||
e.logger.Warn("failed to parse remote IP address",
|
||||
zap.String("remote_addr", validation.MaskIP(remoteAddr)))
|
||||
return remoteIP // Return as-is if we can't parse it
|
||||
}
|
||||
|
||||
// Step 3: Check if the immediate connection is from a trusted proxy
|
||||
if !e.isTrustedProxy(parsedRemoteIP) {
|
||||
// NOT from a trusted proxy - do NOT trust X-Forwarded-For header
|
||||
// This prevents clients from spoofing their IP by setting the header
|
||||
e.logger.Debug("remote IP is not a trusted proxy, using RemoteAddr",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 4: Remote IP is trusted, check X-Forwarded-For header
|
||||
// Format: "client, proxy1, proxy2" (leftmost is original client)
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
// No X-Forwarded-For header, use RemoteAddr
|
||||
e.logger.Debug("no X-Forwarded-For header from trusted proxy",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Step 5: Parse X-Forwarded-For header
|
||||
// Take the FIRST IP (leftmost) which should be the original client
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) == 0 {
|
||||
e.logger.Debug("empty X-Forwarded-For header",
|
||||
zap.String("remote_ip", validation.MaskIP(remoteIP)))
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Get the first IP and trim whitespace
|
||||
clientIP := strings.TrimSpace(ips[0])
|
||||
|
||||
// Step 6: Validate the client IP
|
||||
parsedClientIP := net.ParseIP(clientIP)
|
||||
if parsedClientIP == nil {
|
||||
e.logger.Warn("invalid IP in X-Forwarded-For header",
|
||||
zap.String("xff", xff),
|
||||
zap.String("client_ip", validation.MaskIP(clientIP)))
|
||||
return remoteIP // Fall back to RemoteAddr
|
||||
}
|
||||
|
||||
e.logger.Debug("extracted client IP from X-Forwarded-For",
|
||||
zap.String("client_ip", validation.MaskIP(clientIP)),
|
||||
zap.String("remote_proxy", validation.MaskIP(remoteIP)),
|
||||
zap.String("xff_chain", xff))
|
||||
|
||||
return clientIP
|
||||
}
|
||||
|
||||
// ExtractOrDefault extracts the client IP or returns a default value
|
||||
func (e *Extractor) ExtractOrDefault(r *http.Request, defaultIP string) string {
|
||||
ip := e.Extract(r)
|
||||
if ip == "" {
|
||||
return defaultIP
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isTrustedProxy checks if an IP is in the trusted proxy list
|
||||
func (e *Extractor) isTrustedProxy(ip net.IP) bool {
|
||||
for _, ipNet := range e.trustedProxies {
|
||||
if ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort removes the port from an address string
|
||||
// Handles both IPv4 (192.168.1.1:8080) and IPv6 ([::1]:8080) formats
|
||||
func (e *Extractor) stripPort(addr string) string {
|
||||
// For IPv6, check for bracket format [IP]:port
|
||||
if strings.HasPrefix(addr, "[") {
|
||||
// IPv6 format: [::1]:8080
|
||||
if idx := strings.LastIndex(addr, "]:"); idx != -1 {
|
||||
return addr[1:idx] // Extract IP between [ and ]
|
||||
}
|
||||
// Malformed IPv6 address
|
||||
return addr
|
||||
}
|
||||
|
||||
// For IPv4, split on last colon
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
}
|
||||
|
||||
// No port found
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetTrustedProxyCount returns the number of configured trusted proxy ranges
|
||||
func (e *Extractor) GetTrustedProxyCount() int {
|
||||
return len(e.trustedProxies)
|
||||
}
|
||||
|
||||
// HasTrustedProxies returns true if any trusted proxies are configured
|
||||
func (e *Extractor) HasTrustedProxies() bool {
|
||||
return len(e.trustedProxies) > 0
|
||||
}
|
||||
19
cloud/maplefile-backend/pkg/security/clientip/provider.go
Normal file
19
cloud/maplefile-backend/pkg/security/clientip/provider.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package clientip
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideExtractor provides a client IP extractor configured from the application config
|
||||
func ProvideExtractor(cfg *config.Config, logger *zap.Logger) (*Extractor, error) {
|
||||
// If no trusted proxies configured, use default (no X-Forwarded-For trust)
|
||||
if len(cfg.Security.TrustedProxies) == 0 {
|
||||
logger.Info("no trusted proxies configured - X-Forwarded-For headers will be ignored for security")
|
||||
return NewDefaultExtractor(logger), nil
|
||||
}
|
||||
|
||||
// Create extractor with trusted proxies
|
||||
return NewExtractor(cfg.Security.TrustedProxies, logger)
|
||||
}
|
||||
32
cloud/maplefile-backend/pkg/security/crypto/constants.go
Normal file
32
cloud/maplefile-backend/pkg/security/crypto/constants.go
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
package crypto
|
||||
|
||||
// Constants to ensure compatibility between Go and JavaScript
|
||||
const (
|
||||
// Key sizes
|
||||
MasterKeySize = 32 // 256-bit
|
||||
KeyEncryptionKeySize = 32
|
||||
CollectionKeySize = 32
|
||||
FileKeySize = 32
|
||||
RecoveryKeySize = 32
|
||||
|
||||
// ChaCha20-Poly1305 constants (updated from XSalsa20-Poly1305)
|
||||
NonceSize = 12 // ChaCha20-Poly1305 nonce size (changed from 24)
|
||||
PublicKeySize = 32
|
||||
PrivateKeySize = 32
|
||||
SealedBoxOverhead = 16
|
||||
|
||||
// Legacy naming for backward compatibility
|
||||
SecretBoxNonceSize = NonceSize
|
||||
|
||||
// Argon2 parameters - must match between platforms
|
||||
Argon2IDAlgorithm = "argon2id"
|
||||
Argon2MemLimit = 67108864 // 64 MB
|
||||
Argon2OpsLimit = 4
|
||||
Argon2Parallelism = 1
|
||||
Argon2KeySize = 32
|
||||
Argon2SaltSize = 16
|
||||
|
||||
// Encryption algorithm identifiers
|
||||
ChaCha20Poly1305Algorithm = "chacha20poly1305" // Primary algorithm
|
||||
XSalsa20Poly1305Algorithm = "xsalsa20poly1305" // Legacy algorithm (deprecated)
|
||||
)
|
||||
174
cloud/maplefile-backend/pkg/security/crypto/encrypt.go
Normal file
174
cloud/maplefile-backend/pkg/security/crypto/encrypt.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
// EncryptData represents encrypted data with its nonce
|
||||
type EncryptData struct {
|
||||
Ciphertext []byte
|
||||
Nonce []byte
|
||||
}
|
||||
|
||||
// EncryptWithSecretKey encrypts data with a symmetric key using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretbox_easy() but using ChaCha20-Poly1305
|
||||
func EncryptWithSecretKey(data, key []byte) (*EncryptData, error) {
|
||||
if len(key) != MasterKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", MasterKeySize, len(key))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Generate nonce
|
||||
nonce, err := GenerateRandomNonce()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext := cipher.Seal(nil, nonce, data, nil)
|
||||
|
||||
return &EncryptData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecryptWithSecretKey decrypts data with a symmetric key using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretbox_open_easy() but using ChaCha20-Poly1305
|
||||
func DecryptWithSecretKey(encryptedData *EncryptData, key []byte) ([]byte, error) {
|
||||
if len(key) != MasterKeySize {
|
||||
return nil, fmt.Errorf("invalid key size: expected %d, got %d", MasterKeySize, len(key))
|
||||
}
|
||||
|
||||
if len(encryptedData.Nonce) != NonceSize {
|
||||
return nil, fmt.Errorf("invalid nonce size: expected %d, got %d", NonceSize, len(encryptedData.Nonce))
|
||||
}
|
||||
|
||||
// Create ChaCha20-Poly1305 cipher
|
||||
cipher, err := chacha20poly1305.New(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := cipher.Open(nil, encryptedData.Nonce, encryptedData.Ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptWithPublicKey encrypts data with a public key using NaCl box (XSalsa20-Poly1305)
|
||||
// Note: Asymmetric encryption still uses NaCl box for compatibility
|
||||
// JavaScript equivalent: sodium.crypto_box_seal()
|
||||
func EncryptWithPublicKey(data, recipientPublicKey []byte) ([]byte, error) {
|
||||
if len(recipientPublicKey) != PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size: expected %d, got %d", PublicKeySize, len(recipientPublicKey))
|
||||
}
|
||||
|
||||
// Convert to fixed-size array
|
||||
var pubKeyArray [32]byte
|
||||
copy(pubKeyArray[:], recipientPublicKey)
|
||||
|
||||
// Generate nonce for box encryption (24 bytes for NaCl box)
|
||||
var nonce [24]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// For sealed box, we need to use SealAnonymous
|
||||
sealed, err := box.SealAnonymous(nil, data, &pubKeyArray, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to seal data: %w", err)
|
||||
}
|
||||
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// DecryptWithPrivateKey decrypts data with a private key using NaCl box
|
||||
// Note: Asymmetric encryption still uses NaCl box for compatibility
|
||||
// JavaScript equivalent: sodium.crypto_box_seal_open()
|
||||
// SECURITY: Key arrays are wiped from memory after use to prevent key extraction via memory dumps.
|
||||
func DecryptWithPrivateKey(encryptedData, publicKey, privateKey []byte) ([]byte, error) {
|
||||
if len(privateKey) != PrivateKeySize {
|
||||
return nil, fmt.Errorf("invalid private key size: expected %d, got %d", PrivateKeySize, len(privateKey))
|
||||
}
|
||||
if len(publicKey) != PublicKeySize {
|
||||
return nil, fmt.Errorf("invalid public key size: expected %d, got %d", PublicKeySize, len(publicKey))
|
||||
}
|
||||
|
||||
// Convert to fixed-size arrays
|
||||
var pubKeyArray [32]byte
|
||||
copy(pubKeyArray[:], publicKey)
|
||||
defer memguard.WipeBytes(pubKeyArray[:]) // SECURITY: Wipe public key array
|
||||
|
||||
var privKeyArray [32]byte
|
||||
copy(privKeyArray[:], privateKey)
|
||||
defer memguard.WipeBytes(privKeyArray[:]) // SECURITY: Wipe private key array
|
||||
|
||||
// Decrypt using OpenAnonymous for sealed box
|
||||
plaintext, ok := box.OpenAnonymous(nil, encryptedData, &pubKeyArray, &privKeyArray)
|
||||
if !ok {
|
||||
return nil, errors.New("decryption failed: invalid keys or corrupted data")
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptFileChunked encrypts a file in chunks using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretstream_* but using ChaCha20-Poly1305
|
||||
// SECURITY: Plaintext data is wiped from memory after encryption.
|
||||
func EncryptFileChunked(reader io.Reader, key []byte) ([]byte, error) {
|
||||
// This would be a more complex implementation using
|
||||
// chunked encryption. For brevity, we'll use a simpler approach
|
||||
// that reads the entire file into memory first.
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read data: %w", err)
|
||||
}
|
||||
defer memguard.WipeBytes(data) // SECURITY: Wipe plaintext after encryption
|
||||
|
||||
encData, err := EncryptWithSecretKey(data, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt data: %w", err)
|
||||
}
|
||||
|
||||
// Combine nonce and ciphertext
|
||||
result := make([]byte, len(encData.Nonce)+len(encData.Ciphertext))
|
||||
copy(result, encData.Nonce)
|
||||
copy(result[len(encData.Nonce):], encData.Ciphertext)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptFileChunked decrypts a chunked encrypted file using ChaCha20-Poly1305
|
||||
// JavaScript equivalent: sodium.crypto_secretstream_* but using ChaCha20-Poly1305
|
||||
func DecryptFileChunked(encryptedData, key []byte) ([]byte, error) {
|
||||
// Split nonce and ciphertext
|
||||
if len(encryptedData) < NonceSize {
|
||||
return nil, fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", NonceSize, len(encryptedData))
|
||||
}
|
||||
|
||||
nonce := encryptedData[:NonceSize]
|
||||
ciphertext := encryptedData[NonceSize:]
|
||||
|
||||
// Decrypt
|
||||
return DecryptWithSecretKey(&EncryptData{
|
||||
Ciphertext: ciphertext,
|
||||
Nonce: nonce,
|
||||
}, key)
|
||||
}
|
||||
117
cloud/maplefile-backend/pkg/security/crypto/keys.go
Normal file
117
cloud/maplefile-backend/pkg/security/crypto/keys.go
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"github.com/tyler-smith/go-bip39"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
// GenerateRandomKey generates a new random key using crypto_secretbox_keygen
|
||||
// JavaScript equivalent: sodium.randombytes_buf(crypto.MasterKeySize)
|
||||
func GenerateRandomKey(size int) ([]byte, error) {
|
||||
if size <= 0 {
|
||||
return nil, errors.New("key size must be positive")
|
||||
}
|
||||
|
||||
key := make([]byte, size)
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a public/private key pair using NaCl box
|
||||
// JavaScript equivalent: sodium.crypto_box_keypair()
|
||||
func GenerateKeyPair() (publicKey, privateKey []byte, verificationID string, err error) {
|
||||
pubKey, privKey, err := box.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to generate key pair: %w", err)
|
||||
}
|
||||
|
||||
// Convert from fixed-size arrays to slices
|
||||
publicKey = pubKey[:]
|
||||
privateKey = privKey[:]
|
||||
|
||||
// Generate deterministic verification ID
|
||||
verificationID, err = GenerateVerificationID(publicKey[:])
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to generate verification ID: %w", err)
|
||||
}
|
||||
|
||||
return publicKey, privateKey, verificationID, nil
|
||||
}
|
||||
|
||||
// DeriveKeyFromPassword derives a key encryption key from a password using Argon2id
|
||||
// JavaScript equivalent: sodium.crypto_pwhash()
|
||||
// SECURITY: Password bytes are wiped from memory after key derivation.
|
||||
func DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != Argon2SaltSize {
|
||||
return nil, fmt.Errorf("invalid salt size: expected %d, got %d", Argon2SaltSize, len(salt))
|
||||
}
|
||||
|
||||
// Convert password to bytes for wiping
|
||||
passwordBytes := []byte(password)
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after use
|
||||
|
||||
// These parameters must match between Go and JavaScript
|
||||
key := argon2.IDKey(
|
||||
passwordBytes,
|
||||
salt,
|
||||
Argon2OpsLimit,
|
||||
Argon2MemLimit,
|
||||
Argon2Parallelism,
|
||||
Argon2KeySize,
|
||||
)
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateRandomNonce generates a random nonce for ChaCha20-Poly1305 encryption operations
|
||||
// JavaScript equivalent: sodium.randombytes_buf(crypto.NonceSize)
|
||||
func GenerateRandomNonce() ([]byte, error) {
|
||||
nonce := make([]byte, NonceSize) // NonceSize is now 12 for ChaCha20-Poly1305
|
||||
_, err := io.ReadFull(rand.Reader, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
// GenerateVerificationID creates a human-readable representation of a public key
|
||||
// JavaScript equivalent: The same BIP39 mnemonic implementation
|
||||
// Generate VerificationID from public key (deterministic)
|
||||
func GenerateVerificationID(publicKey []byte) (string, error) {
|
||||
if len(publicKey) == 0 {
|
||||
return "", errors.New("public key cannot be empty")
|
||||
}
|
||||
|
||||
// 1. Hash the public key with SHA256
|
||||
hash := sha256.Sum256(publicKey)
|
||||
|
||||
// 2. Use the hash as entropy for BIP39
|
||||
mnemonic, err := bip39.NewMnemonic(hash[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate verification ID: %w", err)
|
||||
}
|
||||
|
||||
return mnemonic, nil
|
||||
}
|
||||
|
||||
// VerifyVerificationID checks if a verification ID matches a public key
|
||||
func VerifyVerificationID(publicKey []byte, verificationID string) bool {
|
||||
expectedID, err := GenerateVerificationID(publicKey)
|
||||
if err != nil {
|
||||
log.Printf("pkg.crypto.VerifyVerificationID - Failed to generate verification ID with error: %v\n", err)
|
||||
return false
|
||||
}
|
||||
return expectedID == verificationID
|
||||
}
|
||||
45
cloud/maplefile-backend/pkg/security/hash/hash.go
Normal file
45
cloud/maplefile-backend/pkg/security/hash/hash.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Package hash provides secure hashing utilities for tokens and sensitive data.
|
||||
// These utilities are used to hash tokens before storing them as cache keys,
|
||||
// preventing token leakage through cache key inspection.
|
||||
package hash
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// HashToken creates a SHA-256 hash of a token for use as a cache key.
|
||||
// This prevents token leakage via cache key inspection.
|
||||
// The input token bytes are wiped after hashing.
|
||||
func HashToken(token string) string {
|
||||
tokenBytes := []byte(token)
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
hash := sha256.Sum256(tokenBytes)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// HashBytes creates a SHA-256 hash of byte data.
|
||||
// If wipeInput is true, the input bytes are wiped after hashing.
|
||||
func HashBytes(data []byte, wipeInput bool) string {
|
||||
if wipeInput {
|
||||
defer memguard.WipeBytes(data)
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// HashTokenToBytes creates a SHA-256 hash and returns the raw bytes.
|
||||
// The input token bytes are wiped after hashing.
|
||||
func HashTokenToBytes(token string) []byte {
|
||||
tokenBytes := []byte(token)
|
||||
defer memguard.WipeBytes(tokenBytes)
|
||||
|
||||
hash := sha256.Sum256(tokenBytes)
|
||||
result := make([]byte, len(hash))
|
||||
copy(result, hash[:])
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/ipcountryblocker/ipcountryblocker.go
|
||||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// Provider defines the interface for IP-based country blocking operations.
|
||||
// It provides methods to check if an IP or country is blocked and to retrieve
|
||||
// country codes for given IP addresses.
|
||||
type Provider interface {
|
||||
// IsBlockedCountry checks if a country is in the blocked list.
|
||||
// isoCode must be an ISO 3166-1 alpha-2 country code.
|
||||
IsBlockedCountry(isoCode string) bool
|
||||
|
||||
// IsBlockedIP determines if an IP address originates from a blocked country.
|
||||
// Returns false for nil IP addresses or if country lookup fails.
|
||||
IsBlockedIP(ctx context.Context, ip net.IP) bool
|
||||
|
||||
// GetCountryCode returns the ISO 3166-1 alpha-2 country code for an IP address.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
GetCountryCode(ctx context.Context, ip net.IP) (string, error)
|
||||
|
||||
// Close releases resources associated with the provider.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// provider implements the Provider interface using MaxMind's GeoIP2 database.
|
||||
type provider struct {
|
||||
db *geoip2.Reader
|
||||
blockedCountries map[string]struct{} // Uses empty struct to optimize memory
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex // Protects concurrent access to blockedCountries
|
||||
}
|
||||
|
||||
// NewProvider creates a new IP country blocking provider using the provided configuration.
|
||||
// It initializes the GeoIP2 database and sets up the blocked countries list.
|
||||
// Fatally crashes the entire application if the database cannot be opened.
|
||||
func NewProvider(cfg *config.Configuration, logger *zap.Logger) Provider {
|
||||
db, err := geoip2.Open(cfg.Security.GeoLiteDBPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open GeoLite2 DB: %v", err)
|
||||
}
|
||||
|
||||
blocked := make(map[string]struct{}, len(cfg.Security.BannedCountries))
|
||||
for _, country := range cfg.Security.BannedCountries {
|
||||
blocked[country] = struct{}{}
|
||||
}
|
||||
|
||||
logger.Debug("ip blocker initialized",
|
||||
zap.String("db_path", cfg.Security.GeoLiteDBPath),
|
||||
zap.Any("blocked_countries", cfg.Security.BannedCountries))
|
||||
|
||||
return &provider{
|
||||
db: db,
|
||||
blockedCountries: blocked,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedCountry checks if a country code exists in the blocked countries map.
|
||||
// Thread-safe through RLock.
|
||||
func (p *provider) IsBlockedCountry(isoCode string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
_, exists := p.blockedCountries[isoCode]
|
||||
return exists
|
||||
}
|
||||
|
||||
// IsBlockedIP performs a country lookup for the IP and checks if it's blocked.
|
||||
// Returns false for nil IPs or failed lookups to fail safely.
|
||||
func (p *provider) IsBlockedIP(ctx context.Context, ip net.IP) bool {
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
code, err := p.GetCountryCode(ctx, ip)
|
||||
if err != nil {
|
||||
// Developers Note:
|
||||
// Comment this console log as it contributes a `noisy` server log.
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
// p.logger.WarnContext(ctx, "failed to get country code",
|
||||
// zap.Any("ip", ip),
|
||||
// zap.Any("error", err))
|
||||
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
||||
// Developers Note:
|
||||
// If the country d.n.e. exist that means we will return with `false`
|
||||
// indicating this IP address is allowed to access our server. If this
|
||||
// is concerning then you might set this to `true` to block on all
|
||||
// IP address which are not categorized by country.
|
||||
return false
|
||||
}
|
||||
|
||||
return p.IsBlockedCountry(code)
|
||||
}
|
||||
|
||||
// GetCountryCode performs a GeoIP2 database lookup to determine an IP's country.
|
||||
// Returns an error if the lookup fails or no country is found.
|
||||
func (p *provider) GetCountryCode(ctx context.Context, ip net.IP) (string, error) {
|
||||
record, err := p.db.Country(ip)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("lookup country: %w", err)
|
||||
}
|
||||
|
||||
if record == nil || record.Country.IsoCode == "" {
|
||||
return "", fmt.Errorf("no country found for IP: %s", validation.MaskIP(ip.String()))
|
||||
}
|
||||
|
||||
return record.Country.IsoCode, nil
|
||||
}
|
||||
|
||||
// Close cleanly shuts down the GeoIP2 database connection.
|
||||
func (p *provider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/ipcountryblocker/ipcountryblocker_test.go
|
||||
package ipcountryblocker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// testProvider is a test-specific wrapper that allows access to internal fields
|
||||
// of the provider struct for verification in tests. This is a common pattern
|
||||
// when you need to test internal state while keeping the production interface clean.
|
||||
type testProvider struct {
|
||||
Provider // Embedded interface for normal operations
|
||||
internal *provider // Access to internal fields for testing
|
||||
}
|
||||
|
||||
// newTestProvider creates a test provider instance with access to internal fields.
|
||||
// This allows us to verify the internal state in our tests while maintaining
|
||||
// encapsulation in production code.
|
||||
func newTestProvider(cfg *config.Configuration, logger *zap.Logger) testProvider {
|
||||
p := NewProvider(cfg, logger)
|
||||
return testProvider{
|
||||
Provider: p,
|
||||
internal: p.(*provider), // Type assertion to get access to internal fields
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewProvider verifies that the provider is properly initialized with all
|
||||
// required components (database connection, blocked countries map, logger).
|
||||
func TestNewProvider(t *testing.T) {
|
||||
// Setup test configuration with path to test database
|
||||
cfg := &config.Configuration{
|
||||
Security: config.SecurityConfig{
|
||||
GeoLiteDBPath: "../../../static/GeoLite2-Country.mmdb",
|
||||
BannedCountries: []string{"US", "CN"},
|
||||
},
|
||||
}
|
||||
// Initialize logger with JSON output for structured test logs
|
||||
logger, _ := zap.NewDevelopment()
|
||||
|
||||
// Create test provider and verify internal components
|
||||
p := newTestProvider(cfg, logger)
|
||||
assert.NotNil(t, p.Provider, "Provider should not be nil")
|
||||
assert.NotEmpty(t, p.internal.blockedCountries, "Blocked countries map should be initialized")
|
||||
assert.NotNil(t, p.internal.logger, "Logger should be initialized")
|
||||
assert.NotNil(t, p.internal.db, "Database connection should be initialized")
|
||||
defer p.Close() // Ensure cleanup after test
|
||||
}
|
||||
|
||||
// TestProvider_IsBlockedCountry tests the country blocking functionality with
|
||||
// various country codes including edge cases like empty and invalid codes.
|
||||
func TestProvider_IsBlockedCountry(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
// Table-driven test cases covering various scenarios
|
||||
tests := []struct {
|
||||
name string
|
||||
country string
|
||||
expected bool
|
||||
}{
|
||||
// Positive test cases - blocked countries
|
||||
{
|
||||
name: "blocked country US",
|
||||
country: "US",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked country CN",
|
||||
country: "CN",
|
||||
expected: true,
|
||||
},
|
||||
// Negative test cases - allowed countries
|
||||
{
|
||||
name: "non-blocked country GB",
|
||||
country: "GB",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-blocked country JP",
|
||||
country: "JP",
|
||||
expected: false,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "empty country code",
|
||||
country: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid country code",
|
||||
country: "XX",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "lowercase country code", // Tests case sensitivity
|
||||
country: "us",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Run each test case
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := provider.IsBlockedCountry(tt.country)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_IsBlockedIP verifies IP blocking functionality using real-world
|
||||
// IP addresses, including IPv4, IPv6, and various edge cases.
|
||||
func TestProvider_IsBlockedIP(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
expected bool
|
||||
}{
|
||||
// Known IP addresses from blocked countries
|
||||
{
|
||||
name: "blocked IP (US - Google DNS)",
|
||||
ip: net.ParseIP("8.8.8.8"), // Google's primary DNS
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked IP (US - Google DNS 2)",
|
||||
ip: net.ParseIP("8.8.4.4"), // Google's secondary DNS
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "blocked IP (CN - Alibaba)",
|
||||
ip: net.ParseIP("223.5.5.5"), // Alibaba DNS
|
||||
expected: true,
|
||||
},
|
||||
// Non-blocked country IPs
|
||||
{
|
||||
name: "non-blocked IP (GB)",
|
||||
ip: net.ParseIP("178.62.1.1"),
|
||||
expected: false,
|
||||
},
|
||||
// Edge cases and special scenarios
|
||||
{
|
||||
name: "nil IP",
|
||||
ip: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid IP format",
|
||||
ip: net.ParseIP("invalid"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
ip: net.ParseIP("2001:4860:4860::8888"), // Google's IPv6 DNS
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := provider.IsBlockedIP(ctx, tt.ip)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_GetCountryCode verifies the country code lookup functionality
|
||||
// for various IP addresses, including error cases.
|
||||
func TestProvider_GetCountryCode(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
defer provider.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip net.IP
|
||||
expected string
|
||||
expectError bool
|
||||
}{
|
||||
// Valid IP addresses with known countries
|
||||
{
|
||||
name: "US IP (Google DNS)",
|
||||
ip: net.ParseIP("8.8.8.8"),
|
||||
expected: "US",
|
||||
expectError: false,
|
||||
},
|
||||
// Error cases
|
||||
{
|
||||
name: "nil IP",
|
||||
ip: nil,
|
||||
expected: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "private IP", // RFC 1918 address
|
||||
ip: net.ParseIP("192.168.1.1"),
|
||||
expected: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, err := provider.GetCountryCode(ctx, tt.ip)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, "Should return error for invalid IP")
|
||||
assert.Empty(t, code, "Should return empty code on error")
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err, "Should not return error for valid IP")
|
||||
assert.Equal(t, tt.expected, code, "Should return correct country code")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProvider_Close verifies that the provider properly closes its resources
|
||||
// and subsequent operations fail as expected.
|
||||
func TestProvider_Close(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
|
||||
// Verify initial close succeeds
|
||||
err := provider.Close()
|
||||
assert.NoError(t, err, "Initial close should succeed")
|
||||
|
||||
// Verify operations fail after close
|
||||
code, err := provider.GetCountryCode(context.Background(), net.ParseIP("8.8.8.8"))
|
||||
assert.Error(t, err, "Operations should fail after close")
|
||||
assert.Empty(t, code, "No data should be returned after close")
|
||||
}
|
||||
|
||||
// setupTestProvider is a helper function that creates a properly configured
|
||||
// provider instance for testing, using the test database path.
|
||||
func setupTestProvider(t *testing.T) Provider {
|
||||
cfg := &config.Configuration{
|
||||
Security: config.SecurityConfig{
|
||||
GeoLiteDBPath: "../../../static/GeoLite2-Country.mmdb",
|
||||
BannedCountries: []string{"US", "CN"},
|
||||
},
|
||||
}
|
||||
logger, _ := zap.NewDevelopment()
|
||||
return NewProvider(cfg, logger)
|
||||
}
|
||||
223
cloud/maplefile-backend/pkg/security/ipcrypt/encryptor.go
Normal file
223
cloud/maplefile-backend/pkg/security/ipcrypt/encryptor.go
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/validation"
|
||||
)
|
||||
|
||||
// IPEncryptor provides secure IP address encryption for GDPR compliance
|
||||
// Uses AES-GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// Encrypts IP addresses before storage and provides expiration checking
|
||||
type IPEncryptor struct {
|
||||
gcm cipher.AEAD
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewIPEncryptor creates a new IP encryptor with the given encryption key
|
||||
// keyHex should be a 32-character hex string (16 bytes for AES-128)
|
||||
// or 64-character hex string (32 bytes for AES-256)
|
||||
// Example: "0123456789abcdef0123456789abcdef" (AES-128)
|
||||
// Recommended: Use AES-256 with 64-character hex key
|
||||
func NewIPEncryptor(keyHex string, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
// Decode hex key to bytes
|
||||
keyBytes, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid hex key: %w", err)
|
||||
}
|
||||
|
||||
// AES requires exactly 16, 24, or 32 bytes
|
||||
if len(keyBytes) != 16 && len(keyBytes) != 24 && len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("key must be 16, 24, or 32 bytes (32, 48, or 64 hex characters), got %d bytes", len(keyBytes))
|
||||
}
|
||||
|
||||
// Create AES cipher block
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
// Create GCM (Galois/Counter Mode) for authenticated encryption
|
||||
// GCM provides both confidentiality and integrity
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("IP encryptor initialized with AES-GCM",
|
||||
zap.Int("key_length_bytes", len(keyBytes)),
|
||||
zap.Int("nonce_size", gcm.NonceSize()),
|
||||
zap.Int("overhead", gcm.Overhead()))
|
||||
|
||||
return &IPEncryptor{
|
||||
gcm: gcm,
|
||||
logger: logger.Named("ip-encryptor"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts an IP address for secure storage using AES-GCM
|
||||
// Returns base64-encoded encrypted IP address with embedded nonce
|
||||
// Format: base64(nonce + ciphertext + auth_tag)
|
||||
// Supports both IPv4 and IPv6 addresses
|
||||
//
|
||||
// Security Properties:
|
||||
// - Semantic security: same IP address produces different ciphertext each time
|
||||
// - Authentication: tampering with ciphertext is detected
|
||||
// - Unique nonce per encryption prevents pattern analysis
|
||||
func (e *IPEncryptor) Encrypt(ipAddress string) (string, error) {
|
||||
if ipAddress == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Parse IP address to validate format
|
||||
ip := net.ParseIP(ipAddress)
|
||||
if ip == nil {
|
||||
e.logger.Warn("invalid IP address format",
|
||||
zap.String("ip", validation.MaskIP(ipAddress)))
|
||||
return "", fmt.Errorf("invalid IP address: %s", validation.MaskIP(ipAddress))
|
||||
}
|
||||
|
||||
// Convert to 16-byte representation (IPv4 gets converted to IPv6 format)
|
||||
ipBytes := ip.To16()
|
||||
if ipBytes == nil {
|
||||
return "", fmt.Errorf("failed to convert IP to 16-byte format")
|
||||
}
|
||||
|
||||
// Generate a random nonce (number used once)
|
||||
// GCM requires a unique nonce for each encryption operation
|
||||
nonce := make([]byte, e.gcm.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
e.logger.Error("failed to generate nonce", zap.Error(err))
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the IP bytes using AES-GCM
|
||||
// GCM appends the authentication tag to the ciphertext
|
||||
// nil additional data means no associated data
|
||||
ciphertext := e.gcm.Seal(nil, nonce, ipBytes, nil)
|
||||
|
||||
// Prepend nonce to ciphertext for storage
|
||||
// Format: nonce || ciphertext+tag
|
||||
encryptedData := append(nonce, ciphertext...)
|
||||
|
||||
// Encode to base64 for database storage (text-safe)
|
||||
encryptedBase64 := base64.StdEncoding.EncodeToString(encryptedData)
|
||||
|
||||
e.logger.Debug("IP address encrypted with AES-GCM",
|
||||
zap.Int("plaintext_length", len(ipBytes)),
|
||||
zap.Int("nonce_length", len(nonce)),
|
||||
zap.Int("ciphertext_length", len(ciphertext)),
|
||||
zap.Int("total_encrypted_length", len(encryptedData)),
|
||||
zap.Int("base64_length", len(encryptedBase64)))
|
||||
|
||||
return encryptedBase64, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts an encrypted IP address
|
||||
// Takes base64-encoded encrypted IP and returns original IP address string
|
||||
// Verifies authentication tag to detect tampering
|
||||
func (e *IPEncryptor) Decrypt(encryptedBase64 string) (string, error) {
|
||||
if encryptedBase64 == "" {
|
||||
return "", nil // Empty string remains empty
|
||||
}
|
||||
|
||||
// Decode base64 to bytes
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
e.logger.Warn("invalid base64-encoded encrypted IP",
|
||||
zap.String("base64", encryptedBase64),
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
// Extract nonce from the beginning
|
||||
nonceSize := e.gcm.NonceSize()
|
||||
if len(encryptedData) < nonceSize {
|
||||
return "", fmt.Errorf("encrypted data too short: expected at least %d bytes, got %d", nonceSize, len(encryptedData))
|
||||
}
|
||||
|
||||
nonce := encryptedData[:nonceSize]
|
||||
ciphertext := encryptedData[nonceSize:]
|
||||
|
||||
// Decrypt and verify authentication tag using AES-GCM
|
||||
ipBytes, err := e.gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
e.logger.Warn("failed to decrypt IP address (authentication failed or corrupted data)",
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("decryption failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to IP address
|
||||
ip := net.IP(ipBytes)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("failed to parse decrypted IP bytes")
|
||||
}
|
||||
|
||||
// Convert to string
|
||||
ipString := ip.String()
|
||||
|
||||
e.logger.Debug("IP address decrypted with AES-GCM",
|
||||
zap.Int("encrypted_length", len(encryptedData)),
|
||||
zap.Int("decrypted_length", len(ipBytes)))
|
||||
|
||||
return ipString, nil
|
||||
}
|
||||
|
||||
// IsExpired checks if an IP address timestamp has expired (> 90 days old)
|
||||
// GDPR compliance: IP addresses must be deleted after 90 days
|
||||
func (e *IPEncryptor) IsExpired(timestamp time.Time) bool {
|
||||
if timestamp.IsZero() {
|
||||
return false // No timestamp means not expired (will be cleaned up later)
|
||||
}
|
||||
|
||||
// Calculate age in days
|
||||
age := time.Since(timestamp)
|
||||
ageInDays := int(age.Hours() / 24)
|
||||
|
||||
expired := ageInDays > 90
|
||||
|
||||
if expired {
|
||||
e.logger.Debug("IP timestamp expired",
|
||||
zap.Time("timestamp", timestamp),
|
||||
zap.Int("age_days", ageInDays))
|
||||
}
|
||||
|
||||
return expired
|
||||
}
|
||||
|
||||
// ShouldCleanup checks if an IP address should be cleaned up based on timestamp
|
||||
// Returns true if timestamp is older than 90 days OR if timestamp is zero (unset)
|
||||
func (e *IPEncryptor) ShouldCleanup(timestamp time.Time) bool {
|
||||
// Always cleanup if timestamp is not set (backwards compatibility)
|
||||
if timestamp.IsZero() {
|
||||
return false // Don't cleanup unset timestamps immediately
|
||||
}
|
||||
|
||||
return e.IsExpired(timestamp)
|
||||
}
|
||||
|
||||
// ValidateKey validates that a key is properly formatted for IP encryption
|
||||
// Returns true if key is valid 32-character hex string (AES-128) or 64-character (AES-256)
|
||||
func ValidateKey(keyHex string) error {
|
||||
// Check length (must be 16, 24, or 32 bytes = 32, 48, or 64 hex chars)
|
||||
if len(keyHex) != 32 && len(keyHex) != 48 && len(keyHex) != 64 {
|
||||
return fmt.Errorf("key must be 32, 48, or 64 hex characters, got %d characters", len(keyHex))
|
||||
}
|
||||
|
||||
// Check if valid hex
|
||||
_, err := hex.DecodeString(keyHex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("key must be valid hex string: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
13
cloud/maplefile-backend/pkg/security/ipcrypt/provider.go
Normal file
13
cloud/maplefile-backend/pkg/security/ipcrypt/provider.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
package ipcrypt
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
// ProvideIPEncryptor provides an IP encryptor instance
|
||||
// CWE-359: GDPR compliance for IP address storage
|
||||
func ProvideIPEncryptor(cfg *config.Config, logger *zap.Logger) (*IPEncryptor, error) {
|
||||
return NewIPEncryptor(cfg.Security.IPEncryptionKey, logger)
|
||||
}
|
||||
47
cloud/maplefile-backend/pkg/security/jwt/jwt.go
Normal file
47
cloud/maplefile-backend/pkg/security/jwt/jwt.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/jwt_utils"
|
||||
sbytes "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
)
|
||||
|
||||
// JWTProvider provides interface for abstracting JWT generation.
|
||||
type JWTProvider interface {
|
||||
GenerateJWTToken(uuid string, ad time.Duration) (string, time.Time, error)
|
||||
GenerateJWTTokenPair(uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error)
|
||||
ProcessJWTToken(reqToken string) (string, error)
|
||||
}
|
||||
|
||||
type jwtProvider struct {
|
||||
hmacSecret *sbytes.SecureBytes
|
||||
}
|
||||
|
||||
// NewProvider Constructor that returns the JWT generator.
|
||||
func NewJWTProvider(cfg *config.Configuration) JWTProvider {
|
||||
// Convert JWT secret string to SecureBytes
|
||||
secret, _ := sbytes.NewSecureBytes([]byte(cfg.JWT.Secret))
|
||||
return jwtProvider{
|
||||
hmacSecret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateJWTToken generates a single JWT token.
|
||||
func (p jwtProvider) GenerateJWTToken(uuid string, ad time.Duration) (string, time.Time, error) {
|
||||
return jwt_utils.GenerateJWTToken(p.hmacSecret.Bytes(), uuid, ad)
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair Generate the `access token` and `refresh token` for the secret key.
|
||||
func (p jwtProvider) GenerateJWTTokenPair(uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
return jwt_utils.GenerateJWTTokenPair(p.hmacSecret.Bytes(), uuid, ad, rd)
|
||||
}
|
||||
|
||||
func (p jwtProvider) ProcessJWTToken(reqToken string) (string, error) {
|
||||
if p.hmacSecret == nil {
|
||||
return "", errors.New("HMAC secret is required")
|
||||
}
|
||||
return jwt_utils.ProcessJWTToken(p.hmacSecret.Bytes(), reqToken)
|
||||
}
|
||||
98
cloud/maplefile-backend/pkg/security/jwt/jwt_test.go
Normal file
98
cloud/maplefile-backend/pkg/security/jwt/jwt_test.go
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
func setupTestProvider(t *testing.T) JWTProvider {
|
||||
cfg := &config.Configuration{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
},
|
||||
}
|
||||
return NewJWTProvider(cfg)
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
|
||||
func TestGenerateJWTToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
token, expiry, err := provider.GenerateJWTToken(uuid, duration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiry.After(time.Now()))
|
||||
assert.True(t, expiry.Before(time.Now().Add(duration).Add(time.Second)))
|
||||
}
|
||||
|
||||
func TestGenerateJWTTokenPair(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
accessDuration := time.Hour
|
||||
refreshDuration := time.Hour * 24
|
||||
|
||||
accessToken, accessExpiry, refreshToken, refreshExpiry, err := provider.GenerateJWTTokenPair(uuid, accessDuration, refreshDuration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, refreshToken)
|
||||
assert.True(t, accessExpiry.After(time.Now()))
|
||||
assert.True(t, refreshExpiry.After(time.Now()))
|
||||
assert.True(t, accessExpiry.Before(time.Now().Add(accessDuration).Add(time.Second)))
|
||||
assert.True(t, refreshExpiry.Before(time.Now().Add(refreshDuration).Add(time.Second)))
|
||||
}
|
||||
|
||||
func TestProcessJWTToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
// Generate a token first
|
||||
token, _, err := provider.GenerateJWTToken(uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Process the generated token
|
||||
processedUUID, err := provider.ProcessJWTToken(token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedUUID)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_InvalidToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
|
||||
_, err := provider.ProcessJWTToken("invalid-token")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_NilSecret(t *testing.T) {
|
||||
provider := jwtProvider{
|
||||
hmacSecret: nil,
|
||||
}
|
||||
|
||||
_, err := provider.ProcessJWTToken("any-token")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "HMAC secret is required", err.Error())
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_ExpiredToken(t *testing.T) {
|
||||
provider := setupTestProvider(t)
|
||||
uuid := "test-uuid"
|
||||
duration := -time.Hour // negative duration for expired token
|
||||
|
||||
token, _, err := provider.GenerateJWTToken(uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = provider.ProcessJWTToken(token)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
10
cloud/maplefile-backend/pkg/security/jwt/provider.go
Normal file
10
cloud/maplefile-backend/pkg/security/jwt/provider.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package jwt
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideJWTProvider provides a JWT provider instance for Wire DI
|
||||
func ProvideJWTProvider(cfg *config.Config) JWTProvider {
|
||||
return NewJWTProvider(cfg)
|
||||
}
|
||||
130
cloud/maplefile-backend/pkg/security/jwt_utils/jwt.go
Normal file
130
cloud/maplefile-backend/pkg/security/jwt_utils/jwt.go
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
package jwt_utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
jwt "github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// GenerateJWTToken Generate the `access token` for the secret key.
|
||||
// SECURITY: HMAC secret is wiped from memory after signing to prevent memory dump attacks.
|
||||
func GenerateJWTToken(hmacSecret []byte, uuid string, ad time.Duration) (string, time.Time, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
// Note: The original hmacSecret is owned by the caller
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after signing
|
||||
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
expiresIn := time.Now().Add(ad)
|
||||
|
||||
// CWE-391: Safe type assertion even though we just created the token
|
||||
// Defensive programming to prevent future panics if jwt library changes
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", expiresIn, jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
claims["session_uuid"] = uuid
|
||||
claims["exp"] = expiresIn.Unix()
|
||||
|
||||
tokenString, err := token.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", expiresIn, err
|
||||
}
|
||||
|
||||
return tokenString, expiresIn, nil
|
||||
}
|
||||
|
||||
// GenerateJWTTokenPair Generate the `access token` and `refresh token` for the secret key.
|
||||
// SECURITY: HMAC secret is wiped from memory after signing to prevent memory dump attacks.
|
||||
func GenerateJWTTokenPair(hmacSecret []byte, uuid string, ad time.Duration, rd time.Duration) (string, time.Time, string, time.Time, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after signing
|
||||
|
||||
//
|
||||
// Generate token.
|
||||
//
|
||||
token := jwt.New(jwt.SigningMethodHS256)
|
||||
expiresIn := time.Now().Add(ad)
|
||||
|
||||
// CWE-391: Safe type assertion even though we just created the token
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", time.Now(), "", time.Now(), jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
claims["session_uuid"] = uuid
|
||||
claims["exp"] = expiresIn.Unix()
|
||||
|
||||
tokenString, err := token.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", time.Now(), "", time.Now(), err
|
||||
}
|
||||
|
||||
//
|
||||
// Generate refresh token.
|
||||
//
|
||||
refreshToken := jwt.New(jwt.SigningMethodHS256)
|
||||
refreshExpiresIn := time.Now().Add(rd)
|
||||
|
||||
// CWE-391: Safe type assertion for refresh token
|
||||
rtClaims, ok := refreshToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return "", time.Now(), "", time.Now(), jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
|
||||
rtClaims["session_uuid"] = uuid
|
||||
rtClaims["exp"] = refreshExpiresIn.Unix()
|
||||
|
||||
refreshTokenString, err := refreshToken.SignedString(secretCopy)
|
||||
if err != nil {
|
||||
return "", time.Now(), "", time.Now(), err
|
||||
}
|
||||
|
||||
return tokenString, expiresIn, refreshTokenString, refreshExpiresIn, nil
|
||||
}
|
||||
|
||||
// ProcessJWTToken validates either the `access token` or `refresh token` and returns either the `uuid` if success or error on failure.
|
||||
// CWE-347: Implements proper algorithm validation to prevent JWT algorithm confusion attacks
|
||||
// OWASP A02:2021: Cryptographic Failures - Prevents token forgery through algorithm switching
|
||||
// SECURITY: HMAC secret copy is wiped from memory after validation.
|
||||
func ProcessJWTToken(hmacSecret []byte, reqToken string) (string, error) {
|
||||
// SECURITY: Create a copy of the secret and wipe the copy after use
|
||||
secretCopy := make([]byte, len(hmacSecret))
|
||||
copy(secretCopy, hmacSecret)
|
||||
defer memguard.WipeBytes(secretCopy) // SECURITY: Wipe secret copy after validation
|
||||
|
||||
token, err := jwt.Parse(reqToken, func(t *jwt.Token) (any, error) {
|
||||
// CRITICAL SECURITY FIX: Validate signing method to prevent algorithm confusion attacks
|
||||
// Protects against:
|
||||
// 1. "none" algorithm bypass (CVE-2015-9235)
|
||||
// 2. HS256/RS256 algorithm confusion (CVE-2016-5431)
|
||||
// 3. Token forgery through algorithm switching
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
|
||||
// Additional check: Ensure it's specifically HS256
|
||||
if t.Method.Alg() != "HS256" {
|
||||
return nil, jwt.ErrTokenSignatureInvalid
|
||||
}
|
||||
|
||||
return secretCopy, nil
|
||||
})
|
||||
if err == nil && token.Valid {
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
// Safe type assertion with validation
|
||||
sessionUUID, ok := claims["session_uuid"].(string)
|
||||
if !ok {
|
||||
return "", jwt.ErrTokenInvalidClaims
|
||||
}
|
||||
return sessionUUID, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
194
cloud/maplefile-backend/pkg/security/jwt_utils/jwt_test.go
Normal file
194
cloud/maplefile-backend/pkg/security/jwt_utils/jwt_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
package jwt_utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testSecret = []byte("test-secret-key")
|
||||
|
||||
func TestGenerateJWTToken(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
token, expiry, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiry.After(time.Now()))
|
||||
assert.True(t, expiry.Before(time.Now().Add(duration).Add(time.Second)))
|
||||
|
||||
// Verify token can be processed
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedUUID)
|
||||
}
|
||||
|
||||
func TestGenerateJWTTokenPair(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
accessDuration := time.Hour
|
||||
refreshDuration := time.Hour * 24
|
||||
|
||||
accessToken, accessExpiry, refreshToken, refreshExpiry, err := GenerateJWTTokenPair(
|
||||
testSecret,
|
||||
uuid,
|
||||
accessDuration,
|
||||
refreshDuration,
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, refreshToken)
|
||||
assert.True(t, accessExpiry.After(time.Now()))
|
||||
assert.True(t, refreshExpiry.After(time.Now()))
|
||||
assert.True(t, accessExpiry.Before(time.Now().Add(accessDuration).Add(time.Second)))
|
||||
assert.True(t, refreshExpiry.Before(time.Now().Add(refreshDuration).Add(time.Second)))
|
||||
|
||||
// Verify both tokens can be processed
|
||||
processedAccessUUID, err := ProcessJWTToken(testSecret, accessToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedAccessUUID)
|
||||
|
||||
processedRefreshUUID, err := ProcessJWTToken(testSecret, refreshToken)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uuid, processedRefreshUUID)
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed token",
|
||||
token: "not.a.token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong signature",
|
||||
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJ0ZXN0LXV1aWQiLCJleHAiOjE3MDQwNjc1NTF9.wrong",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
uuid, err := ProcessJWTToken(testSecret, tt.token)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, uuid)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, uuid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessJWTToken_Expired(t *testing.T) {
|
||||
uuid := "test-uuid"
|
||||
duration := -time.Hour // negative duration for expired token
|
||||
|
||||
token, _, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
assert.NoError(t, err)
|
||||
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, processedUUID)
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_AlgorithmConfusion tests protection against JWT algorithm confusion attacks
|
||||
// CVE-2015-9235: None algorithm bypass
|
||||
// CVE-2016-5431: HS256/RS256 algorithm confusion
|
||||
// CWE-347: Improper Verification of Cryptographic Signature
|
||||
func TestProcessJWTToken_AlgorithmConfusion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
description string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "none algorithm bypass attempt",
|
||||
// Token with "alg": "none" - should be rejected
|
||||
token: "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.",
|
||||
description: "Attacker tries to bypass signature verification using 'none' algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "RS256 algorithm confusion attempt",
|
||||
// Token with "alg": "RS256" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use RS256 to confuse HMAC validation",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "HS384 algorithm attempt",
|
||||
// Token with "alg": "HS384" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use different HMAC algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "HS512 algorithm attempt",
|
||||
// Token with "alg": "HS512" - should be rejected (we only accept HS256)
|
||||
token: "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJzZXNzaW9uX3V1aWQiOiJhdHRhY2tlci11dWlkIiwiZXhwIjo5OTk5OTk5OTk5fQ.invalid",
|
||||
description: "Attacker tries to use different HMAC algorithm",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Logf("Testing: %s", tt.description)
|
||||
uuid, err := ProcessJWTToken(testSecret, tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Expected error for security vulnerability: %s", tt.description)
|
||||
assert.Empty(t, uuid, "UUID should be empty when algorithm validation fails")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, uuid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_ValidHS256Only tests that only valid HS256 tokens are accepted
|
||||
func TestProcessJWTToken_ValidHS256Only(t *testing.T) {
|
||||
uuid := "valid-test-uuid"
|
||||
duration := time.Hour
|
||||
|
||||
// Generate a valid HS256 token
|
||||
token, _, err := GenerateJWTToken(testSecret, uuid, duration)
|
||||
assert.NoError(t, err, "Should generate valid token")
|
||||
|
||||
// Verify it's accepted
|
||||
processedUUID, err := ProcessJWTToken(testSecret, token)
|
||||
assert.NoError(t, err, "Valid HS256 token should be accepted")
|
||||
assert.Equal(t, uuid, processedUUID, "UUID should match")
|
||||
}
|
||||
|
||||
// TestProcessJWTToken_MissingSessionUUID tests protection against missing session_uuid claim
|
||||
func TestProcessJWTToken_MissingSessionUUID(t *testing.T) {
|
||||
// This test verifies the safe type assertion fix for CWE-391
|
||||
// A token without session_uuid claim should return an error, not panic
|
||||
|
||||
// Note: We can't easily create such a token with our GenerateJWTToken function
|
||||
// as it always includes session_uuid. In a real attack scenario, an attacker
|
||||
// would craft such a token manually. This test documents the expected behavior.
|
||||
|
||||
// For now, we verify that a malformed token is properly rejected
|
||||
malformedToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjk5OTk5OTk5OTl9.invalid"
|
||||
uuid, err := ProcessJWTToken(testSecret, malformedToken)
|
||||
assert.Error(t, err, "Token without session_uuid should be rejected")
|
||||
assert.Empty(t, uuid, "UUID should be empty for invalid token")
|
||||
}
|
||||
96
cloud/maplefile-backend/pkg/security/memutil/memutil.go
Normal file
96
cloud/maplefile-backend/pkg/security/memutil/memutil.go
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
// Package memutil provides utilities for secure memory handling.
|
||||
// These utilities help prevent sensitive data from remaining in memory
|
||||
// after use, protecting against memory dump attacks.
|
||||
package memutil
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
sbytes "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securebytes"
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
// WipeString overwrites a string's backing array with zeros and clears the string.
|
||||
// Note: This only works if the string variable is the only reference to the data.
|
||||
// For better security, use SecureString instead of plain strings for sensitive data.
|
||||
func WipeString(s *string) {
|
||||
if s == nil || *s == "" {
|
||||
return
|
||||
}
|
||||
// Convert to byte slice and wipe
|
||||
// Note: This creates a copy, but we wipe what we can
|
||||
bytes := []byte(*s)
|
||||
memguard.WipeBytes(bytes)
|
||||
*s = ""
|
||||
}
|
||||
|
||||
// SecureCompareStrings performs constant-time comparison of two strings.
|
||||
// This prevents timing attacks when comparing secrets.
|
||||
func SecureCompareStrings(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// SecureCompareBytes performs constant-time comparison of two byte slices.
|
||||
// If wipeAfter is true, both slices are wiped after comparison.
|
||||
func SecureCompareBytes(a, b []byte, wipeAfter bool) bool {
|
||||
if wipeAfter {
|
||||
defer memguard.WipeBytes(a)
|
||||
defer memguard.WipeBytes(b)
|
||||
}
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// WithSecureBytes executes a function with secure byte handling.
|
||||
// The bytes are automatically wiped after the function returns.
|
||||
func WithSecureBytes(data []byte, fn func([]byte) error) error {
|
||||
defer memguard.WipeBytes(data)
|
||||
return fn(data)
|
||||
}
|
||||
|
||||
// WithSecureString executes a function with secure string handling.
|
||||
// The SecureString is automatically wiped after the function returns.
|
||||
func WithSecureString(str string, fn func(*sstring.SecureString) error) error {
|
||||
secure, err := sstring.NewSecureString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer secure.Wipe()
|
||||
return fn(secure)
|
||||
}
|
||||
|
||||
// CloneAndWipe creates a copy of data and wipes the original.
|
||||
// Useful when you need to pass data to a function that will store it,
|
||||
// but want to ensure the original is wiped.
|
||||
func CloneAndWipe(data []byte) []byte {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
clone := make([]byte, len(data))
|
||||
copy(clone, data)
|
||||
memguard.WipeBytes(data)
|
||||
return clone
|
||||
}
|
||||
|
||||
// SecureZero overwrites memory with zeros.
|
||||
// This is a convenience wrapper around memguard.WipeBytes.
|
||||
func SecureZero(data []byte) {
|
||||
memguard.WipeBytes(data)
|
||||
}
|
||||
|
||||
// WipeSecureString wipes a SecureString if it's not nil.
|
||||
// This is a nil-safe convenience wrapper.
|
||||
func WipeSecureString(s *sstring.SecureString) {
|
||||
if s != nil {
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
|
||||
// WipeSecureBytes wipes a SecureBytes if it's not nil.
|
||||
// This is a nil-safe convenience wrapper.
|
||||
func WipeSecureBytes(s *sbytes.SecureBytes) {
|
||||
if s != nil {
|
||||
s.Wipe()
|
||||
}
|
||||
}
|
||||
186
cloud/maplefile-backend/pkg/security/password/password.go
Normal file
186
cloud/maplefile-backend/pkg/security/password/password.go
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/password/password.go
|
||||
package password
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"golang.org/x/crypto/argon2"
|
||||
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidHash = errors.New("the encoded hash is not in the correct format")
|
||||
ErrIncompatibleVersion = errors.New("incompatible version of argon2")
|
||||
)
|
||||
|
||||
type PasswordProvider interface {
|
||||
GenerateHashFromPassword(password *sstring.SecureString) (string, error)
|
||||
ComparePasswordAndHash(password *sstring.SecureString, hash string) (bool, error)
|
||||
AlgorithmName() string
|
||||
GenerateSecureRandomBytes(length int) ([]byte, error)
|
||||
GenerateSecureRandomString(length int) (string, error)
|
||||
}
|
||||
|
||||
type passwordProvider struct {
|
||||
memory uint32
|
||||
iterations uint32
|
||||
parallelism uint8
|
||||
saltLength uint32
|
||||
keyLength uint32
|
||||
}
|
||||
|
||||
func NewPasswordProvider() PasswordProvider {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Establish the parameters to use for Argon2.
|
||||
return &passwordProvider{
|
||||
memory: 64 * 1024,
|
||||
iterations: 3,
|
||||
parallelism: 2,
|
||||
saltLength: 16,
|
||||
keyLength: 32,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateHashFromPassword function takes the plaintext string and returns an Argon2 hashed string.
|
||||
// SECURITY: Password bytes are wiped from memory after hashing to prevent memory dump attacks.
|
||||
func (p *passwordProvider) GenerateHashFromPassword(password *sstring.SecureString) (string, error) {
|
||||
salt, err := generateRandomBytes(p.saltLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer memguard.WipeBytes(salt) // SECURITY: Wipe salt after use
|
||||
|
||||
passwordBytes := password.Bytes()
|
||||
defer memguard.WipeBytes(passwordBytes) // SECURITY: Wipe password bytes after hashing
|
||||
|
||||
hash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
defer memguard.WipeBytes(hash) // SECURITY: Wipe raw hash after encoding
|
||||
|
||||
// Base64 encode the salt and hashed password.
|
||||
b64Salt := base64.RawStdEncoding.EncodeToString(salt)
|
||||
b64Hash := base64.RawStdEncoding.EncodeToString(hash)
|
||||
|
||||
// Return a string using the standard encoded hash representation.
|
||||
encodedHash := fmt.Sprintf("$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", argon2.Version, p.memory, p.iterations, p.parallelism, b64Salt, b64Hash)
|
||||
|
||||
return encodedHash, nil
|
||||
}
|
||||
|
||||
// CheckPasswordHash function checks the plaintext string and hash string and returns either true
|
||||
// or false depending.
|
||||
// SECURITY: All sensitive bytes (password, salt, hashes) are wiped from memory after comparison.
|
||||
func (p *passwordProvider) ComparePasswordAndHash(password *sstring.SecureString, encodedHash string) (match bool, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
// Extract the parameters, salt and derived key from the encoded password
|
||||
// hash.
|
||||
p, salt, hash, err := decodeHash(encodedHash)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer memguard.WipeBytes(salt) // SECURITY: Wipe salt after use
|
||||
defer memguard.WipeBytes(hash) // SECURITY: Wipe stored hash after comparison
|
||||
|
||||
// Get password bytes and ensure they're wiped after use
|
||||
passwordBytes := password.Bytes()
|
||||
defer memguard.WipeBytes(passwordBytes)
|
||||
|
||||
// Derive the key from the other password using the same parameters.
|
||||
otherHash := argon2.IDKey(passwordBytes, salt, p.iterations, p.memory, p.parallelism, p.keyLength)
|
||||
defer memguard.WipeBytes(otherHash) // SECURITY: Wipe computed hash after comparison
|
||||
|
||||
// Check that the contents of the hashed passwords are identical. Note
|
||||
// that we are using the subtle.ConstantTimeCompare() function for this
|
||||
// to help prevent timing attacks.
|
||||
if subtle.ConstantTimeCompare(hash, otherHash) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// AlgorithmName function returns the algorithm used for hashing.
|
||||
func (p *passwordProvider) AlgorithmName() string {
|
||||
return "argon2id"
|
||||
}
|
||||
|
||||
func generateRandomBytes(n uint32) ([]byte, error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func decodeHash(encodedHash string) (p *passwordProvider, salt, hash []byte, err error) {
|
||||
// DEVELOPERS NOTE:
|
||||
// The following code was copy and pasted from: "How to Hash and Verify Passwords With Argon2 in Go" via https://www.alexedwards.net/blog/how-to-hash-and-verify-passwords-with-argon2-in-go
|
||||
|
||||
vals := strings.Split(encodedHash, "$")
|
||||
if len(vals) != 6 {
|
||||
return nil, nil, nil, ErrInvalidHash
|
||||
}
|
||||
|
||||
var version int
|
||||
_, err = fmt.Sscanf(vals[2], "v=%d", &version)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if version != argon2.Version {
|
||||
return nil, nil, nil, ErrIncompatibleVersion
|
||||
}
|
||||
|
||||
p = &passwordProvider{}
|
||||
_, err = fmt.Sscanf(vals[3], "m=%d,t=%d,p=%d", &p.memory, &p.iterations, &p.parallelism)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
salt, err = base64.RawStdEncoding.Strict().DecodeString(vals[4])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.saltLength = uint32(len(salt))
|
||||
|
||||
hash, err = base64.RawStdEncoding.Strict().DecodeString(vals[5])
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.keyLength = uint32(len(hash))
|
||||
|
||||
return p, salt, hash, nil
|
||||
}
|
||||
|
||||
// GenerateSecureRandomBytes generates a secure random byte slice of the specified length.
|
||||
func (p *passwordProvider) GenerateSecureRandomBytes(length int) ([]byte, error) {
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate secure random bytes: %v", err)
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
// GenerateSecureRandomString generates a secure random string of the specified length.
|
||||
func (p *passwordProvider) GenerateSecureRandomString(length int) (string, error) {
|
||||
bytes, err := p.GenerateSecureRandomBytes(length)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package password
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
sstring "codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/security/securestring"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
t.Log("TestPasswordHashing: Starting")
|
||||
|
||||
provider := NewPasswordProvider()
|
||||
t.Log("TestPasswordHashing: Provider created")
|
||||
|
||||
password, err := sstring.NewSecureString("test-password")
|
||||
require.NoError(t, err)
|
||||
t.Log("TestPasswordHashing: Password SecureString created")
|
||||
fmt.Println("TestPasswordHashing: Password SecureString created")
|
||||
|
||||
// Let's add a timeout to see if we can pinpoint the issue
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
fmt.Println("TestPasswordHashing: Generating hash...")
|
||||
hash, err := provider.GenerateHashFromPassword(password)
|
||||
fmt.Printf("TestPasswordHashing: Hash generated: %v, error: %v\n", hash != "", err)
|
||||
|
||||
if err == nil {
|
||||
fmt.Println("TestPasswordHashing: Comparing password and hash...")
|
||||
match, err := provider.ComparePasswordAndHash(password, hash)
|
||||
fmt.Printf("TestPasswordHashing: Comparison done: match=%v, error=%v\n", match, err)
|
||||
}
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
fmt.Println("TestPasswordHashing: Test completed successfully")
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatal("Test timed out after 10 seconds")
|
||||
}
|
||||
|
||||
fmt.Println("TestPasswordHashing: Cleaning up password...")
|
||||
password.Wipe()
|
||||
fmt.Println("TestPasswordHashing: Done")
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package password
|
||||
|
||||
// ProvidePasswordProvider provides a password provider instance for Wire DI
|
||||
func ProvidePasswordProvider() PasswordProvider {
|
||||
return NewPasswordProvider()
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securebytes.go
|
||||
package securebytes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureBytes is used to store a byte slice securely in memory.
|
||||
type SecureBytes struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureBytes creates a new SecureBytes instance from the given byte slice.
|
||||
func NewSecureBytes(b []byte) (*SecureBytes, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, errors.New("byte slice cannot be empty")
|
||||
}
|
||||
|
||||
buffer := memguard.NewBuffer(len(b))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
copy(buffer.Bytes(), b)
|
||||
|
||||
return &SecureBytes{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// Bytes returns the securely stored byte slice.
|
||||
func (sb *SecureBytes) Bytes() []byte {
|
||||
return sb.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the byte slice from memory and makes it unrecoverable.
|
||||
func (sb *SecureBytes) Wipe() error {
|
||||
sb.buffer.Wipe()
|
||||
sb.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securebytes_test.go
|
||||
package securebytes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSecureBytes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid input",
|
||||
input: []byte("test-data"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sb, err := NewSecureBytes(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, sb)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sb)
|
||||
assert.NotNil(t, sb.buffer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureBytes_Bytes(t *testing.T) {
|
||||
input := []byte("test-data")
|
||||
sb, err := NewSecureBytes(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure the SecureBytes object is properly closed after the test
|
||||
defer sb.Wipe()
|
||||
|
||||
output := sb.Bytes()
|
||||
assert.Equal(t, input, output)
|
||||
assert.NotSame(t, &input, &output) // Verify different memory addresses
|
||||
}
|
||||
|
||||
func TestSecureBytes_Wipe(t *testing.T) {
|
||||
sb, err := NewSecureBytes([]byte("test-data"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = sb.Wipe()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// After wiping, the internal buffer should be nil
|
||||
assert.Nil(t, sb.buffer)
|
||||
|
||||
// Attempting to access bytes after wiping might panic or return nil/empty slice
|
||||
// Based on the panic, calling Bytes() on a wiped buffer is unsafe.
|
||||
// We verify the buffer is nil instead of calling Bytes().
|
||||
}
|
||||
|
||||
func TestSecureBytes_DataIsolation(t *testing.T) {
|
||||
original := []byte("test-data")
|
||||
sb, err := NewSecureBytes(original)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Ensure the SecureBytes object is properly closed after the test
|
||||
defer sb.Wipe()
|
||||
|
||||
// Modify original data
|
||||
original[0] = 'x'
|
||||
|
||||
// Verify secure bytes remains unchanged
|
||||
stored := sb.Bytes()
|
||||
assert.NotEqual(t, original, stored)
|
||||
assert.Equal(t, []byte("test-data"), stored)
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
package secureconfig
|
||||
|
||||
import (
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// ProvideSecureConfigProvider provides a SecureConfigProvider for Wire DI.
|
||||
func ProvideSecureConfigProvider(cfg *config.Config) *SecureConfigProvider {
|
||||
return NewSecureConfigProvider(cfg)
|
||||
}
|
||||
|
|
@ -0,0 +1,187 @@
|
|||
// Package secureconfig provides secure access to configuration secrets.
|
||||
// It wraps sensitive configuration values in memguard-protected buffers
|
||||
// to prevent secret leakage through memory dumps.
|
||||
package secureconfig
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config"
|
||||
)
|
||||
|
||||
// SecureConfigProvider provides secure access to configuration secrets.
|
||||
// Secrets are stored in memguard LockedBuffers and wiped when no longer needed.
|
||||
type SecureConfigProvider struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Cached secure buffers - created on first access
|
||||
jwtSecret *memguard.LockedBuffer
|
||||
dbPassword *memguard.LockedBuffer
|
||||
cachePassword *memguard.LockedBuffer
|
||||
s3AccessKey *memguard.LockedBuffer
|
||||
s3SecretKey *memguard.LockedBuffer
|
||||
mailgunAPIKey *memguard.LockedBuffer
|
||||
|
||||
// Original config for initial loading
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSecureConfigProvider creates a new secure config provider from the given config.
|
||||
// The original config secrets are copied to secure buffers and should be cleared
|
||||
// from the original config after this call.
|
||||
func NewSecureConfigProvider(cfg *config.Config) *SecureConfigProvider {
|
||||
provider := &SecureConfigProvider{
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
// Pre-load secrets into secure buffers
|
||||
provider.loadSecrets()
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// loadSecrets copies secrets from config into memguard buffers.
|
||||
// SECURITY: Original config strings remain in memory but secure buffers provide
|
||||
// additional protection for long-lived secret access.
|
||||
func (p *SecureConfigProvider) loadSecrets() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// JWT Secret
|
||||
if p.cfg.JWT.Secret != "" {
|
||||
p.jwtSecret = memguard.NewBufferFromBytes([]byte(p.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// Database Password
|
||||
if p.cfg.Database.Password != "" {
|
||||
p.dbPassword = memguard.NewBufferFromBytes([]byte(p.cfg.Database.Password))
|
||||
}
|
||||
|
||||
// Cache Password
|
||||
if p.cfg.Cache.Password != "" {
|
||||
p.cachePassword = memguard.NewBufferFromBytes([]byte(p.cfg.Cache.Password))
|
||||
}
|
||||
|
||||
// S3 Access Key
|
||||
if p.cfg.S3.AccessKey != "" {
|
||||
p.s3AccessKey = memguard.NewBufferFromBytes([]byte(p.cfg.S3.AccessKey))
|
||||
}
|
||||
|
||||
// S3 Secret Key
|
||||
if p.cfg.S3.SecretKey != "" {
|
||||
p.s3SecretKey = memguard.NewBufferFromBytes([]byte(p.cfg.S3.SecretKey))
|
||||
}
|
||||
|
||||
// Mailgun API Key
|
||||
if p.cfg.Mailgun.APIKey != "" {
|
||||
p.mailgunAPIKey = memguard.NewBufferFromBytes([]byte(p.cfg.Mailgun.APIKey))
|
||||
}
|
||||
}
|
||||
|
||||
// JWTSecret returns the JWT secret as a secure byte slice.
|
||||
// The returned bytes should not be stored - use immediately and let GC collect.
|
||||
func (p *SecureConfigProvider) JWTSecret() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.jwtSecret == nil || !p.jwtSecret.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.jwtSecret.Bytes()
|
||||
}
|
||||
|
||||
// DatabasePassword returns the database password as a secure byte slice.
|
||||
func (p *SecureConfigProvider) DatabasePassword() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.dbPassword == nil || !p.dbPassword.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.dbPassword.Bytes()
|
||||
}
|
||||
|
||||
// CachePassword returns the cache password as a secure byte slice.
|
||||
func (p *SecureConfigProvider) CachePassword() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.cachePassword == nil || !p.cachePassword.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.cachePassword.Bytes()
|
||||
}
|
||||
|
||||
// S3AccessKey returns the S3 access key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) S3AccessKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.s3AccessKey == nil || !p.s3AccessKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.s3AccessKey.Bytes()
|
||||
}
|
||||
|
||||
// S3SecretKey returns the S3 secret key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) S3SecretKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.s3SecretKey == nil || !p.s3SecretKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.s3SecretKey.Bytes()
|
||||
}
|
||||
|
||||
// MailgunAPIKey returns the Mailgun API key as a secure byte slice.
|
||||
func (p *SecureConfigProvider) MailgunAPIKey() []byte {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.mailgunAPIKey == nil || !p.mailgunAPIKey.IsAlive() {
|
||||
return nil
|
||||
}
|
||||
return p.mailgunAPIKey.Bytes()
|
||||
}
|
||||
|
||||
// Destroy securely wipes all cached secrets from memory.
|
||||
// Should be called during application shutdown.
|
||||
func (p *SecureConfigProvider) Destroy() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.jwtSecret != nil && p.jwtSecret.IsAlive() {
|
||||
p.jwtSecret.Destroy()
|
||||
}
|
||||
if p.dbPassword != nil && p.dbPassword.IsAlive() {
|
||||
p.dbPassword.Destroy()
|
||||
}
|
||||
if p.cachePassword != nil && p.cachePassword.IsAlive() {
|
||||
p.cachePassword.Destroy()
|
||||
}
|
||||
if p.s3AccessKey != nil && p.s3AccessKey.IsAlive() {
|
||||
p.s3AccessKey.Destroy()
|
||||
}
|
||||
if p.s3SecretKey != nil && p.s3SecretKey.IsAlive() {
|
||||
p.s3SecretKey.Destroy()
|
||||
}
|
||||
if p.mailgunAPIKey != nil && p.mailgunAPIKey.IsAlive() {
|
||||
p.mailgunAPIKey.Destroy()
|
||||
}
|
||||
|
||||
p.jwtSecret = nil
|
||||
p.dbPassword = nil
|
||||
p.cachePassword = nil
|
||||
p.s3AccessKey = nil
|
||||
p.s3SecretKey = nil
|
||||
p.mailgunAPIKey = nil
|
||||
}
|
||||
|
||||
// Config returns the underlying config for non-secret access.
|
||||
// Prefer using the specific secret accessor methods for sensitive data.
|
||||
func (p *SecureConfigProvider) Config() *config.Config {
|
||||
return p.cfg
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securestring.go
|
||||
package securestring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
)
|
||||
|
||||
// SecureString is used to store a string securely in memory.
|
||||
type SecureString struct {
|
||||
buffer *memguard.LockedBuffer
|
||||
}
|
||||
|
||||
// NewSecureString creates a new SecureString instance from the given string.
|
||||
func NewSecureString(s string) (*SecureString, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, errors.New("string cannot be empty")
|
||||
}
|
||||
|
||||
// Use memguard's built-in method for creating from bytes
|
||||
buffer := memguard.NewBufferFromBytes([]byte(s))
|
||||
|
||||
// Check if buffer was created successfully
|
||||
if buffer == nil {
|
||||
return nil, errors.New("failed to create buffer")
|
||||
}
|
||||
|
||||
return &SecureString{buffer: buffer}, nil
|
||||
}
|
||||
|
||||
// String returns the securely stored string.
|
||||
func (ss *SecureString) String() string {
|
||||
if ss.buffer == nil {
|
||||
fmt.Println("String(): buffer is nil")
|
||||
return ""
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
fmt.Println("String(): buffer is not alive")
|
||||
return ""
|
||||
}
|
||||
return ss.buffer.String()
|
||||
}
|
||||
|
||||
func (ss *SecureString) Bytes() []byte {
|
||||
if ss.buffer == nil {
|
||||
fmt.Println("Bytes(): buffer is nil")
|
||||
return nil
|
||||
}
|
||||
if !ss.buffer.IsAlive() {
|
||||
fmt.Println("Bytes(): buffer is not alive")
|
||||
return nil
|
||||
}
|
||||
return ss.buffer.Bytes()
|
||||
}
|
||||
|
||||
// Wipe removes the string from memory and makes it unrecoverable.
|
||||
func (ss *SecureString) Wipe() error {
|
||||
|
||||
if ss.buffer != nil {
|
||||
if ss.buffer.IsAlive() {
|
||||
ss.buffer.Destroy()
|
||||
}
|
||||
} else {
|
||||
// fmt.Println("Wipe(): Buffer is nil")
|
||||
}
|
||||
ss.buffer = nil
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
// File Path: monorepo/cloud/maplefile-backend/pkg/security/securebytes/securestring_test.go
|
||||
package securestring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSecureString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid string",
|
||||
input: "test-string",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ss, err := NewSecureString(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ss)
|
||||
assert.NotNil(t, ss.buffer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureString_String(t *testing.T) {
|
||||
input := "test-string"
|
||||
ss, err := NewSecureString(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := ss.String()
|
||||
assert.Equal(t, input, output)
|
||||
}
|
||||
|
||||
func TestSecureString_Wipe(t *testing.T) {
|
||||
ss, err := NewSecureString("test-string")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = ss.Wipe()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, ss.buffer)
|
||||
|
||||
// Verify string is wiped
|
||||
output := ss.String()
|
||||
assert.Empty(t, output)
|
||||
}
|
||||
|
||||
func TestSecureString_DataIsolation(t *testing.T) {
|
||||
original := "test-string"
|
||||
ss, err := NewSecureString(original)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Attempt to modify original
|
||||
original = "modified"
|
||||
|
||||
// Verify secure string remains unchanged
|
||||
stored := ss.String()
|
||||
assert.NotEqual(t, original, stored)
|
||||
assert.Equal(t, "test-string", stored)
|
||||
}
|
||||
|
||||
func TestSecureString_StringConsistency(t *testing.T) {
|
||||
input := "test-string"
|
||||
ss, err := NewSecureString(input)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Multiple calls should return same value
|
||||
assert.Equal(t, ss.String(), ss.String())
|
||||
}
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplepress-backend/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinJWTSecretLength is the minimum required length for JWT secrets (256 bits)
|
||||
MinJWTSecretLength = 32
|
||||
|
||||
// RecommendedJWTSecretLength is the recommended length for JWT secrets (512 bits)
|
||||
RecommendedJWTSecretLength = 64
|
||||
|
||||
// MinEntropyBits is the minimum Shannon entropy in bits per character
|
||||
// For reference: random base64 has ~6 bits/char, we require minimum 4.0
|
||||
MinEntropyBits = 4.0
|
||||
|
||||
// MinProductionEntropyBits is the minimum entropy required for production
|
||||
MinProductionEntropyBits = 4.5
|
||||
|
||||
// MaxRepeatingCharacters is the maximum allowed consecutive repeating characters
|
||||
MaxRepeatingCharacters = 3
|
||||
)
|
||||
|
||||
// WeakSecrets contains common weak/default secrets that should never be used
|
||||
var WeakSecrets = []string{
|
||||
"secret",
|
||||
"password",
|
||||
"changeme",
|
||||
"change-me",
|
||||
"change_me",
|
||||
"12345",
|
||||
"123456",
|
||||
"1234567",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"default",
|
||||
"test",
|
||||
"testing",
|
||||
"admin",
|
||||
"administrator",
|
||||
"root",
|
||||
"qwerty",
|
||||
"qwertyuiop",
|
||||
"letmein",
|
||||
"welcome",
|
||||
"monkey",
|
||||
"dragon",
|
||||
"master",
|
||||
"sunshine",
|
||||
"princess",
|
||||
"football",
|
||||
"starwars",
|
||||
"baseball",
|
||||
"superman",
|
||||
"iloveyou",
|
||||
"trustno1",
|
||||
"hello",
|
||||
"abc123",
|
||||
"password123",
|
||||
"admin123",
|
||||
"guest",
|
||||
"user",
|
||||
"demo",
|
||||
"sample",
|
||||
"example",
|
||||
}
|
||||
|
||||
// DangerousPatterns contains patterns that indicate a secret should be changed
|
||||
var DangerousPatterns = []string{
|
||||
"change",
|
||||
"replace",
|
||||
"update",
|
||||
"modify",
|
||||
"sample",
|
||||
"example",
|
||||
"todo",
|
||||
"fixme",
|
||||
"temp",
|
||||
"temporary",
|
||||
}
|
||||
|
||||
// CredentialValidator validates credentials and secrets for security issues
|
||||
type CredentialValidator interface {
|
||||
ValidateJWTSecret(secret string, environment string) error
|
||||
ValidateAllCredentials(cfg *config.Config) error
|
||||
}
|
||||
|
||||
type credentialValidator struct{}
|
||||
|
||||
// NewCredentialValidator creates a new credential validator
|
||||
func NewCredentialValidator() CredentialValidator {
|
||||
return &credentialValidator{}
|
||||
}
|
||||
|
||||
// ValidateJWTSecret validates JWT secret strength and security
|
||||
// CWE-798: Comprehensive validation to prevent hard-coded/weak credentials
|
||||
func (v *credentialValidator) ValidateJWTSecret(secret string, environment string) error {
|
||||
// Check minimum length
|
||||
if len(secret) < MinJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short (%d characters). Minimum required: %d characters (256 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
MinJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for common weak secrets (case-insensitive)
|
||||
secretLower := strings.ToLower(secret)
|
||||
for _, weak := range WeakSecrets {
|
||||
if secretLower == weak || strings.Contains(secretLower, weak) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret cannot contain common weak value: '%s'. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
weak,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for dangerous patterns indicating default/placeholder values
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(secretLower, pattern) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains suspicious pattern '%s' which suggests it's a placeholder. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
pattern,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for repeating character patterns (e.g., "aaaa", "1111")
|
||||
if err := checkRepeatingPatterns(secret); err != nil {
|
||||
return fmt.Errorf(
|
||||
"JWT secret validation failed: %s. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sequential patterns (e.g., "abcd", "1234")
|
||||
if hasSequentialPattern(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret contains sequential patterns (e.g., 'abcd', '1234') which reduces entropy. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Calculate Shannon entropy
|
||||
entropy := calculateShannonEntropy(secret)
|
||||
minEntropy := MinEntropyBits
|
||||
if environment == "production" {
|
||||
minEntropy = MinProductionEntropyBits
|
||||
}
|
||||
|
||||
if entropy < minEntropy {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient entropy: %.2f bits/char (minimum: %.1f bits/char for %s). "+
|
||||
"The secret appears to have low randomness. "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
entropy,
|
||||
minEntropy,
|
||||
environment,
|
||||
)
|
||||
}
|
||||
|
||||
// In production, enforce stricter requirements
|
||||
if environment == "production" {
|
||||
// Check recommended length for production
|
||||
if len(secret) < RecommendedJWTSecretLength {
|
||||
return fmt.Errorf(
|
||||
"JWT secret is too short for production environment (%d characters). "+
|
||||
"Recommended: %d characters (512 bits). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
len(secret),
|
||||
RecommendedJWTSecretLength,
|
||||
)
|
||||
}
|
||||
|
||||
// Check for sufficient character complexity
|
||||
if !hasSufficientComplexity(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret has insufficient complexity for production. It should contain a mix of uppercase, lowercase, " +
|
||||
"digits, and special characters (at least 3 types). Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
|
||||
// Validate base64-like characteristics (recommended generation method)
|
||||
if !looksLikeBase64(secret) {
|
||||
return fmt.Errorf(
|
||||
"JWT secret does not appear to be randomly generated (expected base64-like characteristics). "+
|
||||
"Generate a secure secret with: openssl rand -base64 64",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAllCredentials validates all credentials in the configuration
|
||||
func (v *credentialValidator) ValidateAllCredentials(cfg *config.Config) error {
|
||||
var errors []string
|
||||
|
||||
// Validate JWT Secret
|
||||
if err := v.ValidateJWTSecret(cfg.App.JWTSecret, cfg.App.Environment); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("JWT Secret validation failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
// In production, ensure other critical configs are not using defaults/placeholders
|
||||
if cfg.App.Environment == "production" {
|
||||
// Check Meilisearch API key
|
||||
if cfg.Meilisearch.APIKey == "" {
|
||||
errors = append(errors, "Meilisearch API key must be set in production")
|
||||
} else if containsDangerousPattern(cfg.Meilisearch.APIKey) {
|
||||
errors = append(errors, "Meilisearch API key appears to be a placeholder/default value")
|
||||
}
|
||||
|
||||
// Check database hosts are not using localhost
|
||||
for _, host := range cfg.Database.Hosts {
|
||||
if strings.Contains(strings.ToLower(host), "localhost") || host == "127.0.0.1" {
|
||||
errors = append(errors, "Database hosts should not use localhost in production")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check cache host is not localhost
|
||||
if strings.Contains(strings.ToLower(cfg.Cache.Host), "localhost") || cfg.Cache.Host == "127.0.0.1" {
|
||||
errors = append(errors, "Cache host should not use localhost in production")
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("credential validation failed:\n - %s", strings.Join(errors, "\n - "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShannonEntropy calculates the Shannon entropy of a string in bits per character
|
||||
// Shannon entropy measures the randomness/unpredictability of data
|
||||
// Formula: H(X) = -Σ(p(x) * log2(p(x))) where p(x) is the probability of character x
|
||||
func calculateShannonEntropy(s string) float64 {
|
||||
if len(s) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Count character frequencies
|
||||
frequencies := make(map[rune]int)
|
||||
for _, char := range s {
|
||||
frequencies[char]++
|
||||
}
|
||||
|
||||
// Calculate entropy
|
||||
var entropy float64
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range frequencies {
|
||||
probability := float64(count) / length
|
||||
entropy -= probability * math.Log2(probability)
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// hasSufficientComplexity checks if the secret has a good mix of character types
|
||||
// Requires at least 3 out of 4 character types for production
|
||||
func hasSufficientComplexity(secret string) bool {
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
hasSpecial bool
|
||||
)
|
||||
|
||||
for _, char := range secret {
|
||||
switch {
|
||||
case unicode.IsUpper(char):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(char):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(char):
|
||||
hasDigit = true
|
||||
default:
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
// Require at least 3 out of 4 character types
|
||||
count := 0
|
||||
if hasUpper {
|
||||
count++
|
||||
}
|
||||
if hasLower {
|
||||
count++
|
||||
}
|
||||
if hasDigit {
|
||||
count++
|
||||
}
|
||||
if hasSpecial {
|
||||
count++
|
||||
}
|
||||
|
||||
return count >= 3
|
||||
}
|
||||
|
||||
// checkRepeatingPatterns checks for excessive repeating characters
|
||||
func checkRepeatingPatterns(s string) error {
|
||||
if len(s) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
repeatCount := 1
|
||||
lastChar := rune(s[0])
|
||||
|
||||
for _, char := range s[1:] {
|
||||
if char == lastChar {
|
||||
repeatCount++
|
||||
if repeatCount > MaxRepeatingCharacters {
|
||||
return fmt.Errorf(
|
||||
"contains %d consecutive repeating characters ('%c'), maximum allowed: %d",
|
||||
repeatCount,
|
||||
lastChar,
|
||||
MaxRepeatingCharacters,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
repeatCount = 1
|
||||
lastChar = char
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasSequentialPattern detects common sequential patterns
|
||||
func hasSequentialPattern(s string) bool {
|
||||
if len(s) < 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for at least 4 consecutive sequential characters
|
||||
for i := 0; i < len(s)-3; i++ {
|
||||
// Check ascending sequence (e.g., "abcd", "1234")
|
||||
if s[i+1] == s[i]+1 && s[i+2] == s[i]+2 && s[i+3] == s[i]+3 {
|
||||
return true
|
||||
}
|
||||
// Check descending sequence (e.g., "dcba", "4321")
|
||||
if s[i+1] == s[i]-1 && s[i+2] == s[i]-2 && s[i+3] == s[i]-3 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// looksLikeBase64 checks if the string has base64-like characteristics
|
||||
// Base64 uses: A-Z, a-z, 0-9, +, /, and = for padding
|
||||
func looksLikeBase64(s string) bool {
|
||||
if len(s) < MinJWTSecretLength {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
hasUpper bool
|
||||
hasLower bool
|
||||
hasDigit bool
|
||||
validChars int
|
||||
)
|
||||
|
||||
// Base64 valid characters
|
||||
for _, char := range s {
|
||||
switch {
|
||||
case char >= 'A' && char <= 'Z':
|
||||
hasUpper = true
|
||||
validChars++
|
||||
case char >= 'a' && char <= 'z':
|
||||
hasLower = true
|
||||
validChars++
|
||||
case char >= '0' && char <= '9':
|
||||
hasDigit = true
|
||||
validChars++
|
||||
case char == '+' || char == '/' || char == '=' || char == '-' || char == '_':
|
||||
validChars++
|
||||
default:
|
||||
// Invalid character for base64
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Should have good mix of character types typical of base64
|
||||
charTypesCount := 0
|
||||
if hasUpper {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasLower {
|
||||
charTypesCount++
|
||||
}
|
||||
if hasDigit {
|
||||
charTypesCount++
|
||||
}
|
||||
|
||||
// Base64 typically has at least uppercase, lowercase, and digits
|
||||
// Also check that it doesn't look like a repeated pattern
|
||||
if charTypesCount < 3 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for repeated patterns (e.g., "AbCd12!@" repeated)
|
||||
// If the string has low unique character count relative to its length, it's probably not random
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, char := range s {
|
||||
uniqueChars[char] = true
|
||||
}
|
||||
|
||||
// Random base64 should have at least 50% unique characters for strings over 32 chars
|
||||
uniqueRatio := float64(len(uniqueChars)) / float64(len(s))
|
||||
return uniqueRatio >= 0.4 // At least 40% unique characters
|
||||
}
|
||||
|
||||
// containsDangerousPattern checks if a string contains any dangerous patterns
|
||||
func containsDangerousPattern(value string) bool {
|
||||
valueLower := strings.ToLower(value)
|
||||
for _, pattern := range DangerousPatterns {
|
||||
if strings.Contains(valueLower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Simplified comprehensive test for JWT secret validation
|
||||
func TestJWTSecretValidation(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
// Good secrets - these should pass
|
||||
goodSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
}{
|
||||
{
|
||||
name: "Good 32-char for dev",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "development",
|
||||
},
|
||||
{
|
||||
name: "Good 64-char for prod",
|
||||
secret: "1WDduocStecRuIv+Us1t/RnYDoW1ZcEEbU+H+WykJG+IT5WnijzBb8uUPzGKju+D",
|
||||
env: "production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range goodSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for valid secret, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Bad secrets - these should fail
|
||||
badSecrets := []struct {
|
||||
name string
|
||||
secret string
|
||||
env string
|
||||
mustContain string
|
||||
}{
|
||||
{
|
||||
name: "Too short",
|
||||
secret: "short",
|
||||
env: "development",
|
||||
mustContain: "too short",
|
||||
},
|
||||
{
|
||||
name: "Common weak - password",
|
||||
secret: "password-is-not-secure-but-32char",
|
||||
env: "development",
|
||||
mustContain: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern",
|
||||
secret: "please-change-this-ima7xR+9nT0Yz",
|
||||
env: "development",
|
||||
mustContain: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
env: "development",
|
||||
mustContain: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
env: "development",
|
||||
mustContain: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy",
|
||||
secret: "abababababababababababababababab",
|
||||
env: "development",
|
||||
mustContain: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Prod too short",
|
||||
secret: "ima7xR+9nT0Yz0jKVu/QwtkqdAaU+3Ki",
|
||||
env: "production",
|
||||
mustContain: "too short for production",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range badSecrets {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.env)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error containing '%s', got no error", tt.mustContain)
|
||||
} else if !contains(err.Error(), tt.mustContain) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.mustContain, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && len(substr) > 0 && findSubstring(s, substr)))
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,535 @@
|
|||
package validator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCalculateShannonEntropy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
minBits float64
|
||||
maxBits float64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have 0 entropy",
|
||||
},
|
||||
{
|
||||
name: "All same character",
|
||||
input: "aaaaaaaaaa",
|
||||
minBits: 0,
|
||||
maxBits: 0,
|
||||
expected: "should have very low entropy",
|
||||
},
|
||||
{
|
||||
name: "Low entropy - repeated pattern",
|
||||
input: "abcabcabcabc",
|
||||
minBits: 1.5,
|
||||
maxBits: 2.0,
|
||||
expected: "should have low entropy",
|
||||
},
|
||||
{
|
||||
name: "Medium entropy - simple password",
|
||||
input: "Password123",
|
||||
minBits: 3.0,
|
||||
maxBits: 4.5,
|
||||
expected: "should have medium entropy",
|
||||
},
|
||||
{
|
||||
name: "High entropy - random base64",
|
||||
input: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
minBits: 4.0,
|
||||
maxBits: 6.0,
|
||||
expected: "should have high entropy",
|
||||
},
|
||||
{
|
||||
name: "Very high entropy - long random base64",
|
||||
input: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
minBits: 4.5,
|
||||
maxBits: 6.5,
|
||||
expected: "should have very high entropy",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
entropy := calculateShannonEntropy(tt.input)
|
||||
if entropy < tt.minBits || entropy > tt.maxBits {
|
||||
t.Errorf("%s: got %.2f bits/char, expected between %.1f and %.1f", tt.expected, entropy, tt.minBits, tt.maxBits)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSufficientComplexity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnop",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only uppercase",
|
||||
input: "ABCDEFGHIJKLMNOP",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only digits",
|
||||
input: "1234567890",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase",
|
||||
input: "AbCdEfGhIjKl",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits",
|
||||
input: "abc123def456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Uppercase + digits",
|
||||
input: "ABC123DEF456",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + digits",
|
||||
input: "Abc123Def456",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + uppercase + special",
|
||||
input: "AbC+DeF/GhI=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Lowercase + digits + special",
|
||||
input: "abc123+def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "All four types",
|
||||
input: "Abc123+Def456/",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 string",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b+xK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSufficientComplexity(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSufficientComplexity(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRepeatingPatterns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Single character",
|
||||
input: "a",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "No repeating",
|
||||
input: "abcdefgh",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Two repeating (ok)",
|
||||
input: "aabcdeef",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Three repeating (ok)",
|
||||
input: "aaabcdeee",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Four repeating (error)",
|
||||
input: "aaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Five repeating (error)",
|
||||
input: "aaaaabcde",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple groups of three (ok)",
|
||||
input: "aaabbbccc",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Repeating in middle (error)",
|
||||
input: "abcdddddef",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "Repeating at end (error)",
|
||||
input: "abcdefgggg",
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkRepeatingPatterns(tt.input)
|
||||
if (err != nil) != tt.shouldErr {
|
||||
t.Errorf("checkRepeatingPatterns(%q) error = %v, shouldErr = %v", tt.input, err, tt.shouldErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasSequentialPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "No sequential",
|
||||
input: "acegikmo",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Ascending sequence - abcd",
|
||||
input: "xyzabcdefg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending sequence - dcba",
|
||||
input: "xyzdcbafg",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Ascending digits - 1234",
|
||||
input: "abc1234def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Descending digits - 4321",
|
||||
input: "abc4321def",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Random characters",
|
||||
input: "xK8vN2mP9sQ4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Base64-like",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := hasSequentialPattern(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("hasSequentialPattern(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeBase64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
input: "abc",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Only lowercase",
|
||||
input: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Real base64",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b=",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 without padding",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Base64 with URL-safe chars",
|
||||
input: "K8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b-_",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Generated secret",
|
||||
input: "xK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6bxK8vN2mP9sQ4tR7wY3zA6b",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Simple password",
|
||||
input: "Password123!Password123!Password123!",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := looksLikeBase64(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("looksLikeBase64(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
secret string
|
||||
environment string
|
||||
shouldErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Too short - 20 chars",
|
||||
secret: "12345678901234567890",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "too short",
|
||||
},
|
||||
{
|
||||
name: "Minimum length - 32 chars (acceptable for dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains password",
|
||||
secret: "my-password-is-secure-123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - secret",
|
||||
secret: "secretsecretsecretsecretsecretsec",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Common weak secret - contains 12345",
|
||||
secret: "abcd12345efghijklmnopqrstuvwxyz",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "common weak value",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - change",
|
||||
secret: "please-change-this-j8EJm9ZKnuTYxcVK",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Dangerous pattern - sample",
|
||||
secret: "sample-secret-j8EJm9ZKnuTYxcVKQ",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "suspicious pattern",
|
||||
},
|
||||
{
|
||||
name: "Repeating characters",
|
||||
secret: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "consecutive repeating characters",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - abcd",
|
||||
secret: "abcdefghijklmnopqrstuvwxyzabcdef",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Sequential pattern - 1234",
|
||||
secret: "12345678901234567890123456789012",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "sequential patterns",
|
||||
},
|
||||
{
|
||||
name: "Low entropy secret",
|
||||
secret: "aAbBcCdDeEfFgGhHiIjJkKlLmMnNoOpP",
|
||||
environment: "development",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Good secret - base64 style (dev)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Good secret - longer (dev)",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "development",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - too short (32 chars)",
|
||||
secret: "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "too short for production",
|
||||
},
|
||||
{
|
||||
name: "Production - insufficient complexity",
|
||||
secret: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01",
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient complexity",
|
||||
},
|
||||
{
|
||||
name: "Production - low entropy pattern",
|
||||
secret: strings.Repeat("AbCd12!@", 8), // 64 chars but repetitive
|
||||
environment: "production",
|
||||
shouldErr: true,
|
||||
errContains: "insufficient entropy",
|
||||
},
|
||||
{
|
||||
name: "Production - good secret",
|
||||
secret: "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "Production - excellent secret with padding",
|
||||
secret: "7mK2nP8sR4wT6xZ3bA5cxK7mN1oQ9uS4vY2zA6bxK7mN1oQ9uS4vY2zA6b+W0E=",
|
||||
environment: "production",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateJWTSecret(tt.secret, tt.environment)
|
||||
|
||||
if tt.shouldErr {
|
||||
if err == nil {
|
||||
t.Errorf("ValidateJWTSecret() expected error containing %q, got no error", tt.errContains)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateJWTSecret() error = %q, should contain %q", err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("ValidateJWTSecret() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTSecret_EdgeCases(t *testing.T) {
|
||||
validator := NewCredentialValidator()
|
||||
|
||||
t.Run("Secret with mixed weak patterns", func(t *testing.T) {
|
||||
secret := "password123admin" // Contains multiple weak patterns
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err == nil {
|
||||
t.Error("Expected error for secret containing weak patterns, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at minimum length", func(t *testing.T) {
|
||||
// 32 characters exactly
|
||||
secret := "j8EJm9/ZKnuTYxcVKQK/NWcrt1Drgzx"
|
||||
err := validator.ValidateJWTSecret(secret, "development")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 32-char secret in development, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Secret exactly at recommended length", func(t *testing.T) {
|
||||
// 64 characters exactly - using real random base64
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFir"
|
||||
err := validator.ValidateJWTSecret(secret, "production")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for 64-char secret in production, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure validation is performant
|
||||
func BenchmarkCalculateShannonEntropy(b *testing.B) {
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
calculateShannonEntropy(secret)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateJWTSecret(b *testing.B) {
|
||||
validator := NewCredentialValidator()
|
||||
secret := "PKiQCYBT+AxkksUbC+F5NJsQBG+GDRvlc/5d+240xljW2uVtzsz0uqv0sjCJFirR"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = validator.ValidateJWTSecret(secret, "production")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
package validator
|
||||
|
||||
// ProvideCredentialValidator provides a credential validator for dependency injection
|
||||
func ProvideCredentialValidator() CredentialValidator {
|
||||
return NewCredentialValidator()
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue