diff --git a/api/.importlinter b/api/.importlinter index 5fe76ce4c8..b9d688c1fa 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -114,7 +114,6 @@ ignore_imports = core.workflow.nodes.datasource.datasource_node -> models.model core.workflow.nodes.datasource.datasource_node -> models.tools core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service - core.workflow.nodes.document_extractor.node -> configs core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.http_request.entities -> configs core.workflow.nodes.http_request.executor -> configs diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index bd58bcb6b0..efb2a74176 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -16,6 +16,7 @@ from core.workflow.graph.graph import NodeFactory from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits +from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -44,7 +45,6 @@ class DifyNodeFactory(NodeFactory): self, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - *, code_executor: type[CodeExecutor] | None = None, code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits | None = None, @@ -53,6 +53,7 @@ class DifyNodeFactory(NodeFactory): http_request_http_client: HttpClientProtocol | None = None, http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, http_request_file_manager: FileManagerProtocol | None = None, + document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state @@ -78,6 +79,13 @@ class DifyNodeFactory(NodeFactory): self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory self._http_request_file_manager = http_request_file_manager or file_manager self._rag_retrieval = DatasetRetrieval() + self._document_extractor_unstructured_api_config = ( + document_extractor_unstructured_api_config + or UnstructuredApiConfig( + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY or "", + ) + ) @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -152,6 +160,15 @@ class DifyNodeFactory(NodeFactory): rag_retrieval=self._rag_retrieval, ) + if node_type == NodeType.DOCUMENT_EXTRACTOR: + return DocumentExtractorNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + unstructured_api_config=self._document_extractor_unstructured_api_config, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py index 3cc5fae187..9922e3949d 100644 --- a/api/core/workflow/nodes/document_extractor/__init__.py +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -1,4 +1,4 @@ -from .entities import DocumentExtractorNodeData +from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .node import DocumentExtractorNode -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py index 7e9ffaa889..db05bbf4fe 100644 --- a/api/core/workflow/nodes/document_extractor/entities.py +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -1,7 +1,14 @@ from collections.abc import Sequence +from dataclasses import dataclass from core.workflow.nodes.base import BaseNodeData class DocumentExtractorNodeData(BaseNodeData): variable_selector: Sequence[str] + + +@dataclass(frozen=True) +class UnstructuredApiConfig: + api_url: str | None = None + api_key: str = "" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 0a14b81633..c442e01854 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import charset_normalizer import docx @@ -20,7 +20,6 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from configs import dify_config from core.helper import ssrf_proxy from core.variables import ArrayFileSegment from core.variables.segments import ArrayStringSegment, FileSegment @@ -29,11 +28,15 @@ from core.workflow.file import File, FileTransferMethod, file_manager from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base.node import Node -from .entities import DocumentExtractorNodeData +from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState + class DocumentExtractorNode(Node[DocumentExtractorNodeData]): """ @@ -47,6 +50,23 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): def version(cls) -> str: return "1" + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + unstructured_api_config: UnstructuredApiConfig | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() + def _run(self): variable_selector = self.node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) @@ -64,7 +84,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): try: if isinstance(value, list): - extracted_text_list = list(map(_extract_text_from_file, value)) + extracted_text_list = [ + _extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config) + for file in value + ] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -72,7 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): - extracted_text = _extract_text_from_file(value) + extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -103,7 +126,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): return {node_id + ".files": typed_node_data.variable_selector} -def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: +def _extract_text_by_mime_type( + *, + file_content: bytes, + mime_type: str, + unstructured_api_config: UnstructuredApiConfig, +) -> str: """Extract text from a file based on its MIME type.""" match mime_type: case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": @@ -111,7 +139,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: case "application/pdf": return _extract_text_from_pdf(file_content) case "application/msword": - return _extract_text_from_doc(file_content) + return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": return _extract_text_from_docx(file_content) case "text/csv": @@ -119,11 +147,11 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": return _extract_text_from_excel(file_content) case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content) + return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content) + return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) case "application/epub+zip": - return _extract_text_from_epub(file_content) + return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) case "message/rfc822": return _extract_text_from_eml(file_content) case "application/vnd.ms-outlook": @@ -140,7 +168,12 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") -def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: +def _extract_text_by_file_extension( + *, + file_content: bytes, + file_extension: str, + unstructured_api_config: UnstructuredApiConfig, +) -> str: """Extract text from a file based on its file extension.""" match file_extension: case ( @@ -203,7 +236,7 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) case ".pdf": return _extract_text_from_pdf(file_content) case ".doc": - return _extract_text_from_doc(file_content) + return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) case ".docx": return _extract_text_from_docx(file_content) case ".csv": @@ -211,11 +244,11 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) case ".xls" | ".xlsx": return _extract_text_from_excel(file_content) case ".ppt": - return _extract_text_from_ppt(file_content) + return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) case ".pptx": - return _extract_text_from_pptx(file_content) + return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) case ".epub": - return _extract_text_from_epub(file_content) + return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) case ".eml": return _extract_text_from_eml(file_content) case ".msg": @@ -312,14 +345,15 @@ def _extract_text_from_pdf(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e -def _extract_text_from_doc(file_content: bytes) -> str: +def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: """ Extract text from a DOC file. """ from unstructured.partition.api import partition_via_api - if not dify_config.UNSTRUCTURED_API_URL: - raise TextExtractionError("UNSTRUCTURED_API_URL must be set") + if not unstructured_api_config.api_url: + raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") + api_key = unstructured_api_config.api_key or "" try: with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: @@ -329,8 +363,8 @@ def _extract_text_from_doc(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=api_key, ) os.unlink(temp_file.name) return "\n".join([getattr(element, "text", "") for element in elements]) @@ -420,12 +454,20 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError(f"Error downloading file: {str(e)}") from e -def _extract_text_from_file(file: File): +def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str: file_content = _download_file_content(file) if file.extension: - extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) + extracted_text = _extract_text_by_file_extension( + file_content=file_content, + file_extension=file.extension, + unstructured_api_config=unstructured_api_config, + ) elif file.mime_type: - extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + extracted_text = _extract_text_by_mime_type( + file_content=file_content, + mime_type=file.mime_type, + unstructured_api_config=unstructured_api_config, + ) else: raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") return extracted_text @@ -517,12 +559,14 @@ def _extract_text_from_excel(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e -def _extract_text_from_ppt(file_content: bytes) -> str: +def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.ppt import partition_ppt + api_key = unstructured_api_config.api_key or "" + try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -530,8 +574,8 @@ def _extract_text_from_ppt(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=api_key, ) os.unlink(temp_file.name) else: @@ -543,12 +587,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e -def _extract_text_from_pptx(file_content: bytes) -> str: +def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.pptx import partition_pptx + api_key = unstructured_api_config.api_key or "" + try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -556,8 +602,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=api_key, ) os.unlink(temp_file.name) else: @@ -568,12 +614,14 @@ def _extract_text_from_pptx(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e -def _extract_text_from_epub(file_content: bytes) -> str: +def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: from unstructured.partition.api import partition_via_api from unstructured.partition.epub import partition_epub + api_key = unstructured_api_config.api_key or "" + try: - if dify_config.UNSTRUCTURED_API_URL: + if unstructured_api_config.api_url: with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: temp_file.write(file_content) temp_file.flush() @@ -581,8 +629,8 @@ def _extract_text_from_epub(file_content: bytes) -> str: elements = partition_via_api( file=file, metadata_filename=temp_file.name, - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore + api_url=unstructured_api_config.api_url, + api_key=api_key, ) os.unlink(temp_file.name) else: