diff --git a/api/.importlinter b/api/.importlinter index c9364a0896..725999c28e 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -50,7 +50,6 @@ forbidden_modules = allow_indirect_imports = True ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database @@ -106,9 +105,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.model_manager core.workflow.nodes.agent.agent_node -> core.provider_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager - core.workflow.nodes.datasource.datasource_node -> models.model - core.workflow.nodes.datasource.datasource_node -> models.tools - core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.http_request.node -> core.tools.tool_file_manager core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory @@ -146,8 +142,6 @@ ignore_imports = core.workflow.workflow_entry -> core.app.apps.exc core.workflow.workflow_entry -> core.app.entities.app_invoke_entities core.workflow.workflow_entry -> core.app.workflow.node_factory - core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager - core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager @@ -160,7 +154,6 @@ ignore_imports = core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor - core.workflow.nodes.datasource.datasource_node -> core.variables.variables core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy @@ -197,7 +190,6 @@ ignore_imports = core.workflow.nodes.code.code_node -> core.variables.segments core.workflow.nodes.code.code_node -> core.variables.types core.workflow.nodes.code.entities -> core.variables.types - core.workflow.nodes.datasource.datasource_node -> core.variables.segments core.workflow.nodes.document_extractor.node -> core.variables core.workflow.nodes.document_extractor.node -> core.variables.segments core.workflow.nodes.http_request.executor -> core.variables.segments @@ -240,7 +232,6 @@ ignore_imports = core.workflow.variable_loader -> core.variables.consts core.workflow.workflow_type_encoder -> core.variables core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 07dec1b070..3eeb1d5d58 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -5,6 +5,7 @@ from typing_extensions import override from configs import dify_config from core.app.llm.model_access import build_dify_model_access +from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.ssrf_proxy import ssrf_proxy @@ -18,6 +19,7 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor from core.workflow.nodes.code.entities import CodeLanguage from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.datasource import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode @@ -178,6 +180,15 @@ class DifyNodeFactory(NodeFactory): model_factory=self._llm_model_factory, ) + if node_type == NodeType.DATASOURCE: + return DatasourceNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + datasource_manager=DatasourceManager, + ) + if node_type == NodeType.KNOWLEDGE_RETRIEVAL: return KnowledgeRetrievalNode( id=node_id, diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 002415a7db..9c48f755a9 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -1,16 +1,39 @@ import logging +from collections.abc import Generator from threading import Lock +from typing import Any, cast + +from sqlalchemy import select import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.datasource.entities.datasource_entities import ( + DatasourceMessage, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + OnlineDriveDownloadFileRequest, +) from core.datasource.errors import DatasourceProviderNotFoundError from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController +from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowNodeExecutionMetadataKey +from core.workflow.file import File +from core.workflow.file.enums import FileTransferMethod, FileType +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam +from factories import file_factory +from models.model import UploadFile +from models.tools import ToolFile +from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -103,3 +126,238 @@ class DatasourceManager: tenant_id, datasource_type, ).get_datasource(datasource_name) + + @classmethod + def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: + datasource_runtime = cls.get_datasource_runtime( + provider_id=provider_id, + datasource_name=datasource_name, + tenant_id=tenant_id, + datasource_type=DatasourceProviderType.value_of(datasource_type), + ) + return datasource_runtime.get_icon_url(tenant_id) + + @classmethod + def stream_online_results( + cls, + *, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[DatasourceMessage, None, Any]: + """ + Pull-based streaming of domain messages from datasource plugins. + Returns a generator that yields DatasourceMessage and finally returns a minimal final payload. + Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly. + """ + ds_type = DatasourceProviderType.value_of(datasource_type) + runtime = cls.get_datasource_runtime( + provider_id=provider_id, + datasource_name=datasource_name, + tenant_id=tenant_id, + datasource_type=ds_type, + ) + + dsp_service = DatasourceProviderService() + credentials = dsp_service.get_datasource_credentials( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + credential_id=credential_id, + ) + + if ds_type == DatasourceProviderType.ONLINE_DOCUMENT: + doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime) + if credentials: + doc_runtime.runtime.credentials = credentials + if datasource_param is None: + raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming") + inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content( + user_id=user_id, + datasource_parameters=GetOnlineDocumentPageContentRequest( + workspace_id=datasource_param.workspace_id, + page_id=datasource_param.page_id, + type=datasource_param.type, + ), + provider_type=ds_type, + ) + elif ds_type == DatasourceProviderType.ONLINE_DRIVE: + drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime) + if credentials: + drive_runtime.runtime.credentials = credentials + if online_drive_request is None: + raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming") + inner_gen = drive_runtime.online_drive_download_file( + user_id=user_id, + request=OnlineDriveDownloadFileRequest( + id=online_drive_request.id, + bucket=online_drive_request.bucket, + ), + provider_type=ds_type, + ) + else: + raise ValueError(f"Unsupported datasource type for streaming: {ds_type}") + + # Bridge through to caller while preserving generator return contract + yield from inner_gen + # No structured final data here; node/adapter will assemble outputs + return {} + + @classmethod + def stream_node_events( + cls, + *, + node_id: str, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + variable_pool: Any, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: + ds_type = DatasourceProviderType.value_of(datasource_type) + + messages = cls.stream_online_results( + user_id=user_id, + datasource_name=datasource_name, + datasource_type=datasource_type, + provider_id=provider_id, + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + credential_id=credential_id, + datasource_param=datasource_param, + online_drive_request=online_drive_request, + ) + + transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None + ) + + variables: dict[str, Any] = {} + file_out: File | None = None + + for message in transformed: + mtype = message.type + if mtype in { + DatasourceMessage.MessageType.IMAGE_LINK, + DatasourceMessage.MessageType.BINARY_LINK, + DatasourceMessage.MessageType.IMAGE, + }: + wanted_ds_type = ds_type in { + DatasourceProviderType.ONLINE_DRIVE, + DatasourceProviderType.ONLINE_DOCUMENT, + } + if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage): + url = message.message.text + + datasource_file_id = str(url).split("/")[-1].split(".")[0] + with session_factory.create_session() as session: + stmt = select(ToolFile).where( + ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id + ) + datasource_file = session.scalar(stmt) + if not datasource_file: + raise ValueError( + f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}" + ) + mime_type = datasource_file.mimetype + if datasource_file is not None: + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(mime_type), + "transfer_method": FileTransferMethod.TOOL_FILE, + "url": url, + } + file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + elif mtype == DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) + elif mtype == DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + yield StreamChunkEvent( + selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False + ) + elif mtype == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + name = message.message.variable_name + value = message.message.variable_value + if message.message.stream: + assert isinstance(value, str), "stream variable_value must be str" + variables[name] = variables.get(name, "") + value + yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False) + else: + variables[name] = value + elif mtype == DatasourceMessage.MessageType.FILE: + if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta: + f = message.meta.get("file") + if isinstance(f, File): + file_out = f + else: + pass + + yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True) + + if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None: + variable_pool.add([node_id, "file"], file_out) + + if ds_type == DatasourceProviderType.ONLINE_DOCUMENT: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={**variables}, + ) + ) + else: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=parameters_for_log, + metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, + outputs={ + "file": file_out, + "datasource_type": ds_type, + }, + ) + ) + + @classmethod + def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: + with session_factory.create_session() as session: + upload_file = ( + session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() + ) + if not upload_file: + raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") + + file_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=tenant_id, + type=FileType.CUSTOM, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=upload_file.source_url, + ) + return file_info diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index dde7d59726..a063a3680b 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -379,4 +379,11 @@ class OnlineDriveDownloadFileRequest(BaseModel): """ id: str = Field(..., description="The id of the file") - bucket: str | None = Field(None, description="The name of the bucket") + bucket: str = Field("", description="The name of the bucket") + + @field_validator("bucket", mode="before") + @classmethod + def _coerce_bucket(cls, v) -> str: + if v is None: + return "" + return str(v) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 80869ac7f7..17f8bcb2db 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,40 +1,26 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.datasource.entities.datasource_entities import ( - DatasourceMessage, - DatasourceParameter, - DatasourceProviderType, - GetOnlineDocumentPageContentRequest, - OnlineDriveDownloadFileRequest, -) -from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin -from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin -from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.variables.segments import ArrayAnySegment -from core.variables.variables import ArrayAnyVariable from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.file import File -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.tool.exc import ToolFileError -from core.workflow.runtime import VariablePool -from extensions.ext_database import db -from factories import file_factory -from models.model import UploadFile -from models.tools import ToolFile -from services.datasource_provider_service import DatasourceProviderService +from core.workflow.repositories.datasource_manager_protocol import ( + DatasourceManagerProtocol, + DatasourceParameter, + OnlineDriveDownloadFileParam, +) from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from .entities import DatasourceNodeData -from .exc import DatasourceNodeError, DatasourceParameterError +from .exc import DatasourceNodeError + +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -45,6 +31,22 @@ class DatasourceNode(Node[DatasourceNodeData]): node_type = NodeType.DATASOURCE execution_type = NodeExecutionType.ROOT + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + datasource_manager: DatasourceManagerProtocol, + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self.datasource_manager = datasource_manager + def _run(self) -> Generator: """ Run the datasource node @@ -52,84 +54,69 @@ class DatasourceNode(Node[DatasourceNodeData]): node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) - if not datasource_type_segement: + datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") - datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None - datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) - if not datasource_info_segement: + datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None + datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") - datasource_info_value = datasource_info_segement.value + datasource_info_value = datasource_info_segment.value if not isinstance(datasource_info_value, dict): raise DatasourceNodeError("Invalid datasource info format") datasource_info: dict[str, Any] = datasource_info_value - # get datasource runtime - from core.datasource.datasource_manager import DatasourceManager if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") datasource_type = DatasourceProviderType.value_of(datasource_type) + provider_id = f"{node_data.plugin_id}/{node_data.provider_name}" - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", + datasource_info["icon"] = self.datasource_manager.get_icon_url( + provider_id=provider_id, datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=datasource_type, + datasource_type=datasource_type.value, ) - datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) parameters_for_log = datasource_info try: - datasource_provider_service = DatasourceProviderService() - credentials = datasource_provider_service.get_datasource_credentials( - tenant_id=self.tenant_id, - provider=node_data.provider_name, - plugin_id=node_data.plugin_id, - credential_id=datasource_info.get("credential_id", ""), - ) match datasource_type: - case DatasourceProviderType.ONLINE_DOCUMENT: - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - if credentials: - datasource_runtime.runtime.credentials = credentials - online_document_result: Generator[DatasourceMessage, None, None] = ( - datasource_runtime.get_online_document_page_content( - user_id=self.user_id, - datasource_parameters=GetOnlineDocumentPageContentRequest( - workspace_id=datasource_info.get("workspace_id", ""), - page_id=datasource_info.get("page", {}).get("page_id", ""), - type=datasource_info.get("page", {}).get("type", ""), - ), - provider_type=datasource_type, + case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE: + # Build typed request objects + datasource_parameters = None + if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT: + datasource_parameters = DatasourceParameter( + workspace_id=datasource_info.get("workspace_id", ""), + page_id=datasource_info.get("page", {}).get("page_id", ""), + type=datasource_info.get("page", {}).get("type", ""), ) - ) - yield from self._transform_message( - messages=online_document_result, - parameters_for_log=parameters_for_log, - datasource_info=datasource_info, - ) - case DatasourceProviderType.ONLINE_DRIVE: - datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime) - if credentials: - datasource_runtime.runtime.credentials = credentials - online_drive_result: Generator[DatasourceMessage, None, None] = ( - datasource_runtime.online_drive_download_file( - user_id=self.user_id, - request=OnlineDriveDownloadFileRequest( - id=datasource_info.get("id", ""), - bucket=datasource_info.get("bucket"), - ), - provider_type=datasource_type, + + online_drive_request = None + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: + online_drive_request = OnlineDriveDownloadFileParam( + id=datasource_info.get("id", ""), + bucket=datasource_info.get("bucket", ""), ) - ) - yield from self._transform_datasource_file_message( - messages=online_drive_result, + + credential_id = datasource_info.get("credential_id", "") + + yield from self.datasource_manager.stream_node_events( + node_id=self._node_id, + user_id=self.user_id, + datasource_name=node_data.datasource_name or "", + datasource_type=datasource_type.value, + provider_id=provider_id, + tenant_id=self.tenant_id, + provider=node_data.provider_name, + plugin_id=node_data.plugin_id, + credential_id=credential_id, parameters_for_log=parameters_for_log, datasource_info=datasource_info, variable_pool=variable_pool, - datasource_type=datasource_type, + datasource_param=datasource_parameters, + online_drive_request=online_drive_request, ) case DatasourceProviderType.WEBSITE_CRAWL: yield StreamCompletedEvent( @@ -147,23 +134,9 @@ class DatasourceNode(Node[DatasourceNodeData]): related_id = datasource_info.get("related_id") if not related_id: raise DatasourceNodeError("File is not exist") - upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first() - if not upload_file: - raise ValueError("Invalid upload file Info") - file_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, - type=FileType.CUSTOM, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=upload_file.source_url, + file_info = self.datasource_manager.get_upload_file_by_id( + file_id=related_id, tenant_id=self.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) @@ -201,55 +174,6 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) - def _generate_parameters( - self, - *, - datasource_parameters: Sequence[DatasourceParameter], - variable_pool: VariablePool, - node_data: DatasourceNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} - - result: dict[str, Any] = {} - if node_data.datasource_parameters: - for parameter_name in node_data.datasource_parameters: - parameter = datasource_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - datasource_input = node_data.datasource_parameters[parameter_name] - if datasource_input.type == "variable": - variable = variable_pool.get(datasource_input.value) - if variable is None: - raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") - parameter_value = variable.value - elif datasource_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(datasource_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _fetch_files(self, variable_pool: VariablePool) -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -287,206 +211,6 @@ class DatasourceNode(Node[DatasourceNodeData]): return result - def _transform_message( - self, - messages: Generator[DatasourceMessage, None, None], - parameters_for_log: dict[str, Any], - datasource_info: dict[str, Any], - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in message_stream: - match message.type: - case ( - DatasourceMessage.MessageType.IMAGE_LINK - | DatasourceMessage.MessageType.BINARY_LINK - | DatasourceMessage.MessageType.IMAGE - ): - assert isinstance(message.message, DatasourceMessage.TextMessage) - - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE - - datasource_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - case DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - ) - case DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - case DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - case DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - case DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - case DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) - case ( - DatasourceMessage.MessageType.BLOB_CHUNK - | DatasourceMessage.MessageType.LOG - | DatasourceMessage.MessageType.RETRIEVER_RESOURCES - ): - pass - - # mark the end of the stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={**variables}, - metadata={ - WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info, - }, - inputs=parameters_for_log, - ) - ) - @classmethod def version(cls) -> str: return "1" - - def _transform_datasource_file_message( - self, - messages: Generator[DatasourceMessage, None, None], - parameters_for_log: dict[str, Any], - datasource_info: dict[str, Any], - variable_pool: VariablePool, - datasource_type: DatasourceProviderType, - ) -> Generator: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages( - messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - ) - file = None - for message in message_stream: - if message.type == DatasourceMessage.MessageType.BINARY_LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE - - datasource_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - if file: - variable_pool.add([self._node_id, "file"], file) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - outputs={ - "file": file, - "datasource_type": datasource_type, - }, - ) - ) diff --git a/api/core/workflow/repositories/datasource_manager_protocol.py b/api/core/workflow/repositories/datasource_manager_protocol.py new file mode 100644 index 0000000000..4acf486bef --- /dev/null +++ b/api/core/workflow/repositories/datasource_manager_protocol.py @@ -0,0 +1,50 @@ +from collections.abc import Generator +from typing import Any, Protocol + +from pydantic import BaseModel + +from core.workflow.file import File +from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent + + +class DatasourceParameter(BaseModel): + workspace_id: str + page_id: str + type: str + + +class OnlineDriveDownloadFileParam(BaseModel): + id: str + bucket: str + + +class DatasourceFinal(BaseModel): + data: dict[str, Any] | None = None + + +class DatasourceManagerProtocol(Protocol): + @classmethod + def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ... + + @classmethod + def stream_node_events( + cls, + *, + node_id: str, + user_id: str, + datasource_name: str, + datasource_type: str, + provider_id: str, + tenant_id: str, + provider: str, + plugin_id: str, + credential_id: str, + parameters_for_log: dict[str, Any], + datasource_info: dict[str, Any], + variable_pool: Any, + datasource_param: DatasourceParameter | None = None, + online_drive_request: OnlineDriveDownloadFileParam | None = None, + ) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ... + + @classmethod + def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ... diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py new file mode 100644 index 0000000000..003bb356e5 --- /dev/null +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -0,0 +1,42 @@ +from collections.abc import Generator + +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.workflow.node_events import StreamCompletedEvent + + +def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: + # produce a streamed variable "a"="xy" + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="x", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="y", stream=True), + meta=None, + ) + + +def test_stream_node_events_accumulates_variables(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream()) + events = list( + DatasourceManager.stream_node_events( + node_id="A", + user_id="u", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={"user_id": "u"}, + variable_pool=mocker.Mock(), + datasource_param=type("P", (), {"workspace_id": "w", "page_id": "pg", "type": "t"})(), + online_drive_request=None, + ) + ) + assert isinstance(events[-1], StreamCompletedEvent) diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py new file mode 100644 index 0000000000..909d6377ce --- /dev/null +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -0,0 +1,84 @@ +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import NodeRunResult, StreamCompletedEvent +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + + +class _Seg: + def __init__(self, v): + self.value = v + + +class _VarPool: + def __init__(self, data): + self.data = data + + def get(self, path): + d = self.data + for k in path: + d = d[k] + return _Seg(d) + + def add(self, *_a, **_k): + pass + + +class _GS: + def __init__(self, vp): + self.variable_pool = vp + + +class _GP: + tenant_id = "t1" + app_id = "app-1" + workflow_id = "wf-1" + graph_config = {} + user_id = "u1" + user_from = "account" + invoke_from = "debugger" + call_depth = 0 + + +def test_node_integration_minimal_stream(mocker): + sys_d = { + "sys": { + "datasource_type": "online_document", + "datasource_info": {"workspace_id": "w", "page": {"page_id": "pg", "type": "t"}, "credential_id": ""}, + } + } + vp = _VarPool(sys_d) + + class _Mgr: + @classmethod + def get_icon_url(cls, **_): + return "icon" + + @classmethod + def stream_node_events(cls, **_): + yield from () + yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)) + + @classmethod + def get_upload_file_by_id(cls, **_): + raise AssertionError + + node = DatasourceNode( + id="n", + config={ + "id": "n", + "data": { + "type": "datasource", + "version": "1", + "title": "Datasource", + "provider_type": "plugin", + "provider_name": "p", + "plugin_id": "plug", + "datasource_name": "ds", + }, + }, + graph_init_params=_GP(), + graph_runtime_state=_GS(vp), + datasource_manager=_Mgr, + ) + + out = list(node._run()) + assert isinstance(out[-1], StreamCompletedEvent) diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py new file mode 100644 index 0000000000..9ee1df8bdc --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -0,0 +1,135 @@ +import types +from collections.abc import Generator + +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent + + +def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage(text=text), + meta=None, + ) + + +def test_get_icon_url_calls_runtime(mocker): + fake_runtime = mocker.Mock() + fake_runtime.get_icon_url.return_value = "https://icon" + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime) + + url = DatasourceManager.get_icon_url( + provider_id="p/x", + tenant_id="t1", + datasource_name="ds", + datasource_type="online_document", + ) + assert url == "https://icon" + DatasourceManager.get_datasource_runtime.assert_called_once() + + +def test_stream_online_results_yields_messages_online_document(mocker): + # stub runtime to yield a text message + def _doc_messages(**_): + yield from _gen_messages_text_only("hello") + + fake_runtime = mocker.Mock() + fake_runtime.get_online_document_page_content.side_effect = _doc_messages + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value=None, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + msgs = list(gen) + assert len(msgs) == 1 + assert msgs[0].message.text == "hello" + + +def test_stream_node_events_emits_events_online_document(mocker): + # make manager's low-level stream produce TEXT only + mocker.patch.object( + DatasourceManager, + "stream_online_results", + return_value=_gen_messages_text_only("hello"), + ) + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={"k": "v"}, + datasource_info={"user_id": "u1"}, + variable_pool=mocker.Mock(), + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + # should contain one StreamChunkEvent then a final chunk (empty) and a completed event + assert isinstance(events[0], StreamChunkEvent) + assert events[0].chunk == "hello" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + + +def test_get_upload_file_by_id_builds_file(mocker): + # fake UploadFile row + fake_row = types.SimpleNamespace( + id="fid", + name="f", + extension="txt", + mime_type="text/plain", + size=1, + key="k", + source_url="http://x", + ) + + class _Q: + def __init__(self, row): + self._row = row + + def where(self, *_args, **_kwargs): + return self + + def first(self): + return self._row + + class _S: + def __init__(self, row): + self._row = row + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def query(self, *_): + return _Q(self._row) + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) + + f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") + assert f.related_id == "fid" + assert f.extension == ".txt" diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py new file mode 100644 index 0000000000..584ed23e91 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -0,0 +1,93 @@ +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from core.workflow.nodes.datasource.datasource_node import DatasourceNode + + +class _VarSeg: + def __init__(self, v): + self.value = v + + +class _VarPool: + def __init__(self, mapping): + self._m = mapping + + def get(self, selector): + d = self._m + for k in selector: + d = d[k] + return _VarSeg(d) + + def add(self, *_args, **_kwargs): + pass + + +class _GraphState: + def __init__(self, var_pool): + self.variable_pool = var_pool + + +class _GraphParams: + tenant_id = "t1" + app_id = "app-1" + workflow_id = "wf-1" + graph_config = {} + user_id = "u1" + user_from = "account" + invoke_from = "debugger" + call_depth = 0 + + +def test_datasource_node_delegates_to_manager_stream(mocker): + # prepare sys variables + sys_vars = { + "sys": { + "datasource_type": "online_document", + "datasource_info": { + "workspace_id": "w", + "page": {"page_id": "pg", "type": "t"}, + "credential_id": "", + }, + } + } + var_pool = _VarPool(sys_vars) + gs = _GraphState(var_pool) + gp = _GraphParams() + + # stub manager class + class _Mgr: + @classmethod + def get_icon_url(cls, **_): + return "icon" + + @classmethod + def stream_node_events(cls, **_): + yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False) + yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)) + + @classmethod + def get_upload_file_by_id(cls, **_): + raise AssertionError("not called") + + node = DatasourceNode( + id="n", + config={ + "id": "n", + "data": { + "type": "datasource", + "version": "1", + "title": "Datasource", + "provider_type": "plugin", + "provider_name": "p", + "plugin_id": "plug", + "datasource_name": "ds", + }, + }, + graph_init_params=gp, + graph_runtime_state=gs, + datasource_manager=_Mgr, + ) + + evts = list(node._run()) + assert isinstance(evts[0], StreamChunkEvent) + assert isinstance(evts[-1], StreamCompletedEvent)