package inputvalidation import ( "fmt" "net" "net/url" "strings" ) // AllowedDownloadHosts lists the allowed hosts for presigned download URLs. // These are the only hosts from which the application will download files. var AllowedDownloadHosts = []string{ // Production S3-compatible storage (Digital Ocean Spaces) ".digitaloceanspaces.com", // AWS S3 (if used in future) ".s3.amazonaws.com", ".s3.us-east-1.amazonaws.com", ".s3.us-west-2.amazonaws.com", ".s3.eu-west-1.amazonaws.com", // MapleFile domains (if serving files directly) ".maplefile.ca", // Local development "localhost", "127.0.0.1", } // ValidateDownloadURL validates a presigned download URL before use. // This prevents SSRF attacks by ensuring downloads only happen from trusted hosts. func ValidateDownloadURL(rawURL string) error { if rawURL == "" { return fmt.Errorf("download URL is required") } // Parse the URL parsedURL, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL format: %w", err) } // Validate scheme - must be HTTPS (except localhost for development) if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" { return fmt.Errorf("URL must use HTTP or HTTPS scheme") } // Get host without port host := parsedURL.Hostname() if host == "" { return fmt.Errorf("URL must have a valid host") } // For HTTPS requirement - only allow HTTP for localhost/local IPs if parsedURL.Scheme == "http" { if !isLocalHost(host) { return fmt.Errorf("non-local URLs must use HTTPS") } } // Check if host is in allowed list if !isAllowedHost(host) { return fmt.Errorf("download from host %q is not allowed", host) } // Check for credentials in URL (security risk) if parsedURL.User != nil { return fmt.Errorf("URL must not contain credentials") } // Check for suspicious path traversal in URL path if strings.Contains(parsedURL.Path, "..") { return fmt.Errorf("URL path contains invalid sequences") } return nil } // isAllowedHost checks if a host is in the allowed download hosts list func isAllowedHost(host string) bool { host = strings.ToLower(host) for _, allowed := range AllowedDownloadHosts { allowed = strings.ToLower(allowed) // Exact match if host == allowed { return true } // Suffix match for wildcard domains (e.g., ".digitaloceanspaces.com") if strings.HasPrefix(allowed, ".") && strings.HasSuffix(host, allowed) { return true } // Handle subdomains for non-wildcard entries if !strings.HasPrefix(allowed, ".") { if host == allowed || strings.HasSuffix(host, "."+allowed) { return true } } } return false } // isLocalHost checks if a host is localhost or a local IP address func isLocalHost(host string) bool { host = strings.ToLower(host) // Check common localhost names if host == "localhost" || host == "127.0.0.1" || host == "::1" { return true } // Check if it's a local network IP ip := net.ParseIP(host) if ip == nil { return false } // Check for loopback if ip.IsLoopback() { return true } // Check for private network ranges (10.x.x.x, 192.168.x.x, 172.16-31.x.x) if ip.IsPrivate() { return true } return false } // ValidateAPIBaseURL validates a base URL for API requests func ValidateAPIBaseURL(rawURL string) error { if rawURL == "" { return fmt.Errorf("API URL is required") } parsedURL, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("invalid URL format: %w", err) } // Validate scheme if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" { return fmt.Errorf("URL must use HTTP or HTTPS scheme") } // Get host host := parsedURL.Hostname() if host == "" { return fmt.Errorf("URL must have a valid host") } // For HTTPS requirement - only allow HTTP for localhost/local IPs if parsedURL.Scheme == "http" { if !isLocalHost(host) { return fmt.Errorf("non-local URLs must use HTTPS") } } // Check for credentials in URL if parsedURL.User != nil { return fmt.Errorf("URL must not contain credentials") } return nil }