Compare commits

...

2 Commits

Author SHA1 Message Date
JzoNg
855347caf8 fix: prevent XSS in mcp server oauth url 2025-09-08 11:05:09 +08:00
Novice
8925606f33 fix(mcp): prevent XSS attacks by validating OAuth endpoint URLs 2025-09-05 11:11:28 +08:00
2 changed files with 110 additions and 20 deletions

View File

@@ -25,6 +25,48 @@ OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
def _validate_url_security(url: str) -> None:
"""Validate URL to prevent XSS attacks by ensuring only safe protocols are allowed."""
if not url:
raise ValueError("URL cannot be empty")
try:
parsed_url = urlparse(url)
except Exception as e:
raise ValueError(f"Invalid URL format: {e}")
# Only allow http and https protocols
allowed_schemes = ["http", "https"]
if parsed_url.scheme.lower() not in allowed_schemes:
raise ValueError(f"Unsafe URL protocol '{parsed_url.scheme}'. Only {allowed_schemes} are allowed")
# Ensure the URL has a valid netloc (domain)
if not parsed_url.netloc:
raise ValueError("URL must have a valid domain")
# Additional check for suspicious patterns that could indicate XSS attempts
url_lower = url.lower()
dangerous_patterns = ["javascript:", "data:", "vbscript:", "file:", "ftp:"]
for pattern in dangerous_patterns:
if pattern in url_lower:
raise ValueError(f"URL contains dangerous pattern: {pattern}")
def _validate_oauth_metadata_urls(metadata: OAuthMetadata) -> None:
"""Validate all URLs in OAuth metadata to prevent XSS attacks."""
# Validate authorization endpoint
if metadata.authorization_endpoint:
_validate_url_security(metadata.authorization_endpoint)
# Validate token endpoint
if metadata.token_endpoint:
_validate_url_security(metadata.token_endpoint)
# Validate registration endpoint
if metadata.registration_endpoint:
_validate_url_security(metadata.registration_endpoint)
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
@@ -113,7 +155,10 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
if 200 <= response.status_code < 300:
body = response.json()
if "authorization_server_url" in body:
return True, body["authorization_server_url"][0]
auth_server_url = body["authorization_server_url"][0]
# Validate the authorization server URL to prevent XSS attacks
_validate_url_security(auth_server_url)
return True, auth_server_url
else:
return False, ""
return False, ""
@@ -124,10 +169,15 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
# Validate the server URL first
_validate_url_security(server_url)
# First check if the server supports OAuth 2.0 Resource Discovery
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
if support_resource_discovery:
url = oauth_discovery_url
# Validate the discovered OAuth URL
_validate_url_security(url)
else:
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
@@ -138,7 +188,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
metadata = OAuthMetadata.model_validate(response.json())
# Validate all URLs in the metadata to prevent XSS attacks
_validate_oauth_metadata_urls(metadata)
return metadata
except httpx.RequestError as e:
if isinstance(e, httpx.ConnectError):
response = httpx.get(url)
@@ -146,7 +200,11 @@ def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = N
return None
if not response.is_success:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
metadata = OAuthMetadata.model_validate(response.json())
# Validate all URLs in the metadata to prevent XSS attacks
_validate_oauth_metadata_urls(metadata)
return metadata
raise
@@ -164,6 +222,9 @@ def start_authorization(
if metadata:
authorization_url = metadata.authorization_endpoint
# Validate the authorization endpoint URL to prevent XSS attacks
_validate_url_security(authorization_url)
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
@@ -175,6 +236,8 @@ def start_authorization(
)
else:
authorization_url = urljoin(server_url, "/authorize")
# Validate the constructed authorization URL
_validate_url_security(authorization_url)
code_verifier, code_challenge = generate_pkce_challenge()
@@ -218,10 +281,15 @@ def exchange_authorization(
if metadata:
token_url = metadata.token_endpoint
# Validate the token endpoint URL to prevent XSS attacks
_validate_url_security(token_url)
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Validate the constructed token URL
_validate_url_security(token_url)
params = {
"grant_type": grant_type,
@@ -251,10 +319,15 @@ def refresh_authorization(
if metadata:
token_url = metadata.token_endpoint
# Validate the token endpoint URL to prevent XSS attacks
_validate_url_security(token_url)
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
# Validate the constructed token URL
_validate_url_security(token_url)
params = {
"grant_type": grant_type,
@@ -281,8 +354,12 @@ def register_client(
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
# Validate the registration endpoint URL to prevent XSS attacks
_validate_url_security(registration_url)
else:
registration_url = urljoin(server_url, "/register")
# Validate the constructed registration URL
_validate_url_security(registration_url)
response = httpx.post(
registration_url,

View File

@@ -13,24 +13,37 @@ export const useOAuthCallback = () => {
}
export const openOAuthPopup = (url: string, callback: () => void) => {
const width = 600
const height = 600
const left = window.screenX + (window.outerWidth - width) / 2
const top = window.screenY + (window.outerHeight - height) / 2
try {
const parsedUrl = new URL(url)
const popup = window.open(
url,
'OAuth',
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
)
const handleMessage = (event: MessageEvent) => {
if (event.data?.type === 'oauth_callback') {
window.removeEventListener('message', handleMessage)
callback()
if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:') {
console.error('Invalid URL protocol, only http: and https: are allowed')
return null
}
}
window.addEventListener('message', handleMessage)
return popup
const width = 600
const height = 600
const left = window.screenX + (window.outerWidth - width) / 2
const top = window.screenY + (window.outerHeight - height) / 2
const popup = window.open(
parsedUrl.toString(), // 使用解析和验证后的 URL
'OAuth',
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
)
const handleMessage = (event: MessageEvent) => {
if (event.data?.type === 'oauth_callback') {
window.removeEventListener('message', handleMessage)
callback()
}
}
window.addEventListener('message', handleMessage)
return popup
}
catch (error) {
console.error('Invalid URL:', error)
return null
}
}