diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 3e906a9f97..d92e33df5d 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -1,6 +1,5 @@ -import random - -from flask import redirect, request +from fastapi.encoders import jsonable_encoder +from flask import make_response, redirect, request from flask_login import current_user # type: ignore from flask_restful import ( # type: ignore Resource, # type: ignore @@ -15,76 +14,101 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db from libs.login import login_required -from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from models.oauth import DatasourceOauthParamConfig from services.datasource_provider_service import DatasourceProviderService +from services.plugin.oauth_service import OAuthProxyService -class DatasourcePluginOauthApi(Resource): +class DatasourcePluginOAuthAuthorizationUrl(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - provider = args["provider"] - plugin_id = args["plugin_id"] - # Check user role first + def get(self, provider: str): + user = current_user + tenant_id = user.current_tenant_id if not current_user.is_editor: raise Forbidden() - # get all plugin oauth configs - plugin_oauth_config = ( - db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + + provider_id = DatasourceProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + oauth_config = ( + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first() + ) + if not oauth_config: + raise ValueError(f"No OAuth Client Config for {provider}") + + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) - if not plugin_oauth_config: - raise NotFound() oauth_handler = OAuthHandler() - redirect_url = ( - f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" + redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback" + oauth_client_params = oauth_config.system_credentials + + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, ) - system_credentials = plugin_oauth_config.system_credentials - if system_credentials: - system_credentials["redirect_url"] = redirect_url - response = oauth_handler.get_authorization_url( - current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, ) - return response.model_dump() + return response -class DatasourceOauthCallback(Resource): +class DatasourceOAuthCallback(Resource): @setup_required - @login_required - @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - provider = args["provider"] - plugin_id = args["plugin_id"] - oauth_handler = OAuthHandler() + def get(self, provider: str): + if not current_user.is_editor: + raise Forbidden() + + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + provider_id = DatasourceProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id plugin_oauth_config = ( - db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first() ) if not plugin_oauth_config: raise NotFound() - credentials = oauth_handler.get_credentials( - current_user.current_tenant.id, - current_user.id, - plugin_id, - provider, + redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback" + oauth_handler = OAuthHandler() + oauth_response = oauth_handler.get_credentials( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_id.provider_name, + redirect_uri=redirect_uri, system_credentials=plugin_oauth_config.system_credentials, request=request, ) - datasource_provider = DatasourceProvider( - plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.add_datasource_oauth_provider( + tenant_id=tenant_id, + provider_id=provider_id, + credentials=dict(oauth_response.credentials), + name=None, ) - db.session.add(datasource_provider) - db.session.commit() return redirect(f"{dify_config.CONSOLE_WEB_URL}") @@ -92,26 +116,23 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def post(self): + def post(self, provider: str): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None) parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() - + provider_id = DatasourceProviderID(provider) datasource_provider_service = DatasourceProviderService() try: - datasource_provider_service.datasource_provider_credentials_validate( + datasource_provider_service.add_datasource_api_key_provider( tenant_id=current_user.current_tenant_id, - provider=args["provider"], - plugin_id=args["plugin_id"], + provider_id=provider_id, credentials=args["credentials"], - name="test" + str(random.randint(1, 1000000)), # noqa: S311 + name=args["name"], ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) @@ -121,14 +142,13 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + def get(self, provider: str): + provider_id = DatasourceProviderID(provider) datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"] + tenant_id=current_user.current_tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, ) return {"result": datasources}, 200 @@ -137,29 +157,27 @@ class DatasourceAuthUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, auth_id: str): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + def delete(self, provider: str, auth_id: str): + provider_id = DatasourceProviderID(provider) + plugin_id = provider_id.plugin_id + provider_name = provider_id.provider_name if not current_user.is_editor: raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( tenant_id=current_user.current_tenant_id, auth_id=auth_id, - provider=args["provider"], - plugin_id=args["plugin_id"], + provider=provider_name, + plugin_id=plugin_id, ) return {"result": "success"}, 200 @setup_required @login_required @account_initialization_required - def patch(self, auth_id: str): + def patch(self, provider: str, auth_id: str): + provider_id = DatasourceProviderID(provider) parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() if not current_user.is_editor: @@ -169,8 +187,8 @@ class DatasourceAuthUpdateDeleteApi(Resource): datasource_provider_service.update_datasource_credentials( tenant_id=current_user.current_tenant_id, auth_id=auth_id, - provider=args["provider"], - plugin_id=args["plugin_id"], + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: @@ -193,21 +211,21 @@ class DatasourceAuthListApi(Resource): # Import Rag Pipeline api.add_resource( - DatasourcePluginOauthApi, - "/oauth/plugin/datasource", + DatasourcePluginOAuthAuthorizationUrl, + "/oauth/plugin//datasource/get-authorization-url", ) api.add_resource( - DatasourceOauthCallback, - "/oauth/plugin/datasource/callback", + DatasourceOAuthCallback, + "/oauth/plugin//datasource/callback", ) api.add_resource( DatasourceAuth, - "/auth/plugin/datasource", + "/auth/plugin/datasource/", ) api.add_resource( DatasourceAuthUpdateDeleteApi, - "/auth/plugin/datasource/", + "/auth/plugin/datasource//", ) api.add_resource( diff --git a/api/core/helper/provider_name_generator.py b/api/core/helper/provider_name_generator.py new file mode 100644 index 0000000000..dd5816aa24 --- /dev/null +++ b/api/core/helper/provider_name_generator.py @@ -0,0 +1,35 @@ +import logging +import re +from collections.abc import Sequence +from typing import Any + +from core.tools.entities.tool_entities import CredentialType + +logger = logging.getLogger(__name__) + + +def generate_provider_name( + providers: Sequence[Any], + credential_type: CredentialType, + fallback_context: str = "provider" +) -> str: + try: + default_pattern = f"{credential_type.get_name()}" + + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] + + for provider in providers: + if provider.name: + match = re.match(pattern, provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + if not numbers: + return f"{default_pattern} 1" + + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {fallback_context}: {str(e)}") + return f"{credential_type.get_name()} 1" \ No newline at end of file diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index bbb043aacd..5f1251d9d8 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,13 +1,18 @@ import logging from flask_login import current_user +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE from core.helper import encrypter +from core.helper.provider_name_generator import generate_provider_name from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.datasource import PluginDatasourceManager +from core.tools.entities.tool_entities import CredentialType from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.oauth import DatasourceProvider logger = logging.getLogger(__name__) @@ -21,8 +26,71 @@ class DatasourceProviderService: def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() - def datasource_provider_credentials_validate( - self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str + @staticmethod + def generate_next_datasource_provider_name( + session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType + ) -> str: + db_providers = ( + session.query(DatasourceProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + auth_type=credential_type.value, + ) + .all() + ) + return generate_provider_name(db_providers, credential_type, f"datasource provider {provider_id}") + + def add_datasource_oauth_provider( + self, + name: str | None, + tenant_id: str, + provider_id: DatasourceProviderID, + credentials: dict, + ) -> None: + """ + add datasource oauth provider + """ + credential_type = CredentialType.OAUTH2 + with Session(db.engine) as session: + lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" + with redis_client.lock(lock, timeout=20): + db_provider_name = name or self.generate_next_datasource_provider_name( + session=session, + tenant_id=tenant_id, + provider_id=provider_id, + credential_type=credential_type, + ) + + if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0: + raise ValueError("name is already exists") + + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}" + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_id.provider_name, + plugin_id=provider_id.plugin_id, + auth_type=credential_type.value, + encrypted_credentials=credentials, + ) + session.add(datasource_provider) + session.commit() + + def add_datasource_api_key_provider( + self, + name: str | None, + tenant_id: str, + provider_id: DatasourceProviderID, + credentials: dict, ) -> None: """ validate datasource provider credentials. @@ -31,45 +99,49 @@ class DatasourceProviderService: :param provider: :param credentials: """ - # check name is exist - datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first() - if datasource_provider: - raise ValueError("Authorization name is already exists") + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + with Session(db.engine) as session: + lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key" + with redis_client.lock(lock, timeout=20): + db_provider_name = name or self.generate_next_datasource_provider_name( + session=session, + tenant_id=tenant_id, + provider_id=provider_id, + credential_type=CredentialType.API_KEY, + ) - credential_valid = self.provider_manager.validate_provider_credentials( - tenant_id=tenant_id, - user_id=current_user.id, - provider=provider, - plugin_id=plugin_id, - credentials=credentials, - ) - if credential_valid: - # Get all provider configurations of the current workspace - datasource_provider = ( - db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key") - .first() - ) + # check name is exist + if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0: + raise ValueError("Authorization name is already exists") - provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" - ) - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider( - tenant_id=tenant_id, - name=name, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials, - ) - db.session.add(datasource_provider) - db.session.commit() - else: - raise CredentialsValidateFailedError() + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, + user_id=current_user.id, + provider=provider_name, + plugin_id=plugin_id, + credentials=credentials, + ) + if credential_valid: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, provider_id=f"{provider_id}" + ) + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + name=db_provider_name, + provider=provider_name, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) + db.session.add(datasource_provider) + db.session.commit() + else: + raise CredentialsValidateFailedError() def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: """ diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 430575b532..f12371f77c 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,5 @@ import json import logging -import re from collections.abc import Mapping from pathlib import Path from typing import Any, Optional @@ -11,6 +10,7 @@ from configs import dify_config from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache +from core.helper.provider_name_generator import generate_provider_name from core.plugin.entities.plugin import ToolProviderID from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -299,42 +299,18 @@ class BuiltinToolManageService: def generate_builtin_tool_provider_name( session: Session, tenant_id: str, provider: str, credential_type: CredentialType ) -> str: - try: - db_providers = ( - session.query(BuiltinToolProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider, - credential_type=credential_type.value, - ) - .order_by(BuiltinToolProvider.created_at.desc()) - .all() + db_providers = ( + session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type.value, ) - - # Get the default name pattern - default_pattern = f"{credential_type.get_name()}" - - # Find all names that match the default pattern: "{default_pattern} {number}" - pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" - numbers = [] - - for db_provider in db_providers: - if db_provider.name: - match = re.match(pattern, db_provider.name.strip()) - if match: - numbers.append(int(match.group(1))) - - # If no default pattern names found, start with 1 - if not numbers: - return f"{default_pattern} 1" - - # Find the next number - max_number = max(numbers) - return f"{default_pattern} {max_number + 1}" - except Exception as e: - logger.warning(f"Error generating next provider name for {provider}: {str(e)}") - # fallback - return f"{credential_type.get_name()} 1" + .order_by(BuiltinToolProvider.created_at.desc()) + .all() + ) + + return generate_provider_name(db_providers, credential_type, f"builtin tool provider {provider}") @staticmethod def get_builtin_tool_provider_credentials(