mirror of
https://github.com/langgenius/dify.git
synced 2026-02-24 18:05:11 +00:00
Compare commits
118 Commits
1f864fe8e7
...
review-mys
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93d40d50d3 | ||
|
|
1c756e0073 | ||
|
|
e1681b1a16 | ||
|
|
a63e0dc2d9 | ||
|
|
a811983daa | ||
|
|
23c3319594 | ||
|
|
186ba8a0b4 | ||
|
|
94bda9fda1 | ||
|
|
68a2168db6 | ||
|
|
da6bf01c08 | ||
|
|
4c702ce923 | ||
|
|
dbe37dbd71 | ||
|
|
55dd4a0f89 | ||
|
|
96b280cf9b | ||
|
|
ba75c37c16 | ||
|
|
8cecca58ff | ||
|
|
b2e4072664 | ||
|
|
313ef28a8d | ||
|
|
3cb7194433 | ||
|
|
abbff8d05f | ||
|
|
9d8cff1571 | ||
|
|
7f401d3c69 | ||
|
|
255abccefd | ||
|
|
875f3de415 | ||
|
|
1c4102c1af | ||
|
|
331ce867d9 | ||
|
|
e4c9196465 | ||
|
|
4044d8a8db | ||
|
|
979592e183 | ||
|
|
7ca6219559 | ||
|
|
a7211c6338 | ||
|
|
45032116e0 | ||
|
|
9c8fa5a295 | ||
|
|
489b1fb87d | ||
|
|
88bdef9c04 | ||
|
|
7f959c09c0 | ||
|
|
1fd1d6f503 | ||
|
|
1e1e446ff6 | ||
|
|
d8b5243e6a | ||
|
|
2ab9fa90fd | ||
|
|
b82f82bfcb | ||
|
|
8e05cd4c2e | ||
|
|
e557acc1d3 | ||
|
|
4da6175307 | ||
|
|
db24e36d9f | ||
|
|
f823b1df3b | ||
|
|
35f0d9e857 | ||
|
|
4f94fa81c3 | ||
|
|
5627ca685a | ||
|
|
b753f37a89 | ||
|
|
85018d557c | ||
|
|
16d0096491 | ||
|
|
4f7ee0d66e | ||
|
|
c810b0f472 | ||
|
|
077ee51753 | ||
|
|
83c1da0d09 | ||
|
|
a9927f24ca | ||
|
|
50b3b0111f | ||
|
|
062b101e84 | ||
|
|
3afe9f15d7 | ||
|
|
a9cf7490c4 | ||
|
|
9a41d147be | ||
|
|
cf0760234c | ||
|
|
70ecf8d1bd | ||
|
|
a70b31cd07 | ||
|
|
ad8d253720 | ||
|
|
3459c07974 | ||
|
|
5fc4030477 | ||
|
|
d37657440a | ||
|
|
2616debd1e | ||
|
|
fdfe90c8c0 | ||
|
|
c8a3b92e47 | ||
|
|
ecbe760d2f | ||
|
|
c786aee9e9 | ||
|
|
5cc4ecfbbc | ||
|
|
8f4f8da714 | ||
|
|
819214ba76 | ||
|
|
9d9ab89f80 | ||
|
|
e8cf6d6e1f | ||
|
|
f201f57cd2 | ||
|
|
823cbc304b | ||
|
|
77eb424dd7 | ||
|
|
522aced46b | ||
|
|
46c22330e8 | ||
|
|
a9d7f54b1d | ||
|
|
e7a7506099 | ||
|
|
1d8531161d | ||
|
|
23b42a22a9 | ||
|
|
1c8feece1b | ||
|
|
e2d784f726 | ||
|
|
65681d8351 | ||
|
|
f207c20514 | ||
|
|
ae40313fab | ||
|
|
6e83348be9 | ||
|
|
a1224cd023 | ||
|
|
25552dbd38 | ||
|
|
fc559e5449 | ||
|
|
3af042066f | ||
|
|
12ce014205 | ||
|
|
659c3bd0c4 | ||
|
|
a9b76d957f | ||
|
|
c95adf2534 | ||
|
|
8741717a97 | ||
|
|
4ef51f545e | ||
|
|
78fee247bd | ||
|
|
f07110ff66 | ||
|
|
7feca6a5a5 | ||
|
|
50545d82e4 | ||
|
|
37d80fafb8 | ||
|
|
81c64e3d05 | ||
|
|
4f9891b0fa | ||
|
|
f59029d8ed | ||
|
|
de76c38f7a | ||
|
|
a7d035fa4d | ||
|
|
2781cd8d79 | ||
|
|
49ab0605cd | ||
|
|
fd80ed1c99 | ||
|
|
beb3ce172d |
@@ -114,6 +114,7 @@ ignore_imports =
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> configs
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.entities -> configs
|
||||
core.workflow.nodes.http_request.executor -> configs
|
||||
|
||||
@@ -16,7 +16,6 @@ from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
@@ -45,6 +44,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
@@ -53,7 +53,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
document_extractor_unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
@@ -79,13 +78,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
self._document_extractor_unstructured_api_config = (
|
||||
document_extractor_unstructured_api_config
|
||||
or UnstructuredApiConfig(
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -160,15 +152,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -33,6 +33,18 @@ class SortOrder(StrEnum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
_METADATA_KEY_WHITELIST = {
|
||||
"annotation_id",
|
||||
"app_id",
|
||||
"batch",
|
||||
"dataset_id",
|
||||
"doc_hash",
|
||||
"doc_id",
|
||||
"document_id",
|
||||
"lang",
|
||||
"source",
|
||||
}
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@@ -45,10 +57,17 @@ class MyScaleVector(BaseVector):
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
self._qualified_table = f"{self._config.database}.{self._collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
@classmethod
|
||||
def _validate_metadata_key(cls, key: str) -> str:
|
||||
if key not in cls._METADATA_KEY_WHITELIST:
|
||||
raise ValueError(f"Unsupported metadata key: {key!r}")
|
||||
return key
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@@ -59,7 +78,7 @@ class MyScaleVector(BaseVector):
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
CREATE TABLE IF NOT EXISTS {self._qualified_table}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
@@ -74,73 +93,98 @@ class MyScaleVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
rows = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
rows.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata or {}),
|
||||
)
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
if rows:
|
||||
self._client.insert(self._qualified_table, rows, column_names=columns)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def escape_str(value: Any) -> str:
|
||||
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
results = self._client.query(
|
||||
f"SELECT id FROM {self._qualified_table} WHERE id = %(id)s LIMIT 1",
|
||||
parameters={"id": id},
|
||||
)
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
placeholders, params = self._build_in_params("id", ids)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
f"DELETE FROM {self._qualified_table} WHERE id IN ({placeholders})",
|
||||
parameters=params,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
f"SELECT DISTINCT id FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
f"DELETE FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
return self._search(
|
||||
"TextSearch('enable_nlq=false')(text, %(query)s)",
|
||||
SortOrder.DESC,
|
||||
parameters={"query": query},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
@staticmethod
|
||||
def _build_in_params(prefix: str, values: list[str]) -> tuple[str, dict[str, str]]:
|
||||
params: dict[str, str] = {}
|
||||
placeholders = []
|
||||
for i, value in enumerate(values):
|
||||
name = f"{prefix}_{i}"
|
||||
placeholders.append(f"%({name})s")
|
||||
params[name] = value
|
||||
return ", ".join(placeholders), params
|
||||
|
||||
def _search(
|
||||
self,
|
||||
dist: str,
|
||||
order: SortOrder,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
where_clauses = []
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0:
|
||||
where_clauses.append(f"dist < {1 - score_threshold}")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
query_params = dict(parameters or {})
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
placeholders, params = self._build_in_params("document_id", document_ids_filter)
|
||||
where_clauses.append(f"metadata['document_id'] IN ({placeholders})")
|
||||
query_params.update(params)
|
||||
where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._qualified_table}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
@@ -150,14 +194,14 @@ class MyScaleVector(BaseVector):
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
for r in self._client.query(sql, parameters=query_params).named_results()
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._qualified_table}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import charset_normalizer
|
||||
import docx
|
||||
@@ -20,6 +20,7 @@ from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayStringSegment, FileSegment
|
||||
@@ -28,15 +29,11 @@ from core.workflow.file import File, FileTransferMethod, file_manager
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .entities import DocumentExtractorNodeData, UnstructuredApiConfig
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
"""
|
||||
@@ -50,23 +47,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
unstructured_api_config: UnstructuredApiConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig()
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
@@ -84,10 +64,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = [
|
||||
_extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config)
|
||||
for file in value
|
||||
]
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -95,7 +72,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
outputs={"text": ArrayStringSegment(value=extracted_text_list)},
|
||||
)
|
||||
elif isinstance(value, File):
|
||||
extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config)
|
||||
extracted_text = _extract_text_from_file(value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -126,12 +103,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(
|
||||
*,
|
||||
file_content: bytes,
|
||||
mime_type: str,
|
||||
unstructured_api_config: UnstructuredApiConfig,
|
||||
) -> str:
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
"""Extract text from a file based on its MIME type."""
|
||||
match mime_type:
|
||||
case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml":
|
||||
@@ -139,7 +111,7 @@ def _extract_text_by_mime_type(
|
||||
case "application/pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case "application/msword":
|
||||
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_doc(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case "text/csv":
|
||||
@@ -147,11 +119,11 @@ def _extract_text_by_mime_type(
|
||||
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case "application/vnd.ms-powerpoint":
|
||||
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_ppt(file_content)
|
||||
case "application/vnd.openxmlformats-officedocument.presentationml.presentation":
|
||||
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_pptx(file_content)
|
||||
case "application/epub+zip":
|
||||
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_epub(file_content)
|
||||
case "message/rfc822":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case "application/vnd.ms-outlook":
|
||||
@@ -168,12 +140,7 @@ def _extract_text_by_mime_type(
|
||||
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
|
||||
|
||||
|
||||
def _extract_text_by_file_extension(
|
||||
*,
|
||||
file_content: bytes,
|
||||
file_extension: str,
|
||||
unstructured_api_config: UnstructuredApiConfig,
|
||||
) -> str:
|
||||
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
|
||||
"""Extract text from a file based on its file extension."""
|
||||
match file_extension:
|
||||
case (
|
||||
@@ -236,7 +203,7 @@ def _extract_text_by_file_extension(
|
||||
case ".pdf":
|
||||
return _extract_text_from_pdf(file_content)
|
||||
case ".doc":
|
||||
return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_doc(file_content)
|
||||
case ".docx":
|
||||
return _extract_text_from_docx(file_content)
|
||||
case ".csv":
|
||||
@@ -244,11 +211,11 @@ def _extract_text_by_file_extension(
|
||||
case ".xls" | ".xlsx":
|
||||
return _extract_text_from_excel(file_content)
|
||||
case ".ppt":
|
||||
return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_ppt(file_content)
|
||||
case ".pptx":
|
||||
return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_pptx(file_content)
|
||||
case ".epub":
|
||||
return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config)
|
||||
return _extract_text_from_epub(file_content)
|
||||
case ".eml":
|
||||
return _extract_text_from_eml(file_content)
|
||||
case ".msg":
|
||||
@@ -345,15 +312,14 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC file.
|
||||
"""
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if not unstructured_api_config.api_url:
|
||||
raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.")
|
||||
api_key = unstructured_api_config.api_key or ""
|
||||
if not dify_config.UNSTRUCTURED_API_URL:
|
||||
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
|
||||
@@ -363,8 +329,8 @@ def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: Unst
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=api_key,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
@@ -454,20 +420,12 @@ def _download_file_content(file: File) -> bytes:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
def _extract_text_from_file(file: File):
|
||||
file_content = _download_file_content(file)
|
||||
if file.extension:
|
||||
extracted_text = _extract_text_by_file_extension(
|
||||
file_content=file_content,
|
||||
file_extension=file.extension,
|
||||
unstructured_api_config=unstructured_api_config,
|
||||
)
|
||||
extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension)
|
||||
elif file.mime_type:
|
||||
extracted_text = _extract_text_by_mime_type(
|
||||
file_content=file_content,
|
||||
mime_type=file.mime_type,
|
||||
unstructured_api_config=unstructured_api_config,
|
||||
)
|
||||
extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type)
|
||||
else:
|
||||
raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing")
|
||||
return extracted_text
|
||||
@@ -559,14 +517,12 @@ def _extract_text_from_excel(file_content: bytes) -> str:
|
||||
raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
def _extract_text_from_ppt(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
|
||||
api_key = unstructured_api_config.api_key or ""
|
||||
|
||||
try:
|
||||
if unstructured_api_config.api_url:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -574,8 +530,8 @@ def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: Unst
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=api_key,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
@@ -587,14 +543,12 @@ def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: Unst
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
api_key = unstructured_api_config.api_key or ""
|
||||
|
||||
try:
|
||||
if unstructured_api_config.api_url:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -602,8 +556,8 @@ def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: Uns
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=api_key,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
@@ -614,14 +568,12 @@ def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: Uns
|
||||
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e
|
||||
|
||||
|
||||
def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str:
|
||||
def _extract_text_from_epub(file_content: bytes) -> str:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
api_key = unstructured_api_config.api_key or ""
|
||||
|
||||
try:
|
||||
if unstructured_api_config.api_url:
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file:
|
||||
temp_file.write(file_content)
|
||||
temp_file.flush()
|
||||
@@ -629,8 +581,8 @@ def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: Uns
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
metadata_filename=temp_file.name,
|
||||
api_url=unstructured_api_config.api_url,
|
||||
api_key=api_key,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
|
||||
)
|
||||
os.unlink(temp_file.name)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user