mirror of
https://github.com/langgenius/dify.git
synced 2026-03-06 15:45:14 +00:00
fix: handle enterprise API errors properly to prevent KeyError crashes
When enterprise API returns 403/404, the response contains error JSON instead of expected data structure. Code was accessing fields directly causing KeyError → 500 Internal Server Error. Changes: - Add enterprise-specific error classes (EnterpriseAPIError, etc.) - Implement centralized error validation in EnterpriseRequest.send_request() - Extract error messages from API responses (message/error/detail fields) - Raise domain-specific errors based on HTTP status codes - Preserve backward compatibility with raise_for_status parameter This prevents KeyError crashes and returns proper HTTP error codes (403/404) instead of 500 errors.
This commit is contained in:
@@ -6,6 +6,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,10 +71,59 @@ class BaseRequest:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
|
||||
# Always validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
|
||||
# Legacy support: still respect raise_for_status parameter
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message")
|
||||
or error_data.get("error")
|
||||
or error_data.get("detail")
|
||||
or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
45
api/services/errors/enterprise.py
Normal file
45
api/services/errors/enterprise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
Reference in New Issue
Block a user