mirror of
https://github.com/langgenius/dify.git
synced 2025-12-20 14:42:37 +00:00
Compare commits
2 Commits
copilot/re
...
fix/mcp-oa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
855347caf8 | ||
|
|
8925606f33 |
@@ -25,6 +25,48 @@ OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
def _validate_url_security(url: str) -> None:
|
||||
"""Validate URL to prevent XSS attacks by ensuring only safe protocols are allowed."""
|
||||
if not url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL format: {e}")
|
||||
|
||||
# Only allow http and https protocols
|
||||
allowed_schemes = ["http", "https"]
|
||||
if parsed_url.scheme.lower() not in allowed_schemes:
|
||||
raise ValueError(f"Unsafe URL protocol '{parsed_url.scheme}'. Only {allowed_schemes} are allowed")
|
||||
|
||||
# Ensure the URL has a valid netloc (domain)
|
||||
if not parsed_url.netloc:
|
||||
raise ValueError("URL must have a valid domain")
|
||||
|
||||
# Additional check for suspicious patterns that could indicate XSS attempts
|
||||
url_lower = url.lower()
|
||||
dangerous_patterns = ["javascript:", "data:", "vbscript:", "file:", "ftp:"]
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in url_lower:
|
||||
raise ValueError(f"URL contains dangerous pattern: {pattern}")
|
||||
|
||||
|
||||
def _validate_oauth_metadata_urls(metadata: OAuthMetadata) -> None:
|
||||
"""Validate all URLs in OAuth metadata to prevent XSS attacks."""
|
||||
# Validate authorization endpoint
|
||||
if metadata.authorization_endpoint:
|
||||
_validate_url_security(metadata.authorization_endpoint)
|
||||
|
||||
# Validate token endpoint
|
||||
if metadata.token_endpoint:
|
||||
_validate_url_security(metadata.token_endpoint)
|
||||
|
||||
# Validate registration endpoint
|
||||
if metadata.registration_endpoint:
|
||||
_validate_url_security(metadata.registration_endpoint)
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
@@ -113,7 +155,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
if "authorization_server_url" in body:
|
||||
return True, body["authorization_server_url"][0]
|
||||
auth_server_url = body["authorization_server_url"][0]
|
||||
# Validate the authorization server URL to prevent XSS attacks
|
||||
_validate_url_security(auth_server_url)
|
||||
return True, auth_server_url
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
@@ -124,10 +169,15 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
# Validate the server URL first
|
||||
_validate_url_security(server_url)
|
||||
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
url = oauth_discovery_url
|
||||
# Validate the discovered OAuth URL
|
||||
_validate_url_security(url)
|
||||
else:
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
@@ -138,7 +188,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
|
||||
metadata = OAuthMetadata.model_validate(response.json())
|
||||
# Validate all URLs in the metadata to prevent XSS attacks
|
||||
_validate_oauth_metadata_urls(metadata)
|
||||
return metadata
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
@@ -146,7 +200,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
|
||||
metadata = OAuthMetadata.model_validate(response.json())
|
||||
# Validate all URLs in the metadata to prevent XSS attacks
|
||||
_validate_oauth_metadata_urls(metadata)
|
||||
return metadata
|
||||
raise
|
||||
|
||||
|
||||
@@ -164,6 +222,9 @@ def start_authorization(
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
# Validate the authorization endpoint URL to prevent XSS attacks
|
||||
_validate_url_security(authorization_url)
|
||||
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
@@ -175,6 +236,8 @@ def start_authorization(
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
# Validate the constructed authorization URL
|
||||
_validate_url_security(authorization_url)
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
@@ -218,10 +281,15 @@ def exchange_authorization(
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
# Validate the token endpoint URL to prevent XSS attacks
|
||||
_validate_url_security(token_url)
|
||||
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
# Validate the constructed token URL
|
||||
_validate_url_security(token_url)
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
@@ -251,10 +319,15 @@ def refresh_authorization(
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
# Validate the token endpoint URL to prevent XSS attacks
|
||||
_validate_url_security(token_url)
|
||||
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
# Validate the constructed token URL
|
||||
_validate_url_security(token_url)
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
@@ -281,8 +354,12 @@ def register_client(
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
# Validate the registration endpoint URL to prevent XSS attacks
|
||||
_validate_url_security(registration_url)
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
# Validate the constructed registration URL
|
||||
_validate_url_security(registration_url)
|
||||
|
||||
response = httpx.post(
|
||||
registration_url,
|
||||
|
||||
@@ -13,24 +13,37 @@ export const useOAuthCallback = () => {
|
||||
}
|
||||
|
||||
export const openOAuthPopup = (url: string, callback: () => void) => {
|
||||
const width = 600
|
||||
const height = 600
|
||||
const left = window.screenX + (window.outerWidth - width) / 2
|
||||
const top = window.screenY + (window.outerHeight - height) / 2
|
||||
try {
|
||||
const parsedUrl = new URL(url)
|
||||
|
||||
const popup = window.open(
|
||||
url,
|
||||
'OAuth',
|
||||
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
||||
)
|
||||
|
||||
const handleMessage = (event: MessageEvent) => {
|
||||
if (event.data?.type === 'oauth_callback') {
|
||||
window.removeEventListener('message', handleMessage)
|
||||
callback()
|
||||
if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:') {
|
||||
console.error('Invalid URL protocol, only http: and https: are allowed')
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener('message', handleMessage)
|
||||
return popup
|
||||
const width = 600
|
||||
const height = 600
|
||||
const left = window.screenX + (window.outerWidth - width) / 2
|
||||
const top = window.screenY + (window.outerHeight - height) / 2
|
||||
|
||||
const popup = window.open(
|
||||
parsedUrl.toString(), // 使用解析和验证后的 URL
|
||||
'OAuth',
|
||||
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
||||
)
|
||||
|
||||
const handleMessage = (event: MessageEvent) => {
|
||||
if (event.data?.type === 'oauth_callback') {
|
||||
window.removeEventListener('message', handleMessage)
|
||||
callback()
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener('message', handleMessage)
|
||||
return popup
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Invalid URL:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user