Initial commit: Open sourcing all of the Maple Open Technologies code.

This commit is contained in:
Bartlomiej Mika 2025-12-02 14:33:08 -05:00
commit 755d54a99d
2010 changed files with 448675 additions and 0 deletions

View file

@ -0,0 +1,275 @@
// File Path: monorepo/cloud/maplepress-backend/pkg/validation/email.go
package validation
import (
"fmt"
"strings"
)
// EmailValidator provides comprehensive email validation and normalization
// CWE-20: Improper Input Validation - Ensures email addresses are properly validated
type EmailValidator struct {
validator *Validator
}
// NewEmailValidator creates a new email validator
func NewEmailValidator() *EmailValidator {
return &EmailValidator{
validator: NewValidator(),
}
}
// ValidateAndNormalize validates and normalizes an email address
// Returns the normalized email and any validation error
func (ev *EmailValidator) ValidateAndNormalize(email, fieldName string) (string, error) {
// Step 1: Basic validation using existing validator
if err := ev.validator.ValidateEmail(email, fieldName); err != nil {
return "", err
}
// Step 2: Normalize the email
normalized := ev.Normalize(email)
// Step 3: Additional security checks
if err := ev.ValidateSecurityConstraints(normalized, fieldName); err != nil {
return "", err
}
return normalized, nil
}
// Normalize normalizes an email address for consistent storage and comparison
// CWE-180: Incorrect Behavior Order: Validate Before Canonicalize
func (ev *EmailValidator) Normalize(email string) string {
// Trim whitespace
email = strings.TrimSpace(email)
// Convert to lowercase (email local parts are case-sensitive per RFC 5321,
// but most providers treat them as case-insensitive for better UX)
email = strings.ToLower(email)
// Remove any null bytes
email = strings.ReplaceAll(email, "\x00", "")
// Gmail-specific normalization (optional - commented out by default)
// This removes dots and plus-aliases from Gmail addresses
// Uncomment if you want to prevent abuse via Gmail aliases
// email = ev.normalizeGmail(email)
return email
}
// ValidateSecurityConstraints performs additional security validation
func (ev *EmailValidator) ValidateSecurityConstraints(email, fieldName string) error {
// Check for suspicious patterns
// 1. Detect emails with excessive special characters (potential obfuscation)
specialCharCount := 0
for _, ch := range email {
if ch == '+' || ch == '.' || ch == '_' || ch == '-' || ch == '%' {
specialCharCount++
}
}
if specialCharCount > 10 {
return fmt.Errorf("%s: contains too many special characters", fieldName)
}
// 2. Detect potentially disposable email patterns
if ev.isLikelyDisposable(email) {
// Note: This is a warning-level check. In production, you might want to
// either reject these or flag them for review.
// For now, we'll allow them but this can be configured.
}
// 3. Check for common typos in popular domains
if typo := ev.detectCommonDomainTypo(email); typo != "" {
return fmt.Errorf("%s: possible typo detected, did you mean %s?", fieldName, typo)
}
// 4. Prevent IP-based email addresses
if ev.hasIPAddress(email) {
return fmt.Errorf("%s: IP-based email addresses are not allowed", fieldName)
}
return nil
}
// isLikelyDisposable checks if email is from a known disposable email provider
// This is a basic implementation - in production, use a service like:
// - https://github.com/disposable/disposable-email-domains
// - or an API service
func (ev *EmailValidator) isLikelyDisposable(email string) bool {
// Extract domain
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := strings.ToLower(parts[1])
// Common disposable email patterns
disposablePatterns := []string{
"temp",
"disposable",
"throwaway",
"guerrilla",
"mailinator",
"10minute",
"trashmail",
"yopmail",
"fakeinbox",
}
for _, pattern := range disposablePatterns {
if strings.Contains(domain, pattern) {
return true
}
}
// Known disposable domains (small sample - expand as needed)
disposableDomains := map[string]bool{
"mailinator.com": true,
"guerrillamail.com": true,
"10minutemail.com": true,
"tempmailaddress.com": true,
"yopmail.com": true,
"fakeinbox.com": true,
"trashmail.com": true,
"throwaway.email": true,
}
return disposableDomains[domain]
}
// detectCommonDomainTypo checks for common typos in popular email domains
func (ev *EmailValidator) detectCommonDomainTypo(email string) string {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return ""
}
localPart := parts[0]
domain := strings.ToLower(parts[1])
// Common typos map: typo -> correct
typos := map[string]string{
"gmial.com": "gmail.com",
"gmai.com": "gmail.com",
"gmil.com": "gmail.com",
"yahooo.com": "yahoo.com",
"yaho.com": "yahoo.com",
"hotmial.com": "hotmail.com",
"hotmal.com": "hotmail.com",
"outlok.com": "outlook.com",
"outloo.com": "outlook.com",
"iclodu.com": "icloud.com",
"iclod.com": "icloud.com",
"protonmai.com": "protonmail.com",
"protonmal.com": "protonmail.com",
}
if correct, found := typos[domain]; found {
return localPart + "@" + correct
}
return ""
}
// hasIPAddress checks if email domain is an IP address
func (ev *EmailValidator) hasIPAddress(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := parts[1]
// Check for IPv4 pattern: [192.168.1.1]
if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") {
return true
}
// Check for unbracketed IP patterns (less common but possible)
// Simple heuristic: contains only digits and dots
hasOnlyDigitsAndDots := true
for _, ch := range domain {
if ch != '.' && (ch < '0' || ch > '9') {
hasOnlyDigitsAndDots = false
break
}
}
return hasOnlyDigitsAndDots && strings.Count(domain, ".") >= 3
}
// normalizeGmail normalizes Gmail addresses by removing dots and plus-aliases
// Gmail ignores dots in the local part and treats everything after + as an alias
// Example: john.doe+test@gmail.com -> johndoe@gmail.com
func (ev *EmailValidator) normalizeGmail(email string) string {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return email
}
localPart := parts[0]
domain := strings.ToLower(parts[1])
// Only normalize for Gmail and Googlemail
if domain != "gmail.com" && domain != "googlemail.com" {
return email
}
// Remove dots from local part
localPart = strings.ReplaceAll(localPart, ".", "")
// Remove everything after + (plus-alias)
if plusIndex := strings.Index(localPart, "+"); plusIndex != -1 {
localPart = localPart[:plusIndex]
}
return localPart + "@" + domain
}
// ValidateEmailList validates a list of email addresses
// Returns the first error encountered, or nil if all are valid
func (ev *EmailValidator) ValidateEmailList(emails []string, fieldName string) ([]string, error) {
normalized := make([]string, 0, len(emails))
for i, email := range emails {
norm, err := ev.ValidateAndNormalize(email, fmt.Sprintf("%s[%d]", fieldName, i))
if err != nil {
return nil, err
}
normalized = append(normalized, norm)
}
return normalized, nil
}
// IsValidEmailDomain checks if a domain is likely valid (has proper structure)
// This is a lightweight check - for production, consider DNS MX record validation
func (ev *EmailValidator) IsValidEmailDomain(email string) bool {
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := strings.ToLower(parts[1])
// Must have at least one dot
if !strings.Contains(domain, ".") {
return false
}
// TLD must be at least 2 characters
tldParts := strings.Split(domain, ".")
if len(tldParts) < 2 {
return false
}
tld := tldParts[len(tldParts)-1]
if len(tld) < 2 {
return false
}
return true
}

View file

@ -0,0 +1,120 @@
package validation
import (
"fmt"
"net/http"
"strconv"
)
// ValidatePathUUID validates a UUID path parameter
// CWE-20: Improper Input Validation
func ValidatePathUUID(r *http.Request, paramName string) (string, error) {
value := r.PathValue(paramName)
if value == "" {
return "", fmt.Errorf("%s is required", paramName)
}
validator := NewValidator()
if err := validator.ValidateUUID(value, paramName); err != nil {
return "", err
}
return value, nil
}
// ValidatePathSlug validates a slug path parameter
// CWE-20: Improper Input Validation
func ValidatePathSlug(r *http.Request, paramName string) (string, error) {
value := r.PathValue(paramName)
if value == "" {
return "", fmt.Errorf("%s is required", paramName)
}
validator := NewValidator()
if err := validator.ValidateSlug(value, paramName); err != nil {
return "", err
}
return value, nil
}
// ValidatePathInt validates an integer path parameter
// CWE-20: Improper Input Validation
func ValidatePathInt(r *http.Request, paramName string) (int64, error) {
valueStr := r.PathValue(paramName)
if valueStr == "" {
return 0, fmt.Errorf("%s is required", paramName)
}
value, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("%s must be a valid integer", paramName)
}
if value <= 0 {
return 0, fmt.Errorf("%s must be greater than 0", paramName)
}
return value, nil
}
// ValidatePagination validates pagination query parameters
// Returns limit and offset with defaults and bounds checking
func ValidatePagination(r *http.Request, defaultLimit int) (limit int, offset int, err error) {
limit = defaultLimit
offset = 0
// Validate limit
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
parsedLimit, err := strconv.Atoi(limitStr)
if err != nil || parsedLimit <= 0 || parsedLimit > 100 {
return 0, 0, fmt.Errorf("limit must be between 1 and 100")
}
limit = parsedLimit
}
// Validate offset
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
parsedOffset, err := strconv.Atoi(offsetStr)
if err != nil || parsedOffset < 0 {
return 0, 0, fmt.Errorf("offset must be >= 0")
}
offset = parsedOffset
}
return limit, offset, nil
}
// ValidateSortField validates sort field against whitelist
// CWE-89: SQL Injection prevention via whitelist
func ValidateSortField(r *http.Request, allowedFields []string) (string, error) {
sortBy := r.URL.Query().Get("sort_by")
if sortBy == "" {
return "", nil // Optional field
}
for _, allowed := range allowedFields {
if sortBy == allowed {
return sortBy, nil
}
}
return "", fmt.Errorf("invalid sort_by field (allowed: %v)", allowedFields)
}
// ValidateQueryEmail validates an email query parameter
// CWE-20: Improper Input Validation
func ValidateQueryEmail(r *http.Request, paramName string) (string, error) {
email := r.URL.Query().Get(paramName)
if email == "" {
return "", fmt.Errorf("%s is required", paramName)
}
emailValidator := NewEmailValidator()
normalizedEmail, err := emailValidator.ValidateAndNormalize(email, paramName)
if err != nil {
return "", err
}
return normalizedEmail, nil
}

View file

@ -0,0 +1,6 @@
package validation
// ProvideValidator provides a Validator instance
func ProvideValidator() *Validator {
return NewValidator()
}

View file

@ -0,0 +1,498 @@
package validation
import (
"fmt"
"net/mail"
"net/url"
"regexp"
"strings"
"time"
"unicode"
)
// Common validation errors
var (
ErrRequired = fmt.Errorf("field is required")
ErrInvalidEmail = fmt.Errorf("invalid email format")
ErrInvalidURL = fmt.Errorf("invalid URL format")
ErrInvalidDomain = fmt.Errorf("invalid domain format")
ErrTooShort = fmt.Errorf("value is too short")
ErrTooLong = fmt.Errorf("value is too long")
ErrInvalidCharacters = fmt.Errorf("contains invalid characters")
ErrInvalidFormat = fmt.Errorf("invalid format")
ErrInvalidValue = fmt.Errorf("invalid value")
ErrWhitespaceOnly = fmt.Errorf("cannot contain only whitespace")
ErrContainsHTML = fmt.Errorf("cannot contain HTML tags")
ErrInvalidSlug = fmt.Errorf("invalid slug format")
)
// Regex patterns for validation
var (
// Email validation: RFC 5322 compliant
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._%+\-]*[a-zA-Z0-9]@[a-zA-Z0-9][a-zA-Z0-9.\-]*[a-zA-Z0-9]\.[a-zA-Z]{2,}$`)
// Domain validation: alphanumeric with dots and hyphens
domainRegex = regexp.MustCompile(`^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
// Slug validation: lowercase alphanumeric with hyphens
slugRegex = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
// HTML tag detection
htmlTagRegex = regexp.MustCompile(`<[^>]+>`)
// UUID validation (version 4)
uuidRegex = regexp.MustCompile(`^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$`)
// Alphanumeric only
alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
)
// Reserved slugs that cannot be used for tenant names
var ReservedSlugs = map[string]bool{
"api": true,
"admin": true,
"www": true,
"mail": true,
"email": true,
"health": true,
"status": true,
"metrics": true,
"static": true,
"cdn": true,
"assets": true,
"blog": true,
"docs": true,
"help": true,
"support": true,
"login": true,
"logout": true,
"signup": true,
"register": true,
"app": true,
"dashboard": true,
"settings": true,
"account": true,
"profile": true,
"root": true,
"system": true,
"public": true,
"private": true,
}
// Validator provides input validation utilities
type Validator struct{}
// NewValidator creates a new validator instance
func NewValidator() *Validator {
return &Validator{}
}
// ==================== String Validation ====================
// ValidateRequired checks if a string is not empty
func (v *Validator) ValidateRequired(value, fieldName string) error {
if strings.TrimSpace(value) == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
return nil
}
// ValidateLength checks if string length is within range
func (v *Validator) ValidateLength(value, fieldName string, min, max int) error {
length := len(strings.TrimSpace(value))
if length < min {
return fmt.Errorf("%s: %w (minimum %d characters)", fieldName, ErrTooShort, min)
}
if max > 0 && length > max {
return fmt.Errorf("%s: %w (maximum %d characters)", fieldName, ErrTooLong, max)
}
return nil
}
// ValidateNotWhitespaceOnly ensures the string contains non-whitespace characters
func (v *Validator) ValidateNotWhitespaceOnly(value, fieldName string) error {
if len(strings.TrimSpace(value)) == 0 && len(value) > 0 {
return fmt.Errorf("%s: %w", fieldName, ErrWhitespaceOnly)
}
return nil
}
// ValidateNoHTML checks that the string doesn't contain HTML tags
func (v *Validator) ValidateNoHTML(value, fieldName string) error {
if htmlTagRegex.MatchString(value) {
return fmt.Errorf("%s: %w", fieldName, ErrContainsHTML)
}
return nil
}
// ValidateAlphanumeric checks if string contains only alphanumeric characters
func (v *Validator) ValidateAlphanumeric(value, fieldName string) error {
if !alphanumericRegex.MatchString(value) {
return fmt.Errorf("%s: %w (only letters and numbers allowed)", fieldName, ErrInvalidCharacters)
}
return nil
}
// ValidatePrintable ensures string contains only printable characters
func (v *Validator) ValidatePrintable(value, fieldName string) error {
for _, r := range value {
if !unicode.IsPrint(r) && !unicode.IsSpace(r) {
return fmt.Errorf("%s: %w (contains non-printable characters)", fieldName, ErrInvalidCharacters)
}
}
return nil
}
// ==================== Email Validation ====================
// ValidateEmail validates email format using RFC 5322 compliant regex
func (v *Validator) ValidateEmail(email, fieldName string) error {
email = strings.TrimSpace(email)
// Check required
if email == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (RFC 5321: max 320 chars)
if len(email) > 320 {
return fmt.Errorf("%s: %w (maximum 320 characters)", fieldName, ErrTooLong)
}
// Validate using regex
if !emailRegex.MatchString(email) {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
}
// Additional validation using net/mail package
_, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
}
// Check for consecutive dots
if strings.Contains(email, "..") {
return fmt.Errorf("%s: %w (consecutive dots not allowed)", fieldName, ErrInvalidEmail)
}
// Check for leading/trailing dots in local part
parts := strings.Split(email, "@")
if len(parts) == 2 {
if strings.HasPrefix(parts[0], ".") || strings.HasSuffix(parts[0], ".") {
return fmt.Errorf("%s: %w (local part cannot start or end with dot)", fieldName, ErrInvalidEmail)
}
}
return nil
}
// ==================== URL Validation ====================
// ValidateURL validates URL format and ensures it has a valid scheme
func (v *Validator) ValidateURL(urlStr, fieldName string) error {
urlStr = strings.TrimSpace(urlStr)
// Check required
if urlStr == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (max 2048 chars for URL)
if len(urlStr) > 2048 {
return fmt.Errorf("%s: %w (maximum 2048 characters)", fieldName, ErrTooLong)
}
// Parse URL
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidURL)
}
// Ensure scheme is present and valid
if parsedURL.Scheme == "" {
return fmt.Errorf("%s: %w (missing scheme)", fieldName, ErrInvalidURL)
}
// Only allow http and https
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return fmt.Errorf("%s: %w (only http and https schemes allowed)", fieldName, ErrInvalidURL)
}
// Ensure host is present
if parsedURL.Host == "" {
return fmt.Errorf("%s: %w (missing host)", fieldName, ErrInvalidURL)
}
return nil
}
// ValidateHTTPSURL validates URL and ensures it uses HTTPS
func (v *Validator) ValidateHTTPSURL(urlStr, fieldName string) error {
if err := v.ValidateURL(urlStr, fieldName); err != nil {
return err
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("%s: invalid URL format", fieldName)
}
if parsedURL.Scheme != "https" {
return fmt.Errorf("%s: must use HTTPS protocol", fieldName)
}
return nil
}
// ==================== Domain Validation ====================
// ValidateDomain validates domain name format
// Supports standard domains (example.com) and localhost with ports (localhost:8081) for development
func (v *Validator) ValidateDomain(domain, fieldName string) error {
domain = strings.TrimSpace(strings.ToLower(domain))
// Check required
if domain == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (max 253 chars per RFC 1035)
if len(domain) > 253 {
return fmt.Errorf("%s: %w (maximum 253 characters)", fieldName, ErrTooLong)
}
// Check minimum length
if len(domain) < 4 {
return fmt.Errorf("%s: %w (minimum 4 characters)", fieldName, ErrTooShort)
}
// Allow localhost with optional port for development
// Examples: localhost, localhost:8080, localhost:3000
if strings.HasPrefix(domain, "localhost") {
// If it has a port, validate the port format
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
if len(parts) != 2 {
return fmt.Errorf("%s: %w (invalid localhost format)", fieldName, ErrInvalidDomain)
}
// Port should be numeric
if parts[1] == "" {
return fmt.Errorf("%s: %w (missing port number)", fieldName, ErrInvalidDomain)
}
// Basic port validation (could be more strict)
for _, c := range parts[1] {
if c < '0' || c > '9' {
return fmt.Errorf("%s: %w (port must be numeric)", fieldName, ErrInvalidDomain)
}
}
}
return nil
}
// Allow 127.0.0.1 and other local IPs with optional port for development
if strings.HasPrefix(domain, "127.") || strings.HasPrefix(domain, "192.168.") || strings.HasPrefix(domain, "10.") {
// If it has a port, just verify format (IP:port)
if strings.Contains(domain, ":") {
parts := strings.Split(domain, ":")
if len(parts) != 2 {
return fmt.Errorf("%s: %w (invalid IP format)", fieldName, ErrInvalidDomain)
}
}
return nil
}
// Validate standard domain format (example.com)
if !domainRegex.MatchString(domain) {
return fmt.Errorf("%s: %w", fieldName, ErrInvalidDomain)
}
// Check each label length (max 63 chars per RFC 1035)
labels := strings.Split(domain, ".")
for _, label := range labels {
if len(label) > 63 {
return fmt.Errorf("%s: %w (label exceeds 63 characters)", fieldName, ErrInvalidDomain)
}
}
return nil
}
// ==================== Slug Validation ====================
// ValidateSlug validates slug format (lowercase alphanumeric with hyphens)
func (v *Validator) ValidateSlug(slug, fieldName string) error {
slug = strings.TrimSpace(strings.ToLower(slug))
// Check required
if slug == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check length (3-63 chars)
if len(slug) < 3 {
return fmt.Errorf("%s: %w (minimum 3 characters)", fieldName, ErrTooShort)
}
if len(slug) > 63 {
return fmt.Errorf("%s: %w (maximum 63 characters)", fieldName, ErrTooLong)
}
// Validate format
if !slugRegex.MatchString(slug) {
return fmt.Errorf("%s: %w (only lowercase letters, numbers, and hyphens allowed)", fieldName, ErrInvalidSlug)
}
// Check for reserved slugs
if ReservedSlugs[slug] {
return fmt.Errorf("%s: '%s' is a reserved slug and cannot be used", fieldName, slug)
}
return nil
}
// GenerateSlug generates a URL-friendly slug from a name
// Converts to lowercase, replaces spaces and special chars with hyphens
// Ensures the slug matches the slug validation regex
func (v *Validator) GenerateSlug(name string) string {
// Convert to lowercase and trim spaces
slug := strings.TrimSpace(strings.ToLower(name))
// Replace any non-alphanumeric characters (except hyphens) with hyphens
var result strings.Builder
prevWasHyphen := false
for _, char := range slug {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') {
result.WriteRune(char)
prevWasHyphen = false
} else if !prevWasHyphen {
// Replace any non-alphanumeric character with a hyphen
// But don't add consecutive hyphens
result.WriteRune('-')
prevWasHyphen = true
}
}
slug = result.String()
// Remove leading and trailing hyphens
slug = strings.Trim(slug, "-")
// Enforce length constraints (3-63 chars)
if len(slug) < 3 {
// If too short, pad with random suffix
slug = slug + "-" + strings.ToLower(fmt.Sprintf("%d", time.Now().UnixNano()%10000))
}
if len(slug) > 63 {
// Truncate to 63 chars
slug = slug[:63]
// Remove trailing hyphen if any
slug = strings.TrimRight(slug, "-")
}
return slug
}
// ==================== UUID Validation ====================
// ValidateUUID validates UUID format (version 4)
func (v *Validator) ValidateUUID(id, fieldName string) error {
id = strings.TrimSpace(strings.ToLower(id))
// Check required
if id == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Validate format
if !uuidRegex.MatchString(id) {
return fmt.Errorf("%s: %w (must be a valid UUID v4)", fieldName, ErrInvalidFormat)
}
return nil
}
// ==================== Enum Validation ====================
// ValidateEnum checks if value is in the allowed list (whitelist validation)
func (v *Validator) ValidateEnum(value, fieldName string, allowedValues []string) error {
value = strings.TrimSpace(value)
// Check required
if value == "" {
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
}
// Check if value is in allowed list
for _, allowed := range allowedValues {
if value == allowed {
return nil
}
}
return fmt.Errorf("%s: %w (allowed values: %s)", fieldName, ErrInvalidValue, strings.Join(allowedValues, ", "))
}
// ==================== Number Validation ====================
// ValidateRange checks if a number is within the specified range
func (v *Validator) ValidateRange(value int, fieldName string, min, max int) error {
if value < min {
return fmt.Errorf("%s: value must be at least %d", fieldName, min)
}
if max > 0 && value > max {
return fmt.Errorf("%s: value must be at most %d", fieldName, max)
}
return nil
}
// ==================== Sanitization ====================
// SanitizeString removes potentially dangerous characters and trims whitespace
func (v *Validator) SanitizeString(value string) string {
// Trim whitespace
value = strings.TrimSpace(value)
// Remove null bytes
value = strings.ReplaceAll(value, "\x00", "")
// Normalize Unicode
// Note: For production, consider using golang.org/x/text/unicode/norm
return value
}
// StripHTML removes all HTML tags from a string
func (v *Validator) StripHTML(value string) string {
return htmlTagRegex.ReplaceAllString(value, "")
}
// ==================== Combined Validations ====================
// ValidateAndSanitizeString performs validation and sanitization
func (v *Validator) ValidateAndSanitizeString(value, fieldName string, minLen, maxLen int) (string, error) {
// Sanitize first
value = v.SanitizeString(value)
// Validate required
if err := v.ValidateRequired(value, fieldName); err != nil {
return "", err
}
// Validate length
if err := v.ValidateLength(value, fieldName, minLen, maxLen); err != nil {
return "", err
}
// Validate printable characters
if err := v.ValidatePrintable(value, fieldName); err != nil {
return "", err
}
return value, nil
}

View file

@ -0,0 +1,472 @@
package validation
import (
"strings"
"testing"
)
func TestValidateRequired(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"Valid non-empty string", "test", false},
{"Empty string", "", true},
{"Whitespace only", " ", true},
{"Tab only", "\t", true},
{"Newline only", "\n", true},
{"Valid with spaces", "hello world", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateRequired(tt.value, "test_field")
if (err != nil) != tt.wantError {
t.Errorf("ValidateRequired() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateLength(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
min int
max int
wantError bool
}{
{"Valid length", "hello", 3, 10, false},
{"Too short", "ab", 3, 10, true},
{"Too long", "hello world this is too long", 3, 10, true},
{"Exact minimum", "abc", 3, 10, false},
{"Exact maximum", "0123456789", 3, 10, false},
{"No maximum (0)", "very long string here", 3, 0, false},
{"Whitespace counted correctly", " test ", 4, 10, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateLength(tt.value, "test_field", tt.min, tt.max)
if (err != nil) != tt.wantError {
t.Errorf("ValidateLength() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateEmail(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
email string
wantError bool
}{
// Valid emails
{"Valid email", "user@example.com", false},
{"Valid email with plus", "user+tag@example.com", false},
{"Valid email with dot", "first.last@example.com", false},
{"Valid email with hyphen", "user-name@example-domain.com", false},
{"Valid email with numbers", "user123@example456.com", false},
{"Valid email with subdomain", "user@sub.example.com", false},
// Invalid emails
{"Empty email", "", true},
{"Whitespace only", " ", true},
{"Missing @", "userexample.com", true},
{"Missing domain", "user@", true},
{"Missing local part", "@example.com", true},
{"No TLD", "user@localhost", true},
{"Consecutive dots in local", "user..name@example.com", true},
{"Leading dot in local", ".user@example.com", true},
{"Trailing dot in local", "user.@example.com", true},
{"Double @", "user@@example.com", true},
{"Spaces in email", "user name@example.com", true},
{"Invalid characters", "user<>@example.com", true},
{"Too long", strings.Repeat("a", 320) + "@example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateEmail(tt.email, "email")
if (err != nil) != tt.wantError {
t.Errorf("ValidateEmail() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateURL(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
url string
wantError bool
}{
// Valid URLs
{"Valid HTTP URL", "http://example.com", false},
{"Valid HTTPS URL", "https://example.com", false},
{"Valid URL with path", "https://example.com/path/to/resource", false},
{"Valid URL with query", "https://example.com?param=value", false},
{"Valid URL with port", "https://example.com:8080", false},
{"Valid URL with subdomain", "https://sub.example.com", false},
// Invalid URLs
{"Empty URL", "", true},
{"Whitespace only", " ", true},
{"Missing scheme", "example.com", true},
{"Invalid scheme", "ftp://example.com", true},
{"Missing host", "https://", true},
{"Invalid characters", "https://exam ple.com", true},
{"Too long", "https://" + strings.Repeat("a", 2048) + ".com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateURL(tt.url, "url")
if (err != nil) != tt.wantError {
t.Errorf("ValidateURL() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateHTTPSURL(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
url string
wantError bool
}{
{"Valid HTTPS URL", "https://example.com", false},
{"HTTP URL (should fail)", "http://example.com", true},
{"FTP URL (should fail)", "ftp://example.com", true},
{"Invalid URL", "not-a-url", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateHTTPSURL(tt.url, "url")
if (err != nil) != tt.wantError {
t.Errorf("ValidateHTTPSURL() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateDomain(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
domain string
wantError bool
}{
// Valid domains
{"Valid domain", "example.com", false},
{"Valid subdomain", "sub.example.com", false},
{"Valid deep subdomain", "deep.sub.example.com", false},
{"Valid with hyphen", "my-site.example.com", false},
{"Valid with numbers", "site123.example456.com", false},
// Invalid domains
{"Empty domain", "", true},
{"Whitespace only", " ", true},
{"Too short", "a.b", true},
{"Too long", strings.Repeat("a", 254) + ".com", true},
{"Label too long", strings.Repeat("a", 64) + ".example.com", true},
{"No TLD", "localhost", true},
{"Leading hyphen", "-example.com", true},
{"Trailing hyphen", "example-.com", true},
{"Double dot", "example..com", true},
{"Leading dot", ".example.com", true},
{"Trailing dot", "example.com.", true},
{"Underscore", "my_site.example.com", true},
{"Spaces", "my site.example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateDomain(tt.domain, "domain")
if (err != nil) != tt.wantError {
t.Errorf("ValidateDomain() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateSlug(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
slug string
wantError bool
}{
// Valid slugs
{"Valid slug", "my-company", false},
{"Valid slug with numbers", "company123", false},
{"Valid slug all lowercase", "testcompany", false},
{"Valid slug with multiple hyphens", "my-test-company", false},
// Invalid slugs
{"Empty slug", "", true},
{"Whitespace only", " ", true},
{"Too short", "ab", true},
{"Too long", strings.Repeat("a", 64), true},
{"Uppercase letters", "MyCompany", true},
{"Leading hyphen", "-company", true},
{"Trailing hyphen", "company-", true},
{"Double hyphen", "my--company", true},
{"Underscore", "my_company", true},
{"Spaces", "my company", true},
{"Special characters", "my@company", true},
// Reserved slugs
{"Reserved: api", "api", true},
{"Reserved: admin", "admin", true},
{"Reserved: www", "www", true},
{"Reserved: login", "login", true},
{"Reserved: register", "register", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateSlug(tt.slug, "slug")
if (err != nil) != tt.wantError {
t.Errorf("ValidateSlug() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateUUID(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
uuid string
wantError bool
}{
{"Valid UUID v4", "550e8400-e29b-41d4-a716-446655440000", false},
{"Valid UUID v4 lowercase", "123e4567-e89b-42d3-a456-426614174000", false},
{"Empty UUID", "", true},
{"Invalid format", "not-a-uuid", true},
{"Invalid version", "550e8400-e29b-21d4-a716-446655440000", true},
{"Missing hyphens", "550e8400e29b41d4a716446655440000", true},
{"Too short", "550e8400-e29b-41d4-a716", true},
{"With uppercase", "550E8400-E29B-41D4-A716-446655440000", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateUUID(tt.uuid, "id")
if (err != nil) != tt.wantError {
t.Errorf("ValidateUUID() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateEnum(t *testing.T) {
v := NewValidator()
allowedValues := []string{"free", "basic", "pro", "enterprise"}
tests := []struct {
name string
value string
wantError bool
}{
{"Valid: free", "free", false},
{"Valid: basic", "basic", false},
{"Valid: pro", "pro", false},
{"Valid: enterprise", "enterprise", false},
{"Invalid: premium", "premium", true},
{"Invalid: empty", "", true},
{"Invalid: wrong case", "FREE", true},
{"Invalid: typo", "basi", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateEnum(tt.value, "plan_tier", allowedValues)
if (err != nil) != tt.wantError {
t.Errorf("ValidateEnum() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateRange(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value int
min int
max int
wantError bool
}{
{"Valid within range", 5, 1, 10, false},
{"Valid at minimum", 1, 1, 10, false},
{"Valid at maximum", 10, 1, 10, false},
{"Below minimum", 0, 1, 10, true},
{"Above maximum", 11, 1, 10, true},
{"No maximum (0)", 1000, 1, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateRange(tt.value, "count", tt.min, tt.max)
if (err != nil) != tt.wantError {
t.Errorf("ValidateRange() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestValidateNoHTML(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"Plain text", "Hello world", false},
{"Text with punctuation", "Hello, world!", false},
{"HTML tag <script>", "<script>alert('xss')</script>", true},
{"HTML tag <img>", "<img src='x'>", true},
{"HTML tag <div>", "<div>content</div>", true},
{"HTML tag <a>", "<a href='#'>link</a>", true},
{"Less than symbol", "5 < 10", false},
{"Greater than symbol", "10 > 5", false},
{"Both symbols", "5 < x < 10", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidateNoHTML(tt.value, "content")
if (err != nil) != tt.wantError {
t.Errorf("ValidateNoHTML() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestSanitizeString(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
expected string
}{
{"Trim whitespace", " hello ", "hello"},
{"Remove null bytes", "hello\x00world", "helloworld"},
{"Already clean", "hello", "hello"},
{"Empty string", "", ""},
{"Only whitespace", " ", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := v.SanitizeString(tt.input)
if result != tt.expected {
t.Errorf("SanitizeString() = %q, want %q", result, tt.expected)
}
})
}
}
func TestStripHTML(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
expected string
}{
{"Remove script tag", "<script>alert('xss')</script>", "alert('xss')"},
{"Remove div tag", "<div>content</div>", "content"},
{"Remove multiple tags", "<p>Hello <b>world</b></p>", "Hello world"},
{"No tags", "plain text", "plain text"},
{"Empty string", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := v.StripHTML(tt.input)
if result != tt.expected {
t.Errorf("StripHTML() = %q, want %q", result, tt.expected)
}
})
}
}
func TestValidateAndSanitizeString(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
input string
minLen int
maxLen int
wantValue string
wantError bool
}{
{"Valid and clean", "hello", 3, 10, "hello", false},
{"Trim and validate", " hello ", 3, 10, "hello", false},
{"Too short after trim", " a ", 3, 10, "", true},
{"Too long", "hello world this is too long", 3, 10, "", true},
{"Empty after trim", " ", 3, 10, "", true},
{"Valid with null byte removed", "hel\x00lo", 3, 10, "hello", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := v.ValidateAndSanitizeString(tt.input, "test_field", tt.minLen, tt.maxLen)
if (err != nil) != tt.wantError {
t.Errorf("ValidateAndSanitizeString() error = %v, wantError %v", err, tt.wantError)
}
if !tt.wantError && result != tt.wantValue {
t.Errorf("ValidateAndSanitizeString() = %q, want %q", result, tt.wantValue)
}
})
}
}
func TestValidatePrintable(t *testing.T) {
v := NewValidator()
tests := []struct {
name string
value string
wantError bool
}{
{"All printable", "Hello World 123!", false},
{"With tabs and newlines", "Hello\tWorld\n", false},
{"With control character", "Hello\x01World", true},
{"With bell character", "Hello\x07", true},
{"Empty string", "", false},
{"Unicode printable", "Hello 世界", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := v.ValidatePrintable(tt.value, "test_field")
if (err != nil) != tt.wantError {
t.Errorf("ValidatePrintable() error = %v, wantError %v", err, tt.wantError)
}
})
}
}