mirror of
https://github.com/langgenius/dify.git
synced 2025-12-22 15:27:32 +00:00
Compare commits
2 Commits
codex/run-
...
fix/mcp-oa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
855347caf8 | ||
|
|
8925606f33 |
@@ -25,6 +25,48 @@ OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
|||||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_url_security(url: str) -> None:
|
||||||
|
"""Validate URL to prevent XSS attacks by ensuring only safe protocols are allowed."""
|
||||||
|
if not url:
|
||||||
|
raise ValueError("URL cannot be empty")
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid URL format: {e}")
|
||||||
|
|
||||||
|
# Only allow http and https protocols
|
||||||
|
allowed_schemes = ["http", "https"]
|
||||||
|
if parsed_url.scheme.lower() not in allowed_schemes:
|
||||||
|
raise ValueError(f"Unsafe URL protocol '{parsed_url.scheme}'. Only {allowed_schemes} are allowed")
|
||||||
|
|
||||||
|
# Ensure the URL has a valid netloc (domain)
|
||||||
|
if not parsed_url.netloc:
|
||||||
|
raise ValueError("URL must have a valid domain")
|
||||||
|
|
||||||
|
# Additional check for suspicious patterns that could indicate XSS attempts
|
||||||
|
url_lower = url.lower()
|
||||||
|
dangerous_patterns = ["javascript:", "data:", "vbscript:", "file:", "ftp:"]
|
||||||
|
for pattern in dangerous_patterns:
|
||||||
|
if pattern in url_lower:
|
||||||
|
raise ValueError(f"URL contains dangerous pattern: {pattern}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_oauth_metadata_urls(metadata: OAuthMetadata) -> None:
|
||||||
|
"""Validate all URLs in OAuth metadata to prevent XSS attacks."""
|
||||||
|
# Validate authorization endpoint
|
||||||
|
if metadata.authorization_endpoint:
|
||||||
|
_validate_url_security(metadata.authorization_endpoint)
|
||||||
|
|
||||||
|
# Validate token endpoint
|
||||||
|
if metadata.token_endpoint:
|
||||||
|
_validate_url_security(metadata.token_endpoint)
|
||||||
|
|
||||||
|
# Validate registration endpoint
|
||||||
|
if metadata.registration_endpoint:
|
||||||
|
_validate_url_security(metadata.registration_endpoint)
|
||||||
|
|
||||||
|
|
||||||
class OAuthCallbackState(BaseModel):
|
class OAuthCallbackState(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@@ -113,7 +155,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||||||
if 200 <= response.status_code < 300:
|
if 200 <= response.status_code < 300:
|
||||||
body = response.json()
|
body = response.json()
|
||||||
if "authorization_server_url" in body:
|
if "authorization_server_url" in body:
|
||||||
return True, body["authorization_server_url"][0]
|
auth_server_url = body["authorization_server_url"][0]
|
||||||
|
# Validate the authorization server URL to prevent XSS attacks
|
||||||
|
_validate_url_security(auth_server_url)
|
||||||
|
return True, auth_server_url
|
||||||
else:
|
else:
|
||||||
return False, ""
|
return False, ""
|
||||||
return False, ""
|
return False, ""
|
||||||
@@ -124,10 +169,15 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
|||||||
|
|
||||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||||
|
# Validate the server URL first
|
||||||
|
_validate_url_security(server_url)
|
||||||
|
|
||||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||||
if support_resource_discovery:
|
if support_resource_discovery:
|
||||||
url = oauth_discovery_url
|
url = oauth_discovery_url
|
||||||
|
# Validate the discovered OAuth URL
|
||||||
|
_validate_url_security(url)
|
||||||
else:
|
else:
|
||||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||||
|
|
||||||
@@ -138,7 +188,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
|
|||||||
return None
|
return None
|
||||||
if not response.is_success:
|
if not response.is_success:
|
||||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||||
return OAuthMetadata.model_validate(response.json())
|
|
||||||
|
metadata = OAuthMetadata.model_validate(response.json())
|
||||||
|
# Validate all URLs in the metadata to prevent XSS attacks
|
||||||
|
_validate_oauth_metadata_urls(metadata)
|
||||||
|
return metadata
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
if isinstance(e, httpx.ConnectError):
|
if isinstance(e, httpx.ConnectError):
|
||||||
response = httpx.get(url)
|
response = httpx.get(url)
|
||||||
@@ -146,7 +200,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
|
|||||||
return None
|
return None
|
||||||
if not response.is_success:
|
if not response.is_success:
|
||||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||||
return OAuthMetadata.model_validate(response.json())
|
|
||||||
|
metadata = OAuthMetadata.model_validate(response.json())
|
||||||
|
# Validate all URLs in the metadata to prevent XSS attacks
|
||||||
|
_validate_oauth_metadata_urls(metadata)
|
||||||
|
return metadata
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +222,9 @@ def start_authorization(
|
|||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
authorization_url = metadata.authorization_endpoint
|
authorization_url = metadata.authorization_endpoint
|
||||||
|
# Validate the authorization endpoint URL to prevent XSS attacks
|
||||||
|
_validate_url_security(authorization_url)
|
||||||
|
|
||||||
if response_type not in metadata.response_types_supported:
|
if response_type not in metadata.response_types_supported:
|
||||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||||
if (
|
if (
|
||||||
@@ -175,6 +236,8 @@ def start_authorization(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
authorization_url = urljoin(server_url, "/authorize")
|
authorization_url = urljoin(server_url, "/authorize")
|
||||||
|
# Validate the constructed authorization URL
|
||||||
|
_validate_url_security(authorization_url)
|
||||||
|
|
||||||
code_verifier, code_challenge = generate_pkce_challenge()
|
code_verifier, code_challenge = generate_pkce_challenge()
|
||||||
|
|
||||||
@@ -218,10 +281,15 @@ def exchange_authorization(
|
|||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
token_url = metadata.token_endpoint
|
token_url = metadata.token_endpoint
|
||||||
|
# Validate the token endpoint URL to prevent XSS attacks
|
||||||
|
_validate_url_security(token_url)
|
||||||
|
|
||||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||||
else:
|
else:
|
||||||
token_url = urljoin(server_url, "/token")
|
token_url = urljoin(server_url, "/token")
|
||||||
|
# Validate the constructed token URL
|
||||||
|
_validate_url_security(token_url)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"grant_type": grant_type,
|
"grant_type": grant_type,
|
||||||
@@ -251,10 +319,15 @@ def refresh_authorization(
|
|||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
token_url = metadata.token_endpoint
|
token_url = metadata.token_endpoint
|
||||||
|
# Validate the token endpoint URL to prevent XSS attacks
|
||||||
|
_validate_url_security(token_url)
|
||||||
|
|
||||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||||
else:
|
else:
|
||||||
token_url = urljoin(server_url, "/token")
|
token_url = urljoin(server_url, "/token")
|
||||||
|
# Validate the constructed token URL
|
||||||
|
_validate_url_security(token_url)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"grant_type": grant_type,
|
"grant_type": grant_type,
|
||||||
@@ -281,8 +354,12 @@ def register_client(
|
|||||||
if not metadata.registration_endpoint:
|
if not metadata.registration_endpoint:
|
||||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||||
registration_url = metadata.registration_endpoint
|
registration_url = metadata.registration_endpoint
|
||||||
|
# Validate the registration endpoint URL to prevent XSS attacks
|
||||||
|
_validate_url_security(registration_url)
|
||||||
else:
|
else:
|
||||||
registration_url = urljoin(server_url, "/register")
|
registration_url = urljoin(server_url, "/register")
|
||||||
|
# Validate the constructed registration URL
|
||||||
|
_validate_url_security(registration_url)
|
||||||
|
|
||||||
response = httpx.post(
|
response = httpx.post(
|
||||||
registration_url,
|
registration_url,
|
||||||
|
|||||||
@@ -13,24 +13,37 @@ export const useOAuthCallback = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const openOAuthPopup = (url: string, callback: () => void) => {
|
export const openOAuthPopup = (url: string, callback: () => void) => {
|
||||||
const width = 600
|
try {
|
||||||
const height = 600
|
const parsedUrl = new URL(url)
|
||||||
const left = window.screenX + (window.outerWidth - width) / 2
|
|
||||||
const top = window.screenY + (window.outerHeight - height) / 2
|
|
||||||
|
|
||||||
const popup = window.open(
|
if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:') {
|
||||||
url,
|
console.error('Invalid URL protocol, only http: and https: are allowed')
|
||||||
'OAuth',
|
return null
|
||||||
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
|
||||||
)
|
|
||||||
|
|
||||||
const handleMessage = (event: MessageEvent) => {
|
|
||||||
if (event.data?.type === 'oauth_callback') {
|
|
||||||
window.removeEventListener('message', handleMessage)
|
|
||||||
callback()
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
window.addEventListener('message', handleMessage)
|
const width = 600
|
||||||
return popup
|
const height = 600
|
||||||
|
const left = window.screenX + (window.outerWidth - width) / 2
|
||||||
|
const top = window.screenY + (window.outerHeight - height) / 2
|
||||||
|
|
||||||
|
const popup = window.open(
|
||||||
|
parsedUrl.toString(), // 使用解析和验证后的 URL
|
||||||
|
'OAuth',
|
||||||
|
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
||||||
|
)
|
||||||
|
|
||||||
|
const handleMessage = (event: MessageEvent) => {
|
||||||
|
if (event.data?.type === 'oauth_callback') {
|
||||||
|
window.removeEventListener('message', handleMessage)
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('message', handleMessage)
|
||||||
|
return popup
|
||||||
|
}
|
||||||
|
catch (error) {
|
||||||
|
console.error('Invalid URL:', error)
|
||||||
|
return null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user