Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
275
cloud/maplepress-backend/pkg/validation/email.go
Normal file
275
cloud/maplepress-backend/pkg/validation/email.go
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
// File Path: monorepo/cloud/maplepress-backend/pkg/validation/email.go
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EmailValidator provides comprehensive email validation and normalization
|
||||
// CWE-20: Improper Input Validation - Ensures email addresses are properly validated
|
||||
type EmailValidator struct {
|
||||
validator *Validator
|
||||
}
|
||||
|
||||
// NewEmailValidator creates a new email validator
|
||||
func NewEmailValidator() *EmailValidator {
|
||||
return &EmailValidator{
|
||||
validator: NewValidator(),
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateAndNormalize validates and normalizes an email address
|
||||
// Returns the normalized email and any validation error
|
||||
func (ev *EmailValidator) ValidateAndNormalize(email, fieldName string) (string, error) {
|
||||
// Step 1: Basic validation using existing validator
|
||||
if err := ev.validator.ValidateEmail(email, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Step 2: Normalize the email
|
||||
normalized := ev.Normalize(email)
|
||||
|
||||
// Step 3: Additional security checks
|
||||
if err := ev.ValidateSecurityConstraints(normalized, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// Normalize normalizes an email address for consistent storage and comparison
|
||||
// CWE-180: Incorrect Behavior Order: Validate Before Canonicalize
|
||||
func (ev *EmailValidator) Normalize(email string) string {
|
||||
// Trim whitespace
|
||||
email = strings.TrimSpace(email)
|
||||
|
||||
// Convert to lowercase (email local parts are case-sensitive per RFC 5321,
|
||||
// but most providers treat them as case-insensitive for better UX)
|
||||
email = strings.ToLower(email)
|
||||
|
||||
// Remove any null bytes
|
||||
email = strings.ReplaceAll(email, "\x00", "")
|
||||
|
||||
// Gmail-specific normalization (optional - commented out by default)
|
||||
// This removes dots and plus-aliases from Gmail addresses
|
||||
// Uncomment if you want to prevent abuse via Gmail aliases
|
||||
// email = ev.normalizeGmail(email)
|
||||
|
||||
return email
|
||||
}
|
||||
|
||||
// ValidateSecurityConstraints performs additional security validation
|
||||
func (ev *EmailValidator) ValidateSecurityConstraints(email, fieldName string) error {
|
||||
// Check for suspicious patterns
|
||||
|
||||
// 1. Detect emails with excessive special characters (potential obfuscation)
|
||||
specialCharCount := 0
|
||||
for _, ch := range email {
|
||||
if ch == '+' || ch == '.' || ch == '_' || ch == '-' || ch == '%' {
|
||||
specialCharCount++
|
||||
}
|
||||
}
|
||||
if specialCharCount > 10 {
|
||||
return fmt.Errorf("%s: contains too many special characters", fieldName)
|
||||
}
|
||||
|
||||
// 2. Detect potentially disposable email patterns
|
||||
if ev.isLikelyDisposable(email) {
|
||||
// Note: This is a warning-level check. In production, you might want to
|
||||
// either reject these or flag them for review.
|
||||
// For now, we'll allow them but this can be configured.
|
||||
}
|
||||
|
||||
// 3. Check for common typos in popular domains
|
||||
if typo := ev.detectCommonDomainTypo(email); typo != "" {
|
||||
return fmt.Errorf("%s: possible typo detected, did you mean %s?", fieldName, typo)
|
||||
}
|
||||
|
||||
// 4. Prevent IP-based email addresses
|
||||
if ev.hasIPAddress(email) {
|
||||
return fmt.Errorf("%s: IP-based email addresses are not allowed", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLikelyDisposable checks if email is from a known disposable email provider
|
||||
// This is a basic implementation - in production, use a service like:
|
||||
// - https://github.com/disposable/disposable-email-domains
|
||||
// - or an API service
|
||||
func (ev *EmailValidator) isLikelyDisposable(email string) bool {
|
||||
// Extract domain
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Common disposable email patterns
|
||||
disposablePatterns := []string{
|
||||
"temp",
|
||||
"disposable",
|
||||
"throwaway",
|
||||
"guerrilla",
|
||||
"mailinator",
|
||||
"10minute",
|
||||
"trashmail",
|
||||
"yopmail",
|
||||
"fakeinbox",
|
||||
}
|
||||
|
||||
for _, pattern := range disposablePatterns {
|
||||
if strings.Contains(domain, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Known disposable domains (small sample - expand as needed)
|
||||
disposableDomains := map[string]bool{
|
||||
"mailinator.com": true,
|
||||
"guerrillamail.com": true,
|
||||
"10minutemail.com": true,
|
||||
"tempmailaddress.com": true,
|
||||
"yopmail.com": true,
|
||||
"fakeinbox.com": true,
|
||||
"trashmail.com": true,
|
||||
"throwaway.email": true,
|
||||
}
|
||||
|
||||
return disposableDomains[domain]
|
||||
}
|
||||
|
||||
// detectCommonDomainTypo checks for common typos in popular email domains
|
||||
func (ev *EmailValidator) detectCommonDomainTypo(email string) string {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
localPart := parts[0]
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Common typos map: typo -> correct
|
||||
typos := map[string]string{
|
||||
"gmial.com": "gmail.com",
|
||||
"gmai.com": "gmail.com",
|
||||
"gmil.com": "gmail.com",
|
||||
"yahooo.com": "yahoo.com",
|
||||
"yaho.com": "yahoo.com",
|
||||
"hotmial.com": "hotmail.com",
|
||||
"hotmal.com": "hotmail.com",
|
||||
"outlok.com": "outlook.com",
|
||||
"outloo.com": "outlook.com",
|
||||
"iclodu.com": "icloud.com",
|
||||
"iclod.com": "icloud.com",
|
||||
"protonmai.com": "protonmail.com",
|
||||
"protonmal.com": "protonmail.com",
|
||||
}
|
||||
|
||||
if correct, found := typos[domain]; found {
|
||||
return localPart + "@" + correct
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// hasIPAddress checks if email domain is an IP address
|
||||
func (ev *EmailValidator) hasIPAddress(email string) bool {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := parts[1]
|
||||
|
||||
// Check for IPv4 pattern: [192.168.1.1]
|
||||
if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unbracketed IP patterns (less common but possible)
|
||||
// Simple heuristic: contains only digits and dots
|
||||
hasOnlyDigitsAndDots := true
|
||||
for _, ch := range domain {
|
||||
if ch != '.' && (ch < '0' || ch > '9') {
|
||||
hasOnlyDigitsAndDots = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return hasOnlyDigitsAndDots && strings.Count(domain, ".") >= 3
|
||||
}
|
||||
|
||||
// normalizeGmail normalizes Gmail addresses by removing dots and plus-aliases
|
||||
// Gmail ignores dots in the local part and treats everything after + as an alias
|
||||
// Example: john.doe+test@gmail.com -> johndoe@gmail.com
|
||||
func (ev *EmailValidator) normalizeGmail(email string) string {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return email
|
||||
}
|
||||
|
||||
localPart := parts[0]
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Only normalize for Gmail and Googlemail
|
||||
if domain != "gmail.com" && domain != "googlemail.com" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Remove dots from local part
|
||||
localPart = strings.ReplaceAll(localPart, ".", "")
|
||||
|
||||
// Remove everything after + (plus-alias)
|
||||
if plusIndex := strings.Index(localPart, "+"); plusIndex != -1 {
|
||||
localPart = localPart[:plusIndex]
|
||||
}
|
||||
|
||||
return localPart + "@" + domain
|
||||
}
|
||||
|
||||
// ValidateEmailList validates a list of email addresses
|
||||
// Returns the first error encountered, or nil if all are valid
|
||||
func (ev *EmailValidator) ValidateEmailList(emails []string, fieldName string) ([]string, error) {
|
||||
normalized := make([]string, 0, len(emails))
|
||||
|
||||
for i, email := range emails {
|
||||
norm, err := ev.ValidateAndNormalize(email, fmt.Sprintf("%s[%d]", fieldName, i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
normalized = append(normalized, norm)
|
||||
}
|
||||
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// IsValidEmailDomain checks if a domain is likely valid (has proper structure)
|
||||
// This is a lightweight check - for production, consider DNS MX record validation
|
||||
func (ev *EmailValidator) IsValidEmailDomain(email string) bool {
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := strings.ToLower(parts[1])
|
||||
|
||||
// Must have at least one dot
|
||||
if !strings.Contains(domain, ".") {
|
||||
return false
|
||||
}
|
||||
|
||||
// TLD must be at least 2 characters
|
||||
tldParts := strings.Split(domain, ".")
|
||||
if len(tldParts) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
tld := tldParts[len(tldParts)-1]
|
||||
if len(tld) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
120
cloud/maplepress-backend/pkg/validation/helpers.go
Normal file
120
cloud/maplepress-backend/pkg/validation/helpers.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// ValidatePathUUID validates a UUID path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathUUID(r *http.Request, paramName string) (string, error) {
|
||||
value := r.PathValue(paramName)
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
validator := NewValidator()
|
||||
if err := validator.ValidateUUID(value, paramName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePathSlug validates a slug path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathSlug(r *http.Request, paramName string) (string, error) {
|
||||
value := r.PathValue(paramName)
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
validator := NewValidator()
|
||||
if err := validator.ValidateSlug(value, paramName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePathInt validates an integer path parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidatePathInt(r *http.Request, paramName string) (int64, error) {
|
||||
valueStr := r.PathValue(paramName)
|
||||
if valueStr == "" {
|
||||
return 0, fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
value, err := strconv.ParseInt(valueStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%s must be a valid integer", paramName)
|
||||
}
|
||||
|
||||
if value <= 0 {
|
||||
return 0, fmt.Errorf("%s must be greater than 0", paramName)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// ValidatePagination validates pagination query parameters
|
||||
// Returns limit and offset with defaults and bounds checking
|
||||
func ValidatePagination(r *http.Request, defaultLimit int) (limit int, offset int, err error) {
|
||||
limit = defaultLimit
|
||||
offset = 0
|
||||
|
||||
// Validate limit
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
parsedLimit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || parsedLimit <= 0 || parsedLimit > 100 {
|
||||
return 0, 0, fmt.Errorf("limit must be between 1 and 100")
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
|
||||
// Validate offset
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
parsedOffset, err := strconv.Atoi(offsetStr)
|
||||
if err != nil || parsedOffset < 0 {
|
||||
return 0, 0, fmt.Errorf("offset must be >= 0")
|
||||
}
|
||||
offset = parsedOffset
|
||||
}
|
||||
|
||||
return limit, offset, nil
|
||||
}
|
||||
|
||||
// ValidateSortField validates sort field against whitelist
|
||||
// CWE-89: SQL Injection prevention via whitelist
|
||||
func ValidateSortField(r *http.Request, allowedFields []string) (string, error) {
|
||||
sortBy := r.URL.Query().Get("sort_by")
|
||||
if sortBy == "" {
|
||||
return "", nil // Optional field
|
||||
}
|
||||
|
||||
for _, allowed := range allowedFields {
|
||||
if sortBy == allowed {
|
||||
return sortBy, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("invalid sort_by field (allowed: %v)", allowedFields)
|
||||
}
|
||||
|
||||
// ValidateQueryEmail validates an email query parameter
|
||||
// CWE-20: Improper Input Validation
|
||||
func ValidateQueryEmail(r *http.Request, paramName string) (string, error) {
|
||||
email := r.URL.Query().Get(paramName)
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("%s is required", paramName)
|
||||
}
|
||||
|
||||
emailValidator := NewEmailValidator()
|
||||
normalizedEmail, err := emailValidator.ValidateAndNormalize(email, paramName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return normalizedEmail, nil
|
||||
}
|
||||
6
cloud/maplepress-backend/pkg/validation/provider.go
Normal file
6
cloud/maplepress-backend/pkg/validation/provider.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
package validation
|
||||
|
||||
// ProvideValidator provides a Validator instance
|
||||
func ProvideValidator() *Validator {
|
||||
return NewValidator()
|
||||
}
|
||||
498
cloud/maplepress-backend/pkg/validation/validator.go
Normal file
498
cloud/maplepress-backend/pkg/validation/validator.go
Normal file
|
|
@ -0,0 +1,498 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Common validation errors
|
||||
var (
|
||||
ErrRequired = fmt.Errorf("field is required")
|
||||
ErrInvalidEmail = fmt.Errorf("invalid email format")
|
||||
ErrInvalidURL = fmt.Errorf("invalid URL format")
|
||||
ErrInvalidDomain = fmt.Errorf("invalid domain format")
|
||||
ErrTooShort = fmt.Errorf("value is too short")
|
||||
ErrTooLong = fmt.Errorf("value is too long")
|
||||
ErrInvalidCharacters = fmt.Errorf("contains invalid characters")
|
||||
ErrInvalidFormat = fmt.Errorf("invalid format")
|
||||
ErrInvalidValue = fmt.Errorf("invalid value")
|
||||
ErrWhitespaceOnly = fmt.Errorf("cannot contain only whitespace")
|
||||
ErrContainsHTML = fmt.Errorf("cannot contain HTML tags")
|
||||
ErrInvalidSlug = fmt.Errorf("invalid slug format")
|
||||
)
|
||||
|
||||
// Regex patterns for validation
|
||||
var (
|
||||
// Email validation: RFC 5322 compliant
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._%+\-]*[a-zA-Z0-9]@[a-zA-Z0-9][a-zA-Z0-9.\-]*[a-zA-Z0-9]\.[a-zA-Z]{2,}$`)
|
||||
|
||||
// Domain validation: alphanumeric with dots and hyphens
|
||||
domainRegex = regexp.MustCompile(`^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$`)
|
||||
|
||||
// Slug validation: lowercase alphanumeric with hyphens
|
||||
slugRegex = regexp.MustCompile(`^[a-z0-9]+(?:-[a-z0-9]+)*$`)
|
||||
|
||||
// HTML tag detection
|
||||
htmlTagRegex = regexp.MustCompile(`<[^>]+>`)
|
||||
|
||||
// UUID validation (version 4)
|
||||
uuidRegex = regexp.MustCompile(`^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$`)
|
||||
|
||||
// Alphanumeric only
|
||||
alphanumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
|
||||
)
|
||||
|
||||
// Reserved slugs that cannot be used for tenant names
|
||||
var ReservedSlugs = map[string]bool{
|
||||
"api": true,
|
||||
"admin": true,
|
||||
"www": true,
|
||||
"mail": true,
|
||||
"email": true,
|
||||
"health": true,
|
||||
"status": true,
|
||||
"metrics": true,
|
||||
"static": true,
|
||||
"cdn": true,
|
||||
"assets": true,
|
||||
"blog": true,
|
||||
"docs": true,
|
||||
"help": true,
|
||||
"support": true,
|
||||
"login": true,
|
||||
"logout": true,
|
||||
"signup": true,
|
||||
"register": true,
|
||||
"app": true,
|
||||
"dashboard": true,
|
||||
"settings": true,
|
||||
"account": true,
|
||||
"profile": true,
|
||||
"root": true,
|
||||
"system": true,
|
||||
"public": true,
|
||||
"private": true,
|
||||
}
|
||||
|
||||
// Validator provides input validation utilities
|
||||
type Validator struct{}
|
||||
|
||||
// NewValidator creates a new validator instance
|
||||
func NewValidator() *Validator {
|
||||
return &Validator{}
|
||||
}
|
||||
|
||||
// ==================== String Validation ====================
|
||||
|
||||
// ValidateRequired checks if a string is not empty
|
||||
func (v *Validator) ValidateRequired(value, fieldName string) error {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateLength checks if string length is within range
|
||||
func (v *Validator) ValidateLength(value, fieldName string, min, max int) error {
|
||||
length := len(strings.TrimSpace(value))
|
||||
|
||||
if length < min {
|
||||
return fmt.Errorf("%s: %w (minimum %d characters)", fieldName, ErrTooShort, min)
|
||||
}
|
||||
|
||||
if max > 0 && length > max {
|
||||
return fmt.Errorf("%s: %w (maximum %d characters)", fieldName, ErrTooLong, max)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNotWhitespaceOnly ensures the string contains non-whitespace characters
|
||||
func (v *Validator) ValidateNotWhitespaceOnly(value, fieldName string) error {
|
||||
if len(strings.TrimSpace(value)) == 0 && len(value) > 0 {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrWhitespaceOnly)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNoHTML checks that the string doesn't contain HTML tags
|
||||
func (v *Validator) ValidateNoHTML(value, fieldName string) error {
|
||||
if htmlTagRegex.MatchString(value) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrContainsHTML)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateAlphanumeric checks if string contains only alphanumeric characters
|
||||
func (v *Validator) ValidateAlphanumeric(value, fieldName string) error {
|
||||
if !alphanumericRegex.MatchString(value) {
|
||||
return fmt.Errorf("%s: %w (only letters and numbers allowed)", fieldName, ErrInvalidCharacters)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePrintable ensures string contains only printable characters
|
||||
func (v *Validator) ValidatePrintable(value, fieldName string) error {
|
||||
for _, r := range value {
|
||||
if !unicode.IsPrint(r) && !unicode.IsSpace(r) {
|
||||
return fmt.Errorf("%s: %w (contains non-printable characters)", fieldName, ErrInvalidCharacters)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Email Validation ====================
|
||||
|
||||
// ValidateEmail validates email format using RFC 5322 compliant regex
|
||||
func (v *Validator) ValidateEmail(email, fieldName string) error {
|
||||
email = strings.TrimSpace(email)
|
||||
|
||||
// Check required
|
||||
if email == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (RFC 5321: max 320 chars)
|
||||
if len(email) > 320 {
|
||||
return fmt.Errorf("%s: %w (maximum 320 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Validate using regex
|
||||
if !emailRegex.MatchString(email) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Additional validation using net/mail package
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Check for consecutive dots
|
||||
if strings.Contains(email, "..") {
|
||||
return fmt.Errorf("%s: %w (consecutive dots not allowed)", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
|
||||
// Check for leading/trailing dots in local part
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) == 2 {
|
||||
if strings.HasPrefix(parts[0], ".") || strings.HasSuffix(parts[0], ".") {
|
||||
return fmt.Errorf("%s: %w (local part cannot start or end with dot)", fieldName, ErrInvalidEmail)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== URL Validation ====================
|
||||
|
||||
// ValidateURL validates URL format and ensures it has a valid scheme
|
||||
func (v *Validator) ValidateURL(urlStr, fieldName string) error {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
|
||||
// Check required
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (max 2048 chars for URL)
|
||||
if len(urlStr) > 2048 {
|
||||
return fmt.Errorf("%s: %w (maximum 2048 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Parse URL
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Ensure scheme is present and valid
|
||||
if parsedURL.Scheme == "" {
|
||||
return fmt.Errorf("%s: %w (missing scheme)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Only allow http and https
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("%s: %w (only http and https schemes allowed)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
// Ensure host is present
|
||||
if parsedURL.Host == "" {
|
||||
return fmt.Errorf("%s: %w (missing host)", fieldName, ErrInvalidURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateHTTPSURL validates URL and ensures it uses HTTPS
|
||||
func (v *Validator) ValidateHTTPSURL(urlStr, fieldName string) error {
|
||||
if err := v.ValidateURL(urlStr, fieldName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: invalid URL format", fieldName)
|
||||
}
|
||||
if parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("%s: must use HTTPS protocol", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Domain Validation ====================
|
||||
|
||||
// ValidateDomain validates domain name format
|
||||
// Supports standard domains (example.com) and localhost with ports (localhost:8081) for development
|
||||
func (v *Validator) ValidateDomain(domain, fieldName string) error {
|
||||
domain = strings.TrimSpace(strings.ToLower(domain))
|
||||
|
||||
// Check required
|
||||
if domain == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (max 253 chars per RFC 1035)
|
||||
if len(domain) > 253 {
|
||||
return fmt.Errorf("%s: %w (maximum 253 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(domain) < 4 {
|
||||
return fmt.Errorf("%s: %w (minimum 4 characters)", fieldName, ErrTooShort)
|
||||
}
|
||||
|
||||
// Allow localhost with optional port for development
|
||||
// Examples: localhost, localhost:8080, localhost:3000
|
||||
if strings.HasPrefix(domain, "localhost") {
|
||||
// If it has a port, validate the port format
|
||||
if strings.Contains(domain, ":") {
|
||||
parts := strings.Split(domain, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("%s: %w (invalid localhost format)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
// Port should be numeric
|
||||
if parts[1] == "" {
|
||||
return fmt.Errorf("%s: %w (missing port number)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
// Basic port validation (could be more strict)
|
||||
for _, c := range parts[1] {
|
||||
if c < '0' || c > '9' {
|
||||
return fmt.Errorf("%s: %w (port must be numeric)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Allow 127.0.0.1 and other local IPs with optional port for development
|
||||
if strings.HasPrefix(domain, "127.") || strings.HasPrefix(domain, "192.168.") || strings.HasPrefix(domain, "10.") {
|
||||
// If it has a port, just verify format (IP:port)
|
||||
if strings.Contains(domain, ":") {
|
||||
parts := strings.Split(domain, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("%s: %w (invalid IP format)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate standard domain format (example.com)
|
||||
if !domainRegex.MatchString(domain) {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
|
||||
// Check each label length (max 63 chars per RFC 1035)
|
||||
labels := strings.Split(domain, ".")
|
||||
for _, label := range labels {
|
||||
if len(label) > 63 {
|
||||
return fmt.Errorf("%s: %w (label exceeds 63 characters)", fieldName, ErrInvalidDomain)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Slug Validation ====================
|
||||
|
||||
// ValidateSlug validates slug format (lowercase alphanumeric with hyphens)
|
||||
func (v *Validator) ValidateSlug(slug, fieldName string) error {
|
||||
slug = strings.TrimSpace(strings.ToLower(slug))
|
||||
|
||||
// Check required
|
||||
if slug == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check length (3-63 chars)
|
||||
if len(slug) < 3 {
|
||||
return fmt.Errorf("%s: %w (minimum 3 characters)", fieldName, ErrTooShort)
|
||||
}
|
||||
|
||||
if len(slug) > 63 {
|
||||
return fmt.Errorf("%s: %w (maximum 63 characters)", fieldName, ErrTooLong)
|
||||
}
|
||||
|
||||
// Validate format
|
||||
if !slugRegex.MatchString(slug) {
|
||||
return fmt.Errorf("%s: %w (only lowercase letters, numbers, and hyphens allowed)", fieldName, ErrInvalidSlug)
|
||||
}
|
||||
|
||||
// Check for reserved slugs
|
||||
if ReservedSlugs[slug] {
|
||||
return fmt.Errorf("%s: '%s' is a reserved slug and cannot be used", fieldName, slug)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSlug generates a URL-friendly slug from a name
|
||||
// Converts to lowercase, replaces spaces and special chars with hyphens
|
||||
// Ensures the slug matches the slug validation regex
|
||||
func (v *Validator) GenerateSlug(name string) string {
|
||||
// Convert to lowercase and trim spaces
|
||||
slug := strings.TrimSpace(strings.ToLower(name))
|
||||
|
||||
// Replace any non-alphanumeric characters (except hyphens) with hyphens
|
||||
var result strings.Builder
|
||||
prevWasHyphen := false
|
||||
|
||||
for _, char := range slug {
|
||||
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') {
|
||||
result.WriteRune(char)
|
||||
prevWasHyphen = false
|
||||
} else if !prevWasHyphen {
|
||||
// Replace any non-alphanumeric character with a hyphen
|
||||
// But don't add consecutive hyphens
|
||||
result.WriteRune('-')
|
||||
prevWasHyphen = true
|
||||
}
|
||||
}
|
||||
|
||||
slug = result.String()
|
||||
|
||||
// Remove leading and trailing hyphens
|
||||
slug = strings.Trim(slug, "-")
|
||||
|
||||
// Enforce length constraints (3-63 chars)
|
||||
if len(slug) < 3 {
|
||||
// If too short, pad with random suffix
|
||||
slug = slug + "-" + strings.ToLower(fmt.Sprintf("%d", time.Now().UnixNano()%10000))
|
||||
}
|
||||
|
||||
if len(slug) > 63 {
|
||||
// Truncate to 63 chars
|
||||
slug = slug[:63]
|
||||
// Remove trailing hyphen if any
|
||||
slug = strings.TrimRight(slug, "-")
|
||||
}
|
||||
|
||||
return slug
|
||||
}
|
||||
|
||||
// ==================== UUID Validation ====================
|
||||
|
||||
// ValidateUUID validates UUID format (version 4)
|
||||
func (v *Validator) ValidateUUID(id, fieldName string) error {
|
||||
id = strings.TrimSpace(strings.ToLower(id))
|
||||
|
||||
// Check required
|
||||
if id == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Validate format
|
||||
if !uuidRegex.MatchString(id) {
|
||||
return fmt.Errorf("%s: %w (must be a valid UUID v4)", fieldName, ErrInvalidFormat)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Enum Validation ====================
|
||||
|
||||
// ValidateEnum checks if value is in the allowed list (whitelist validation)
|
||||
func (v *Validator) ValidateEnum(value, fieldName string, allowedValues []string) error {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Check required
|
||||
if value == "" {
|
||||
return fmt.Errorf("%s: %w", fieldName, ErrRequired)
|
||||
}
|
||||
|
||||
// Check if value is in allowed list
|
||||
for _, allowed := range allowedValues {
|
||||
if value == allowed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: %w (allowed values: %s)", fieldName, ErrInvalidValue, strings.Join(allowedValues, ", "))
|
||||
}
|
||||
|
||||
// ==================== Number Validation ====================
|
||||
|
||||
// ValidateRange checks if a number is within the specified range
|
||||
func (v *Validator) ValidateRange(value int, fieldName string, min, max int) error {
|
||||
if value < min {
|
||||
return fmt.Errorf("%s: value must be at least %d", fieldName, min)
|
||||
}
|
||||
|
||||
if max > 0 && value > max {
|
||||
return fmt.Errorf("%s: value must be at most %d", fieldName, max)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== Sanitization ====================
|
||||
|
||||
// SanitizeString removes potentially dangerous characters and trims whitespace
|
||||
func (v *Validator) SanitizeString(value string) string {
|
||||
// Trim whitespace
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
// Remove null bytes
|
||||
value = strings.ReplaceAll(value, "\x00", "")
|
||||
|
||||
// Normalize Unicode
|
||||
// Note: For production, consider using golang.org/x/text/unicode/norm
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// StripHTML removes all HTML tags from a string
|
||||
func (v *Validator) StripHTML(value string) string {
|
||||
return htmlTagRegex.ReplaceAllString(value, "")
|
||||
}
|
||||
|
||||
// ==================== Combined Validations ====================
|
||||
|
||||
// ValidateAndSanitizeString performs validation and sanitization
|
||||
func (v *Validator) ValidateAndSanitizeString(value, fieldName string, minLen, maxLen int) (string, error) {
|
||||
// Sanitize first
|
||||
value = v.SanitizeString(value)
|
||||
|
||||
// Validate required
|
||||
if err := v.ValidateRequired(value, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate length
|
||||
if err := v.ValidateLength(value, fieldName, minLen, maxLen); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Validate printable characters
|
||||
if err := v.ValidatePrintable(value, fieldName); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
472
cloud/maplepress-backend/pkg/validation/validator_test.go
Normal file
472
cloud/maplepress-backend/pkg/validation/validator_test.go
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateRequired(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid non-empty string", "test", false},
|
||||
{"Empty string", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Tab only", "\t", true},
|
||||
{"Newline only", "\n", true},
|
||||
{"Valid with spaces", "hello world", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateRequired(tt.value, "test_field")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateRequired() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateLength(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
min int
|
||||
max int
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid length", "hello", 3, 10, false},
|
||||
{"Too short", "ab", 3, 10, true},
|
||||
{"Too long", "hello world this is too long", 3, 10, true},
|
||||
{"Exact minimum", "abc", 3, 10, false},
|
||||
{"Exact maximum", "0123456789", 3, 10, false},
|
||||
{"No maximum (0)", "very long string here", 3, 0, false},
|
||||
{"Whitespace counted correctly", " test ", 4, 10, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateLength(tt.value, "test_field", tt.min, tt.max)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateLength() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid emails
|
||||
{"Valid email", "user@example.com", false},
|
||||
{"Valid email with plus", "user+tag@example.com", false},
|
||||
{"Valid email with dot", "first.last@example.com", false},
|
||||
{"Valid email with hyphen", "user-name@example-domain.com", false},
|
||||
{"Valid email with numbers", "user123@example456.com", false},
|
||||
{"Valid email with subdomain", "user@sub.example.com", false},
|
||||
|
||||
// Invalid emails
|
||||
{"Empty email", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Missing @", "userexample.com", true},
|
||||
{"Missing domain", "user@", true},
|
||||
{"Missing local part", "@example.com", true},
|
||||
{"No TLD", "user@localhost", true},
|
||||
{"Consecutive dots in local", "user..name@example.com", true},
|
||||
{"Leading dot in local", ".user@example.com", true},
|
||||
{"Trailing dot in local", "user.@example.com", true},
|
||||
{"Double @", "user@@example.com", true},
|
||||
{"Spaces in email", "user name@example.com", true},
|
||||
{"Invalid characters", "user<>@example.com", true},
|
||||
{"Too long", strings.Repeat("a", 320) + "@example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateEmail(tt.email, "email")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateEmail() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateURL(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid URLs
|
||||
{"Valid HTTP URL", "http://example.com", false},
|
||||
{"Valid HTTPS URL", "https://example.com", false},
|
||||
{"Valid URL with path", "https://example.com/path/to/resource", false},
|
||||
{"Valid URL with query", "https://example.com?param=value", false},
|
||||
{"Valid URL with port", "https://example.com:8080", false},
|
||||
{"Valid URL with subdomain", "https://sub.example.com", false},
|
||||
|
||||
// Invalid URLs
|
||||
{"Empty URL", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Missing scheme", "example.com", true},
|
||||
{"Invalid scheme", "ftp://example.com", true},
|
||||
{"Missing host", "https://", true},
|
||||
{"Invalid characters", "https://exam ple.com", true},
|
||||
{"Too long", "https://" + strings.Repeat("a", 2048) + ".com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateURL(tt.url, "url")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateURL() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateHTTPSURL(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid HTTPS URL", "https://example.com", false},
|
||||
{"HTTP URL (should fail)", "http://example.com", true},
|
||||
{"FTP URL (should fail)", "ftp://example.com", true},
|
||||
{"Invalid URL", "not-a-url", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateHTTPSURL(tt.url, "url")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateHTTPSURL() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid domains
|
||||
{"Valid domain", "example.com", false},
|
||||
{"Valid subdomain", "sub.example.com", false},
|
||||
{"Valid deep subdomain", "deep.sub.example.com", false},
|
||||
{"Valid with hyphen", "my-site.example.com", false},
|
||||
{"Valid with numbers", "site123.example456.com", false},
|
||||
|
||||
// Invalid domains
|
||||
{"Empty domain", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Too short", "a.b", true},
|
||||
{"Too long", strings.Repeat("a", 254) + ".com", true},
|
||||
{"Label too long", strings.Repeat("a", 64) + ".example.com", true},
|
||||
{"No TLD", "localhost", true},
|
||||
{"Leading hyphen", "-example.com", true},
|
||||
{"Trailing hyphen", "example-.com", true},
|
||||
{"Double dot", "example..com", true},
|
||||
{"Leading dot", ".example.com", true},
|
||||
{"Trailing dot", "example.com.", true},
|
||||
{"Underscore", "my_site.example.com", true},
|
||||
{"Spaces", "my site.example.com", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateDomain(tt.domain, "domain")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateDomain() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSlug(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
slug string
|
||||
wantError bool
|
||||
}{
|
||||
// Valid slugs
|
||||
{"Valid slug", "my-company", false},
|
||||
{"Valid slug with numbers", "company123", false},
|
||||
{"Valid slug all lowercase", "testcompany", false},
|
||||
{"Valid slug with multiple hyphens", "my-test-company", false},
|
||||
|
||||
// Invalid slugs
|
||||
{"Empty slug", "", true},
|
||||
{"Whitespace only", " ", true},
|
||||
{"Too short", "ab", true},
|
||||
{"Too long", strings.Repeat("a", 64), true},
|
||||
{"Uppercase letters", "MyCompany", true},
|
||||
{"Leading hyphen", "-company", true},
|
||||
{"Trailing hyphen", "company-", true},
|
||||
{"Double hyphen", "my--company", true},
|
||||
{"Underscore", "my_company", true},
|
||||
{"Spaces", "my company", true},
|
||||
{"Special characters", "my@company", true},
|
||||
|
||||
// Reserved slugs
|
||||
{"Reserved: api", "api", true},
|
||||
{"Reserved: admin", "admin", true},
|
||||
{"Reserved: www", "www", true},
|
||||
{"Reserved: login", "login", true},
|
||||
{"Reserved: register", "register", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateSlug(tt.slug, "slug")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateSlug() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUUID(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid UUID v4", "550e8400-e29b-41d4-a716-446655440000", false},
|
||||
{"Valid UUID v4 lowercase", "123e4567-e89b-42d3-a456-426614174000", false},
|
||||
{"Empty UUID", "", true},
|
||||
{"Invalid format", "not-a-uuid", true},
|
||||
{"Invalid version", "550e8400-e29b-21d4-a716-446655440000", true},
|
||||
{"Missing hyphens", "550e8400e29b41d4a716446655440000", true},
|
||||
{"Too short", "550e8400-e29b-41d4-a716", true},
|
||||
{"With uppercase", "550E8400-E29B-41D4-A716-446655440000", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateUUID(tt.uuid, "id")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateUUID() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEnum(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
allowedValues := []string{"free", "basic", "pro", "enterprise"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid: free", "free", false},
|
||||
{"Valid: basic", "basic", false},
|
||||
{"Valid: pro", "pro", false},
|
||||
{"Valid: enterprise", "enterprise", false},
|
||||
{"Invalid: premium", "premium", true},
|
||||
{"Invalid: empty", "", true},
|
||||
{"Invalid: wrong case", "FREE", true},
|
||||
{"Invalid: typo", "basi", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateEnum(tt.value, "plan_tier", allowedValues)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateEnum() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRange(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value int
|
||||
min int
|
||||
max int
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid within range", 5, 1, 10, false},
|
||||
{"Valid at minimum", 1, 1, 10, false},
|
||||
{"Valid at maximum", 10, 1, 10, false},
|
||||
{"Below minimum", 0, 1, 10, true},
|
||||
{"Above maximum", 11, 1, 10, true},
|
||||
{"No maximum (0)", 1000, 1, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateRange(tt.value, "count", tt.min, tt.max)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateRange() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateNoHTML(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"Plain text", "Hello world", false},
|
||||
{"Text with punctuation", "Hello, world!", false},
|
||||
{"HTML tag <script>", "<script>alert('xss')</script>", true},
|
||||
{"HTML tag <img>", "<img src='x'>", true},
|
||||
{"HTML tag <div>", "<div>content</div>", true},
|
||||
{"HTML tag <a>", "<a href='#'>link</a>", true},
|
||||
{"Less than symbol", "5 < 10", false},
|
||||
{"Greater than symbol", "10 > 5", false},
|
||||
{"Both symbols", "5 < x < 10", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidateNoHTML(tt.value, "content")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateNoHTML() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeString(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Trim whitespace", " hello ", "hello"},
|
||||
{"Remove null bytes", "hello\x00world", "helloworld"},
|
||||
{"Already clean", "hello", "hello"},
|
||||
{"Empty string", "", ""},
|
||||
{"Only whitespace", " ", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := v.SanitizeString(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeString() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripHTML(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"Remove script tag", "<script>alert('xss')</script>", "alert('xss')"},
|
||||
{"Remove div tag", "<div>content</div>", "content"},
|
||||
{"Remove multiple tags", "<p>Hello <b>world</b></p>", "Hello world"},
|
||||
{"No tags", "plain text", "plain text"},
|
||||
{"Empty string", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := v.StripHTML(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("StripHTML() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAndSanitizeString(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
minLen int
|
||||
maxLen int
|
||||
wantValue string
|
||||
wantError bool
|
||||
}{
|
||||
{"Valid and clean", "hello", 3, 10, "hello", false},
|
||||
{"Trim and validate", " hello ", 3, 10, "hello", false},
|
||||
{"Too short after trim", " a ", 3, 10, "", true},
|
||||
{"Too long", "hello world this is too long", 3, 10, "", true},
|
||||
{"Empty after trim", " ", 3, 10, "", true},
|
||||
{"Valid with null byte removed", "hel\x00lo", 3, 10, "hello", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := v.ValidateAndSanitizeString(tt.input, "test_field", tt.minLen, tt.maxLen)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateAndSanitizeString() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
if !tt.wantError && result != tt.wantValue {
|
||||
t.Errorf("ValidateAndSanitizeString() = %q, want %q", result, tt.wantValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePrintable(t *testing.T) {
|
||||
v := NewValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantError bool
|
||||
}{
|
||||
{"All printable", "Hello World 123!", false},
|
||||
{"With tabs and newlines", "Hello\tWorld\n", false},
|
||||
{"With control character", "Hello\x01World", true},
|
||||
{"With bell character", "Hello\x07", true},
|
||||
{"Empty string", "", false},
|
||||
{"Unicode printable", "Hello 世界", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := v.ValidatePrintable(tt.value, "test_field")
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidatePrintable() error = %v, wantError %v", err, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue