167 lines
3.9 KiB
Go
167 lines
3.9 KiB
Go
package inputvalidation
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
// AllowedDownloadHosts lists the allowed hosts for presigned download URLs.
|
|
// These are the only hosts from which the application will download files.
|
|
var AllowedDownloadHosts = []string{
|
|
// Production S3-compatible storage (Digital Ocean Spaces)
|
|
".digitaloceanspaces.com",
|
|
// AWS S3 (if used in future)
|
|
".s3.amazonaws.com",
|
|
".s3.us-east-1.amazonaws.com",
|
|
".s3.us-west-2.amazonaws.com",
|
|
".s3.eu-west-1.amazonaws.com",
|
|
// MapleFile domains (if serving files directly)
|
|
".maplefile.ca",
|
|
// Local development
|
|
"localhost",
|
|
"127.0.0.1",
|
|
}
|
|
|
|
// ValidateDownloadURL validates a presigned download URL before use.
|
|
// This prevents SSRF attacks by ensuring downloads only happen from trusted hosts.
|
|
func ValidateDownloadURL(rawURL string) error {
|
|
if rawURL == "" {
|
|
return fmt.Errorf("download URL is required")
|
|
}
|
|
|
|
// Parse the URL
|
|
parsedURL, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL format: %w", err)
|
|
}
|
|
|
|
// Validate scheme - must be HTTPS (except localhost for development)
|
|
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
|
return fmt.Errorf("URL must use HTTP or HTTPS scheme")
|
|
}
|
|
|
|
// Get host without port
|
|
host := parsedURL.Hostname()
|
|
if host == "" {
|
|
return fmt.Errorf("URL must have a valid host")
|
|
}
|
|
|
|
// For HTTPS requirement - only allow HTTP for localhost/local IPs
|
|
if parsedURL.Scheme == "http" {
|
|
if !isLocalHost(host) {
|
|
return fmt.Errorf("non-local URLs must use HTTPS")
|
|
}
|
|
}
|
|
|
|
// Check if host is in allowed list
|
|
if !isAllowedHost(host) {
|
|
return fmt.Errorf("download from host %q is not allowed", host)
|
|
}
|
|
|
|
// Check for credentials in URL (security risk)
|
|
if parsedURL.User != nil {
|
|
return fmt.Errorf("URL must not contain credentials")
|
|
}
|
|
|
|
// Check for suspicious path traversal in URL path
|
|
if strings.Contains(parsedURL.Path, "..") {
|
|
return fmt.Errorf("URL path contains invalid sequences")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isAllowedHost checks if a host is in the allowed download hosts list
|
|
func isAllowedHost(host string) bool {
|
|
host = strings.ToLower(host)
|
|
|
|
for _, allowed := range AllowedDownloadHosts {
|
|
allowed = strings.ToLower(allowed)
|
|
|
|
// Exact match
|
|
if host == allowed {
|
|
return true
|
|
}
|
|
|
|
// Suffix match for wildcard domains (e.g., ".digitaloceanspaces.com")
|
|
if strings.HasPrefix(allowed, ".") && strings.HasSuffix(host, allowed) {
|
|
return true
|
|
}
|
|
|
|
// Handle subdomains for non-wildcard entries
|
|
if !strings.HasPrefix(allowed, ".") {
|
|
if host == allowed || strings.HasSuffix(host, "."+allowed) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isLocalHost checks if a host is localhost or a local IP address
|
|
func isLocalHost(host string) bool {
|
|
host = strings.ToLower(host)
|
|
|
|
// Check common localhost names
|
|
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
|
return true
|
|
}
|
|
|
|
// Check if it's a local network IP
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
// Check for loopback
|
|
if ip.IsLoopback() {
|
|
return true
|
|
}
|
|
|
|
// Check for private network ranges (10.x.x.x, 192.168.x.x, 172.16-31.x.x)
|
|
if ip.IsPrivate() {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// ValidateAPIBaseURL validates a base URL for API requests
|
|
func ValidateAPIBaseURL(rawURL string) error {
|
|
if rawURL == "" {
|
|
return fmt.Errorf("API URL is required")
|
|
}
|
|
|
|
parsedURL, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid URL format: %w", err)
|
|
}
|
|
|
|
// Validate scheme
|
|
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
|
return fmt.Errorf("URL must use HTTP or HTTPS scheme")
|
|
}
|
|
|
|
// Get host
|
|
host := parsedURL.Hostname()
|
|
if host == "" {
|
|
return fmt.Errorf("URL must have a valid host")
|
|
}
|
|
|
|
// For HTTPS requirement - only allow HTTP for localhost/local IPs
|
|
if parsedURL.Scheme == "http" {
|
|
if !isLocalHost(host) {
|
|
return fmt.Errorf("non-local URLs must use HTTPS")
|
|
}
|
|
}
|
|
|
|
// Check for credentials in URL
|
|
if parsedURL.User != nil {
|
|
return fmt.Errorf("URL must not contain credentials")
|
|
}
|
|
|
|
return nil
|
|
}
|