From ba843c26911954b08908bcd46e5f981d3be9e9ba Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 26 Jun 2025 11:44:00 +0800 Subject: [PATCH] feat(oauth): update api --- .../console/workspace/model_providers.py | 1 + .../console/workspace/tool_providers.py | 6 +- api/core/tools/tool_manager.py | 90 +++++------- .../python/examples/github/provider/github.py | 67 --------- .../tools/builtin_tools_manage_service.py | 128 ++++++------------ api/services/tools/tools_transform_service.py | 2 +- 6 files changed, 84 insertions(+), 210 deletions(-) delete mode 100644 api/dify-plugin-sdks/python/examples/github/provider/github.py diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 32139781b0..ff0fcbda6e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -35,6 +35,7 @@ class ModelProviderListApi(Resource): model_provider_service = ModelProviderService() provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) + return jsonable_encoder({"data": provider_list}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c581a39200..ceea178214 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -371,12 +371,12 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): + def get(self, provider, credential_type): user = current_user tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) + return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id) class ToolApiProviderSchemaApi(Resource): @@ -789,7 +789,7 @@ api.add_resource( ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema", + "/workspaces/current/tool-provider/builtin///credentials_schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 86ffa01667..bd4a635923 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, - ToolProviderType, -) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ProviderConfigEncrypter, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -64,8 +55,11 @@ class ToolManager: @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -109,7 +103,12 @@ class ToolManager: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + with contexts.plugin_tool_providers_lock.get(): + # double check plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] @@ -127,25 +126,7 @@ class ToolManager: ) plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool + return controller @classmethod def get_tool_runtime( @@ -563,6 +544,22 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] + @classmethod + def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]: + """ + list all the builtin providers + """ + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + @classmethod def list_providers_from_api( cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral @@ -577,30 +574,13 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - - def get_builtin_providers(tenant_id): - # according to multi credentials, select the one with is_default=True first, then created_at oldest - # for compatibility with old version - sql = """ - SELECT DISTINCT ON (tenant_id, provider) id - FROM tool_builtin_providers - WHERE tenant_id = :tenant_id - ORDER BY tenant_id, provider, is_default DESC, created_at DESC - """ - ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] - return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() - builtin_providers = cls.list_builtin_providers(tenant_id) - # get builtin providers - db_builtin_providers = get_builtin_providers(tenant_id) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - db_provider.provider = str(ToolProviderID(db_provider.provider)) - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) + # key: provider name, value: provider + db_builtin_providers = { + str(ToolProviderID(provider.provider)): provider + for provider in cls.list_default_builtin_providers(tenant_id) + } # append builtin providers for provider in builtin_providers: @@ -612,10 +592,9 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), + db_provider=db_builtin_providers.get(provider.entity.identity.name), decrypt_credentials=False, ) @@ -625,7 +604,6 @@ class ToolManager: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() diff --git a/api/dify-plugin-sdks/python/examples/github/provider/github.py b/api/dify-plugin-sdks/python/examples/github/provider/github.py deleted file mode 100644 index 7fb7bd33df..0000000000 --- a/api/dify-plugin-sdks/python/examples/github/provider/github.py +++ /dev/null @@ -1,67 +0,0 @@ -import secrets -import urllib.parse -from collections.abc import Mapping -from typing import Any - -import requests -from dify_plugin import ToolProvider -from dify_plugin.errors.tool import ToolProviderCredentialValidationError -from werkzeug import Request - - -class GithubProvider(ToolProvider): - _AUTH_URL = "https://github.com/login/oauth/authorize" - _TOKEN_URL = "https://github.com/login/oauth/access_token" - _API_USER_URL = "https://api.github.com/user" - - def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str: - """ - Generate the authorization URL for the Github OAuth. - """ - state = secrets.token_urlsafe(16) - params = { - "client_id": system_credentials["client_id"], - "redirect_uri": system_credentials["redirect_uri"], - "scope": system_credentials.get("scope", "read:user"), - "state": state, - # Optionally: allow_signup, login, etc. - } - return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - - def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]: - """ - Exchange code for access_token. - """ - code = request.args.get("code") - state = request.args.get("state") - if not code: - raise ValueError("No code provided") - # Optionally: validate state here - - data = { - "client_id": system_credentials["client_id"], - "client_secret": system_credentials["client_secret"], - "code": code, - "redirect_uri": system_credentials["redirect_uri"], - } - headers = {"Accept": "application/json"} - response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10) - response_json = response.json() - access_token = response_json.get("access_token") - if not access_token: - raise ValueError(f"Error in GitHub OAuth: {response_json}") - return {"access_token": access_token} - - def _validate_credentials(self, credentials: dict) -> None: - try: - if "access_token" not in credentials or not credentials.get("access_token"): - raise ToolProviderCredentialValidationError("GitHub API Access Token is required.") - headers = { - "Authorization": f"Bearer {credentials['access_token']}", - "Accept": "application/vnd.github+json", - } - response = requests.get(self._API_USER_URL, headers=headers, timeout=10) - if response.status_code != 200: - raise ToolProviderCredentialValidationError(response.json().get("message")) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b4f043c647..0137e13b20 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging import re from pathlib import Path +from typing import Optional, Union from sqlalchemy import ColumnExpressionArgument from sqlalchemy.orm import Session @@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError +from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType @@ -40,12 +42,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) @@ -53,7 +50,7 @@ class BuiltinToolManageService: if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + credentials = tool_configuration.decrypt(credentials) result: list[ToolApiEntity] = [] for tool in tools or []: @@ -74,12 +71,7 @@ class BuiltinToolManageService: get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) @@ -87,7 +79,7 @@ class BuiltinToolManageService: if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + credentials = tool_configuration.decrypt(credentials) entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -100,7 +92,7 @@ class BuiltinToolManageService: return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str): """ list builtin provider credentials schema @@ -123,35 +115,28 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f"you have not added provider {provider_name}") - + try: if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Decrypt and restore original credentials for masked values original_credentials = tool_configuration.decrypt(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: # type: ignore - credentials[name] = original_credentials[name] # type: ignore + for key, value in credentials.items(): + if key in masked_credentials and value == masked_credentials[key]: + credentials[key] = original_credentials[key] # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) - else: - raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one") # update name if provided if name is not None and provider.name != name: @@ -180,8 +165,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}" - with redis_client.lock(lock_name, timeout=20): + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}" + with redis_client.lock(lock, timeout=20): if name is None: name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) @@ -198,12 +183,7 @@ class BuiltinToolManageService: if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( @@ -268,23 +248,17 @@ class BuiltinToolManageService: return [] provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) credentials: list[ToolProviderCredentialApiEntity] = [] for provider in providers: decrypt_credential = tool_configuration.mask_tool_credentials( tool_configuration.decrypt(provider.credentials) ) - credentials.append( - ToolTransformService.convert_builtin_provider_to_credential_api_entity( - provider=provider, - credentials=decrypt_credential, - ) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, ) + credentials.append(credential_entity) return credentials @staticmethod @@ -292,22 +266,17 @@ class BuiltinToolManageService: """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) - if provider_obj is None: + if tool_provider is None: raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider_obj) + db.session.delete(tool_provider) db.session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) tool_configuration.delete_tool_credentials_cache() return {"result": "success"} @@ -334,7 +303,9 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str): + def get_builtin_tool_oauth_client( + tenant_id: str, provider: str, plugin_id: str + ) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]: """ get builtin tool provider """ @@ -350,14 +321,12 @@ class BuiltinToolManageService: .first() ) if user_client: - plugin_oauth_config = user_client - else: - plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + return user_client - if plugin_oauth_config: - return plugin_oauth_config - - raise ValueError("no oauth available config found for this plugin") + system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + if system_client is None: + raise ValueError("no oauth available client config found for this tool provider") + return system_client @staticmethod def get_builtin_tool_provider_icon(provider: str): @@ -379,9 +348,7 @@ class BuiltinToolManageService: with db.session.no_autoflush: # get all user added providers - db_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: @@ -432,8 +399,8 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: - provider = ( + def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: + provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -444,14 +411,14 @@ class BuiltinToolManageService: return provider @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: + def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ - def _query(provider_filters: list[ColumnExpressionArgument[bool]]): + def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]: return ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) @@ -484,21 +451,16 @@ class BuiltinToolManageService: return provider except Exception: # it's an old provider without organization - provider_obj = _query([BuiltinToolProvider.provider == provider_name]) - return provider_obj + return _query([BuiltinToolProvider.provider == provider_name]) @staticmethod - def _decrypt_and_restore_credentials(tool_configuration, provider, credentials): - """ - Decrypt original credentials and restore masked values from the input credentials - - :param tool_configuration: the tool configuration encrypter - :param provider: the provider object from database - :param credentials: the input credentials from user - :return: the processed credentials with original values restored - """ - - return credentials + def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): + return ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) @staticmethod def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b896f6c88f..66be67dbe6 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -307,7 +307,7 @@ class ToolTransformService: ) @staticmethod - def convert_builtin_provider_to_credential_api_entity( + def convert_builtin_provider_to_credential_entity( provider: BuiltinToolProvider, credentials: dict ) -> ToolProviderCredentialApiEntity: return ToolProviderCredentialApiEntity(