Compare commits

...

5 Commits

4 changed files with 14 additions and 11 deletions

View File

@@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING: if TYPE_CHECKING:
from models.tools import MCPToolProvider from models.tools import MCPToolProvider
@@ -272,6 +271,8 @@ class MCPProviderEntity(BaseModel):
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]: def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields""" """Generic method to decrypt dictionary fields"""
from core.tools.utils.encryption import create_provider_encrypter
if not data: if not data:
return {} return {}

View File

@@ -109,13 +109,17 @@ class ClientSession(
self._message_handler = message_handler or _default_message_handler self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult: def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() # Only set capabilities if non-default callbacks are provided
roots = types.RootsCapability( # This prevents servers from attempting callbacks when we don't actually support them
# TODO: Should this be based on whether we sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
# _will_ send notifications, or only whether roots = (
# they're supported? types.RootsCapability(
# Only enable listChanged if we have a custom callback
listChanged=True, listChanged=True,
) )
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = self.send_request( result = self.send_request(
types.ClientRequest( types.ClientRequest(

View File

@@ -7,7 +7,6 @@ from pydantic import ValidationError
from yarl import URL from yarl import URL
from configs import dify_config from configs import dify_config
from core.entities.mcp_provider import MCPConfiguration
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity from core.plugin.entities.plugin_daemon import PluginDatasourceProviderEntity
@@ -240,6 +239,8 @@ class ToolTransformService:
user_name: str | None = None, user_name: str | None = None,
include_sensitive: bool = True, include_sensitive: bool = True,
) -> ToolProviderApiEntity: ) -> ToolProviderApiEntity:
from core.entities.mcp_provider import MCPConfiguration
# Use provided user_name to avoid N+1 query, fallback to load_user() if not provided # Use provided user_name to avoid N+1 query, fallback to load_user() if not provided
if user_name is None: if user_name is None:
user = db_provider.load_user() user = db_provider.load_user()

View File

@@ -395,9 +395,6 @@ def test_client_capabilities_default():
# Assert default capabilities # Assert default capabilities
assert received_capabilities is not None assert received_capabilities is not None
assert received_capabilities.sampling is not None
assert received_capabilities.roots is not None
assert received_capabilities.roots.listChanged is True
def test_client_capabilities_with_custom_callbacks(): def test_client_capabilities_with_custom_callbacks():