diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index f32df9763d..fea74ba492 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -7,6 +7,7 @@ from typing import Any, Optional from sqlalchemy.orm import Session from configs import dify_config +from constants import HIDDEN_VALUE from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID @@ -114,52 +115,65 @@ class BuiltinToolManageService: """ update builtin tool provider """ - # get if the provider exists - db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) - - if db_provider is None: - raise ValueError(f"you have not added provider {provider}") - - try: - if CredentialType.of(db_provider.credential_type).is_editable(): - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider} does not need credentials") - - encrypter, cache = BuiltinToolManageService.create_tool_encrypter( - tenant_id, db_provider, provider, provider_controller + with Session(db.engine) as session: + # get if the provider exists + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, ) + .first() + ) + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - # Decrypt and restore original credentials for masked values - original_credentials = encrypter.decrypt(db_provider.credentials) - masked_credentials = encrypter.mask_tool_credentials(original_credentials) + try: + if CredentialType.of(db_provider.credential_type).is_editable(): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") - # check if the credential has changed, save the original credential - for key, value in credentials.items(): - if key in masked_credentials and value == masked_credentials[key]: - credentials[key] = original_credentials[key] + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) - if CredentialType.of(db_provider.credential_type).is_validate_allowed(): - provider_controller.validate_credentials(user_id, credentials) + original_credentials = encrypter.decrypt(db_provider.credentials) + new_credentials: dict = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE) + for key, value in credentials.items() + } - # encrypt credentials - db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + if CredentialType.of(db_provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, new_credentials) - cache.delete() + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) - # update name if provided - if name is not None and db_provider.name != name: - db_provider.name = name + cache.delete() - db.session.commit() - except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, - ) as e: - db.session.rollback() - raise ValueError(str(e)) + # update name if provided + if name is not None and db_provider.name != name: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + db_provider.name = name + + session.commit() + except ( + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, + ) as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @@ -175,59 +189,69 @@ class BuiltinToolManageService: """ add builtin tool provider """ - lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" try: - with redis_client.lock(lock, timeout=20): - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider} does not need credentials") + with Session(db.engine) as session: + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" + with redis_client.lock(lock, timeout=20): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") - provider_count = ( - db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() - ) - - # check if the provider count is reached the limit - if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: - raise ValueError(f"you have reached the maximum number of providers for {provider}") - - # validate credentials if allowed - if CredentialType.of(api_type).is_validate_allowed(): - provider_controller.validate_credentials(user_id, credentials) - - # generate name if not provided - if name is None: - name = BuiltinToolManageService.generate_builtin_tool_provider_name( - tenant_id=tenant_id, provider=provider, credential_type=api_type + provider_count = ( + session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() ) - # create encrypter - encrypter, _ = create_provider_encrypter( - tenant_id=tenant_id, - config=[ - x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema_by_type(api_type) - ], - cache=NoOpProviderCredentialCache(), - ) + # check if the provider count is reached the limit + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") - db_provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider, - encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), - credential_type=api_type.value, - name=name, - ) + # validate credentials if allowed + if CredentialType.of(api_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) - db.session.add(db_provider) - db.session.commit() + # generate name if not provided + if name is None: + name = BuiltinToolManageService.generate_builtin_tool_provider_name( + session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type + ) + else: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + # create encrypter + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(api_type) + ], + cache=NoOpProviderCredentialCache(), + ) + + db_provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=api_type.value, + name=name, + ) + + session.add(db_provider) + session.commit() except ( PluginDaemonClientSideError, ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError, ) as e: - db.session.rollback() + session.rollback() raise ValueError(str(e)) return {"result": "success"} @@ -249,10 +273,12 @@ class BuiltinToolManageService: return encrypter, cache @staticmethod - def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str: + def generate_builtin_tool_provider_name( + session: Session, tenant_id: str, provider: str, credential_type: CredentialType + ) -> str: try: db_providers = ( - db.session.query(BuiltinToolProvider) + session.query(BuiltinToolProvider) .filter_by( tenant_id=tenant_id, provider=provider, @@ -308,7 +334,7 @@ class BuiltinToolManageService: default_provider = providers[0] default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) - encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + encrypter, _ = BuiltinToolManageService.create_tool_encrypter( tenant_id, default_provider, default_provider.provider, provider_controller ) @@ -343,20 +369,28 @@ class BuiltinToolManageService: """ delete tool provider """ - tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) + with Session(db.engine) as session: + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) - if tool_provider is None: - raise ValueError(f"you have not added provider {provider}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - db.session.delete(tool_provider) - db.session.commit() + session.delete(db_provider) + session.commit() - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - _, cache = BuiltinToolManageService.create_tool_encrypter( - tenant_id, tool_provider, provider, provider_controller - ) - cache.delete() + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) + cache.delete() return {"result": "success"} @@ -507,18 +541,6 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) - @staticmethod - def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: - provider: Optional[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - .first() - ) - return provider - @staticmethod def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """