refactor(document_extractor): Extract configs (#31828)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-02-16 23:39:50 +08:00
committed by GitHub
parent 7656d514b9
commit 41a4a57d2e
5 changed files with 110 additions and 39 deletions

View File

@@ -114,7 +114,6 @@ ignore_imports =
core.workflow.nodes.datasource.datasource_node -> models.model core.workflow.nodes.datasource.datasource_node -> models.model
core.workflow.nodes.datasource.datasource_node -> models.tools core.workflow.nodes.datasource.datasource_node -> models.tools
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
core.workflow.nodes.document_extractor.node -> configs
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.entities -> configs core.workflow.nodes.http_request.entities -> configs
core.workflow.nodes.http_request.executor -> configs core.workflow.nodes.http_request.executor -> configs

View File

@@ -16,6 +16,7 @@ from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request.node import HttpRequestNode from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@@ -44,7 +45,6 @@ class DifyNodeFactory(NodeFactory):
self, self,
graph_init_params: "GraphInitParams", graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState", graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None, code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None, code_limits: CodeNodeLimits | None = None,
@@ -53,6 +53,7 @@ class DifyNodeFactory(NodeFactory):
http_request_http_client: HttpClientProtocol | None = None, http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None, http_request_file_manager: FileManagerProtocol | None = None,
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
) -> None: ) -> None:
self.graph_init_params = graph_init_params self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
@@ -78,6 +79,13 @@ class DifyNodeFactory(NodeFactory):
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager self._http_request_file_manager = http_request_file_manager or file_manager
self._rag_retrieval = DatasetRetrieval() self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = (
document_extractor_unstructured_api_config
or UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
)
)
@override @override
def create_node(self, node_config: NodeConfigDict) -> Node: def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -152,6 +160,15 @@ class DifyNodeFactory(NodeFactory):
rag_retrieval=self._rag_retrieval, rag_retrieval=self._rag_retrieval,
) )
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
return node_class( return node_class(
id=node_id, id=node_id,
config=node_config, config=node_config,

View File

@@ -1,4 +1,4 @@
from .entities import DocumentExtractorNodeData from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .node import DocumentExtractorNode from .node import DocumentExtractorNode
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] __all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"]

View File

@@ -1,7 +1,14 @@
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData): class DocumentExtractorNodeData(BaseNodeData):
variable_selector: Sequence[str] variable_selector: Sequence[str]
@dataclass(frozen=True)
class UnstructuredApiConfig:
api_url: str | None = None
api_key: str = ""

View File

