mirror of
https://github.com/langgenius/dify.git
synced 2026-01-06 06:26:00 +00:00
feat(oauth): update api
This commit is contained in:
@@ -35,6 +35,7 @@ class ModelProviderListApi(Resource):
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
|
||||
|
||||
return jsonable_encoder({"data": provider_list})
|
||||
|
||||
|
||||
|
||||
@@ -371,12 +371,12 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider, credential_type):
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@@ -789,7 +789,7 @@ api.add_resource(
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@@ -64,8 +55,11 @@ class ToolManager:
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
|
||||
get the hardcoded provider
|
||||
|
||||
"""
|
||||
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
@@ -109,7 +103,12 @@ class ToolManager:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(Lock())
|
||||
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
|
||||
with contexts.plugin_tool_providers_lock.get():
|
||||
# double check
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
@@ -127,25 +126,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(
|
||||
@@ -563,6 +544,22 @@ class ToolManager:
|
||||
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||
@@ -577,30 +574,13 @@ class ToolManager:
|
||||
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
|
||||
def get_builtin_providers(tenant_id):
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get builtin providers
|
||||
db_builtin_providers = get_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
str(ToolProviderID(provider.provider)): provider
|
||||
for provider in cls.list_default_builtin_providers(tenant_id)
|
||||
}
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
@@ -612,10 +592,9 @@ class ToolManager:
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
db_provider=db_builtin_providers.get(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
@@ -625,7 +604,6 @@ class ToolManager:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from dify_plugin import ToolProvider
|
||||
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
|
||||
from werkzeug import Request
|
||||
|
||||
|
||||
class GithubProvider(ToolProvider):
|
||||
_AUTH_URL = "https://github.com/login/oauth/authorize"
|
||||
_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
_API_USER_URL = "https://api.github.com/user"
|
||||
|
||||
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Generate the authorization URL for the Github OAuth.
|
||||
"""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": system_credentials["client_id"],
|
||||
"redirect_uri": system_credentials["redirect_uri"],
|
||||
"scope": system_credentials.get("scope", "read:user"),
|
||||
"state": state,
|
||||
# Optionally: allow_signup, login, etc.
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
|
||||
"""
|
||||
Exchange code for access_token.
|
||||
"""
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
if not code:
|
||||
raise ValueError("No code provided")
|
||||
# Optionally: validate state here
|
||||
|
||||
data = {
|
||||
"client_id": system_credentials["client_id"],
|
||||
"client_secret": system_credentials["client_secret"],
|
||||
"code": code,
|
||||
"redirect_uri": system_credentials["redirect_uri"],
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in GitHub OAuth: {response_json}")
|
||||
return {"access_token": access_token}
|
||||
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "access_token" not in credentials or not credentials.get("access_token"):
|
||||
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials['access_token']}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
}
|
||||
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(response.json().get("message"))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import ColumnExpressionArgument
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
@@ -40,12 +42,7 @@ class BuiltinToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
@@ -53,7 +50,7 @@ class BuiltinToolManageService:
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
@@ -74,12 +71,7 @@ class BuiltinToolManageService:
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
@@ -87,7 +79,7 @@ class BuiltinToolManageService:
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
@@ -100,7 +92,7 @@ class BuiltinToolManageService:
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
@@ -123,35 +115,28 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
|
||||
try:
|
||||
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Decrypt and restore original credentials for masked values
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
|
||||
credentials[name] = original_credentials[name] # type: ignore
|
||||
for key, value in credentials.items():
|
||||
if key in masked_credentials and value == masked_credentials[key]:
|
||||
credentials[key] = original_credentials[key]
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one")
|
||||
|
||||
# update name if provided
|
||||
if name is not None and provider.name != name:
|
||||
@@ -180,8 +165,8 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
if name is None:
|
||||
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
|
||||
|
||||
@@ -198,12 +183,7 @@ class BuiltinToolManageService:
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
@@ -268,23 +248,17 @@ class BuiltinToolManageService:
|
||||
return []
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
)
|
||||
credentials.append(
|
||||
ToolTransformService.convert_builtin_provider_to_credential_api_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
@@ -292,22 +266,17 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if provider_obj is None:
|
||||
if tool_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider_obj)
|
||||
db.session.delete(tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {"result": "success"}
|
||||
@@ -334,7 +303,9 @@ class BuiltinToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str):
|
||||
def get_builtin_tool_oauth_client(
|
||||
tenant_id: str, provider: str, plugin_id: str
|
||||
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
|
||||
"""
|
||||
get builtin tool provider
|
||||
"""
|
||||
@@ -350,14 +321,12 @@ class BuiltinToolManageService:
|
||||
.first()
|
||||
)
|
||||
if user_client:
|
||||
plugin_oauth_config = user_client
|
||||
else:
|
||||
plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
|
||||
return user_client
|
||||
|
||||
if plugin_oauth_config:
|
||||
return plugin_oauth_config
|
||||
|
||||
raise ValueError("no oauth available config found for this plugin")
|
||||
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
|
||||
if system_client is None:
|
||||
raise ValueError("no oauth available client config found for this tool provider")
|
||||
return system_client
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
@@ -379,9 +348,7 @@ class BuiltinToolManageService:
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
@@ -432,8 +399,8 @@ class BuiltinToolManageService:
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
|
||||
provider = (
|
||||
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
||||
provider: Optional[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
@@ -444,14 +411,14 @@ class BuiltinToolManageService:
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
|
||||
def _query(provider_filters: list[ColumnExpressionArgument[bool]]):
|
||||
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
|
||||
return (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
|
||||
@@ -484,21 +451,16 @@ class BuiltinToolManageService:
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
provider_obj = _query([BuiltinToolProvider.provider == provider_name])
|
||||
return provider_obj
|
||||
return _query([BuiltinToolProvider.provider == provider_name])
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_and_restore_credentials(tool_configuration, provider, credentials):
|
||||
"""
|
||||
Decrypt original credentials and restore masked values from the input credentials
|
||||
|
||||
:param tool_configuration: the tool configuration encrypter
|
||||
:param provider: the provider object from database
|
||||
:param credentials: the input credentials from user
|
||||
:return: the processed credentials with original values restored
|
||||
"""
|
||||
|
||||
return credentials
|
||||
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
||||
return ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
|
||||
|
||||
@@ -307,7 +307,7 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_api_entity(
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
provider: BuiltinToolProvider, credentials: dict
|
||||
) -> ToolProviderCredentialApiEntity:
|
||||
return ToolProviderCredentialApiEntity(
|
||||
|
||||
Reference in New Issue
Block a user