feat(oauth): update api

This commit is contained in:
Harry
2025-06-26 11:44:00 +08:00
parent 6c9e99b0c6
commit ba843c2691
6 changed files with 84 additions and 210 deletions

View File

@@ -35,6 +35,7 @@ class ModelProviderListApi(Resource):
model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
return jsonable_encoder({"data": provider_list})

View File

@@ -371,12 +371,12 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
class ToolApiProviderSchemaApi(Resource):
@@ -789,7 +789,7 @@ api.add_resource(
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
"/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

View File

@@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@@ -64,8 +55,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -109,7 +103,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -127,25 +126,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
return controller
@classmethod
def get_tool_runtime(
@@ -563,6 +544,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -577,30 +574,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
def get_builtin_providers(tenant_id):
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
builtin_providers = cls.list_builtin_providers(tenant_id)
# get builtin providers
db_builtin_providers = get_builtin_providers(tenant_id)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# key: provider name, value: provider
db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
for provider in cls.list_default_builtin_providers(tenant_id)
}
# append builtin providers
for provider in builtin_providers:
@@ -612,10 +592,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -625,7 +604,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()

View File

@@ -1,67 +0,0 @@
import secrets
import urllib.parse
from collections.abc import Mapping
from typing import Any
import requests
from dify_plugin import ToolProvider
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
from werkzeug import Request
class GithubProvider(ToolProvider):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_API_USER_URL = "https://api.github.com/user"
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
"""
Generate the authorization URL for the Github OAuth.
"""
state = secrets.token_urlsafe(16)
params = {
"client_id": system_credentials["client_id"],
"redirect_uri": system_credentials["redirect_uri"],
"scope": system_credentials.get("scope", "read:user"),
"state": state,
# Optionally: allow_signup, login, etc.
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
"""
Exchange code for access_token.
"""
code = request.args.get("code")
state = request.args.get("state")
if not code:
raise ValueError("No code provided")
# Optionally: validate state here
data = {
"client_id": system_credentials["client_id"],
"client_secret": system_credentials["client_secret"],
"code": code,
"redirect_uri": system_credentials["redirect_uri"],
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return {"access_token": access_token}
def _validate_credentials(self, credentials: dict) -> None:
try:
if "access_token" not in credentials or not credentials.get("access_token"):
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
headers = {
"Authorization": f"Bearer {credentials['access_token']}",
"Accept": "application/vnd.github+json",
}
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(response.json().get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -2,6 +2,7 @@ import json
import logging
import re
from pathlib import Path
from typing import Optional, Union
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session
@@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
@@ -40,12 +42,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@@ -53,7 +50,7 @@ class BuiltinToolManageService:
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
credentials = tool_configuration.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
@@ -74,12 +71,7 @@ class BuiltinToolManageService:
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@@ -87,7 +79,7 @@ class BuiltinToolManageService:
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
credentials = tool_configuration.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -100,7 +92,7 @@ class BuiltinToolManageService:
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
"""
list builtin provider credentials schema
@@ -123,35 +115,28 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Decrypt and restore original credentials for masked values
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
for key, value in credentials.items():
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
else:
raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one")
# update name if provided
if name is not None and provider.name != name:
@@ -180,8 +165,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}"
with redis_client.lock(lock_name, timeout=20):
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
with redis_client.lock(lock, timeout=20):
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
@@ -198,12 +183,7 @@ class BuiltinToolManageService:
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
@@ -268,23 +248,17 @@ class BuiltinToolManageService:
return []
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
credentials.append(
ToolTransformService.convert_builtin_provider_to_credential_api_entity(
provider=provider,
credentials=decrypt_credential,
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
@@ -292,22 +266,17 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider_obj is None:
if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider_obj)
db.session.delete(tool_provider)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tool_configuration.delete_tool_credentials_cache()
return {"result": "success"}
@@ -334,7 +303,9 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str):
def get_builtin_tool_oauth_client(
tenant_id: str, provider: str, plugin_id: str
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
"""
get builtin tool provider
"""
@@ -350,14 +321,12 @@ class BuiltinToolManageService:
.first()
)
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
return user_client
if plugin_oauth_config:
return plugin_oauth_config
raise ValueError("no oauth available config found for this plugin")
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if system_client is None:
raise ValueError("no oauth available client config found for this tool provider")
return system_client
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@@ -379,9 +348,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@@ -432,8 +399,8 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
provider = (
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@@ -444,14 +411,14 @@ class BuiltinToolManageService:
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
def _query(provider_filters: list[ColumnExpressionArgument[bool]]):
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
return (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
@@ -484,21 +451,16 @@ class BuiltinToolManageService:
return provider
except Exception:
# it's an old provider without organization
provider_obj = _query([BuiltinToolProvider.provider == provider_name])
return provider_obj
return _query([BuiltinToolProvider.provider == provider_name])
@staticmethod
def _decrypt_and_restore_credentials(tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the input credentials from user
:return: the processed credentials with original values restored
"""
return credentials
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
return ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):

View File

@@ -307,7 +307,7 @@ class ToolTransformService:
)
@staticmethod
def convert_builtin_provider_to_credential_api_entity(
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(