@@ -5,7 +5,7 @@ import logging
import os import os
import tempfile import tempfile
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any from typing import TYPE_CHECKING, Any
import charset_normalizer import charset_normalizer
import docx import docx
@@ -20,7 +20,6 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table from docx.table import Table
from docx.text.paragraph import Paragraph from docx.text.paragraph import Paragraph
from configs import dify_config
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment from core.variables import ArrayFileSegment
from core.variables.segments import ArrayStringSegment, FileSegment from core.variables.segments import ArrayStringSegment, FileSegment
@@ -29,11 +28,15 @@ from core.workflow.file import File, FileTransferMethod, file_manager
from core.workflow.node_events import NodeRunResult from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class DocumentExtractorNode(Node[DocumentExtractorNodeData]): class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
""" """
@@ -47,6 +50,23 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
unstructured_api_config: UnstructuredApiConfig | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)
@@ -64,7 +84,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
try: try:
if isinstance(value, list): if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value)) extracted_text_list = [
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
for file in value
]
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
@@ -72,7 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
outputs={"text": ArrayStringSegment(value=extracted_text_list)}, outputs={"text": ArrayStringSegment(value=extracted_text_list)},
) )
elif isinstance(value, File): elif isinstance(value, File):
extracted_text = _extract_text_from_file(value) extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
@@ -103,7 +126,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
return {node_id + ".files": typed_node_data.variable_selector} return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_mime_type(
*,
file_content: bytes,
mime_type: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its MIME type.""" """Extract text from a file based on its MIME type."""
match mime_type: match mime_type:
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
@@ -111,7 +139,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
case "application/pdf": case "application/pdf":
return _extract_text_from_pdf(file_content) return _extract_text_from_pdf(file_content)
case "application/msword": case "application/msword":
return _extract_text_from_doc(file_content) return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return _extract_text_from_docx(file_content) return _extract_text_from_docx(file_content)
case "text/csv": case "text/csv":
@@ -119,11 +147,11 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
return _extract_text_from_excel(file_content) return _extract_text_from_excel(file_content)
case "application/vnd.ms-powerpoint": case "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content) return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.presentationml.presentation": case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content) return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case "application/epub+zip": case "application/epub+zip":
return _extract_text_from_epub(file_content) return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case "message/rfc822": case "message/rfc822":
return _extract_text_from_eml(file_content) return _extract_text_from_eml(file_content)
case "application/vnd.ms-outlook": case "application/vnd.ms-outlook":
@@ -140,7 +168,12 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: def _extract_text_by_file_extension(
*,
file_content: bytes,
file_extension: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its file extension.""" """Extract text from a file based on its file extension."""
match file_extension: match file_extension:
case ( case (
@@ -203,7 +236,7 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
case ".pdf": case ".pdf":
return _extract_text_from_pdf(file_content) return _extract_text_from_pdf(file_content)
case ".doc": case ".doc":
return _extract_text_from_doc(file_content) return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case ".docx": case ".docx":
return _extract_text_from_docx(file_content) return _extract_text_from_docx(file_content)
case ".csv": case ".csv":
@@ -211,11 +244,11 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
case ".xls" | ".xlsx": case ".xls" | ".xlsx":
return _extract_text_from_excel(file_content) return _extract_text_from_excel(file_content)
case ".ppt": case ".ppt":
return _extract_text_from_ppt(file_content) return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case ".pptx": case ".pptx":
return _extract_text_from_pptx(file_content) return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case ".epub": case ".epub":
return _extract_text_from_epub(file_content) return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case ".eml": case ".eml":
return _extract_text_from_eml(file_content) return _extract_text_from_eml(file_content)
case ".msg": case ".msg":
@@ -312,14 +345,15 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
""" """
Extract text from a DOC file. Extract text from a DOC file.
""" """
from unstructured.partition.api import partition_via_api from unstructured.partition.api import partition_via_api
if not dify_config.UNSTRUCTURED_API_URL: if not unstructured_api_config.api_url:
raise TextExtractionError("UNSTRUCTURED_API_URL must be set") raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.")
api_key = unstructured_api_config.api_key or ""
try: try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
@@ -329,8 +363,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
elements = partition_via_api( elements = partition_via_api(
file=file, file=file,
metadata_filename=temp_file.name, metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL, api_url=unstructured_api_config.api_url,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore api_key=api_key,
) )
os.unlink(temp_file.name) os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements]) return "\n".join([getattr(element, "text", "") for element in elements])
@@ -420,12 +454,20 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(file: File): def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
file_content = _download_file_content(file) file_content = _download_file_content(file)
if file.extension: if file.extension:
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) extracted_text = _extract_text_by_file_extension(
file_content=file_content,
file_extension=file.extension,
unstructured_api_config=unstructured_api_config,
)
elif file.mime_type: elif file.mime_type:
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) extracted_text = _extract_text_by_mime_type(
file_content=file_content,
mime_type=file.mime_type,
unstructured_api_config=unstructured_api_config,
)
else: else:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
return extracted_text return extracted_text
@@ -517,12 +559,14 @@ def _extract_text_from_excel(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api from unstructured.partition.api import partition_via_api
from unstructured.partition.ppt import partition_ppt from unstructured.partition.ppt import partition_ppt
api_key = unstructured_api_config.api_key or ""
try: try:
if dify_config.UNSTRUCTURED_API_URL: if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
temp_file.write(file_content) temp_file.write(file_content)
temp_file.flush() temp_file.flush()
@@ -530,8 +574,8 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
elements = partition_via_api( elements = partition_via_api(
file=file, file=file,
metadata_filename=temp_file.name, metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL, api_url=unstructured_api_config.api_url,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore api_key=api_key,
) )
os.unlink(temp_file.name) os.unlink(temp_file.name)
else: else:
@@ -543,12 +587,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx from unstructured.partition.pptx import partition_pptx
api_key = unstructured_api_config.api_key or ""
try: try:
if dify_config.UNSTRUCTURED_API_URL: if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content) temp_file.write(file_content)
temp_file.flush() temp_file.flush()
@@ -556,8 +602,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
elements = partition_via_api( elements = partition_via_api(
file=file, file=file,
metadata_filename=temp_file.name, metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL, api_url=unstructured_api_config.api_url,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore api_key=api_key,
) )
os.unlink(temp_file.name) os.unlink(temp_file.name)
else: else:
@@ -568,12 +614,14 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes) -> str: def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api from unstructured.partition.api import partition_via_api
from unstructured.partition.epub import partition_epub from unstructured.partition.epub import partition_epub
api_key = unstructured_api_config.api_key or ""
try: try:
if dify_config.UNSTRUCTURED_API_URL: if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
temp_file.write(file_content) temp_file.write(file_content)
temp_file.flush() temp_file.flush()
@@ -581,8 +629,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
elements = partition_via_api( elements = partition_via_api(
file=file, file=file,
metadata_filename=temp_file.name, metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL, api_url=unstructured_api_config.api_url,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore api_key=api_key,
) )
os.unlink(temp_file.name) os.unlink(temp_file.name)
else: else: