From eca97b90833662c4634d2cfde1525c049c2ab560 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 16 Feb 2026 19:29:04 +0800 Subject: [PATCH] feat: combine changes after rebase onto main --- api/.importlinter | 1 - api/core/app/workflow/node_factory.py | 47 +++++--- .../nodes/document_extractor/__init__.py | 4 +- .../nodes/document_extractor/entities.py | 7 ++ .../workflow/nodes/document_extractor/node.py | 111 ++++++++++++------ 5 files changed, 113 insertions(+), 57 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index e30f498ba9..68243baa4f 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -114,7 +114,6 @@ ignore_imports = core.workflow.nodes.datasource.datasource_node -> models.model core.workflow.nodes.datasource.datasource_node -> models.tools core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service - core.workflow.nodes.document_extractor.node -> configs core.workflow.nodes.document_extractor.node -> core.file.file_manager core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.http_request.entities -> configs diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 18db750d28..efb6a51ea9 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, final +from typing import TYPE_CHECKING, Any, cast, final from typing_extensions import override @@ -16,6 +16,7 @@ from core.workflow.graph.graph import NodeFactory from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -53,6 +54,7 @@ class DifyNodeFactory(NodeFactory): http_request_http_client: HttpClientProtocol | None = None, http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, http_request_file_manager: FileManagerProtocol | None = None, + document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state @@ -78,6 +80,13 @@ class DifyNodeFactory(NodeFactory): self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory self._http_request_file_manager = http_request_file_manager or file_manager self._rag_retrieval = DatasetRetrieval() + self._document_extractor_unstructured_api_config = ( + document_extractor_unstructured_api_config + or UnstructuredApiConfig( + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY or "", + ) + ) @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -110,13 +119,17 @@ class DifyNodeFactory(NodeFactory): if not node_class: raise ValueError(f"No latest version class found for node type: {node_type}") + common_kwargs: dict[str, Any] = { + "id": node_id, + "config": node_config, + "graph_init_params": self.graph_init_params, + "graph_runtime_state": self.graph_runtime_state, + } + # Create node instance if node_type == NodeType.CODE: return CodeNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, + **common_kwargs, code_executor=self._code_executor, code_providers=self._code_providers, code_limits=self._code_limits, @@ -124,20 +137,14 @@ class DifyNodeFactory(NodeFactory): if node_type == NodeType.TEMPLATE_TRANSFORM: return TemplateTransformNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, + **common_kwargs, template_renderer=self._template_renderer, max_output_length=self._template_transform_max_output_length, ) if node_type == NodeType.HTTP_REQUEST: return HttpRequestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, + **common_kwargs, http_client=self._http_request_http_client, tool_file_manager_factory=self._http_request_tool_file_manager_factory, file_manager=self._http_request_file_manager, @@ -152,9 +159,11 @@ class DifyNodeFactory(NodeFactory): rag_retrieval=self._rag_retrieval, ) - return node_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) + if node_type == NodeType.DOCUMENT_EXTRACTOR: + document_extractor_class = cast(type[DocumentExtractorNode], node_class) + return document_extractor_class( + **common_kwargs, + unstructured_api_config=self._document_extractor_unstructured_api_config, + ) + + return node_class(**common_kwargs) diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py index 3cc5fae187..9922e3949d 100644 --- a/api/core/workflow/nodes/document_extractor/__init__.py +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -1,4 +1,4 @@ -from .entities import DocumentExtractorNodeData +from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .node import DocumentExtractorNode -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py index 7e9ffaa889..db05bbf4fe 100644 --- a/api/core/workflow/nodes/document_extractor/entities.py +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -1,7 +1,14 @@ from collections.abc import Sequence +from dataclasses import dataclass from core.workflow.nodes.base import BaseNodeData class DocumentExtractorNodeData(BaseNodeData): variable_selector: Sequence[str] + + +@dataclass(frozen=True) +class UnstructuredApiConfig: + api_url: str | None = None + api_key: str = "" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 14ebd1f9ae..957d3ded04 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import charset_normalizer import docx @@ -20,7 +20,6 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment @@ -29,11 +28,15 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node -from .entities import DocumentExtractorNodeData +from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class DocumentExtractorNode(Node[DocumentExtractorNodeData]): """ @@ -47,6 +50,23 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): def version(cls) -> str: return "1" + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + unstructured_api_config: UnstructuredApiConfig | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() + def _run(self): variable_selector = self.node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) @@ -64,7 +84,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): try: if isinstance(value, list): - extracted_text_list = list(map(_extract_text_from_file, value)) + extracted_text_list = [ + _extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config) + for file in value + ] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -72,7 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): - extracted_text = _extract_text_from_file(value) + extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -103,7 +126,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): return {node_id + ".files": typed_node_data.variable_selector} -def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: +def _extract_text_by_mime_type( + *, + file_content: bytes, + mime_type: str, + unstructured_api_config: UnstructuredApiConfig, +) -> str: """Extract text from a file based on its MIME type.""" match mime_type: case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": @@ -111,7 +139,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: case "application/pdf": return _extract_text_from_pdf(file_content) case "application/msword": - return _extract_text_from_doc(file_content) + return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": return _extract_text_from_docx(file_content) case "text/csv": @@ -119,11 +147,11 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": return _extract_text_from_excel(file_content) case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content) + return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content) + return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) case "application/epub+zip": - return _extract_text_from_epub(file_content) + return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) case "message/rfc822": return _extract_text_from_eml(file_content) case "application/vnd.ms-outlook": @@ -140,7 +168,12 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") -def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: +def _extract_text_by_file_extension( + *, + file_content: bytes, + file_extension: str, + unstructured_api_config: UnstructuredApiConfig, +) -> str: """Extract text from a file based on its file extension.""" match file_extension: case ( @@ -203,7 +236,7 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) case ".pdf": return _extract_text_from_pdf(file_content) case ".doc": - return _extract_text_from_doc(file_content) + return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) case ".docx": return _extract_text_from_docx(file_content) case ".csv": @@ -211,11 +244,11 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) case ".xls" | ".xlsx": return _extract_text_from_excel(file_content) case ".ppt": - return _extract_text_from_ppt(file_content) + return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) case ".pptx": - return _extract_text_from_pptx(file_content) + return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) case ".epub": - return _extract_text_from_epub(file_content) + return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) case ".eml": return _extract_text_from_eml(file_content) case ".msg": @@ -312,14 +345,14 @@ def _extract_text_from_pdf(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e -def _extract_text_from_doc(file_content: bytes) -> str: +def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: """ Extract text from a DOC file. """ from unstructured.partition.api import partition_via_api - if not dify_config.UNSTRUCTURED_API_URL: - raise TextExtractionError("UNSTRUCTURED_API_URL must be set") + if not unstructured_api_config.api_url: + raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") try: with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: @@ -329,8 +362,8 @@ def _extract_text_from_doc(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=unstructured_api_config.api_key, ) os.unlink(temp_file.name) return "\n".join([getattr(element, "text", "") for element in elements]) @@ -420,12 +453,20 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError(f"Error downloading file: {str(e)}") from e -def _extract_text_from_file(file: File): +def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str: file_content = _download_file_content(file) if file.extension: - extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) + extracted_text = _extract_text_by_file_extension( + file_content=file_content, + file_extension=file.extension, + unstructured_api_config=unstructured_api_config, + ) elif file.mime_type: - extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + extracted_text = _extract_text_by_mime_type( + file_content=file_content, + mime_type=file.mime_type, + unstructured_api_config=unstructured_api_config, + ) else: raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") return extracted_text @@ -517,12 +558,12 @@ def _extract_text_from_excel(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e -def _extract_text_from_ppt(file_content: bytes) -> str: +def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.ppt import partition_ppt try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -530,8 +571,8 @@ def _extract_text_from_ppt(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=unstructured_api_config.api_key, ) os.unlink(temp_file.name) else: @@ -543,12 +584,12 @@ def _extract_text_from_ppt(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e -def _extract_text_from_pptx(file_content: bytes) -> str: +def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.pptx import partition_pptx try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -556,8 +597,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=unstructured_api_config.api_key, ) os.unlink(temp_file.name) else: @@ -568,12 +609,12 @@ def _extract_text_from_pptx(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e -def _extract_text_from_epub(file_content: bytes) -> str: +def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.epub import partition_epub try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -581,8 +622,8 @@ def _extract_text_from_epub(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=unstructured_api_config.api_key, ) os.unlink(temp_file.name) else: