Compare commits

...

2 Commits

Author SHA1 Message Date
-LAN-
c2a11ffa97 fix: satisfy type checks in nodes 2026-02-04 21:01:37 +08:00
-LAN-
5eaf535a0d refactor(document_extractor): Extract configs
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-02 16:54:04 +08:00
5 changed files with 113 additions and 57 deletions

View File

@@ -112,7 +112,6 @@ ignore_imports =
core.workflow.nodes.datasource.datasource_node -> models.model
core.workflow.nodes.datasource.datasource_node -> models.tools
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
core.workflow.nodes.document_extractor.node -> configs
core.workflow.nodes.document_extractor.node -> core.file.file_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.entities -> configs

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing import TYPE_CHECKING, Any, cast, final
from typing_extensions import override
@@ -15,6 +15,7 @@ from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
@@ -50,6 +51,7 @@ class DifyNodeFactory(NodeFactory):
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@@ -71,6 +73,13 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._document_extractor_unstructured_api_config = (
document_extractor_unstructured_api_config
or UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
)
)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -103,13 +112,17 @@ class DifyNodeFactory(NodeFactory):
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
common_kwargs: dict[str, Any] = {
"id": node_id,
"config": node_config,
"graph_init_params": self.graph_init_params,
"graph_runtime_state": self.graph_runtime_state,
}
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**common_kwargs,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
@@ -117,27 +130,23 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**common_kwargs,
template_renderer=self._template_renderer,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
**common_kwargs,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
return node_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
if node_type == NodeType.DOCUMENT_EXTRACTOR:
document_extractor_class = cast(type[DocumentExtractorNode], node_class)
return document_extractor_class(
**common_kwargs,
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
return node_class(**common_kwargs)

View File

@@ -1,4 +1,4 @@
from .entities import DocumentExtractorNodeData
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .node import DocumentExtractorNode
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"]
__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"]

View File

@@ -1,7 +1,14 @@
from collections.abc import Sequence
from dataclasses import dataclass
from core.workflow.nodes.base import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData):
variable_selector: Sequence[str]
@dataclass(frozen=True)
class UnstructuredApiConfig:
api_url: str | None = None
api_key: str = ""

View File

@@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
import charset_normalizer
import docx
@@ -20,7 +20,6 @@ from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
@@ -29,11 +28,15 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from .entities import DocumentExtractorNodeData
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
"""
@@ -47,6 +50,23 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def version(cls) -> str:
return "1"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
unstructured_api_config: UnstructuredApiConfig | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
def _run(self):
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
@@ -64,7 +84,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
extracted_text_list = [
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
for file in value
]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -72,7 +95,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -103,7 +126,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_mime_type(
*,
file_content: bytes,
mime_type: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its MIME type."""
match mime_type:
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
@@ -111,7 +139,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
case "application/pdf":
return _extract_text_from_pdf(file_content)
case "application/msword":
return _extract_text_from_doc(file_content)
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return _extract_text_from_docx(file_content)
case "text/csv":
@@ -119,11 +147,11 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
return _extract_text_from_excel(file_content)
case "application/vnd.ms-powerpoint":
return _extract_text_from_ppt(file_content)
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
return _extract_text_from_pptx(file_content)
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case "application/epub+zip":
return _extract_text_from_epub(file_content)
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case "message/rfc822":
return _extract_text_from_eml(file_content)
case "application/vnd.ms-outlook":
@@ -140,7 +168,12 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
def _extract_text_by_file_extension(
*,
file_content: bytes,
file_extension: str,
unstructured_api_config: UnstructuredApiConfig,
) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
case (
@@ -203,7 +236,7 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
case ".pdf":
return _extract_text_from_pdf(file_content)
case ".doc":
return _extract_text_from_doc(file_content)
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
case ".docx":
return _extract_text_from_docx(file_content)
case ".csv":
@@ -211,11 +244,11 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
case ".xls" | ".xlsx":
return _extract_text_from_excel(file_content)
case ".ppt":
return _extract_text_from_ppt(file_content)
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
case ".pptx":
return _extract_text_from_pptx(file_content)
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
case ".epub":
return _extract_text_from_epub(file_content)
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
case ".eml":
return _extract_text_from_eml(file_content)
case ".msg":
@@ -312,14 +345,14 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
def _extract_text_from_doc(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
"""
Extract text from a DOC file.
"""
from unstructured.partition.api import partition_via_api
if not dify_config.UNSTRUCTURED_API_URL:
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
if not unstructured_api_config.api_url:
raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.")
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
@@ -329,8 +362,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
api_url=unstructured_api_config.api_url,
api_key=unstructured_api_config.api_key,
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
@@ -420,12 +453,20 @@ def _download_file_content(file: File) -> bytes:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
def _extract_text_from_file(file: File):
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
file_content = _download_file_content(file)
if file.extension:
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
extracted_text = _extract_text_by_file_extension(
file_content=file_content,
file_extension=file.extension,
unstructured_api_config=unstructured_api_config,
)
elif file.mime_type:
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
extracted_text = _extract_text_by_mime_type(
file_content=file_content,
mime_type=file.mime_type,
unstructured_api_config=unstructured_api_config,
)
else:
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
return extracted_text
@@ -517,12 +558,12 @@ def _extract_text_from_excel(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.ppt import partition_ppt
try:
if dify_config.UNSTRUCTURED_API_URL:
if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
@@ -530,8 +571,8 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
api_url=unstructured_api_config.api_url,
api_key=unstructured_api_config.api_key,
)
os.unlink(temp_file.name)
else:
@@ -543,12 +584,12 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL:
if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
@@ -556,8 +597,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
api_url=unstructured_api_config.api_url,
api_key=unstructured_api_config.api_key,
)
os.unlink(temp_file.name)
else:
@@ -568,12 +609,12 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.epub import partition_epub
try:
if dify_config.UNSTRUCTURED_API_URL:
if unstructured_api_config.api_url:
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
@@ -581,8 +622,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
api_url=unstructured_api_config.api_url,
api_key=unstructured_api_config.api_key,
)
os.unlink(temp_file.name)
else: