Compare commits

...

52 Commits

Author SHA1 Message Date
Harry
c238b88641 feat(oauth): improve error handling for provider retrieval and clarify logging messages 2025-07-17 12:05:05 +08:00
Harry
5db388716b feat(uuid): enhance UUID validation to check for empty strings 2025-07-16 15:46:44 +08:00
Harry
f9f3c207f4 feat(oauth): add UUID validation and enhance error handling for credential retrieval 2025-07-16 15:45:38 +08:00
Harry
2bafcd596f feat(oauth): add setup command for system tool OAuth client 2025-07-16 12:56:59 +08:00
Harry
d162905bb5 feat(oauth): enhance error handling for authorization URL and credentials retrieval 2025-07-16 11:01:13 +08:00
Harry
fd2651f5aa refactor: simplify error handling by removing specific exceptions in tool management 2025-07-15 23:16:02 +08:00
Harry
9cdbf30238 feat(env): update localhost URLs in .env.example 2025-07-15 19:33:12 +08:00
Harry
6a085fab26 Merge remote-tracking branch 'origin/main' into feat/tool-plugin-oauth 2025-07-15 17:28:31 +08:00
Harry
22297d0326 feat(oauth): add functionality to delete custom OAuth client parameters and verify plugin status 2025-07-14 19:58:20 +08:00
Harry
37be099442 feat(oauth): update client parameters handling and improve oauth_params parsing 2025-07-14 14:43:50 +08:00
Harry
f68201af0b feat(oauth): implement name length validation 2025-07-14 12:00:41 +08:00
Harry
06802afc94 Merge branch 'main' into feat/tool-plugin-oauth 2025-07-14 11:28:10 +08:00
Harry
458e44133e feat(oauth): enhance tool provider updates with name validation and include credential ID in agent tools 2025-07-13 13:39:21 +08:00
Harry
f9c4897ff3 feat(oauth): replace HIDDEN_VALUE with UNKNOWN_VALUE for better credential handling 2025-07-13 12:45:56 +08:00
Harry
6cb4a6f692 feat(dsl): filter out credential IDs from workflow and model configuration exports 2025-07-13 12:37:19 +08:00
Harry
7de3436e6b feat(oauth): add credential handling and context support for tool invocations 2025-07-13 01:44:29 +08:00
Harry
8fc5ccab35 fix(oauth): add error handling for OAuth parameter decryption 2025-07-11 22:34:07 +08:00
Harry
5090f63df5 feat(tool): update tool provider methods to handle optional credentials and name 2025-07-11 22:10:13 +08:00
Harry
31e1261ae2 fix(oauth): improve name validation logic for tool providers 2025-07-11 21:44:45 +08:00
Harry
ace6e11a6f feat(oauth): implement AES encryption and decryption for system OAuth parameters 2025-07-11 21:28:02 +08:00
Harry
7ba09dfa06 fix(migrations): update down_revision references in migration files 2025-07-11 17:09:40 +08:00
Harry
ab6ae1f209 feat(oauth): improve credential schema validation in provider 2025-07-11 16:48:38 +08:00
Harry
0532135a9c Merge remote-tracking branch 'origin/main' into feat/tool-plugin-oauth 2025-07-11 16:36:52 +08:00
Harry
adc39f7b0d feat(oauth): enhance OAuth client management and validation 2025-07-11 16:28:40 +08:00
Harry
fb9e4a4227 feat(oauth): migrations 2025-07-11 16:05:11 +08:00
Harry
545c21b196 feat(oauth): clean up imports and streamline OAuth client parameter retrieval 2025-07-11 13:51:31 +08:00
Harry
f3bbab0eed Merge remote-tracking branch 'origin/main' into feat/tool-plugin-oauth
# Conflicts:
#	api/controllers/console/workspace/tool_providers.py
#	api/core/tools/entities/api_entities.py
#	api/core/tools/tool_manager.py
#	api/core/tools/utils/configuration.py
#	api/services/tools/tools_transform_service.py
2025-07-11 13:48:41 +08:00
Harry
f35b8d6245 feat(oauth): refactor session management in tool provider operations 2025-07-09 14:45:52 +08:00
Harry
ef330fec2c feat(oauth): add credential validation for providers 2025-07-09 11:57:31 +08:00
Harry
0dc5bfb2c7 feat(oauth): refactor tool encryption utils 2025-07-04 17:28:22 +08:00
Harry
eaefa1b7e6 feat(oauth): refactor encryption 2025-07-04 17:28:13 +08:00
Harry
9f053f3bbc feat(oauth): rename ToolProviderCredentialType to CredentialType for consistency 2025-07-04 17:28:09 +08:00
Harry
26b46b88c9 feat(oauth): add multi credentials support 2025-07-04 17:28:06 +08:00
Harry
b316867bab Merge remote-tracking branch 'origin/main' into feat/tool-plugin-oauth 2025-07-02 21:54:50 +08:00
Harry
988a76066d feat(oauth): enhance OAuth client handling and add custom client support 2025-07-02 20:19:04 +08:00
Harry
6ef1e017df feat(oauth): add support for retrieving credential info and OAuth client schema 2025-07-02 14:58:50 +08:00
Harry
7951a1c4df refactor(tool): implement multi provider credentials support 2025-07-02 10:05:18 +08:00
Harry
daec82bd44 feat(oauth): refactor tool provider methods and enhance credential handling 2025-07-01 12:53:48 +08:00
Harry
8a954c0b19 Merge branch 'main' into feat/tool-plugin-oauth 2025-06-26 13:29:15 +08:00
Harry
f4f6e41074 feat(oauth): add oauth redirect_uri parameters 2025-06-26 13:28:37 +08:00
Harry
ba843c2691 feat(oauth): update api 2025-06-26 11:59:20 +08:00
Harry
6c9e99b0c6 Merge branch 'main' into feat/tool-plugin-oauth 2025-06-25 15:14:03 +08:00
Harry
ce4cc54cc9 feat(oauth): merge tool oauth and remove sequence number branches 2025-06-25 14:51:55 +08:00
Harry
1a2dfd950e Merge branch 'main' into feat/tool-plugin-oauth
# Conflicts:
#	api/core/plugin/impl/oauth.py
#	api/services/plugin/oauth_service.py
2025-06-25 14:31:15 +08:00
Harry
a58e99c671 Merge branch 'main' into feat/tool-plugin-oauth 2025-06-25 14:29:39 +08:00
Harry
8bd05aee4b Merge branch 'feat/plugin-oauth' into feat/tool-plugin-oauth
# Conflicts:
#	api/services/plugin/oauth_service.py
2025-06-25 10:38:18 +08:00
Harry
fcfaa7ce13 feat(oauth): plugin oauth service 2025-06-25 10:13:41 +08:00
Harry
7979e05ade Merge branch 'main' into feat/tool-plugin-oauth
# Conflicts:
#	README.md
#	api/services/tools/builtin_tools_manage_service.py
2025-06-24 21:09:15 +08:00
Harry
5e7c5863ef refactor(tool oauth): update api implementation 2025-06-24 21:07:45 +08:00
Harry
7f292dc261 fix: remove debugging flags 2025-06-23 12:49:18 +08:00
Harry
b3a8dbe2f5 fix: typo 2025-06-23 11:20:54 +08:00
Harry
12c20ec7f6 feat: plugin OAuth with stateful 2025-06-23 10:48:20 +08:00
44 changed files with 1910 additions and 552 deletions

