Files
dify/api/core/helper/ssrf_proxy.py
zyssyz123 64f5e34096 Update api/core/helper/ssrf_proxy.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-12 11:31:27 +08:00

202 lines
6.4 KiB
Python

"""
Proxy requests to avoid SSRF
"""
import ipaddress
import logging
import time
from urllib.parse import urlparse
import httpx
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
logger = logging.getLogger(__name__)
def is_private_or_local_address(url: str) -> bool:
"""
Check if URL points to a private/local network address (SSRF protection).
This function validates URLs to prevent Server-Side Request Forgery (SSRF) attacks
by detecting private IP addresses, localhost, and local network domains.
Args:
url: The URL string to check
Returns:
True if the URL points to a private/local address, False otherwise
Examples:
>>> is_private_or_local_address("http://localhost/api")
True
>>> is_private_or_local_address("http://192.168.1.1/api")
True
>>> is_private_or_local_address("https://example.com/api")
False
"""
if not url:
return False
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname_lower = hostname.lower()
# Check for localhost variants
if hostname_lower in ("localhost", "127.0.0.1", "::1"):
return True
# Check for .local domains (link-local)
if hostname_lower.endswith(".local"):
return True
# Try to parse as IP address
try:
ip = ipaddress.ip_address(hostname)
# Check if it's a private, loopback, or link-local address.
# - Private: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7
# - Loopback: 127.0.0.0/8, ::1
# - Link-local: 169.254.0.0/16, fe80::/10
return ip.is_private or ip.is_loopback or ip.is_link_local
except ValueError:
# Not a valid IP address, might be a domain name
# Domain names could resolve to private IPs, but we only check the literal hostname here
# For more thorough checks, DNS resolution would be needed (but adds latency)
return False
except (ValueError, TypeError, AttributeError):
return False
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
_SSRF_CLIENT_LIMITS = httpx.Limits(
max_connections=dify_config.SSRF_POOL_MAX_CONNECTIONS,
max_keepalive_connections=dify_config.SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS,
keepalive_expiry=dify_config.SSRF_POOL_KEEPALIVE_EXPIRY,
)
class MaxRetriesExceededError(ValueError):
"""Raised when the maximum number of retries is exceeded."""
pass
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTP_URL,
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
),
}
def _build_ssrf_client(verify: bool) -> httpx.Client:
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
proxy=dify_config.SSRF_PROXY_ALL_URL,
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
return httpx.Client(
mounts=_create_proxy_mounts(),
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)
return httpx.Client(verify=verify, limits=_SSRF_CLIENT_LIMITS)
def _get_ssrf_client(ssl_verify_enabled: bool) -> httpx.Client:
if not isinstance(ssl_verify_enabled, bool):
raise ValueError("SSRF client verify flag must be a boolean")
return get_pooled_http_client(
_SSL_VERIFIED_POOL_KEY if ssl_verify_enabled else _SSL_UNVERIFIED_POOL_KEY,
lambda: _build_ssrf_client(verify=ssl_verify_enabled),
)
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
kwargs["follow_redirects"] = allow_redirects
if "timeout" not in kwargs:
kwargs["timeout"] = httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
)
# prioritize per-call option, which can be switched on and off inside the HTTP node on the web UI
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
retries = 0
while retries <= max_retries:
try:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:
return response
else:
logger.warning(
"Received status code %s for URL %s which is in the force list",
response.status_code,
url,
)
except httpx.RequestError as e:
logger.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e)
if max_retries == 0:
raise
retries += 1
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("GET", url, max_retries=max_retries, **kwargs)
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("POST", url, max_retries=max_retries, **kwargs)
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("PUT", url, max_retries=max_retries, **kwargs)
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("PATCH", url, max_retries=max_retries, **kwargs)
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("DELETE", url, max_retries=max_retries, **kwargs)
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs)