mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 14:27:05 +00:00
Compare commits
25 Commits
optional-p
...
yanli/pyre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e5e5fe14ff | ||
|
|
aa30eeaf27 | ||
|
|
0745f573e6 | ||
|
|
94b05b2ca1 | ||
|
|
c3b17fc833 | ||
|
|
8f99dc1ac1 | ||
|
|
2028e2c3b8 | ||
|
|
7ff470d8a0 | ||
|
|
a4dbb76d3a | ||
|
|
134ae75a9b | ||
|
|
2bbb45e97f | ||
|
|
a0017183b6 | ||
|
|
db7d5e30cb | ||
|
|
295587718d | ||
|
|
b85d010e42 | ||
|
|
c71e407c39 | ||
|
|
1e5e65326e | ||
|
|
fe18405f1d | ||
|
|
7ccc736929 | ||
|
|
bcac77c212 | ||
|
|
2b53f1bfea | ||
|
|
eec9c76b7b | ||
|
|
98a94019c4 | ||
|
|
e8f120d87b | ||
|
|
7572db15ff |
@@ -88,7 +88,7 @@ class LindormVectorStore(BaseVector):
|
||||
batch_size: int = 64,
|
||||
timeout: int = 60,
|
||||
**kwargs,
|
||||
):
|
||||
) -> list[str]:
|
||||
logger.info("Total documents to add: %s", len(documents))
|
||||
uuids = self._get_uuids(documents)
|
||||
|
||||
@@ -130,8 +130,11 @@ class LindormVectorStore(BaseVector):
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
action_values[ROUTING_FIELD] = self._routing
|
||||
routing = self._routing
|
||||
if routing is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
action_header["index"]["routing"] = routing
|
||||
action_values[ROUTING_FIELD] = routing
|
||||
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
@@ -147,6 +150,8 @@ class LindormVectorStore(BaseVector):
|
||||
logger.exception("Failed to process batch %s", batch_num + 1)
|
||||
raise
|
||||
|
||||
return uuids
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
|
||||
@@ -378,18 +383,21 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
raise ValueError("LINDORM_USING_UGC is not set")
|
||||
routing_value = None
|
||||
if dataset.index_struct:
|
||||
index_struct_dict = dataset.index_struct_dict
|
||||
if index_struct_dict is None:
|
||||
raise ValueError("dataset.index_struct_dict is missing")
|
||||
# if an existed record's index_struct_dict doesn't contain using_ugc field,
|
||||
# it actually stores in the normal index format
|
||||
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
|
||||
stored_in_ugc: bool = index_struct_dict.get("using_ugc", False)
|
||||
using_ugc = stored_in_ugc
|
||||
if stored_in_ugc:
|
||||
dimension = dataset.index_struct_dict["dimension"]
|
||||
index_type = dataset.index_struct_dict["index_type"]
|
||||
distance_type = dataset.index_struct_dict["distance_type"]
|
||||
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
dimension = index_struct_dict["dimension"]
|
||||
index_type = index_struct_dict["index_type"]
|
||||
distance_type = index_struct_dict["distance_type"]
|
||||
routing_value = index_struct_dict["vector_store"]["class_prefix"]
|
||||
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
|
||||
else:
|
||||
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
|
||||
index_name = index_struct_dict["vector_store"]["class_prefix"].lower()
|
||||
else:
|
||||
embedding_vector = embeddings.embed_query("hello word")
|
||||
dimension = len(embedding_vector)
|
||||
|
||||
@@ -7,7 +7,9 @@ from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseVector(ABC):
|
||||
def __init__(self, collection_name: str):
|
||||
_collection_name: str
|
||||
|
||||
def __init__(self, collection_name: str) -> None:
|
||||
self._collection_name = collection_name
|
||||
|
||||
@abstractmethod
|
||||
@@ -30,7 +32,7 @@ class BaseVector(ABC):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -63,5 +65,5 @@ class BaseVector(ABC):
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
def collection_name(self) -> str:
|
||||
return self._collection_name
|
||||
|
||||
@@ -2,7 +2,8 @@ import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -13,7 +14,7 @@ from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -24,19 +25,34 @@ from models.model import UploadFile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStoreIndexConfig(TypedDict):
|
||||
class_prefix: str
|
||||
|
||||
|
||||
class VectorIndexStructDict(TypedDict):
|
||||
type: VectorType
|
||||
vector_store: VectorStoreIndexConfig
|
||||
|
||||
|
||||
VectorDocumentInput = Document | ChildDocument | AttachmentDocument
|
||||
|
||||
|
||||
class AbstractVectorFactory(ABC):
|
||||
@abstractmethod
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
|
||||
def init_vector(self, dataset: Dataset, attributes: list[str], embeddings: Embeddings) -> BaseVector:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
|
||||
index_struct_dict: VectorIndexStructDict = {
|
||||
"type": vector_type,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
def __init__(self, dataset: Dataset, attributes: list[str] | None = None) -> None:
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
self._dataset = dataset
|
||||
@@ -198,12 +214,12 @@ class Vector:
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
def create(self, texts: Sequence[Document | ChildDocument] | None = None, **kwargs: Any) -> None:
|
||||
if texts:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(texts) + batch_size - 1
|
||||
total_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
@@ -212,29 +228,33 @@ class Vector:
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
|
||||
self._vector_processor.create(
|
||||
texts=self._normalize_documents(batch), embeddings=batch_embeddings, **kwargs
|
||||
)
|
||||
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
|
||||
|
||||
def create_multimodal(self, file_documents: list | None = None, **kwargs):
|
||||
def create_multimodal(self, file_documents: list[AttachmentDocument] | None = None, **kwargs: Any) -> None:
|
||||
if file_documents:
|
||||
start = time.time()
|
||||
logger.info("start embedding %s files %s", len(file_documents), start)
|
||||
batch_size = 1000
|
||||
total_batches = len(file_documents) + batch_size - 1
|
||||
total_batches = (len(file_documents) + batch_size - 1) // batch_size
|
||||
for i in range(0, len(file_documents), batch_size):
|
||||
batch = file_documents[i : i + batch_size]
|
||||
batch_start = time.time()
|
||||
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
|
||||
|
||||
# Batch query all upload files to avoid N+1 queries
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
|
||||
attachment_ids = [doc.metadata["doc_id"] for doc in batch if doc.metadata is not None]
|
||||
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
upload_files = db.session.scalars(stmt).all()
|
||||
upload_file_map = {str(f.id): f for f in upload_files}
|
||||
|
||||
file_base64_list = []
|
||||
real_batch = []
|
||||
file_base64_list: list[dict[str, str]] = []
|
||||
real_batch: list[AttachmentDocument] = []
|
||||
for document in batch:
|
||||
if document.metadata is None:
|
||||
continue
|
||||
attachment_id = document.metadata["doc_id"]
|
||||
doc_type = document.metadata["doc_type"]
|
||||
upload_file = upload_file_map.get(attachment_id)
|
||||
@@ -249,14 +269,20 @@ class Vector:
|
||||
}
|
||||
)
|
||||
real_batch.append(document)
|
||||
if not real_batch:
|
||||
continue
|
||||
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
|
||||
logger.info(
|
||||
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
|
||||
)
|
||||
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
|
||||
self._vector_processor.create(
|
||||
texts=self._normalize_documents(real_batch),
|
||||
embeddings=batch_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
def add_texts(self, documents: list[Document], **kwargs: Any) -> None:
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
|
||||
@@ -266,10 +292,10 @@ class Vector:
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self._vector_processor.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._vector_processor.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
self._vector_processor.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -295,7 +321,7 @@ class Vector:
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._vector_processor.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
self._vector_processor.delete()
|
||||
# delete collection redis cache
|
||||
if self._vector_processor.collection_name:
|
||||
@@ -325,7 +351,26 @@ class Vector:
|
||||
|
||||
return texts
|
||||
|
||||
def __getattr__(self, name):
|
||||
@staticmethod
|
||||
def _normalize_documents(documents: Sequence[VectorDocumentInput]) -> list[Document]:
|
||||
normalized_documents: list[Document] = []
|
||||
for document in documents:
|
||||
if isinstance(document, Document):
|
||||
normalized_documents.append(document)
|
||||
continue
|
||||
|
||||
normalized_documents.append(
|
||||
Document(
|
||||
page_content=document.page_content,
|
||||
vector=document.vector,
|
||||
metadata=document.metadata,
|
||||
provider=(document.provider or "dify") if isinstance(document, AttachmentDocument) else "dify",
|
||||
)
|
||||
)
|
||||
|
||||
return normalized_documents
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if self._vector_processor is not None:
|
||||
method = getattr(self._vector_processor, name)
|
||||
if callable(method):
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from models.dataset import Document
|
||||
from models.model import UploadFile
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
|
||||
class NotionInfo(BaseModel):
|
||||
@@ -12,7 +16,7 @@ class NotionInfo(BaseModel):
|
||||
credential_id: str | None = None
|
||||
notion_workspace_id: str | None = ""
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
notion_page_type: Literal["database", "page"]
|
||||
document: Document | None = None
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -25,20 +29,27 @@ class WebsiteInfo(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
provider: str
|
||||
provider: AuthType
|
||||
job_id: str
|
||||
url: str
|
||||
mode: str
|
||||
mode: Literal["crawl", "crawl_return_urls", "scrape"]
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
@field_validator("mode", mode="before")
|
||||
@classmethod
|
||||
def _normalize_legacy_mode(cls, value: str) -> str:
|
||||
if value == "single":
|
||||
return "crawl"
|
||||
return value
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
datasource_type: str
|
||||
datasource_type: DatasourceType
|
||||
upload_file: UploadFile | None = None
|
||||
notion_info: NotionInfo | None = None
|
||||
website_info: WebsiteInfo | None = None
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import TypeAlias
|
||||
from urllib.parse import unquote
|
||||
|
||||
from configs import dify_config
|
||||
@@ -31,19 +32,27 @@ from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import UploadFile
|
||||
from services.auth.auth_type import AuthType
|
||||
|
||||
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
|
||||
USER_AGENT = (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
|
||||
" Safari/537.36"
|
||||
)
|
||||
ExtractProcessorOutput: TypeAlias = list[Document] | str
|
||||
|
||||
|
||||
class ExtractProcessor:
|
||||
@staticmethod
|
||||
def _build_temp_file_path(temp_dir: str, suffix: str) -> str:
|
||||
file_descriptor, file_path = tempfile.mkstemp(dir=temp_dir, suffix=suffix)
|
||||
os.close(file_descriptor)
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_from_upload_file(
|
||||
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
|
||||
) -> Union[list[Document], str]:
|
||||
) -> ExtractProcessorOutput:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
|
||||
)
|
||||
@@ -54,7 +63,7 @@ class ExtractProcessor:
|
||||
return cls.extract(extract_setting, is_automatic)
|
||||
|
||||
@classmethod
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
|
||||
def load_from_url(cls, url: str, return_text: bool = False) -> ExtractProcessorOutput:
|
||||
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
@@ -65,17 +74,16 @@ class ExtractProcessor:
|
||||
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
|
||||
else:
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
# Generate a temporary filename under the created temp_dir and ensure the directory exists
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
if content_disposition:
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
file_path = cls._build_temp_file_path(temp_dir, suffix)
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
|
||||
if return_text:
|
||||
@@ -94,13 +102,13 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
upload_file = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
file_path = cls._build_temp_file_path(temp_dir, suffix)
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
@@ -113,7 +121,11 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
extractor = PdfExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = (
|
||||
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
@@ -123,7 +135,11 @@ class ExtractProcessor:
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
extractor = WordExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension == ".doc":
|
||||
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
|
||||
elif file_extension == ".csv":
|
||||
@@ -149,13 +165,21 @@ class ExtractProcessor:
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
elif file_extension == ".pdf":
|
||||
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
extractor = PdfExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension in {".md", ".markdown", ".mdx"}:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in {".htm", ".html"}:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
extractor = WordExtractor(
|
||||
file_path,
|
||||
upload_file.tenant_id if upload_file else None,
|
||||
upload_file.created_by if upload_file else None,
|
||||
)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension == ".epub":
|
||||
@@ -177,7 +201,7 @@ class ExtractProcessor:
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
|
||||
assert extract_setting.website_info is not None, "website_info is required"
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
if extract_setting.website_info.provider == AuthType.FIRECRAWL:
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
@@ -186,7 +210,7 @@ class ExtractProcessor:
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "watercrawl":
|
||||
elif extract_setting.website_info.provider == AuthType.WATERCRAWL:
|
||||
extractor = WaterCrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
@@ -195,7 +219,7 @@ class ExtractProcessor:
|
||||
only_main_content=extract_setting.website_info.only_main_content,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.website_info.provider == "jinareader":
|
||||
elif extract_setting.website_info.provider == AuthType.JINA:
|
||||
extractor = JinaReaderWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class BaseExtractor(ABC):
|
||||
"""Interface for extract files."""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self):
|
||||
def extract(self) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -30,7 +30,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
|
||||
For large files, reading only a sample is sufficient and prevents timeout.
|
||||
"""
|
||||
|
||||
def read_and_detect(filename: str):
|
||||
def read_and_detect(filename: str) -> list[FileEncoding]:
|
||||
rst = charset_normalizer.from_path(filename)
|
||||
best = rst.best()
|
||||
if best is None:
|
||||
|
||||
@@ -28,8 +28,8 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
Args:
|
||||
file_path: Path to the PDF file.
|
||||
tenant_id: Workspace ID.
|
||||
user_id: ID of the user performing the extraction.
|
||||
tenant_id: Workspace ID used for extracted image persistence when available.
|
||||
user_id: ID of the user performing the extraction when available.
|
||||
file_cache_key: Optional cache key for the extracted text.
|
||||
"""
|
||||
|
||||
@@ -47,7 +47,13 @@ class PdfExtractor(BaseExtractor):
|
||||
]
|
||||
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
tenant_id: str | None,
|
||||
user_id: str | None,
|
||||
file_cache_key: str | None = None,
|
||||
):
|
||||
"""Initialize PdfExtractor."""
|
||||
self._file_path = file_path
|
||||
self._tenant_id = tenant_id
|
||||
@@ -116,6 +122,9 @@ class PdfExtractor(BaseExtractor):
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
if self._tenant_id is None or self._user_id is None:
|
||||
return ""
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
for obj in image_objects:
|
||||
|
||||
@@ -9,6 +9,8 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from docx import Document as DocxDocument
|
||||
@@ -35,7 +37,7 @@ class WordExtractor(BaseExtractor):
|
||||
file_path: Path to the file to load.
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, tenant_id: str, user_id: str):
|
||||
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.tenant_id = tenant_id
|
||||
@@ -86,9 +88,12 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
image_map: dict[object, str] = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
if self.tenant_id is None or self.user_id is None:
|
||||
return image_map
|
||||
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
image_count += 1
|
||||
@@ -264,7 +269,7 @@ class WordExtractor(BaseExtractor):
|
||||
def parse_docx(self, docx_path):
|
||||
doc = DocxDocument(docx_path)
|
||||
|
||||
content = []
|
||||
content: list[str] = []
|
||||
|
||||
image_map = self._extract_images_from_docx(doc)
|
||||
|
||||
@@ -362,10 +367,10 @@ class WordExtractor(BaseExtractor):
|
||||
if link_text:
|
||||
target_buffer.append(link_text)
|
||||
|
||||
paragraph_content = []
|
||||
paragraph_content: list[str] = []
|
||||
# State for legacy HYPERLINK fields
|
||||
hyperlink_field_url = None
|
||||
hyperlink_field_text_parts: list = []
|
||||
hyperlink_field_url: str | None = None
|
||||
hyperlink_field_text_parts: list[str] = []
|
||||
is_collecting_field_text = False
|
||||
# Iterate through paragraph elements in document order
|
||||
for child in paragraph._element:
|
||||
@@ -422,7 +427,8 @@ class WordExtractor(BaseExtractor):
|
||||
|
||||
paragraphs = doc.paragraphs.copy()
|
||||
tables = doc.tables.copy()
|
||||
for element in doc.element.body:
|
||||
body_elements = cast(Iterable[object], getattr(doc.element, "body", []))
|
||||
for element in body_elements:
|
||||
if hasattr(element, "tag"):
|
||||
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
|
||||
para = paragraphs.pop(0)
|
||||
|
||||
@@ -3,7 +3,7 @@ import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import delete, func, select
|
||||
@@ -20,6 +20,16 @@ from .processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexAndCleanResult(TypedDict):
|
||||
dataset_id: str
|
||||
dataset_name: str
|
||||
batch: str
|
||||
document_id: str
|
||||
document_name: str
|
||||
created_at: float
|
||||
display_status: Literal["completed"]
|
||||
|
||||
|
||||
class IndexProcessor:
|
||||
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
@@ -51,9 +61,9 @@ class IndexProcessor:
|
||||
document_id: str,
|
||||
original_document_id: str,
|
||||
chunks: Mapping[str, Any],
|
||||
batch: Any,
|
||||
batch: str,
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None,
|
||||
):
|
||||
) -> IndexAndCleanResult:
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
|
||||
@@ -122,6 +122,7 @@ class BaseIndexProcessor(ABC):
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
character_splitter: TextSplitter
|
||||
if processing_rule_mode in ["custom", "hierarchical"]:
|
||||
# The user-defined segmentation rule
|
||||
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
|
||||
@@ -147,7 +148,7 @@ class BaseIndexProcessor(ABC):
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
)
|
||||
|
||||
return character_splitter # type: ignore
|
||||
return character_splitter
|
||||
|
||||
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
|
||||
"""
|
||||
@@ -158,7 +159,7 @@ class BaseIndexProcessor(ABC):
|
||||
images = self._extract_markdown_images(text)
|
||||
if not images:
|
||||
return multi_model_documents
|
||||
upload_file_id_list = []
|
||||
upload_file_id_list: list[str] = []
|
||||
|
||||
for image in images:
|
||||
# Collect all upload_file_ids including duplicates to preserve occurrence count
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit."""
|
||||
|
||||
def __init__(self, index_type: str | None):
|
||||
def __init__(self, index_type: str | None) -> None:
|
||||
self._index_type = index_type
|
||||
|
||||
def init_index_processor(self) -> BaseIndexProcessor:
|
||||
@@ -19,11 +19,12 @@ class IndexProcessorFactory:
|
||||
if not self._index_type:
|
||||
raise ValueError("Index type must be specified.")
|
||||
|
||||
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
|
||||
return ParagraphIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.QA_INDEX:
|
||||
return QAIndexProcessor()
|
||||
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
return ParentChildIndexProcessor()
|
||||
else:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
match self._index_type:
|
||||
case IndexStructureType.PARAGRAPH_INDEX:
|
||||
return ParagraphIndexProcessor()
|
||||
case IndexStructureType.QA_INDEX:
|
||||
return QAIndexProcessor()
|
||||
case IndexStructureType.PARENT_CHILD_INDEX:
|
||||
return ParentChildIndexProcessor()
|
||||
case _:
|
||||
raise ValueError(f"Index type {self._index_type} is not supported.")
|
||||
|
||||
@@ -30,7 +30,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> TS:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
@@ -8,7 +8,7 @@ class BaseStorage(ABC):
|
||||
"""Interface for file storage."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, filename: str, data: bytes):
|
||||
def save(self, filename: str, data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -16,7 +16,7 @@ class BaseStorage(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def load_stream(self, filename: str) -> Generator[bytes, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -28,10 +28,10 @@ class BaseStorage(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, filename: str):
|
||||
def delete(self, filename: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def scan(self, path, files=True, directories=False) -> list[str]:
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""
|
||||
Scan files and directories in the given path.
|
||||
This method is implemented only in some storage backends.
|
||||
|
||||
@@ -30,8 +30,8 @@ class GitHubEmailRecord(TypedDict, total=False):
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
name: NotRequired[str | None]
|
||||
email: NotRequired[str | None]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
@@ -138,7 +138,7 @@ class GitHubOAuth(OAuth):
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=payload.get("name") or "", email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
|
||||
@@ -20,7 +20,7 @@ else:
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
page_icon: dict[str, object] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
@@ -43,58 +43,6 @@ core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/rag/retrieval/router/multi_dataset_function_call_router.py
|
||||
core/rag/summary_index/summary_index.py
|
||||
core/repositories/sqlalchemy_workflow_execution_repository.py
|
||||
@@ -140,27 +88,10 @@ dify_graph/nodes/variable_assigner/v2/node.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
@@ -188,3 +119,75 @@ tasks/disable_segment_from_index_task.py
|
||||
tasks/enable_segment_to_index_task.py
|
||||
tasks/remove_document_from_index_task.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
|
||||
# no need to fix for now: storage adapters
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
|
||||
# no need to fix for now: keyword adapters
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
|
||||
# no need to fix for now: vector db adapters
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
|
||||
# no need to fix for now: extractors
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
|
||||
# no need to fix for now: index processors
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class ApiKeyAuthConfig(TypedDict, total=False):
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
|
||||
class ApiKeyAuthCredentials(TypedDict):
|
||||
auth_type: object
|
||||
config: ApiKeyAuthConfig
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
credentials: ApiKeyAuthCredentials
|
||||
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials) -> None:
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
def validate_credentials(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.auth_type import AuthType
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
from services.auth.auth_type import AuthProvider, AuthType
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
auth: ApiKeyAuthBase
|
||||
|
||||
def __init__(self, provider: AuthProvider, credentials: ApiKeyAuthCredentials) -> None:
|
||||
auth_factory = self.get_apikey_auth_factory(provider)
|
||||
self.auth = auth_factory(credentials)
|
||||
|
||||
def validate_credentials(self):
|
||||
def validate_credentials(self) -> bool:
|
||||
return self.auth.validate_credentials()
|
||||
|
||||
@staticmethod
|
||||
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
||||
match provider:
|
||||
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
|
||||
match ApiKeyAuthFactory._normalize_provider(provider):
|
||||
case AuthType.FIRECRAWL:
|
||||
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||
|
||||
@@ -27,3 +29,13 @@ class ApiKeyAuthFactory:
|
||||
return JinaAuth
|
||||
case _:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_provider(provider: AuthProvider) -> AuthType | str:
|
||||
if isinstance(provider, AuthType):
|
||||
return provider
|
||||
|
||||
try:
|
||||
return AuthType(provider)
|
||||
except ValueError:
|
||||
return provider
|
||||
|
||||
@@ -1,40 +1,63 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthCredentials
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthCreateArgs(TypedDict):
|
||||
category: str
|
||||
provider: str
|
||||
credentials: ApiKeyAuthCredentials
|
||||
|
||||
|
||||
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
|
||||
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
def get_provider_auth_list(tenant_id: str) -> list[DataSourceApiKeyAuthBinding]:
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
return list(data_source_api_key_bindings)
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, object]) -> None:
|
||||
validated_args = ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
raw_credentials = ApiKeyAuthService._get_credentials_dict(args)
|
||||
auth_result = ApiKeyAuthFactory(
|
||||
validated_args["provider"], validated_args["credentials"]
|
||||
).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
api_key_value = validated_args["credentials"]["config"].get("api_key")
|
||||
if api_key_value is None:
|
||||
raise KeyError("api_key")
|
||||
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
|
||||
raw_config = ApiKeyAuthService._get_config_dict(raw_credentials)
|
||||
raw_config["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
tenant_id=tenant_id,
|
||||
category=validated_args["category"],
|
||||
provider=validated_args["provider"],
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
data_source_api_key_binding.credentials = json.dumps(raw_credentials, ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str) -> dict[str, object] | None:
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
@@ -50,10 +73,10 @@ class ApiKeyAuthService:
|
||||
if not data_source_api_key_bindings.credentials:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
return AUTH_CREDENTIALS_ADAPTER.validate_python(credentials)
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
@@ -63,8 +86,10 @@ class ApiKeyAuthService:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
@staticmethod
|
||||
def validate_api_key_auth_args(args: dict[str, object] | None) -> ApiKeyAuthCreateArgs:
|
||||
if args is None:
|
||||
raise TypeError("argument of type 'NoneType' is not iterable")
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
@@ -75,3 +100,18 @@ class ApiKeyAuthService:
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
return AUTH_CREATE_ARGS_ADAPTER.validate_python(args)
|
||||
|
||||
@staticmethod
|
||||
def _get_credentials_dict(args: dict[str, object]) -> dict[str, object]:
|
||||
credentials = args["credentials"]
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
return cast(dict[str, object], credentials)
|
||||
|
||||
@staticmethod
|
||||
def _get_config_dict(credentials: dict[str, object]) -> dict[str, object]:
|
||||
config = credentials["config"]
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError(f"credentials['config'] must be a dictionary, got {type(config).__name__}")
|
||||
return cast(dict[str, object], config)
|
||||
|
||||
@@ -5,3 +5,6 @@ class AuthType(StrEnum):
|
||||
FIRECRAWL = "firecrawl"
|
||||
WATERCRAWL = "watercrawl"
|
||||
JINA = "jinareader"
|
||||
|
||||
|
||||
AuthProvider = AuthType | str
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@@ -2,11 +2,11 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class JinaAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
|
||||
@@ -3,11 +3,11 @@ from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
|
||||
|
||||
|
||||
class WatercrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
def __init__(self, credentials: ApiKeyAuthCredentials):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "x-api-key":
|
||||
|
||||
@@ -200,6 +200,29 @@ class TestExtractProcessorFileRouting:
|
||||
with pytest.raises(AssertionError, match="upload_file is required"):
|
||||
ExtractProcessor.extract(setting)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("extension", "etl_type", "expected_extractor"),
|
||||
[
|
||||
(".pdf", "Unstructured", "PdfExtractor"),
|
||||
(".docx", "Unstructured", "WordExtractor"),
|
||||
(".pdf", "SelfHosted", "PdfExtractor"),
|
||||
(".docx", "SelfHosted", "WordExtractor"),
|
||||
],
|
||||
)
|
||||
def test_extract_allows_url_file_paths_without_upload_context(
|
||||
self, monkeypatch, extension: str, etl_type: str, expected_extractor: str
|
||||
):
|
||||
factory = _patch_all_extractors(monkeypatch)
|
||||
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type)
|
||||
|
||||
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
|
||||
|
||||
docs = ExtractProcessor.extract(setting, file_path=f"/tmp/example{extension}")
|
||||
|
||||
assert docs[0].page_content == f"extracted-by-{expected_extractor}"
|
||||
assert factory.calls[-1][0] == expected_extractor
|
||||
assert factory.calls[-1][1] == (f"/tmp/example{extension}", None, None)
|
||||
|
||||
|
||||
class TestExtractProcessorDatasourceRouting:
|
||||
def test_extract_routes_notion_datasource(self, monkeypatch):
|
||||
|
||||
@@ -184,3 +184,21 @@ def test_extract_images_failures(mock_dependencies):
|
||||
assert len(saves) == 1
|
||||
assert saves[0][1] == jpeg_bytes
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_skips_persistence_without_upload_context(mock_dependencies):
|
||||
mock_page = MagicMock()
|
||||
mock_image_obj = MagicMock()
|
||||
mock_image_obj.extract.side_effect = lambda buf, fb_format=None: buf.write(b"\xff\xd8\xff image")
|
||||
mock_page.get_objects.return_value = [mock_image_obj]
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id=None, user_id=None)
|
||||
|
||||
with patch("pypdfium2.raw", autospec=True) as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
assert result == ""
|
||||
assert mock_dependencies.saves == []
|
||||
assert mock_dependencies.db.session.added == []
|
||||
assert mock_dependencies.db.session.committed is False
|
||||
|
||||
@@ -179,6 +179,27 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_skips_persistence_without_upload_context(monkeypatch):
|
||||
saves: list[tuple[str, bytes]] = []
|
||||
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: saves.append((key, data))))
|
||||
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=lambda: None))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
|
||||
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
|
||||
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext}))
|
||||
|
||||
extractor = object.__new__(WordExtractor)
|
||||
extractor.tenant_id = None
|
||||
extractor.user_id = None
|
||||
|
||||
image_map = extractor._extract_images_from_docx(doc)
|
||||
|
||||
assert image_map == {}
|
||||
assert saves == []
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
|
||||
@@ -13,8 +13,11 @@ class TestApiKeyAuthFactory:
|
||||
("provider", "auth_class_path"),
|
||||
[
|
||||
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.FIRECRAWL.value, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
|
||||
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.WATERCRAWL.value, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
|
||||
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
|
||||
(AuthType.JINA.value, "services.auth.jina.jina.JinaAuth"),
|
||||
],
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -68,7 +69,16 @@ class TestApiKeyAuthService:
|
||||
# Mock successful auth validation
|
||||
mock_auth_instance = Mock()
|
||||
mock_auth_instance.validate_credentials.return_value = True
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
captured_provider = None
|
||||
captured_credentials = None
|
||||
|
||||
def factory_side_effect(provider, credentials):
|
||||
nonlocal captured_provider, captured_credentials
|
||||
captured_provider = provider
|
||||
captured_credentials = deepcopy(credentials)
|
||||
return mock_auth_instance
|
||||
|
||||
mock_factory.side_effect = factory_side_effect
|
||||
|
||||
# Mock encryption
|
||||
encrypted_key = "encrypted_test_key_123"
|
||||
@@ -77,11 +87,14 @@ class TestApiKeyAuthService:
|
||||
# Mock database operations
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
expected_credentials = deepcopy(self.mock_credentials)
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
|
||||
|
||||
# Verify factory class calls
|
||||
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
|
||||
assert mock_factory.call_count == 1
|
||||
assert captured_provider == self.provider
|
||||
assert captured_credentials == expected_credentials
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
# Verify encryption calls
|
||||
|
||||
@@ -4,11 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.vector_service as vector_service_module
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
|
||||
from services.vector_service import VectorService
|
||||
|
||||
|
||||
@@ -702,3 +704,105 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch:
|
||||
|
||||
logger_mock.exception.assert_called_once()
|
||||
db_mock.session.rollback.assert_called_once()
|
||||
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
def test_vector_create_normalizes_child_documents(mock_get_embeddings: Mock, mock_init_vector: Mock) -> None:
|
||||
dataset = _make_dataset()
|
||||
documents = [ChildDocument(page_content="Child content", metadata={"doc_id": "child-1", "dataset_id": "dataset-1"})]
|
||||
|
||||
mock_embeddings = Mock()
|
||||
mock_embeddings.embed_documents.return_value = [[0.1] * 1536]
|
||||
mock_get_embeddings.return_value = mock_embeddings
|
||||
|
||||
mock_vector_processor = Mock()
|
||||
mock_init_vector.return_value = mock_vector_processor
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
vector.create(texts=documents)
|
||||
|
||||
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
|
||||
assert isinstance(normalized_document, Document)
|
||||
assert normalized_document.page_content == "Child content"
|
||||
assert normalized_document.metadata["doc_id"] == "child-1"
|
||||
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.storage")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.db.session")
|
||||
def test_vector_create_multimodal_normalizes_attachment_documents(
|
||||
mock_session: Mock,
|
||||
mock_storage: Mock,
|
||||
mock_get_embeddings: Mock,
|
||||
mock_init_vector: Mock,
|
||||
) -> None:
|
||||
dataset = _make_dataset()
|
||||
file_document = AttachmentDocument(
|
||||
page_content="Attachment content",
|
||||
provider="custom-provider",
|
||||
metadata={"doc_id": "file-1", "doc_type": "image/png"},
|
||||
)
|
||||
upload_file = Mock(id="file-1", key="upload-key")
|
||||
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.all.return_value = [upload_file]
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
mock_storage.load_once.return_value = b"binary-content"
|
||||
|
||||
mock_embeddings = Mock()
|
||||
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
|
||||
mock_get_embeddings.return_value = mock_embeddings
|
||||
|
||||
mock_vector_processor = Mock()
|
||||
mock_init_vector.return_value = mock_vector_processor
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
vector.create_multimodal(file_documents=[file_document])
|
||||
|
||||
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
|
||||
assert isinstance(normalized_document, Document)
|
||||
assert normalized_document.provider == "custom-provider"
|
||||
assert normalized_document.metadata["doc_id"] == "file-1"
|
||||
|
||||
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.storage")
|
||||
@patch("core.rag.datasource.vdb.vector_factory.db.session")
|
||||
def test_vector_create_multimodal_falls_back_to_dify_provider_when_attachment_provider_is_none(
|
||||
mock_session: Mock,
|
||||
mock_storage: Mock,
|
||||
mock_get_embeddings: Mock,
|
||||
mock_init_vector: Mock,
|
||||
) -> None:
|
||||
dataset = _make_dataset()
|
||||
file_document = AttachmentDocument(
|
||||
page_content="Attachment content",
|
||||
provider=None,
|
||||
metadata={"doc_id": "file-1", "doc_type": "image/png"},
|
||||
)
|
||||
upload_file = Mock(id="file-1", key="upload-key")
|
||||
|
||||
mock_scalars = Mock()
|
||||
mock_scalars.all.return_value = [upload_file]
|
||||
mock_session.scalars.return_value = mock_scalars
|
||||
mock_storage.load_once.return_value = b"binary-content"
|
||||
|
||||
mock_embeddings = Mock()
|
||||
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
|
||||
mock_get_embeddings.return_value = mock_embeddings
|
||||
|
||||
mock_vector_processor = Mock()
|
||||
mock_init_vector.return_value = mock_vector_processor
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
vector.create_multimodal(file_documents=[file_document])
|
||||
|
||||
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
|
||||
assert isinstance(normalized_document, Document)
|
||||
assert normalized_document.provider == "dify"
|
||||
|
||||
Reference in New Issue
Block a user