View File

@@ -5,17 +5,17 @@
SECRET_KEY=
# Console API base URL
CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000
CONSOLE_API_URL=http://localhost:5001
CONSOLE_WEB_URL=http://localhost:3000
# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001
SERVICE_API_URL=http://localhost:5001
# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000
APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
@@ -138,8 +138,8 @@ SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone

View File

@@ -11,10 +11,12 @@ from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -27,6 +29,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
@@ -1155,3 +1158,49 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_tool_oauth_client(provider, client_params):
"""
Setup system tool oauth client
"""
provider_id = ToolProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
json.loads(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = ToolOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))

View File

@@ -1,6 +1,7 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3

View File

@@ -1,23 +1,32 @@
import io
from urllib.parse import urlparse
from flask import redirect, send_file
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from flask_restful import (
Resource,
reqparse,
)
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
@@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required
def post(self, provider):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
tenant_id,
provider,
args["credential_id"],
)
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=args["credentials"],
name=args["name"],
api_type=CredentialType.of(args["type"]),
)
@@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credential_id=args["credential_id"],
credentials=args.get("credentials", None),
name=args.get("name", ""),
)
return result
@@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider):
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
)
)
@@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, CredentialType.of(credential_type), tenant_id
)
)
class ToolApiProviderSchemaApi(Resource):
@@ -586,15 +631,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
]
@@ -631,6 +673,178 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
)
)
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
)
)
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@@ -794,17 +1008,33 @@ class ToolMCPCallbackApi(Resource):
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

View File

@@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id
),
)

View File

@@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
credential_id: str | None = None
class AgentPromptEntity(BaseModel):

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC):
@@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
@@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@@ -58,4 +60,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
context=PluginInvokeContext(
credentials=credentials or InvokeCredentials()
),
)

View File

@@ -39,6 +39,7 @@ class AgentConfigManager:
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))

View File

@@ -0,0 +1,84 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCache(ABC):
"""Base class for provider credentials cache"""
def __init__(self, **kwargs):
self.cache_key = self._generate_cache_key(**kwargs)
@abstractmethod
def _generate_cache_key(self, **kwargs) -> str:
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
try:
cached_credentials = cached_credentials.decode("utf-8")
return dict(json.loads(cached_credentials))
except JSONDecodeError:
return None
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None:
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
pass
def delete(self) -> None:
"""Delete cached provider credentials"""
pass

View File

@@ -1,51 +0,0 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter(
encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
provider_type=payload.namespace,
provider_identity=payload.identity,
cache=SingletonProviderCredentialsCache(
tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
)
if payload.opt == "encrypt":
@@ -22,7 +26,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data),
}
elif payload.opt == "clear":
encrypter.delete_tool_credentials_cache()
cache.delete()
return {
"data": {},
}

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
@@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

View File

@@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
)
class InvokeCredentials(BaseModel):
tool_credentials: dict[str, str] = Field(
default_factory=dict,
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
)
class PluginInvokeContext(BaseModel):
credentials: Optional[InvokeCredentials] = Field(
default_factory=InvokeCredentials,
description="Credentials context for the plugin invocation or backward invocation.",
)
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
@@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
provider: str
tool: str
tool_parameters: dict
credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel):

View File

@@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
@@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"context": context.model_dump() if context else {},
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,

View File

@@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
},
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting authorization URL: {e}")
def get_credentials(
self,
@@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
@@ -50,30 +56,34 @@ class OAuthHandler(BasePluginClient):
Get credentials from the given request.
"""
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
try:
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
},
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):
@@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
credential_type: CredentialType,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
@@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
"credential_type": credential_type,
"tool_parameters": tool_parameters,
},
},

View File

@@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[CredentialType] = CredentialType.API_KEY
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return self.entity.credentials_schema.copy()
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
:return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
return types
def get_tools(self) -> list[BuiltinTool]:
"""
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property
def provider_type(self) -> ToolProviderType:

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

View File

@@ -355,10 +355,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@@ -438,6 +446,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@@ -445,3 +454,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

View File

@@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService
@@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -68,8 +70,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -113,7 +118,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -131,25 +141,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
return controller
@classmethod
def get_tool_runtime(
@@ -160,6 +152,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -170,6 +163,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
@@ -193,49 +187,70 @@ class ToolManager:
)
),
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get credentials
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.first()
)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
except Exception as e:
builtin_provider = None
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@@ -245,22 +260,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
controller=api_provider,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@@ -320,6 +329,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -362,6 +372,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()
@@ -391,6 +402,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@@ -402,6 +414,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -551,6 +564,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -565,21 +594,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# key: provider name, value: provider
db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
for provider in cls.list_default_builtin_providers(tenant_id)
}
# append builtin providers
for provider in builtin_providers:
@@ -591,10 +612,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -604,7 +624,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@@ -764,15 +783,12 @@ class ToolManager:
auth_type,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
controller=controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

