Compare commits

...

25 Commits

Author SHA1 Message Date
autofix-ci[bot]
e5e5fe14ff [autofix.ci] apply automated fixes 2026-03-18 17:40:49 +00:00
Yanli 盐粒
aa30eeaf27 🐛 fix: clarify invalid auth config errors 2026-03-19 01:38:45 +08:00
Yanli 盐粒
0745f573e6 fix: add missing word extractor type annotation 2026-03-19 01:06:18 +08:00
Yanli 盐粒
94b05b2ca1 refactor: remove resolved pyrefly excludes 2026-03-19 00:55:35 +08:00
Yanli 盐粒
c3b17fc833 🐛 fix: preserve URL-based PDF and DOCX extraction 2026-03-19 00:43:23 +08:00
Yanli 盐粒
8f99dc1ac1 refactor(api): remove unused vector payload alias 2026-03-19 00:10:25 +08:00
Yanli 盐粒
2028e2c3b8 Default attachment provider during normalization 2026-03-18 23:40:17 +08:00
Yanli 盐粒
7ff470d8a0 Preserve auth service interface compatibility 2026-03-18 23:34:24 +08:00
Yanli 盐粒
a4dbb76d3a Keep typing changes scoped in auth and extractor tests 2026-03-18 23:13:31 +08:00
Yanli 盐粒
134ae75a9b chore(api): keep auth api key validation aligned with main 2026-03-18 22:22:52 +08:00
Yanli 盐粒
2bbb45e97f fix(api): require api key in auth create validation 2026-03-18 20:05:50 +08:00
Yanli 盐粒
a0017183b6 Address API review follow-ups 2026-03-18 18:31:09 +08:00
Yanli 盐粒
db7d5e30cb Merge origin/main into yanli/pyrefly-fix-plan-v2 2026-03-18 18:16:42 +08:00
Yanli 盐粒
295587718d Stabilize API key auth validation tests 2026-03-18 18:13:14 +08:00
Yanli 盐粒
b85d010e42 Handle legacy website crawl mode values 2026-03-18 18:02:41 +08:00
Yanli 盐粒
c71e407c39 Fix OAuth payload validation regressions 2026-03-18 17:56:38 +08:00
Yanli 盐粒
1e5e65326e refactor(api): keep extractor typing changes behavior-neutral 2026-03-17 20:34:34 +08:00
Yanli 盐粒
fe18405f1d refactor(api): remove dead metadata guard 2026-03-17 20:11:10 +08:00
Yanli 盐粒
7ccc736929 refactor(api): preserve pydantic auth validation errors 2026-03-17 20:10:05 +08:00
Yanli 盐粒
bcac77c212 refactor(api): relax vector metadata id lookup contract 2026-03-17 20:07:07 +08:00
Yanli 盐粒
2b53f1bfea fix(api): annotate splitter variable for mypy 2026-03-17 20:06:07 +08:00
Yanli 盐粒
eec9c76b7b chore(api): clarify deferred pyrefly exclude comments 2026-03-17 19:59:01 +08:00
autofix-ci[bot]
98a94019c4 [autofix.ci] apply automated fixes 2026-03-17 11:55:31 +00:00
Yanli 盐粒
e8f120d87b fix(api): use typing_extensions TypedDict in auth 2026-03-17 19:53:19 +08:00
Yanli 盐粒
7572db15ff refactor(api): tighten shared adapter typing contracts 2026-03-17 19:47:16 +08:00
31 changed files with 571 additions and 198 deletions

View File

@@ -88,7 +88,7 @@ class LindormVectorStore(BaseVector):
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
) -> list[str]:
logger.info("Total documents to add: %s", len(documents))
uuids = self._get_uuids(documents)
@@ -130,8 +130,11 @@ class LindormVectorStore(BaseVector):
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[ROUTING_FIELD] = self._routing
routing = self._routing
if routing is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
action_header["index"]["routing"] = routing
action_values[ROUTING_FIELD] = routing
actions.append(action_header)
actions.append(action_values)
@@ -147,6 +150,8 @@ class LindormVectorStore(BaseVector):
logger.exception("Failed to process batch %s", batch_num + 1)
raise
return uuids
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
@@ -378,18 +383,21 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
raise ValueError("LINDORM_USING_UGC is not set")
routing_value = None
if dataset.index_struct:
index_struct_dict = dataset.index_struct_dict
if index_struct_dict is None:
raise ValueError("dataset.index_struct_dict is missing")
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
stored_in_ugc: bool = index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
dimension = index_struct_dict["dimension"]
index_type = index_struct_dict["index_type"]
distance_type = index_struct_dict["distance_type"]
routing_value = index_struct_dict["vector_store"]["class_prefix"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
index_name = index_struct_dict["vector_store"]["class_prefix"].lower()
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)

View File

@@ -7,7 +7,9 @@ from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
_collection_name: str
def __init__(self, collection_name: str) -> None:
self._collection_name = collection_name
@abstractmethod
@@ -30,7 +32,7 @@ class BaseVector(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str):
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
raise NotImplementedError
@abstractmethod
@@ -63,5 +65,5 @@ class BaseVector(ABC):
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):
def collection_name(self) -> str:
return self._collection_name

View File

@@ -2,7 +2,8 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Sequence
from typing import Any, TypedDict
from sqlalchemy import select
@@ -13,7 +14,7 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -24,19 +25,34 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
class VectorStoreIndexConfig(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: VectorType
vector_store: VectorStoreIndexConfig
VectorDocumentInput = Document | ChildDocument | AttachmentDocument
class AbstractVectorFactory(ABC):
@abstractmethod
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
def init_vector(self, dataset: Dataset, attributes: list[str], embeddings: Embeddings) -> BaseVector:
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
def __init__(self, dataset: Dataset, attributes: list[str] | None = None) -> None:
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
self._dataset = dataset
@@ -198,12 +214,12 @@ class Vector:
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: list | None = None, **kwargs):
def create(self, texts: Sequence[Document | ChildDocument] | None = None, **kwargs: Any) -> None:
if texts:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
total_batches = len(texts) + batch_size - 1
total_batches = (len(texts) + batch_size - 1) // batch_size
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
@@ -212,29 +228,33 @@ class Vector:
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
self._vector_processor.create(
texts=self._normalize_documents(batch), embeddings=batch_embeddings, **kwargs
)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
def create_multimodal(self, file_documents: list[AttachmentDocument] | None = None, **kwargs: Any) -> None:
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = len(file_documents) + batch_size - 1
total_batches = (len(file_documents) + batch_size - 1) // batch_size
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
attachment_ids = [doc.metadata["doc_id"] for doc in batch if doc.metadata is not None]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list = []
real_batch = []
file_base64_list: list[dict[str, str]] = []
real_batch: list[AttachmentDocument] = []
for document in batch:
if document.metadata is None:
continue
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
@@ -249,14 +269,20 @@ class Vector:
}
)
real_batch.append(document)
if not real_batch:
continue
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
self._vector_processor.create(
texts=self._normalize_documents(real_batch),
embeddings=batch_embeddings,
**kwargs,
)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs):
def add_texts(self, documents: list[Document], **kwargs: Any) -> None:
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@@ -266,10 +292,10 @@ class Vector:
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]):
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str):
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
@@ -295,7 +321,7 @@ class Vector:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self):
def delete(self) -> None:
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
@@ -325,7 +351,26 @@ class Vector:
return texts
def __getattr__(self, name):
@staticmethod
def _normalize_documents(documents: Sequence[VectorDocumentInput]) -> list[Document]:
normalized_documents: list[Document] = []
for document in documents:
if isinstance(document, Document):
normalized_documents.append(document)
continue
normalized_documents.append(
Document(
page_content=document.page_content,
vector=document.vector,
metadata=document.metadata,
provider=(document.provider or "dify") if isinstance(document, AttachmentDocument) else "dify",
)
)
return normalized_documents
def __getattr__(self, name: str) -> Any:
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):

View File

@@ -1,7 +1,11 @@
from pydantic import BaseModel, ConfigDict
from typing import Literal
from pydantic import BaseModel, ConfigDict, field_validator
from core.rag.extractor.entity.datasource_type import DatasourceType
from models.dataset import Document
from models.model import UploadFile
from services.auth.auth_type import AuthType
class NotionInfo(BaseModel):
@@ -12,7 +16,7 @@ class NotionInfo(BaseModel):
credential_id: str | None = None
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: str
notion_page_type: Literal["database", "page"]
document: Document | None = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -25,20 +29,27 @@ class WebsiteInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
provider: str
provider: AuthType
job_id: str
url: str
mode: str
mode: Literal["crawl", "crawl_return_urls", "scrape"]
tenant_id: str
only_main_content: bool = False
@field_validator("mode", mode="before")
@classmethod
def _normalize_legacy_mode(cls, value: str) -> str:
if value == "single":
return "crawl"
return value
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: str
datasource_type: DatasourceType
upload_file: UploadFile | None = None
notion_info: NotionInfo | None = None
website_info: WebsiteInfo | None = None

View File

@@ -1,7 +1,8 @@
import os
import re
import tempfile
from pathlib import Path
from typing import Union
from typing import TypeAlias
from urllib.parse import unquote
from configs import dify_config
@@ -31,19 +32,27 @@ from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
from services.auth.auth_type import AuthType
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
" Safari/537.36"
)
ExtractProcessorOutput: TypeAlias = list[Document] | str
class ExtractProcessor:
@staticmethod
def _build_temp_file_path(temp_dir: str, suffix: str) -> str:
file_descriptor, file_path = tempfile.mkstemp(dir=temp_dir, suffix=suffix)
os.close(file_descriptor)
return file_path
@classmethod
def load_from_upload_file(
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> Union[list[Document], str]:
) -> ExtractProcessorOutput:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
)
@@ -54,7 +63,7 @@ class ExtractProcessor:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
def load_from_url(cls, url: str, return_text: bool = False) -> ExtractProcessorOutput:
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
with tempfile.TemporaryDirectory() as temp_dir:
@@ -65,17 +74,16 @@ class ExtractProcessor:
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
else:
content_disposition = response.headers.get("Content-Disposition")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
# Generate a temporary filename under the created temp_dir and ensure the directory exists
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
if content_disposition:
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
file_path = cls._build_temp_file_path(temp_dir, suffix)
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
if return_text:
@@ -94,13 +102,13 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
upload_file = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
file_path = cls._build_temp_file_path(temp_dir, suffix)
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
@@ -113,7 +121,11 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@@ -123,7 +135,11 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".csv":
@@ -149,13 +165,21 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == ".epub":
@@ -177,7 +201,7 @@ class ExtractProcessor:
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
if extract_setting.website_info.provider == AuthType.FIRECRAWL:
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@@ -186,7 +210,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "watercrawl":
elif extract_setting.website_info.provider == AuthType.WATERCRAWL:
extractor = WaterCrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@@ -195,7 +219,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == "jinareader":
elif extract_setting.website_info.provider == AuthType.JINA:
extractor = JinaReaderWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,

View File

@@ -2,10 +2,12 @@
from abc import ABC, abstractmethod
from core.rag.models.document import Document
class BaseExtractor(ABC):
"""Interface for extract files."""
@abstractmethod
def extract(self):
def extract(self) -> list[Document]:
raise NotImplementedError

View File

@@ -30,7 +30,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
For large files, reading only a sample is sufficient and prevents timeout.
"""
def read_and_detect(filename: str):
def read_and_detect(filename: str) -> list[FileEncoding]:
rst = charset_normalizer.from_path(filename)
best = rst.best()
if best is None:

View File

@@ -28,8 +28,8 @@ class PdfExtractor(BaseExtractor):
Args:
file_path: Path to the PDF file.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
tenant_id: Workspace ID used for extracted image persistence when available.
user_id: ID of the user performing the extraction when available.
file_cache_key: Optional cache key for the extracted text.
"""
@@ -47,7 +47,13 @@ class PdfExtractor(BaseExtractor):
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
def __init__(
self,
file_path: str,
tenant_id: str | None,
user_id: str | None,
file_cache_key: str | None = None,
):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
@@ -116,6 +122,9 @@ class PdfExtractor(BaseExtractor):
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
if self._tenant_id is None or self._user_id is None:
return ""
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:

View File

@@ -9,6 +9,8 @@ import os
import re
import tempfile
import uuid
from collections.abc import Iterable
from typing import cast
from urllib.parse import urlparse
from docx import Document as DocxDocument
@@ -35,7 +37,7 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, tenant_id: str, user_id: str):
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None):
"""Initialize with file path."""
self.file_path = file_path
self.tenant_id = tenant_id
@@ -86,9 +88,12 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
image_map: dict[object, str] = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
if self.tenant_id is None or self.user_id is None:
return image_map
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
@@ -264,7 +269,7 @@ class WordExtractor(BaseExtractor):
def parse_docx(self, docx_path):
doc = DocxDocument(docx_path)
content = []
content: list[str] = []
image_map = self._extract_images_from_docx(doc)
@@ -362,10 +367,10 @@ class WordExtractor(BaseExtractor):
if link_text:
target_buffer.append(link_text)
paragraph_content = []
paragraph_content: list[str] = []
# State for legacy HYPERLINK fields
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
hyperlink_field_url: str | None = None
hyperlink_field_text_parts: list[str] = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:
@@ -422,7 +427,8 @@ class WordExtractor(BaseExtractor):
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
for element in doc.element.body:
body_elements = cast(Iterable[object], getattr(doc.element, "body", []))
for element in body_elements:
if hasattr(element, "tag"):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)

View File

@@ -3,7 +3,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal, TypedDict
from flask import current_app
from sqlalchemy import delete, func, select
@@ -20,6 +20,16 @@ from .processor.paragraph_index_processor import ParagraphIndexProcessor
logger = logging.getLogger(__name__)
class IndexAndCleanResult(TypedDict):
dataset_id: str
dataset_name: str
batch: str
document_id: str
document_name: str
created_at: float
display_status: Literal["completed"]
class IndexProcessor:
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
@@ -51,9 +61,9 @@ class IndexProcessor:
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: Any,
batch: str,
summary_index_setting: SummaryIndexSettingDict | None = None,
):
) -> IndexAndCleanResult:
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:

View File

@@ -122,6 +122,7 @@ class BaseIndexProcessor(ABC):
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -147,7 +148,7 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter # type: ignore
return character_splitter
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
@@ -158,7 +159,7 @@ class BaseIndexProcessor(ABC):
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list = []
upload_file_id_list: list[str] = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count

View File

@@ -10,7 +10,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit."""
def __init__(self, index_type: str | None):
def __init__(self, index_type: str | None) -> None:
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
@@ -19,11 +19,12 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")
match self._index_type:
case IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
case IndexStructureType.QA_INDEX:
return QAIndexProcessor()
case IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
case _:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -30,7 +30,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
) -> TS:
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

View File

@@ -8,7 +8,7 @@ class BaseStorage(ABC):
"""Interface for file storage."""
@abstractmethod
def save(self, filename: str, data: bytes):
def save(self, filename: str, data: bytes) -> None:
raise NotImplementedError
@abstractmethod
@@ -16,7 +16,7 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def load_stream(self, filename: str) -> Generator:
def load_stream(self, filename: str) -> Generator[bytes, None, None]:
raise NotImplementedError
@abstractmethod
@@ -28,10 +28,10 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def delete(self, filename: str):
def delete(self, filename: str) -> None:
raise NotImplementedError
def scan(self, path, files=True, directories=False) -> list[str]:
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
"""
Scan files and directories in the given path.
This method is implemented only in some storage backends.

View File

@@ -30,8 +30,8 @@ class GitHubEmailRecord(TypedDict, total=False):
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
name: NotRequired[str | None]
email: NotRequired[str | None]
class GoogleRawUserInfo(TypedDict):
@@ -138,7 +138,7 @@ class GitHubOAuth(OAuth):
email = payload.get("email")
if not email:
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
return OAuthUserInfo(id=str(payload["id"]), name=payload.get("name") or "", email=email)
class GoogleOAuth(OAuth):

View File

@@ -20,7 +20,7 @@ else:
class NotionPageSummary(TypedDict):
page_id: str
page_name: str
page_icon: dict[str, str] | None
page_icon: dict[str, object] | None
parent_id: str
type: Literal["page", "database"]

View File

@@ -43,58 +43,6 @@ core/ops/tencent_trace/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/rag/retrieval/router/multi_dataset_function_call_router.py
core/rag/summary_index/summary_index.py
core/repositories/sqlalchemy_workflow_execution_repository.py
@@ -140,27 +88,10 @@ dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
libs/gmpy2_pkcs10aep_cipher.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
services/conversation_service.py
services/dataset_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
@@ -188,3 +119,75 @@ tasks/disable_segment_from_index_task.py
tasks/enable_segment_to_index_task.py
tasks/remove_document_from_index_task.py
tasks/workflow_execution_tasks.py
# no need to fix for now: storage adapters
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
# no need to fix for now: keyword adapters
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
# no need to fix for now: vector db adapters
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
# no need to fix for now: extractors
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
# no need to fix for now: index processors
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py

View File

@@ -1,10 +1,24 @@
from abc import ABC, abstractmethod
from typing_extensions import TypedDict
class ApiKeyAuthConfig(TypedDict, total=False):
api_key: str
base_url: str
class ApiKeyAuthCredentials(TypedDict):
auth_type: object
config: ApiKeyAuthConfig
class ApiKeyAuthBase(ABC):
def __init__(self, credentials: dict):
credentials: ApiKeyAuthCredentials
def __init__(self, credentials: ApiKeyAuthCredentials) -> None:
self.credentials = credentials
@abstractmethod
def validate_credentials(self):
def validate_credentials(self) -> bool:
raise NotImplementedError

View File

@@ -1,18 +1,20 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.auth_type import AuthProvider, AuthType
class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict):
auth: ApiKeyAuthBase
def __init__(self, provider: AuthProvider, credentials: ApiKeyAuthCredentials) -> None:
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)
def validate_credentials(self):
def validate_credentials(self) -> bool:
return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
match ApiKeyAuthFactory._normalize_provider(provider):
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
@@ -27,3 +29,13 @@ class ApiKeyAuthFactory:
return JinaAuth
case _:
raise ValueError("Invalid provider")
@staticmethod
def _normalize_provider(provider: AuthProvider) -> AuthType | str:
if isinstance(provider, AuthType):
return provider
try:
return AuthType(provider)
except ValueError:
return provider

View File

