Compare commits

...

5 Commits

3 changed files with 13 additions and 10 deletions

View File

@@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.encryption import create_provider_encrypter
if TYPE_CHECKING:
from models.tools import MCPToolProvider
@@ -272,6 +271,9 @@ class MCPProviderEntity(BaseModel):
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
"""Generic method to decrypt dictionary fields"""
# Lazy import to avoid circular dependency
from core.tools.utils.encryption import create_provider_encrypter
if not data:
return {}

View File

@@ -109,12 +109,16 @@ class ClientSession(
self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
# Only set capabilities if non-default callbacks are provided
# This prevents servers from attempting callbacks when we don't actually support them
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
roots = (
types.RootsCapability(
# Only enable listChanged if we have a custom callback
listChanged=True,
)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = self.send_request(

View File

@@ -395,9 +395,6 @@ def test_client_capabilities_default():
# Assert default capabilities
assert received_capabilities is not None
assert received_capabilities.sampling is not None
assert received_capabilities.roots is not None
assert received_capabilities.roots.listChanged is True
def test_client_capabilities_with_custom_callbacks():