View File

@@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
if use_cache:
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
if use_cache:
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager

View File

@@ -0,0 +1,141 @@
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

View File

@@ -0,0 +1,192 @@
import base64
import hashlib
import json
import logging
from typing import Any, Optional
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: Optional[str] = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: str) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
if not oauth_params:
raise ValueError("oauth_params cannot be empty")
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(oauth_params.encode("utf-8"), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
params_json = unpadded_data.decode("utf-8")
oauth_params = json.loads(params_json)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except (ValueError, TypeError) as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: str) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

View File

@@ -1,8 +1,11 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True
except Exception:

View File

@@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
@@ -84,6 +91,7 @@ class AgentNode(ToolNode):
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -94,6 +102,7 @@ class AgentNode(ToolNode):
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
yield RunCompletedEvent(
@@ -246,6 +255,7 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
@@ -276,6 +286,7 @@ class AgentNode(ToolNode):
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
@@ -305,6 +316,27 @@ class AgentNode(ToolNode):
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")

View File

@@ -20,6 +20,7 @@ def handle(sender, **kwargs):
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,

View File

@@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
)
@@ -40,6 +41,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -0,0 +1,41 @@
"""empty message
Revision ID: 16081485540c
Revises: d28f2004b072
Create Date: 2025-05-15 16:35:39.113777
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16081485540c'
down_revision = '2adcbe1f5dfb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
# ### end Alembic commands ###

View File

@@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4474872b0ee6'
down_revision = '2adcbe1f5dfb'
down_revision = '16081485540c'
branch_labels = None
depends_on = None

View File

@@ -0,0 +1,62 @@
"""tool oauth
Revision ID: 71f5020c6470
Revises: 4474872b0ee6
Create Date: 2025-06-24 17:05:43.118647
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '71f5020c6470'
down_revision = '58eb7bdb93fe'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_tenant_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
)
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.drop_column('credential_type')
batch_op.drop_column('is_default')
batch_op.drop_column('name')
op.drop_table('tool_oauth_tenant_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

View File

@@ -21,6 +21,43 @@ from .model import Account, App, Tenant
from .types import StringUUID
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@@ -29,12 +66,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@@ -49,6 +88,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
@property
def credentials(self) -> dict:
@@ -68,7 +112,7 @@ class ApiToolProvider(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
name = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon
icon = db.Column(db.String(255), nullable=False)
# original schema
@@ -281,18 +325,17 @@ class MCPToolProvider(Base):
@property
def decrypted_credentials(self) -> dict:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter(
return create_provider_encrypter(
tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)[0].decrypt(self.credentials)
class ToolModelInvoke(Base):

View File

@@ -575,13 +575,26 @@ class AppDslService:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
# TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node_data.get("dataset_ids", [])
node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
@@ -602,7 +615,15 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
model_config = app_model_config.to_dict()
# TODO: refactor: we need a better way to filter workspace related data from model config
# filter credential id from model config
for tool in model_config.get("agent_mode", {}).get("tools", []):
tool.pop("credential_id", None)
export_data["model_config"] = model_config
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
@@ -38,11 +38,9 @@ class PluginParameterService:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# check if credentials are required
@@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials)
credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")

View File

@@ -196,6 +196,17 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
Check if the plugin is verified
"""
manager = PluginInstaller()
try:
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
except Exception:
return False
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""

View File

@@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@@ -297,28 +293,26 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ProviderConfigEncrypter(
encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials)
credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@@ -416,15 +410,13 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
@@ -446,7 +438,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@@ -474,7 +466,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
tenant_id=tenant_id, tool=tool, labels=labels
)
)

View File

@@ -1,28 +1,83 @@
import json
import logging
import re
from pathlib import Path
from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.entities.api_entities import (
ToolApiEntity,
ToolProviderApiEntity,
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
"""
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
"""
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
)
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
provider_name
)
result = {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
"is_system_oauth_params_exists": is_system_oauth_params_exists,
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
}
return result
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@@ -36,27 +91,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@@ -65,25 +104,15 @@ class BuiltinToolManageService:
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
if builtin_provider is None:
raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -92,128 +121,407 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
user_id: str,
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
name: str | None = None,
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
with Session(db.engine) as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
raise ValueError(str(e))
try:
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
db.session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
# delete cache
tool_configuration.delete_tool_credentials_cache()
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, new_credentials)
db.session.commit()
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
def add_builtin_tool_provider(
user_id: str,
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided
if name is None or name == "":
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if provider_obj is None:
return {}
if len(providers) == 0:
return []
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
credentials: list[ToolProviderCredentialApiEntity] = []
encrypters = {}
for provider in providers:
credential_type = provider.credential_type
if credential_type not in encrypters:
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)[0]
encrypter = encrypters[credential_type]
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
"""
get builtin tool provider credential info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
supported_credential_types = provider_controller.get_supported_credential_types()
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
)
return credential_info
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
db.session.delete(provider_obj)
db.session.commit()
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration.delete_tool_credentials_cache()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_system_client_exists(provider_name: str) -> bool:
"""
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
return system_client is not None
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
"""
get builtin tool provider
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: dict[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
# only verified provider can use custom oauth client
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
if not is_verified:
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
"""
@@ -234,9 +542,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@@ -275,7 +581,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@@ -287,43 +592,153 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
.first()
)
else:
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider_obj is None:
return None
@staticmethod
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
return provider_obj
except Exception:
# it's an old provider without organization
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
original_params = encrypter.decrypt(custom_client_params.oauth_params)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

View File

@@ -7,13 +7,14 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -69,6 +70,7 @@ class MCPToolManageService:
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
@@ -197,8 +199,7 @@ class MCPToolManageService:
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()

View File

@@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -119,7 +121,12 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
)
}
for name, value in schema.items():
if result.masked_credentials:
@@ -136,15 +143,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@@ -287,16 +302,14 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@@ -306,7 +319,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@@ -316,7 +328,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
credentials={},
tenant_id=tenant_id,
)
)
@@ -357,6 +369,19 @@ class ToolTransformService:
labels=labels or [],
)
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
"""

View File

@@ -98,7 +98,7 @@ const Question: FC<QuestionProps> = ({
return (
<div className='mb-2 flex justify-end last:mb-0'>
<div className={cn('group relative mr-4 flex max-w-full items-start pl-14 overflow-x-hidden', isEditing && 'flex-1')}>
<div className={cn('group relative mr-4 flex max-w-full items-start overflow-x-hidden pl-14', isEditing && 'flex-1')}>
<div className={cn('mr-2 gap-1', isEditing ? 'hidden' : 'flex')}>
<div
className="absolute hidden gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm group-hover:flex"
@@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
</div>
<div
ref={contentRef}
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
>
{

View File

@@ -61,7 +61,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
>
<div
className={classNames(
'size-5 border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge relative flex items-center justify-center rounded-[6px]',
'relative flex size-5 items-center justify-center rounded-[6px] border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge',
)}
ref={containerRef}
>
@@ -73,7 +73,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
src={icon}
alt='tool icon'
className={classNames(
'w-full h-full size-3.5 object-cover',
'size-3.5 h-full w-full object-cover',
notSuccess && 'opacity-50',
)}
onError={() => setIconFetchError(true)}
@@ -82,7 +82,7 @@ export const ToolIcon = memo(({ providerName }: ToolIconProps) => {
if (typeof icon === 'object') {
return <AppIcon
className={classNames(
'w-full h-full size-3.5 object-cover',
'size-3.5 h-full w-full object-cover',
notSuccess && 'opacity-50',
)}
icon={icon?.content}