@@ -1,40 +1,63 @@
import json
from typing import cast
from pydantic import TypeAdapter
from sqlalchemy import select
from typing_extensions import TypedDict
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_base import ApiKeyAuthCredentials
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthCreateArgs(TypedDict):
category: str
provider: str
credentials: ApiKeyAuthCredentials
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(dict[str, object])
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
def get_provider_auth_list(tenant_id: str) -> list[DataSourceApiKeyAuthBinding]:
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings
return list(data_source_api_key_bindings)
@staticmethod
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
def create_provider_auth(tenant_id: str, args: dict[str, object]) -> None:
validated_args = ApiKeyAuthService.validate_api_key_auth_args(args)
raw_credentials = ApiKeyAuthService._get_credentials_dict(args)
auth_result = ApiKeyAuthFactory(
validated_args["provider"], validated_args["credentials"]
).validate_credentials()
if auth_result:
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
api_key_value = validated_args["credentials"]["config"].get("api_key")
if api_key_value is None:
raise KeyError("api_key")
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
raw_config = ApiKeyAuthService._get_config_dict(raw_credentials)
raw_config["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
tenant_id=tenant_id,
category=validated_args["category"],
provider=validated_args["provider"],
)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
data_source_api_key_binding.credentials = json.dumps(raw_credentials, ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
def get_auth_credentials(tenant_id: str, category: str, provider: str) -> dict[str, object] | None:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
@@ -50,10 +73,10 @@ class ApiKeyAuthService:
if not data_source_api_key_bindings.credentials:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return credentials
return AUTH_CREDENTIALS_ADAPTER.validate_python(credentials)
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
@@ -63,8 +86,10 @@ class ApiKeyAuthService:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):
@staticmethod
def validate_api_key_auth_args(args: dict[str, object] | None) -> ApiKeyAuthCreateArgs:
if args is None:
raise TypeError("argument of type 'NoneType' is not iterable")
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
@@ -75,3 +100,18 @@ class ApiKeyAuthService:
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")
return AUTH_CREATE_ARGS_ADAPTER.validate_python(args)
@staticmethod
def _get_credentials_dict(args: dict[str, object]) -> dict[str, object]:
credentials = args["credentials"]
if not isinstance(credentials, dict):
raise ValueError("credentials must be a dictionary")
return cast(dict[str, object], credentials)
@staticmethod
def _get_config_dict(credentials: dict[str, object]) -> dict[str, object]:
config = credentials["config"]
if not isinstance(config, dict):
raise TypeError(f"credentials['config'] must be a dictionary, got {type(config).__name__}")
return cast(dict[str, object], config)

View File

@@ -5,3 +5,6 @@ class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
WATERCRAWL = "watercrawl"
JINA = "jinareader"
AuthProvider = AuthType | str

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: dict):
def __init__(self, credentials: ApiKeyAuthCredentials):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@@ -200,6 +200,29 @@ class TestExtractProcessorFileRouting:
with pytest.raises(AssertionError, match="upload_file is required"):
ExtractProcessor.extract(setting)
@pytest.mark.parametrize(
("extension", "etl_type", "expected_extractor"),
[
(".pdf", "Unstructured", "PdfExtractor"),
(".docx", "Unstructured", "WordExtractor"),
(".pdf", "SelfHosted", "PdfExtractor"),
(".docx", "SelfHosted", "WordExtractor"),
],
)
def test_extract_allows_url_file_paths_without_upload_context(
self, monkeypatch, extension: str, etl_type: str, expected_extractor: str
):
factory = _patch_all_extractors(monkeypatch)
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type)
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
docs = ExtractProcessor.extract(setting, file_path=f"/tmp/example{extension}")
assert docs[0].page_content == f"extracted-by-{expected_extractor}"
assert factory.calls[-1][0] == expected_extractor
assert factory.calls[-1][1] == (f"/tmp/example{extension}", None, None)
class TestExtractProcessorDatasourceRouting:
def test_extract_routes_notion_datasource(self, monkeypatch):

View File

@@ -184,3 +184,21 @@ def test_extract_images_failures(mock_dependencies):
assert len(saves) == 1
assert saves[0][1] == jpeg_bytes
assert db_stub.session.committed is True
def test_extract_images_skips_persistence_without_upload_context(mock_dependencies):
mock_page = MagicMock()
mock_image_obj = MagicMock()
mock_image_obj.extract.side_effect = lambda buf, fb_format=None: buf.write(b"\xff\xd8\xff image")
mock_page.get_objects.return_value = [mock_image_obj]
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id=None, user_id=None)
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
assert result == ""
assert mock_dependencies.saves == []
assert mock_dependencies.db.session.added == []
assert mock_dependencies.db.session.committed is False

View File

@@ -179,6 +179,27 @@ def test_extract_images_from_docx(monkeypatch):
assert db_stub.session.committed is True
def test_extract_images_from_docx_skips_persistence_without_upload_context(monkeypatch):
saves: list[tuple[str, bytes]] = []
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: saves.append((key, data))))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext}))
extractor = object.__new__(WordExtractor)
extractor.tenant_id = None
extractor.user_id = None
image_map = extractor._extract_images_from_docx(doc)
assert image_map == {}
assert saves == []
def test_extract_images_from_docx_uses_internal_files_url():
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
# Test the URL generation logic directly

View File

@@ -13,8 +13,11 @@ class TestApiKeyAuthFactory:
("provider", "auth_class_path"),
[
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.FIRECRAWL.value, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.WATERCRAWL.value, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
(AuthType.JINA.value, "services.auth.jina.jina.JinaAuth"),
],
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):

View File

@@ -1,4 +1,5 @@
import json
from copy import deepcopy
from unittest.mock import Mock, patch
import pytest
@@ -68,7 +69,16 @@ class TestApiKeyAuthService:
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
mock_factory.return_value = mock_auth_instance
captured_provider = None
captured_credentials = None
def factory_side_effect(provider, credentials):
nonlocal captured_provider, captured_credentials
captured_provider = provider
captured_credentials = deepcopy(credentials)
return mock_auth_instance
mock_factory.side_effect = factory_side_effect
# Mock encryption
encrypted_key = "encrypted_test_key_123"
@@ -77,11 +87,14 @@ class TestApiKeyAuthService:
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
expected_credentials = deepcopy(self.mock_credentials)
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
assert mock_factory.call_count == 1
assert captured_provider == self.provider
assert captured_credentials == expected_credentials
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls

View File

@@ -4,11 +4,13 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock, patch
import pytest
import services.vector_service as vector_service_module
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from services.vector_service import VectorService
@@ -702,3 +704,105 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch:
logger_mock.exception.assert_called_once()
db_mock.session.rollback.assert_called_once()
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
def test_vector_create_normalizes_child_documents(mock_get_embeddings: Mock, mock_init_vector: Mock) -> None:
dataset = _make_dataset()
documents = [ChildDocument(page_content="Child content", metadata={"doc_id": "child-1", "dataset_id": "dataset-1"})]
mock_embeddings = Mock()
mock_embeddings.embed_documents.return_value = [[0.1] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create(texts=documents)
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.page_content == "Child content"
assert normalized_document.metadata["doc_id"] == "child-1"
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.storage")
@patch("core.rag.datasource.vdb.vector_factory.db.session")
def test_vector_create_multimodal_normalizes_attachment_documents(
mock_session: Mock,
mock_storage: Mock,
mock_get_embeddings: Mock,
mock_init_vector: Mock,
) -> None:
dataset = _make_dataset()
file_document = AttachmentDocument(
page_content="Attachment content",
provider="custom-provider",
metadata={"doc_id": "file-1", "doc_type": "image/png"},
)
upload_file = Mock(id="file-1", key="upload-key")
mock_scalars = Mock()
mock_scalars.all.return_value = [upload_file]
mock_session.scalars.return_value = mock_scalars
mock_storage.load_once.return_value = b"binary-content"
mock_embeddings = Mock()
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create_multimodal(file_documents=[file_document])
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.provider == "custom-provider"
assert normalized_document.metadata["doc_id"] == "file-1"
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.storage")
@patch("core.rag.datasource.vdb.vector_factory.db.session")
def test_vector_create_multimodal_falls_back_to_dify_provider_when_attachment_provider_is_none(
mock_session: Mock,
mock_storage: Mock,
mock_get_embeddings: Mock,
mock_init_vector: Mock,
) -> None:
dataset = _make_dataset()
file_document = AttachmentDocument(
page_content="Attachment content",
provider=None,
metadata={"doc_id": "file-1", "doc_type": "image/png"},
)
upload_file = Mock(id="file-1", key="upload-key")
mock_scalars = Mock()
mock_scalars.all.return_value = [upload_file]
mock_session.scalars.return_value = mock_scalars
mock_storage.load_once.return_value = b"binary-content"
mock_embeddings = Mock()
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create_multimodal(file_documents=[file_document])
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.provider == "dify"