mirror of
https://github.com/langgenius/dify.git
synced 2026-03-07 00:05:12 +00:00
Compare commits
4 Commits
yanli/docx
...
refactor/t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ac461d882 | ||
|
|
fa763216d0 | ||
|
|
d546210040 | ||
|
|
4e0a7a7f9e |
@@ -34,6 +34,7 @@ from .dataset import (
|
||||
metadata,
|
||||
segment,
|
||||
)
|
||||
from .dataset.rag_pipeline import rag_pipeline_workflow
|
||||
from .end_user import end_user
|
||||
from .workspace import models
|
||||
|
||||
@@ -53,6 +54,7 @@ __all__ = [
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"rag_pipeline_workflow",
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +10,7 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -41,7 +40,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@@ -76,7 +75,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -131,7 +130,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -232,12 +231,4 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
return serialize_upload_file(upload_file), 201
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
|
||||
}
|
||||
@@ -217,6 +217,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
@@ -253,12 +255,18 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
|
||||
@@ -114,6 +114,7 @@ class PdfExtractor(BaseExtractor):
|
||||
"""
|
||||
image_content = []
|
||||
upload_files = []
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
try:
|
||||
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
|
||||
@@ -163,7 +164,7 @@ class PdfExtractor(BaseExtractor):
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
upload_files.append(upload_file)
|
||||
image_content.append(f"")
|
||||
image_content.append(f"")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract image from PDF: %s", e)
|
||||
continue
|
||||
|
||||
@@ -87,6 +87,7 @@ class WordExtractor(BaseExtractor):
|
||||
def _extract_images_from_docx(self, doc):
|
||||
image_count = 0
|
||||
image_map = {}
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
|
||||
for r_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.target_ref:
|
||||
@@ -125,7 +126,7 @@ class WordExtractor(BaseExtractor):
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
image_map[r_id] = f""
|
||||
image_map[r_id] = f""
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
@@ -153,7 +154,7 @@ class WordExtractor(BaseExtractor):
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
db.session.add(upload_file)
|
||||
image_map[rel.target_part] = f""
|
||||
image_map[rel.target_part] = f""
|
||||
db.session.commit()
|
||||
return image_map
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from operator import itemgetter
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -805,53 +804,41 @@ class DocumentSegment(Base):
|
||||
def sign_content(self) -> str:
|
||||
return self.get_sign_content()
|
||||
|
||||
@staticmethod
|
||||
def _build_signed_query_params(*, sign_target: str, upload_file_id: str) -> str:
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"{sign_target}|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
return f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
def _get_accessible_upload_file_ids(self, upload_file_ids: set[str]) -> set[str]:
|
||||
if not upload_file_ids:
|
||||
return set()
|
||||
|
||||
matched_upload_file_ids = db.session.scalars(
|
||||
select(UploadFile.id).where(
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.id.in_(list(upload_file_ids)),
|
||||
)
|
||||
).all()
|
||||
return {str(upload_file_id) for upload_file_id in matched_upload_file_ids}
|
||||
|
||||
def get_sign_content(self) -> str:
|
||||
signed_urls: list[tuple[int, int, str]] = []
|
||||
text = self.content
|
||||
|
||||
upload_file_preview_patterns = {
|
||||
"image-preview": r"(?:https?://[^\s\)\"\']+)?/files/([a-f0-9\-]+)/image-preview(?:\?[^\s\)\"\']*)?",
|
||||
"file-preview": r"(?:https?://[^\s\)\"\']+)?/files/([a-f0-9\-]+)/file-preview(?:\?[^\s\)\"\']*)?",
|
||||
}
|
||||
upload_file_matches: list[tuple[re.Match[str], str, str]] = []
|
||||
upload_file_ids: set[str] = set()
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
for preview_type, pattern in upload_file_preview_patterns.items():
|
||||
for match in re.finditer(pattern, text):
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_matches.append((match, preview_type, upload_file_id))
|
||||
upload_file_ids.add(upload_file_id)
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
base_url = f"/files/{upload_file_id}/image-preview"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
accessible_upload_file_ids = self._get_accessible_upload_file_ids(upload_file_ids)
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
for match, preview_type, upload_file_id in upload_file_matches:
|
||||
if upload_file_id not in accessible_upload_file_ids:
|
||||
continue
|
||||
|
||||
params = self._build_signed_query_params(sign_target=preview_type, upload_file_id=upload_file_id)
|
||||
base_url = f"/files/{upload_file_id}/{preview_type}"
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
base_url = f"/files/{upload_file_id}/file-preview"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
@@ -862,13 +849,19 @@ class DocumentSegment(Base):
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
file_extension = match.group(2)
|
||||
params = self._build_signed_query_params(sign_target="file-preview", upload_file_id=upload_file_id)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
base_url = f"/files/tools/{upload_file_id}.{file_extension}"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
# Reconstruct the text with signed URLs
|
||||
signed_urls.sort(key=itemgetter(0))
|
||||
offset = 0
|
||||
for start, end, signed_url in signed_urls:
|
||||
text = text[: start + offset] + signed_url + text[end + offset :]
|
||||
|
||||
@@ -1329,10 +1329,24 @@ class RagPipelineService:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@@ -1413,10 +1427,24 @@ class RagPipelineService:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset: Dataset | None = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
pipeline: Pipeline | None = (
|
||||
db.session.query(Pipeline)
|
||||
.where(
|
||||
Pipeline.id == dataset.pipeline_id,
|
||||
Pipeline.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
||||
@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_index_node_ids = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
index_type = dataset.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||
session.execute(document_delete_stmt)
|
||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||
session.execute(document_delete_stmt)
|
||||
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||
dataset_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
tenant_id = None
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
@@ -35,94 +36,120 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
return
|
||||
|
||||
if document.indexing_status == "parsing":
|
||||
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
if document.data_source_type == "notion_import":
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
if document.data_source_type != "notion_import":
|
||||
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
|
||||
return
|
||||
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if (
|
||||
not data_source_info
|
||||
or "notion_page_id" not in data_source_info
|
||||
or "notion_workspace_id" not in data_source_info
|
||||
):
|
||||
raise ValueError("no notion page found")
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
document.tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
workspace_id = data_source_info["notion_workspace_id"]
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
credential_id = data_source_info.get("credential_id")
|
||||
tenant_id = document.tenant_id
|
||||
index_type = document.doc_form
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
if not credential:
|
||||
logger.error(
|
||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||
document_id,
|
||||
tenant_id,
|
||||
credential_id,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
return
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=document.tenant_id,
|
||||
)
|
||||
loader = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
last_edited_time = loader.get_notion_last_edited_time()
|
||||
if last_edited_time == page_edited_time:
|
||||
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
|
||||
return
|
||||
|
||||
# check the page is updated
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if not document:
|
||||
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||
return
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = data_source_info
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
||||
document_id, end_at - start_at
|
||||
),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
||||
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
with session_factory.create_session() as session:
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception as e:
|
||||
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).filter_by(id=document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
|
||||
@@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task(document_ids, dataset.id)
|
||||
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(document_ids))
|
||||
@@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
|
||||
== 0
|
||||
)
|
||||
|
||||
# Verify index processor was called for each document
|
||||
# Verify index processor was called
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
assert mock_processor.clean.call_count == len(document_ids)
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# This test successfully verifies:
|
||||
# 1. Document records are properly deleted from the database
|
||||
@@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
|
||||
non_existent_dataset_id = str(uuid.uuid4())
|
||||
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
# Execute cleanup task with non-existent dataset
|
||||
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
||||
# Execute cleanup task with non-existent dataset - expect exception
|
||||
with pytest.raises(Exception, match="Document has no dataset"):
|
||||
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
||||
|
||||
# Verify that the index processor was not called
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_not_called()
|
||||
# Verify that the index processor factory was not used
|
||||
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
|
||||
|
||||
def test_clean_notion_document_task_empty_document_list(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
@@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task with empty document list
|
||||
clean_notion_document_task([], dataset.id)
|
||||
|
||||
# Verify that the index processor was not called
|
||||
# Verify that the index processor was called once with empty node list
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_not_called()
|
||||
assert mock_processor.clean.call_count == 1
|
||||
args, kwargs = mock_processor.clean.call_args
|
||||
# args: (dataset, total_index_node_ids)
|
||||
assert isinstance(args[0], Dataset)
|
||||
assert args[1] == []
|
||||
|
||||
def test_clean_notion_document_task_with_different_index_types(
|
||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||
@@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
|
||||
# Note: This test successfully verifies cleanup with different document types.
|
||||
# The task properly handles various index types and document configurations.
|
||||
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == document.id)
|
||||
@@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
== 0
|
||||
@@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
clean_notion_document_task(documents_to_clean, dataset.id)
|
||||
|
||||
# Verify only specified documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
|
||||
# Verify only specified documents' segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id.in_(documents_to_clean))
|
||||
@@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock index processor to raise an exception
|
||||
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
|
||||
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
||||
|
||||
# Execute cleanup task - it should handle the exception gracefully
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
# Execute cleanup task - current implementation propagates the exception
|
||||
with pytest.raises(Exception, match="Index processor error"):
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Note: This test demonstrates the task's error handling capability.
|
||||
# Even with external service errors, the database operations complete successfully.
|
||||
@@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
|
||||
all_document_ids = [doc.id for doc in documents]
|
||||
clean_notion_document_task(all_document_ids, dataset.id)
|
||||
|
||||
# Verify all documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
||||
# Verify all segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
== 0
|
||||
@@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
|
||||
|
||||
clean_notion_document_task([target_document.id], target_dataset.id)
|
||||
|
||||
# Verify only documents from target dataset are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
|
||||
# Verify only documents' segments from target dataset are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == target_document.id)
|
||||
@@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
|
||||
all_document_ids = [doc.id for doc in documents]
|
||||
clean_notion_document_task(all_document_ids, dataset.id)
|
||||
|
||||
# Verify all documents and segments are deleted regardless of status
|
||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
||||
# Verify all segments are deleted regardless of status
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||
== 0
|
||||
@@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
|
||||
# Execute cleanup task
|
||||
clean_notion_document_task([document.id], dataset.id)
|
||||
|
||||
# Verify documents and segments are deleted
|
||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
||||
# Verify segments are deleted
|
||||
assert (
|
||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||
== 0
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Unit tests for Service API knowledge pipeline file-upload serialization.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FakeUploadFile:
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
created_by: str
|
||||
created_at: datetime | None
|
||||
|
||||
|
||||
def _load_serialize_upload_file():
|
||||
api_dir = Path(__file__).resolve().parents[5]
|
||||
serializers_path = api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "serializers.py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location("rag_pipeline_serializers", serializers_path)
|
||||
assert spec
|
||||
assert spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore[attr-defined]
|
||||
return module.serialize_upload_file
|
||||
|
||||
|
||||
def test_file_upload_created_at_is_isoformat_string():
|
||||
serialize_upload_file = _load_serialize_upload_file()
|
||||
|
||||
created_at = datetime(2026, 2, 8, 12, 0, 0, tzinfo=UTC)
|
||||
upload_file = FakeUploadFile()
|
||||
upload_file.id = "file-1"
|
||||
upload_file.name = "test.pdf"
|
||||
upload_file.size = 123
|
||||
upload_file.extension = "pdf"
|
||||
upload_file.mime_type = "application/pdf"
|
||||
upload_file.created_by = "account-1"
|
||||
upload_file.created_at = created_at
|
||||
|
||||
result = serialize_upload_file(upload_file)
|
||||
assert result["created_at"] == created_at.isoformat()
|
||||
|
||||
|
||||
def test_file_upload_created_at_none_serializes_to_null():
|
||||
serialize_upload_file = _load_serialize_upload_file()
|
||||
|
||||
upload_file = FakeUploadFile()
|
||||
upload_file.id = "file-1"
|
||||
upload_file.name = "test.pdf"
|
||||
upload_file.size = 123
|
||||
upload_file.extension = "pdf"
|
||||
upload_file.mime_type = "application/pdf"
|
||||
upload_file.created_by = "account-1"
|
||||
upload_file.created_at = None
|
||||
|
||||
result = serialize_upload_file(upload_file)
|
||||
assert result["created_at"] is None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Unit tests for Service API knowledge pipeline route registration.
|
||||
"""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_rag_pipeline_routes_registered():
|
||||
api_dir = Path(__file__).resolve().parents[5]
|
||||
|
||||
service_api_init = api_dir / "controllers" / "service_api" / "__init__.py"
|
||||
rag_pipeline_workflow = (
|
||||
api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py"
|
||||
)
|
||||
|
||||
assert service_api_init.exists()
|
||||
assert rag_pipeline_workflow.exists()
|
||||
|
||||
init_tree = ast.parse(service_api_init.read_text(encoding="utf-8"))
|
||||
import_found = False
|
||||
for node in ast.walk(init_tree):
|
||||
if not isinstance(node, ast.ImportFrom):
|
||||
continue
|
||||
if node.module != "dataset.rag_pipeline" or node.level != 1:
|
||||
continue
|
||||
if any(alias.name == "rag_pipeline_workflow" for alias in node.names):
|
||||
import_found = True
|
||||
break
|
||||
assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py"
|
||||
|
||||
workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8"))
|
||||
route_paths: set[str] = set()
|
||||
|
||||
for node in ast.walk(workflow_tree):
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
for decorator in node.decorator_list:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
continue
|
||||
if not isinstance(decorator.func, ast.Attribute):
|
||||
continue
|
||||
if decorator.func.attr != "route":
|
||||
continue
|
||||
if not decorator.args:
|
||||
continue
|
||||
first_arg = decorator.args[0]
|
||||
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
|
||||
route_paths.add(first_arg.value)
|
||||
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/datasource-plugins" in route_paths
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run" in route_paths
|
||||
assert "/datasets/<uuid:dataset_id>/pipeline/run" in route_paths
|
||||
assert "/datasets/pipeline/file-upload" in route_paths
|
||||
@@ -87,7 +87,7 @@ def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, exp
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
assert f"" in result
|
||||
assert f"" in result
|
||||
assert len(saves) == 1
|
||||
assert saves[0][1] == image_bytes
|
||||
assert len(db_stub.session.added) == 1
|
||||
@@ -180,7 +180,7 @@ def test_extract_images_failures(mock_dependencies):
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
# Should have one success
|
||||
assert "" in result
|
||||
assert "" in result
|
||||
assert len(saves) == 1
|
||||
assert saves[0][1] == jpeg_bytes
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
@@ -121,7 +121,8 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
db_stub = SimpleNamespace(session=DummySession())
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
|
||||
# Patch config value used in this code path
|
||||
# Patch config values used for URL composition and storage type
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
# Patch UploadFile to avoid real DB models
|
||||
@@ -163,7 +164,7 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
|
||||
# Returned map should contain entries for external (keyed by rId) and internal (keyed by target_part)
|
||||
assert set(image_map.keys()) == {"rId1", internal_part}
|
||||
assert all(v.startswith(" and v.endswith("/file-preview)") for v in image_map.values())
|
||||
assert all(v.startswith(" and v.endswith("/file-preview)") for v in image_map.values())
|
||||
|
||||
# Storage should receive both payloads
|
||||
payloads = {data for _, data in saves}
|
||||
@@ -175,6 +176,39 @@ def test_extract_images_from_docx(monkeypatch):
|
||||
assert db_stub.session.committed is True
|
||||
|
||||
|
||||
def test_extract_images_from_docx_uses_internal_files_url():
|
||||
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
|
||||
# Test the URL generation logic directly
|
||||
from configs import dify_config
|
||||
|
||||
# Mock the configuration values
|
||||
original_files_url = getattr(dify_config, "FILES_URL", None)
|
||||
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
|
||||
|
||||
try:
|
||||
# Set both URLs - INTERNAL should take precedence
|
||||
dify_config.FILES_URL = "http://external.example.com"
|
||||
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
|
||||
|
||||
# Test the URL generation logic (same as in word_extractor.py)
|
||||
upload_file_id = "test_file_id"
|
||||
|
||||
# This is the pattern we fixed in the word extractor
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
|
||||
|
||||
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
|
||||
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
|
||||
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
if original_files_url is not None:
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
|
||||
def test_extract_hyperlinks(monkeypatch):
|
||||
# Mock db and storage to avoid issues during image extraction (even if no images are present)
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
|
||||
)
|
||||
|
||||
|
||||
@settings(max_examples=50)
|
||||
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||
@given(_scalar_value())
|
||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
seg = variable_factory.build_segment(value)
|
||||
@@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||
assert seg.value == value
|
||||
|
||||
|
||||
@settings(max_examples=50)
|
||||
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||
@given(values=st.lists(_scalar_value(), max_size=20))
|
||||
def test_build_segment_and_extract_values_for_array_types(values):
|
||||
seg = variable_factory.build_segment(values)
|
||||
|
||||
@@ -15,7 +15,6 @@ from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import models.dataset as dataset_module
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
ChildChunk,
|
||||
@@ -490,15 +489,6 @@ class TestDocumentModelRelationships:
|
||||
class TestDocumentSegmentIndexing:
|
||||
"""Test suite for DocumentSegment model indexing and operations."""
|
||||
|
||||
@staticmethod
|
||||
def _mock_scalars_result(upload_file_ids: list[str]):
|
||||
class _ScalarsResult:
|
||||
@staticmethod
|
||||
def all() -> list[str]:
|
||||
return upload_file_ids
|
||||
|
||||
return _ScalarsResult()
|
||||
|
||||
def test_document_segment_creation_with_required_fields(self):
|
||||
"""Test creating a document segment with all required fields."""
|
||||
# Arrange
|
||||
@@ -557,139 +547,6 @@ class TestDocumentSegmentIndexing:
|
||||
assert segment.index_node_hash == index_node_hash
|
||||
assert segment.keywords == keywords
|
||||
|
||||
def test_document_segment_sign_content_strips_absolute_files_host(self):
|
||||
"""Test that sign_content strips scheme/host from absolute /files URLs and returns a signed relative URL."""
|
||||
# Arrange
|
||||
upload_file_id = "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
segment = DocumentSegment(
|
||||
tenant_id=str(uuid4()),
|
||||
dataset_id=str(uuid4()),
|
||||
document_id=str(uuid4()),
|
||||
position=1,
|
||||
content=f"",
|
||||
word_count=1,
|
||||
tokens=1,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
mock_scalars_result = self._mock_scalars_result([upload_file_id])
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(dataset_module.dify_config, "SECRET_KEY", "secret", create=True),
|
||||
patch("models.dataset.db.session.scalars", return_value=mock_scalars_result),
|
||||
patch("models.dataset.time.time", return_value=1700000000),
|
||||
patch("models.dataset.os.urandom", return_value=b"\x00" * 16),
|
||||
):
|
||||
signed = segment.get_sign_content()
|
||||
|
||||
# Assert
|
||||
assert "internal.docker:5001" not in signed
|
||||
assert f"/files/{upload_file_id}/file-preview?timestamp=" in signed
|
||||
assert "&nonce=" in signed
|
||||
assert "&sign=" in signed
|
||||
|
||||
def test_document_segment_sign_content_strips_absolute_files_host_for_image_preview(self):
|
||||
"""Test that sign_content strips scheme/host from absolute image-preview URLs."""
|
||||
# Arrange
|
||||
upload_file_id = "e2a4f7b1-1234-5678-9abc-def012345678"
|
||||
segment = DocumentSegment(
|
||||
tenant_id=str(uuid4()),
|
||||
dataset_id=str(uuid4()),
|
||||
document_id=str(uuid4()),
|
||||
position=1,
|
||||
content=f"",
|
||||
word_count=1,
|
||||
tokens=1,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
mock_scalars_result = self._mock_scalars_result([upload_file_id])
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(dataset_module.dify_config, "SECRET_KEY", "secret", create=True),
|
||||
patch("models.dataset.db.session.scalars", return_value=mock_scalars_result),
|
||||
patch("models.dataset.time.time", return_value=1700000000),
|
||||
patch("models.dataset.os.urandom", return_value=b"\x00" * 16),
|
||||
):
|
||||
signed = segment.get_sign_content()
|
||||
|
||||
# Assert
|
||||
assert "internal.docker:5001" not in signed
|
||||
assert f"/files/{upload_file_id}/image-preview?timestamp=" in signed
|
||||
assert "&nonce=" in signed
|
||||
assert "&sign=" in signed
|
||||
|
||||
def test_document_segment_sign_content_skips_upload_files_outside_tenant(self):
|
||||
"""Test that sign_content only signs upload files belonging to the segment tenant."""
|
||||
# Arrange
|
||||
allowed_upload_file_id = "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
denied_upload_file_id = "f8f35fca-568f-4626-adf0-4f30de96aa32"
|
||||
segment = DocumentSegment(
|
||||
tenant_id=str(uuid4()),
|
||||
dataset_id=str(uuid4()),
|
||||
document_id=str(uuid4()),
|
||||
position=1,
|
||||
content=(
|
||||
f"allowed:  "
|
||||
f"denied: "
|
||||
),
|
||||
word_count=1,
|
||||
tokens=1,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
mock_scalars_result = self._mock_scalars_result([allowed_upload_file_id])
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(dataset_module.dify_config, "SECRET_KEY", "secret", create=True),
|
||||
patch("models.dataset.db.session.scalars", return_value=mock_scalars_result),
|
||||
patch("models.dataset.time.time", return_value=1700000000),
|
||||
patch("models.dataset.os.urandom", return_value=b"\x00" * 16),
|
||||
):
|
||||
signed = segment.get_sign_content()
|
||||
|
||||
# Assert
|
||||
assert f"/files/{allowed_upload_file_id}/file-preview?timestamp=" in signed
|
||||
assert f"/files/{denied_upload_file_id}/file-preview?timestamp=" not in signed
|
||||
assert f"/files/{denied_upload_file_id}/file-preview)" in signed
|
||||
|
||||
def test_document_segment_sign_content_handles_mixed_preview_order(self):
|
||||
"""Test that sign_content preserves content when file-preview appears before image-preview."""
|
||||
# Arrange
|
||||
file_preview_id = "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
image_preview_id = "e2a4f7b1-1234-5678-9abc-def012345678"
|
||||
segment = DocumentSegment(
|
||||
tenant_id=str(uuid4()),
|
||||
dataset_id=str(uuid4()),
|
||||
document_id=str(uuid4()),
|
||||
position=1,
|
||||
content=(
|
||||
f"file-first:  "
|
||||
f"then-image: "
|
||||
),
|
||||
word_count=1,
|
||||
tokens=1,
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
mock_scalars_result = self._mock_scalars_result([file_preview_id, image_preview_id])
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch.object(dataset_module.dify_config, "SECRET_KEY", "secret", create=True),
|
||||
patch("models.dataset.db.session.scalars", return_value=mock_scalars_result),
|
||||
patch("models.dataset.time.time", return_value=1700000000),
|
||||
patch("models.dataset.os.urandom", return_value=b"\x00" * 16),
|
||||
):
|
||||
signed = segment.get_sign_content()
|
||||
|
||||
# Assert
|
||||
file_signed = f"/files/{file_preview_id}/file-preview?timestamp="
|
||||
image_signed = f"/files/{image_preview_id}/image-preview?timestamp="
|
||||
assert file_signed in signed
|
||||
assert image_signed in signed
|
||||
assert signed.index(file_signed) < signed.index(image_signed)
|
||||
assert signed.count("&sign=") == 2
|
||||
|
||||
def test_document_segment_with_answer_field(self):
|
||||
"""Test creating a document segment with answer field for QA model."""
|
||||
# Arrange
|
||||
|
||||
@@ -109,40 +109,87 @@ def mock_document_segments(document_id):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
"""Mock database session via session_factory.create_session().
|
||||
|
||||
After session split refactor, the code calls create_session() multiple times.
|
||||
This fixture creates shared query mocks so all sessions use the same
|
||||
query configuration, simulating database persistence across sessions.
|
||||
|
||||
The fixture automatically converts side_effect to cycle to prevent StopIteration.
|
||||
Tests configure mocks the same way as before, but behind the scenes the values
|
||||
are cycled infinitely for all sessions.
|
||||
"""
|
||||
from itertools import cycle
|
||||
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
sessions = []
|
||||
|
||||
# Mock session.begin() context manager to auto-commit on exit
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
# Shared query mocks - all sessions use these
|
||||
shared_query = MagicMock()
|
||||
shared_filter_by = MagicMock()
|
||||
shared_scalars_result = MagicMock()
|
||||
|
||||
def _begin_exit_side_effect(*args, **kwargs):
|
||||
# session.begin().__exit__() should commit if no exception
|
||||
if args[0] is None: # No exception
|
||||
session.commit()
|
||||
# Create custom first mock that auto-cycles side_effect
|
||||
class CyclicMock(MagicMock):
|
||||
def __setattr__(self, name, value):
|
||||
if name == "side_effect" and value is not None:
|
||||
# Convert list/tuple to infinite cycle
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = cycle(value)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
shared_query.where.return_value.first = CyclicMock()
|
||||
shared_filter_by.first = CyclicMock()
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
def _create_session():
|
||||
"""Create a new mock session for each create_session() call."""
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
# Mock session.begin() context manager
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
def _begin_exit_side_effect(exc_type, exc, tb):
|
||||
# commit on success
|
||||
if exc_type is None:
|
||||
session.commit()
|
||||
# return False to propagate exceptions
|
||||
return False
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(exc_type, exc, tb):
|
||||
session.close()
|
||||
return False
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# All sessions use the same shared query mocks
|
||||
session.query.return_value = shared_query
|
||||
shared_query.where.return_value = shared_query
|
||||
shared_query.filter_by.return_value = shared_filter_by
|
||||
session.scalars.return_value = shared_scalars_result
|
||||
|
||||
sessions.append(session)
|
||||
# Attach helpers on the first created session for assertions across all sessions
|
||||
if len(sessions) == 1:
|
||||
session.get_all_sessions = lambda: sessions
|
||||
session.any_close_called = lambda: any(s.close.called for s in sessions)
|
||||
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
|
||||
# Create first session and return it
|
||||
_create_session()
|
||||
yield sessions[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -201,8 +248,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_db_session.close.assert_called_once()
|
||||
# Assert - at least one session should have been closed
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
@@ -245,6 +292,7 @@ class TestDocumentIndexingSyncTask:
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
@@ -254,8 +302,8 @@ class TestDocumentIndexingSyncTask:
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
mock_db_session.commit.assert_called()
|
||||
mock_db_session.close.assert_called()
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
@@ -269,6 +317,7 @@ class TestDocumentIndexingSyncTask:
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
@@ -278,8 +327,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# Session should still be closed via context manager teardown
|
||||
assert mock_db_session.close.called
|
||||
# At least one session should have been closed via context manager teardown
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
@@ -296,7 +345,20 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
# Set exact sequence of returns across calls to `.first()`:
|
||||
# 1) document (initial fetch)
|
||||
# 2) dataset (pre-check)
|
||||
# 3) dataset (cleaning phase)
|
||||
# 4) document (pre-indexing update)
|
||||
# 5) document (indexing runner fetch)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
@@ -314,28 +376,40 @@ class TestDocumentIndexingSyncTask:
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
# Aggregate execute calls across all created sessions
|
||||
execute_sqls = []
|
||||
for s in mock_db_session.get_all_sessions():
|
||||
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
mock_db_session.close.assert_called_once()
|
||||
# Verify session operations (across any created session)
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_indexing_runner,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
||||
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
None,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
@@ -344,8 +418,8 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
# At least one session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
@@ -361,8 +435,14 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Make the cleaning step fail but not the segment fetch
|
||||
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
processor.clean.side_effect = Exception("Cleaning error")
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
@@ -371,7 +451,7 @@ class TestDocumentIndexingSyncTask:
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
mock_db_session.close.assert_called_once()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
@@ -388,7 +468,10 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
@@ -398,7 +481,7 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
mock_db_session.close.assert_called_once()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
@@ -415,7 +498,10 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
@@ -425,7 +511,7 @@ class TestDocumentIndexingSyncTask:
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
@@ -532,7 +618,14 @@ class TestDocumentIndexingSyncTask:
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ const DeprecationNotice: FC<DeprecationNoticeProps> = ({
|
||||
iconWrapperClassName,
|
||||
textClassName,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { t } = useTranslation('plugin')
|
||||
|
||||
const deprecatedReasonKey = useMemo(() => {
|
||||
if (!deprecatedReason)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { Resource } from 'i18next'
|
||||
import type { Locale } from '.'
|
||||
import type { NamespaceCamelCase, NamespaceKebabCase } from './resources'
|
||||
import type { Namespace, NamespaceInFileName } from './resources'
|
||||
import { kebabCase } from 'es-toolkit/string'
|
||||
import { createInstance } from 'i18next'
|
||||
import resourcesToBackend from 'i18next-resources-to-backend'
|
||||
@@ -14,7 +14,7 @@ export function createI18nextInstance(lng: Locale, resources: Resource) {
|
||||
.use(initReactI18next)
|
||||
.use(resourcesToBackend((
|
||||
language: Locale,
|
||||
namespace: NamespaceKebabCase | NamespaceCamelCase,
|
||||
namespace: NamespaceInFileName | Namespace,
|
||||
) => {
|
||||
const namespaceKebab = kebabCase(namespace)
|
||||
return import(`../i18n/${language}/${namespaceKebab}.json`)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
'use client'
|
||||
|
||||
import type { NamespaceCamelCase } from './resources'
|
||||
import type { Namespace } from './resources'
|
||||
import { useTranslation as useTranslationOriginal } from 'react-i18next'
|
||||
|
||||
export function useTranslation(ns?: NamespaceCamelCase) {
|
||||
export function useTranslation<T extends Namespace | undefined = undefined>(ns?: T) {
|
||||
return useTranslationOriginal(ns)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import type { NamespaceCamelCase } from './resources'
|
||||
import type { Namespace } from './resources'
|
||||
import { use } from 'react'
|
||||
import { getLocaleOnServer, getTranslation } from './server'
|
||||
|
||||
async function getI18nConfig(ns?: NamespaceCamelCase) {
|
||||
async function getI18nConfig<T extends Namespace | undefined = undefined>(ns?: T) {
|
||||
const lang = await getLocaleOnServer()
|
||||
return getTranslation(lang, ns)
|
||||
}
|
||||
|
||||
export function useTranslation(ns?: NamespaceCamelCase) {
|
||||
export function useTranslation<T extends Namespace | undefined = undefined>(ns?: T) {
|
||||
return use(getI18nConfig(ns))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { kebabCase } from 'es-toolkit/string'
|
||||
import { kebabCase } from 'string-ts'
|
||||
import { ObjectKeys } from '@/utils/object'
|
||||
import appAnnotation from '../i18n/en-US/app-annotation.json'
|
||||
import appApi from '../i18n/en-US/app-api.json'
|
||||
import appDebug from '../i18n/en-US/app-debug.json'
|
||||
@@ -64,19 +65,10 @@ const resources = {
|
||||
workflow,
|
||||
}
|
||||
|
||||
export type KebabCase<S extends string> = S extends `${infer T}${infer U}`
|
||||
? T extends Lowercase<T>
|
||||
? `${T}${KebabCase<U>}`
|
||||
: `-${Lowercase<T>}${KebabCase<U>}`
|
||||
: S
|
||||
|
||||
export type CamelCase<S extends string> = S extends `${infer T}-${infer U}`
|
||||
? `${T}${Capitalize<CamelCase<U>>}`
|
||||
: S
|
||||
|
||||
export type Resources = typeof resources
|
||||
export type NamespaceCamelCase = keyof Resources
|
||||
export type NamespaceKebabCase = KebabCase<NamespaceCamelCase>
|
||||
|
||||
export const namespacesCamelCase = Object.keys(resources) as NamespaceCamelCase[]
|
||||
export const namespacesKebabCase = namespacesCamelCase.map(ns => kebabCase(ns)) as NamespaceKebabCase[]
|
||||
export const namespaces = ObjectKeys(resources)
|
||||
export type Namespace = typeof namespaces[number]
|
||||
|
||||
export const namespacesInFileName = namespaces.map(ns => kebabCase(ns))
|
||||
export type NamespaceInFileName = typeof namespacesInFileName[number]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { i18n as I18nInstance, Resource, ResourceLanguage } from 'i18next'
|
||||
import type { Locale } from '.'
|
||||
import type { NamespaceCamelCase, NamespaceKebabCase } from './resources'
|
||||
import type { Namespace, NamespaceInFileName } from './resources'
|
||||
import { match } from '@formatjs/intl-localematcher'
|
||||
import { kebabCase } from 'es-toolkit/compat'
|
||||
import { camelCase } from 'es-toolkit/string'
|
||||
@@ -12,7 +12,7 @@ import { cache } from 'react'
|
||||
import { initReactI18next } from 'react-i18next/initReactI18next'
|
||||
import { serverOnlyContext } from '@/utils/server-only-context'
|
||||
import { i18n } from '.'
|
||||
import { namespacesKebabCase } from './resources'
|
||||
import { namespacesInFileName } from './resources'
|
||||
import { getInitOptions } from './settings'
|
||||
|
||||
const [getLocaleCache, setLocaleCache] = serverOnlyContext<Locale | null>(null)
|
||||
@@ -26,8 +26,8 @@ const getOrCreateI18next = async (lng: Locale) => {
|
||||
instance = createInstance()
|
||||
await instance
|
||||
.use(initReactI18next)
|
||||
.use(resourcesToBackend((language: Locale, namespace: NamespaceCamelCase | NamespaceKebabCase) => {
|
||||
const fileNamespace = kebabCase(namespace) as NamespaceKebabCase
|
||||
.use(resourcesToBackend((language: Locale, namespace: Namespace | NamespaceInFileName) => {
|
||||
const fileNamespace = kebabCase(namespace)
|
||||
return import(`../i18n/${language}/${fileNamespace}.json`)
|
||||
}))
|
||||
.init({
|
||||
@@ -38,7 +38,7 @@ const getOrCreateI18next = async (lng: Locale) => {
|
||||
return instance
|
||||
}
|
||||
|
||||
export async function getTranslation(lng: Locale, ns?: NamespaceCamelCase) {
|
||||
export async function getTranslation<T extends Namespace>(lng: Locale, ns?: T) {
|
||||
const i18nextInstance = await getOrCreateI18next(lng)
|
||||
|
||||
if (ns && !i18nextInstance.hasLoadedNamespace(ns))
|
||||
@@ -84,7 +84,7 @@ export const getResources = cache(async (lng: Locale): Promise<Resource> => {
|
||||
const messages = {} as ResourceLanguage
|
||||
|
||||
await Promise.all(
|
||||
(namespacesKebabCase).map(async (ns) => {
|
||||
(namespacesInFileName).map(async (ns) => {
|
||||
const mod = await import(`../i18n/${lng}/${ns}.json`)
|
||||
messages[camelCase(ns)] = mod.default
|
||||
}),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { InitOptions } from 'i18next'
|
||||
import { namespacesCamelCase } from './resources'
|
||||
import { namespaces } from './resources'
|
||||
|
||||
export function getInitOptions(): InitOptions {
|
||||
return {
|
||||
@@ -8,7 +8,7 @@ export function getInitOptions(): InitOptions {
|
||||
fallbackLng: 'en-US',
|
||||
partialBundledLanguages: true,
|
||||
keySeparator: false,
|
||||
ns: namespacesCamelCase,
|
||||
ns: namespaces,
|
||||
interpolation: {
|
||||
escapeValue: false,
|
||||
},
|
||||
|
||||
7
web/types/i18n.d.ts
vendored
7
web/types/i18n.d.ts
vendored
@@ -1,17 +1,16 @@
|
||||
import type { NamespaceCamelCase, Resources } from '../i18n-config/resources'
|
||||
import type { Namespace, Resources } from '../i18n-config/resources'
|
||||
import 'i18next'
|
||||
|
||||
declare module 'i18next' {
|
||||
// eslint-disable-next-line ts/consistent-type-definitions
|
||||
interface CustomTypeOptions {
|
||||
defaultNS: 'common'
|
||||
resources: Resources
|
||||
keySeparator: false
|
||||
}
|
||||
}
|
||||
|
||||
export type I18nKeysByPrefix<
|
||||
NS extends NamespaceCamelCase,
|
||||
NS extends Namespace,
|
||||
Prefix extends string = '',
|
||||
> = Prefix extends ''
|
||||
? keyof Resources[NS]
|
||||
@@ -22,7 +21,7 @@ export type I18nKeysByPrefix<
|
||||
: never
|
||||
|
||||
export type I18nKeysWithPrefix<
|
||||
NS extends NamespaceCamelCase,
|
||||
NS extends Namespace,
|
||||
Prefix extends string = '',
|
||||
> = Prefix extends ''
|
||||
? keyof Resources[NS]
|
||||
|
||||
7
web/utils/object.ts
Normal file
7
web/utils/object.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export function ObjectFromEntries<const T extends ReadonlyArray<readonly [PropertyKey, unknown]>>(entries: T): { [K in T[number]as K[0]]: K[1] } {
|
||||
return Object.fromEntries(entries) as { [K in T[number]as K[0]]: K[1] }
|
||||
}
|
||||
|
||||
export function ObjectKeys<const T extends Record<string, unknown>>(obj: T): (keyof T)[] {
|
||||
return Object.keys(obj) as (keyof T)[]
|
||||
}
|
||||
Reference in New Issue
Block a user