Compare commits

...

4 Commits

Author SHA1 Message Date
Vlad D
4ac461d882 fix(api): serialize pipeline file-upload created_at (#32098)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 17:50:29 +08:00
Vlad D
fa763216d0 fix(api): register knowledge pipeline service API routes (#32097)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Trigger i18n Sync on Push / trigger (push) Has been cancelled
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
2026-02-09 17:43:36 +08:00
wangxiaolei
d546210040 refactor: document_indexing_sync_task split db session (#32129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 17:12:16 +08:00
Stephen Zhou
4e0a7a7f9e chore: fix type for useTranslation in #i18n (#32134)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-09 16:42:53 +08:00
21 changed files with 522 additions and 240 deletions

View File

@@ -34,6 +34,7 @@ from .dataset import (
metadata, metadata,
segment, segment,
) )
from .dataset.rag_pipeline import rag_pipeline_workflow
from .end_user import end_user from .end_user import end_user
from .workspace import models from .workspace import models
@@ -53,6 +54,7 @@ __all__ = [
"message", "message",
"metadata", "metadata",
"models", "models",
"rag_pipeline_workflow",
"segment", "segment",
"site", "site",
"workflow", "workflow",

View File

@@ -1,5 +1,3 @@
import string
import uuid
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
@@ -12,6 +10,7 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
from controllers.common.schema import register_schema_model from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import PipelineRunError from controllers.service_api.dataset.error import PipelineRunError
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@@ -41,7 +40,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
register_schema_model(service_api_ns, PipelineRunApiEntity) register_schema_model(service_api_ns, PipelineRunApiEntity)
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins") @service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
class DatasourcePluginsApi(DatasetApiResource): class DatasourcePluginsApi(DatasetApiResource):
"""Resource for datasource plugins.""" """Resource for datasource plugins."""
@@ -76,7 +75,7 @@ class DatasourcePluginsApi(DatasetApiResource):
return datasource_plugins, 200 return datasource_plugins, 200
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run") @service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
class DatasourceNodeRunApi(DatasetApiResource): class DatasourceNodeRunApi(DatasetApiResource):
"""Resource for datasource node run.""" """Resource for datasource node run."""
@@ -131,7 +130,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
) )
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run") @service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
class PipelineRunApi(DatasetApiResource): class PipelineRunApi(DatasetApiResource):
"""Resource for datasource node run.""" """Resource for datasource node run."""
@@ -232,12 +231,4 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
return { return serialize_upload_file(upload_file), 201
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@@ -0,0 +1,22 @@
"""
Serialization helpers for Service API knowledge pipeline endpoints.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import UploadFile
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
}

View File

@@ -217,6 +217,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]): def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view) @wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs): def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs # get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments # Flask passes URL path parameters as positional arguments
dataset_id = None dataset_id = None
@@ -253,12 +255,18 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
# Validate dataset if dataset_id is provided # Validate dataset if dataset_id is provided
if dataset_id: if dataset_id:
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() dataset = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.first()
)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
if not dataset.enable_api: if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.") raise Forbidden("Dataset api access is not enabled.")
api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id) .where(Tenant.id == api_token.tenant_id)

View File

@@ -1329,10 +1329,24 @@ class RagPipelineService:
""" """
Get datasource plugins Get datasource plugins
""" """
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() dataset: Dataset | None = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
)
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() pipeline: Pipeline | None = (
db.session.query(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
)
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")
@@ -1413,10 +1427,24 @@ class RagPipelineService:
""" """
Get pipeline Get pipeline
""" """
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() dataset: Dataset | None = (
db.session.query(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == tenant_id,
)
.first()
)
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() pipeline: Pipeline | None = (
db.session.query(Pipeline)
.where(
Pipeline.id == dataset.pipeline_id,
Pipeline.tenant_id == tenant_id,
)
.first()
)
if not pipeline: if not pipeline:
raise ValueError("Pipeline not found") raise ValueError("Pipeline not found")
return pipeline return pipeline

View File

@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
""" """
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green")) logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise Exception("Document has no dataset") raise Exception("Document has no dataset")
index_type = dataset.doc_form index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt) session.execute(document_delete_stmt)
for document_id in document_ids: for document_id in document_ids:
segments = session.scalars( segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
select(DocumentSegment).where(DocumentSegment.document_id == document_id) total_index_node_ids.extend([segment.index_node_id for segment in segments])
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean( with session_factory.create_session() as session:
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
) if dataset:
segment_ids = [segment.id for segment in segments] index_processor.clean(
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
) )
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed") with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)

View File

@@ -27,6 +27,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
""" """
logger.info(click.style(f"Start sync document: {document_id}", fg="green")) logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session, session.begin(): with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
@@ -35,94 +36,120 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Document not found: {document_id}", fg="red")) logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import": if document.data_source_type != "notion_import":
if ( logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
not data_source_info return
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
# Get credentials from datasource provider if (
datasource_provider_service = DatasourceProviderService() not data_source_info
credential = datasource_provider_service.get_datasource_credentials( or "notion_page_id" not in data_source_info
tenant_id=document.tenant_id, or "notion_workspace_id" not in data_source_info
credential_id=credential_id, ):
provider="notion_datasource", raise ValueError("no notion page found")
plugin_id="langgenius/notion_datasource",
)
if not credential: workspace_id = data_source_info["notion_workspace_id"]
logger.error( page_id = data_source_info["notion_page_id"]
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", page_type = data_source_info["type"]
document_id, page_edited_time = data_source_info["last_edited_time"]
document.tenant_id, credential_id = data_source_info.get("credential_id")
credential_id, tenant_id = document.tenant_id
) index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now() document.stopped_at = naive_utc_now()
return return
loader = NotionExtractor( loader = NotionExtractor(
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id, tenant_id=tenant_id,
) )
last_edited_time = loader.get_notion_last_edited_time() last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
# check the page is updated logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
# delete all document segment and index try:
try: index_processor = IndexProcessorFactory(index_type).init_index_processor()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() with session_factory.create_session() as session:
if not dataset: dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
raise Exception("Dataset not found") if dataset:
index_type = document.doc_form index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor = IndexProcessorFactory(index_type).init_index_processor() logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
segments = session.scalars( with session_factory.create_session() as session, session.begin():
select(DocumentSegment).where(DocumentSegment.document_id == document_id) document = session.query(Document).filter_by(id=document_id).first()
).all() if not document:
index_node_ids = [segment.index_node_id for segment in segments] logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
# delete from vector index data_source_info = document.data_source_info_dict
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
segment_ids = [segment.id for segment in segments] document.indexing_status = "parsing"
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) document.processing_started_at = naive_utc_now()
session.execute(segment_delete_stmt)
end_at = time.perf_counter() segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
logger.info( session.execute(segment_delete_stmt)
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try: logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
indexing_runner = IndexingRunner()
indexing_runner.run([document]) try:
end_at = time.perf_counter() indexing_runner = IndexingRunner()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) with session_factory.create_session() as session:
except DocumentIsPausedError as ex: document = session.query(Document).filter_by(id=document_id).first()
logger.info(click.style(str(ex), fg="yellow")) if document:
except Exception: indexing_runner.run([document])
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task(document_ids, dataset.id) clean_notion_document_task(document_ids, dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids)) .filter(DocumentSegment.document_id.in_(document_ids))
@@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
== 0 == 0
) )
# Verify index processor was called for each document # Verify index processor was called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == len(document_ids) mock_processor.clean.assert_called_once()
# This test successfully verifies: # This test successfully verifies:
# 1. Document records are properly deleted from the database # 1. Document records are properly deleted from the database
@@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4()) non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())] document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset # Execute cleanup task with non-existent dataset - expect exception
clean_notion_document_task(document_ids, non_existent_dataset_id) with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor was not called # Verify that the index processor factory was not used
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
mock_processor.clean.assert_not_called()
def test_clean_notion_document_task_empty_document_list( def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list # Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id) clean_notion_document_task([], dataset.id)
# Verify that the index processor was not called # Verify that the index processor was called once with empty node list
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called() assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types( def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types. # Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations. # The task properly handles various index types and document configurations.
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id) .filter(DocumentSegment.document_id == document.id)
@@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task([document.id], dataset.id) clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0 == 0
@@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id) clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents and segments are deleted # Verify only specified documents' segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean)) .filter(DocumentSegment.document_id.in_(documents_to_clean))
@@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit() db_session_with_containers.commit()
# Mock index processor to raise an exception # Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error") mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - it should handle the exception gracefully # Execute cleanup task - current implementation propagates the exception
clean_notion_document_task([document.id], dataset.id) with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability. # Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully. # Even with external service errors, the database operations complete successfully.
@@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents] all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id) clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted # Verify all segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0 == 0
@@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id) clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents from target dataset are deleted # Verify only documents' segments from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id) .filter(DocumentSegment.document_id == target_document.id)
@@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents] all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id) clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted regardless of status # Verify all segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0 == 0
@@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task([document.id], dataset.id) clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0 == 0

View File

@@ -0,0 +1,62 @@
"""
Unit tests for Service API knowledge pipeline file-upload serialization.
"""
import importlib.util
from datetime import UTC, datetime
from pathlib import Path
class FakeUploadFile:
id: str
name: str
size: int
extension: str
mime_type: str
created_by: str
created_at: datetime | None
def _load_serialize_upload_file():
api_dir = Path(__file__).resolve().parents[5]
serializers_path = api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "serializers.py"
spec = importlib.util.spec_from_file_location("rag_pipeline_serializers", serializers_path)
assert spec
assert spec.loader
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[attr-defined]
return module.serialize_upload_file
def test_file_upload_created_at_is_isoformat_string():
serialize_upload_file = _load_serialize_upload_file()
created_at = datetime(2026, 2, 8, 12, 0, 0, tzinfo=UTC)
upload_file = FakeUploadFile()
upload_file.id = "file-1"
upload_file.name = "test.pdf"
upload_file.size = 123
upload_file.extension = "pdf"
upload_file.mime_type = "application/pdf"
upload_file.created_by = "account-1"
upload_file.created_at = created_at
result = serialize_upload_file(upload_file)
assert result["created_at"] == created_at.isoformat()
def test_file_upload_created_at_none_serializes_to_null():
serialize_upload_file = _load_serialize_upload_file()
upload_file = FakeUploadFile()
upload_file.id = "file-1"
upload_file.name = "test.pdf"
upload_file.size = 123
upload_file.extension = "pdf"
upload_file.mime_type = "application/pdf"
upload_file.created_by = "account-1"
upload_file.created_at = None
result = serialize_upload_file(upload_file)
assert result["created_at"] is None

View File

@@ -0,0 +1,54 @@
"""
Unit tests for Service API knowledge pipeline route registration.
"""
import ast
from pathlib import Path
def test_rag_pipeline_routes_registered():
api_dir = Path(__file__).resolve().parents[5]
service_api_init = api_dir / "controllers" / "service_api" / "__init__.py"
rag_pipeline_workflow = (
api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py"
)
assert service_api_init.exists()
assert rag_pipeline_workflow.exists()
init_tree = ast.parse(service_api_init.read_text(encoding="utf-8"))
import_found = False
for node in ast.walk(init_tree):
if not isinstance(node, ast.ImportFrom):
continue
if node.module != "dataset.rag_pipeline" or node.level != 1:
continue
if any(alias.name == "rag_pipeline_workflow" for alias in node.names):
import_found = True
break
assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py"
workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8"))
route_paths: set[str] = set()
for node in ast.walk(workflow_tree):
if not isinstance(node, ast.ClassDef):
continue
for decorator in node.decorator_list:
if not isinstance(decorator, ast.Call):
continue
if not isinstance(decorator.func, ast.Attribute):
continue
if decorator.func.attr != "route":
continue
if not decorator.args:
continue
first_arg = decorator.args[0]
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
route_paths.add(first_arg.value)
assert "/datasets/<uuid:dataset_id>/pipeline/datasource-plugins" in route_paths
assert "/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run" in route_paths
assert "/datasets/<uuid:dataset_id>/pipeline/run" in route_paths
assert "/datasets/pipeline/file-upload" in route_paths

View File

@@ -4,7 +4,7 @@ from typing import Any
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from hypothesis import given, settings from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
@@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
) )
@settings(max_examples=50) @settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(_scalar_value()) @given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value): def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value) seg = variable_factory.build_segment(value)
@@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
assert seg.value == value assert seg.value == value
@settings(max_examples=50) @settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(values=st.lists(_scalar_value(), max_size=20)) @given(values=st.lists(_scalar_value(), max_size=20))
def test_build_segment_and_extract_values_for_array_types(values): def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values) seg = variable_factory.build_segment(values)

View File

@@ -109,40 +109,87 @@ def mock_document_segments(document_id):
@pytest.fixture @pytest.fixture
def mock_db_session(): def mock_db_session():
"""Mock database session via session_factory.create_session().""" """Mock database session via session_factory.create_session().
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock() sessions = []
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager to auto-commit on exit # Shared query mocks - all sessions use these
begin_cm = MagicMock() shared_query = MagicMock()
begin_cm.__enter__.return_value = session shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
def _begin_exit_side_effect(*args, **kwargs): # Create custom first mock that auto-cycles side_effect
# session.begin().__exit__() should commit if no exception class CyclicMock(MagicMock):
if args[0] is None: # No exception def __setattr__(self, name, value):
session.commit() if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
begin_cm.__exit__.side_effect = _begin_exit_side_effect shared_query.where.return_value.first = CyclicMock()
session.begin.return_value = begin_cm shared_filter_by.first = CyclicMock()
# Mock create_session() context manager def _create_session():
cm = MagicMock() """Create a new mock session for each create_session() call."""
cm.__enter__.return_value = session session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
def _exit_side_effect(*args, **kwargs): # Mock session.begin() context manager
session.close() begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
cm.__exit__.side_effect = _exit_side_effect def _begin_exit_side_effect(exc_type, exc, tb):
mock_sf.create_session.return_value = cm # commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
query = MagicMock() begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.query.return_value = query session.begin.return_value = begin_cm
query.where.return_value = query
session.scalars.return_value = MagicMock() # Mock create_session() context manager
yield session cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
@pytest.fixture @pytest.fixture
@@ -201,8 +248,8 @@ class TestDocumentIndexingSyncTask:
# Act # Act
document_indexing_sync_task(dataset_id, document_id) document_indexing_sync_task(dataset_id, document_id)
# Assert # Assert - at least one session should have been closed
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing.""" """Test that task raises error when notion_workspace_id is missing."""
@@ -245,6 +292,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task handles missing credentials by updating document status.""" """Test that task handles missing credentials by updating document status."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act # Act
@@ -254,8 +302,8 @@ class TestDocumentIndexingSyncTask:
assert mock_document.indexing_status == "error" assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called() assert mock_db_session.any_commit_called()
mock_db_session.close.assert_called() assert mock_db_session.any_close_called()
def test_page_not_updated( def test_page_not_updated(
self, self,
@@ -269,6 +317,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task does nothing when page has not been updated.""" """Test that task does nothing when page has not been updated."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document # Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
@@ -278,8 +327,8 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Document status should remain unchanged # Document status should remain unchanged
assert mock_document.indexing_status == "completed" assert mock_document.indexing_status == "completed"
# Session should still be closed via context manager teardown # At least one session should have been closed via context manager teardown
assert mock_db_session.close.called assert mock_db_session.any_close_called()
def test_successful_sync_when_page_updated( def test_successful_sync_when_page_updated(
self, self,
@@ -296,7 +345,20 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test successful sync flow when Notion page has been updated.""" """Test successful sync flow when Notion page has been updated."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] # Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time # NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
@@ -314,28 +376,40 @@ class TestDocumentIndexingSyncTask:
mock_processor.clean.assert_called_once() mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments) # Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] # Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called # Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document]) mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations # Verify session operations (across any created session)
assert mock_db_session.commit.called assert mock_db_session.any_commit_called()
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_dataset_not_found_during_cleaning( def test_dataset_not_found_during_cleaning(
self, self,
mock_db_session, mock_db_session,
mock_datasource_provider_service, mock_datasource_provider_service,
mock_notion_extractor, mock_notion_extractor,
mock_indexing_runner,
mock_document, mock_document,
dataset_id, dataset_id,
document_id, document_id,
): ):
"""Test that task handles dataset not found during cleaning phase.""" """Test that task handles dataset not found during cleaning phase."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] # Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act # Act
@@ -344,8 +418,8 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Document should still be set to parsing # Document should still be set to parsing
assert mock_document.indexing_status == "parsing" assert mock_document.indexing_status == "parsing"
# Session should be closed after error # At least one session should be closed after error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_cleaning_error_continues_to_indexing( def test_cleaning_error_continues_to_indexing(
self, self,
@@ -361,8 +435,14 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that indexing continues even if cleaning fails.""" """Test that indexing continues even if cleaning fails."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act # Act
@@ -371,7 +451,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Indexing should still be attempted despite cleaning error # Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document]) mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_indexing_runner_document_paused_error( def test_indexing_runner_document_paused_error(
self, self,
@@ -388,7 +468,10 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that DocumentIsPausedError is handled gracefully.""" """Test that DocumentIsPausedError is handled gracefully."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
@@ -398,7 +481,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Session should be closed after handling error # Session should be closed after handling error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_indexing_runner_general_error( def test_indexing_runner_general_error(
self, self,
@@ -415,7 +498,10 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that general exceptions during indexing are handled.""" """Test that general exceptions during indexing are handled."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error") mock_indexing_runner.run.side_effect = Exception("Indexing error")
@@ -425,7 +511,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Session should be closed after error # Session should be closed after error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_notion_extractor_initialized_with_correct_params( def test_notion_extractor_initialized_with_correct_params(
self, self,
@@ -532,7 +618,14 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that index processor clean is called with correct parameters.""" """Test that index processor clean is called with correct parameters."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] # Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"

View File

@@ -38,7 +38,7 @@ const DeprecationNotice: FC<DeprecationNoticeProps> = ({
iconWrapperClassName, iconWrapperClassName,
textClassName, textClassName,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation('plugin')
const deprecatedReasonKey = useMemo(() => { const deprecatedReasonKey = useMemo(() => {
if (!deprecatedReason) if (!deprecatedReason)

View File

@@ -1,7 +1,7 @@
'use client' 'use client'
import type { Resource } from 'i18next' import type { Resource } from 'i18next'
import type { Locale } from '.' import type { Locale } from '.'
import type { NamespaceCamelCase, NamespaceKebabCase } from './resources' import type { Namespace, NamespaceInFileName } from './resources'
import { kebabCase } from 'es-toolkit/string' import { kebabCase } from 'es-toolkit/string'
import { createInstance } from 'i18next' import { createInstance } from 'i18next'
import resourcesToBackend from 'i18next-resources-to-backend' import resourcesToBackend from 'i18next-resources-to-backend'
@@ -14,7 +14,7 @@ export function createI18nextInstance(lng: Locale, resources: Resource) {
.use(initReactI18next) .use(initReactI18next)
.use(resourcesToBackend(( .use(resourcesToBackend((
language: Locale, language: Locale,
namespace: NamespaceKebabCase | NamespaceCamelCase, namespace: NamespaceInFileName | Namespace,
) => { ) => {
const namespaceKebab = kebabCase(namespace) const namespaceKebab = kebabCase(namespace)
return import(`../i18n/${language}/${namespaceKebab}.json`) return import(`../i18n/${language}/${namespaceKebab}.json`)

View File

@@ -1,9 +1,9 @@
'use client' 'use client'
import type { NamespaceCamelCase } from './resources' import type { Namespace } from './resources'
import { useTranslation as useTranslationOriginal } from 'react-i18next' import { useTranslation as useTranslationOriginal } from 'react-i18next'
export function useTranslation(ns?: NamespaceCamelCase) { export function useTranslation<T extends Namespace | undefined = undefined>(ns?: T) {
return useTranslationOriginal(ns) return useTranslationOriginal(ns)
} }

View File

@@ -1,13 +1,13 @@
import type { NamespaceCamelCase } from './resources' import type { Namespace } from './resources'
import { use } from 'react' import { use } from 'react'
import { getLocaleOnServer, getTranslation } from './server' import { getLocaleOnServer, getTranslation } from './server'
async function getI18nConfig(ns?: NamespaceCamelCase) { async function getI18nConfig<T extends Namespace | undefined = undefined>(ns?: T) {
const lang = await getLocaleOnServer() const lang = await getLocaleOnServer()
return getTranslation(lang, ns) return getTranslation(lang, ns)
} }
export function useTranslation(ns?: NamespaceCamelCase) { export function useTranslation<T extends Namespace | undefined = undefined>(ns?: T) {
return use(getI18nConfig(ns)) return use(getI18nConfig(ns))
} }

View File

@@ -1,4 +1,5 @@
import { kebabCase } from 'es-toolkit/string' import { kebabCase } from 'string-ts'
import { ObjectKeys } from '@/utils/object'
import appAnnotation from '../i18n/en-US/app-annotation.json' import appAnnotation from '../i18n/en-US/app-annotation.json'
import appApi from '../i18n/en-US/app-api.json' import appApi from '../i18n/en-US/app-api.json'
import appDebug from '../i18n/en-US/app-debug.json' import appDebug from '../i18n/en-US/app-debug.json'
@@ -64,19 +65,10 @@ const resources = {
workflow, workflow,
} }
export type KebabCase<S extends string> = S extends `${infer T}${infer U}`
? T extends Lowercase<T>
? `${T}${KebabCase<U>}`
: `-${Lowercase<T>}${KebabCase<U>}`
: S
export type CamelCase<S extends string> = S extends `${infer T}-${infer U}`
? `${T}${Capitalize<CamelCase<U>>}`
: S
export type Resources = typeof resources export type Resources = typeof resources
export type NamespaceCamelCase = keyof Resources
export type NamespaceKebabCase = KebabCase<NamespaceCamelCase>
export const namespacesCamelCase = Object.keys(resources) as NamespaceCamelCase[] export const namespaces = ObjectKeys(resources)
export const namespacesKebabCase = namespacesCamelCase.map(ns => kebabCase(ns)) as NamespaceKebabCase[] export type Namespace = typeof namespaces[number]
export const namespacesInFileName = namespaces.map(ns => kebabCase(ns))
export type NamespaceInFileName = typeof namespacesInFileName[number]

View File

@@ -1,6 +1,6 @@
import type { i18n as I18nInstance, Resource, ResourceLanguage } from 'i18next' import type { i18n as I18nInstance, Resource, ResourceLanguage } from 'i18next'
import type { Locale } from '.' import type { Locale } from '.'
import type { NamespaceCamelCase, NamespaceKebabCase } from './resources' import type { Namespace, NamespaceInFileName } from './resources'
import { match } from '@formatjs/intl-localematcher' import { match } from '@formatjs/intl-localematcher'
import { kebabCase } from 'es-toolkit/compat' import { kebabCase } from 'es-toolkit/compat'
import { camelCase } from 'es-toolkit/string' import { camelCase } from 'es-toolkit/string'
@@ -12,7 +12,7 @@ import { cache } from 'react'
import { initReactI18next } from 'react-i18next/initReactI18next' import { initReactI18next } from 'react-i18next/initReactI18next'
import { serverOnlyContext } from '@/utils/server-only-context' import { serverOnlyContext } from '@/utils/server-only-context'
import { i18n } from '.' import { i18n } from '.'
import { namespacesKebabCase } from './resources' import { namespacesInFileName } from './resources'
import { getInitOptions } from './settings' import { getInitOptions } from './settings'
const [getLocaleCache, setLocaleCache] = serverOnlyContext<Locale | null>(null) const [getLocaleCache, setLocaleCache] = serverOnlyContext<Locale | null>(null)
@@ -26,8 +26,8 @@ const getOrCreateI18next = async (lng: Locale) => {
instance = createInstance() instance = createInstance()
await instance await instance
.use(initReactI18next) .use(initReactI18next)
.use(resourcesToBackend((language: Locale, namespace: NamespaceCamelCase | NamespaceKebabCase) => { .use(resourcesToBackend((language: Locale, namespace: Namespace | NamespaceInFileName) => {
const fileNamespace = kebabCase(namespace) as NamespaceKebabCase const fileNamespace = kebabCase(namespace)
return import(`../i18n/${language}/${fileNamespace}.json`) return import(`../i18n/${language}/${fileNamespace}.json`)
})) }))
.init({ .init({
@@ -38,7 +38,7 @@ const getOrCreateI18next = async (lng: Locale) => {
return instance return instance
} }
export async function getTranslation(lng: Locale, ns?: NamespaceCamelCase) { export async function getTranslation<T extends Namespace>(lng: Locale, ns?: T) {
const i18nextInstance = await getOrCreateI18next(lng) const i18nextInstance = await getOrCreateI18next(lng)
if (ns && !i18nextInstance.hasLoadedNamespace(ns)) if (ns && !i18nextInstance.hasLoadedNamespace(ns))
@@ -84,7 +84,7 @@ export const getResources = cache(async (lng: Locale): Promise<Resource> => {
const messages = {} as ResourceLanguage const messages = {} as ResourceLanguage
await Promise.all( await Promise.all(
(namespacesKebabCase).map(async (ns) => { (namespacesInFileName).map(async (ns) => {
const mod = await import(`../i18n/${lng}/${ns}.json`) const mod = await import(`../i18n/${lng}/${ns}.json`)
messages[camelCase(ns)] = mod.default messages[camelCase(ns)] = mod.default
}), }),

View File

@@ -1,5 +1,5 @@
import type { InitOptions } from 'i18next' import type { InitOptions } from 'i18next'
import { namespacesCamelCase } from './resources' import { namespaces } from './resources'
export function getInitOptions(): InitOptions { export function getInitOptions(): InitOptions {
return { return {
@@ -8,7 +8,7 @@ export function getInitOptions(): InitOptions {
fallbackLng: 'en-US', fallbackLng: 'en-US',
partialBundledLanguages: true, partialBundledLanguages: true,
keySeparator: false, keySeparator: false,
ns: namespacesCamelCase, ns: namespaces,
interpolation: { interpolation: {
escapeValue: false, escapeValue: false,
}, },

7
web/types/i18n.d.ts vendored
View File

@@ -1,17 +1,16 @@
import type { NamespaceCamelCase, Resources } from '../i18n-config/resources' import type { Namespace, Resources } from '../i18n-config/resources'
import 'i18next' import 'i18next'
declare module 'i18next' { declare module 'i18next' {
// eslint-disable-next-line ts/consistent-type-definitions // eslint-disable-next-line ts/consistent-type-definitions
interface CustomTypeOptions { interface CustomTypeOptions {
defaultNS: 'common'
resources: Resources resources: Resources
keySeparator: false keySeparator: false
} }
} }
export type I18nKeysByPrefix< export type I18nKeysByPrefix<
NS extends NamespaceCamelCase, NS extends Namespace,
Prefix extends string = '', Prefix extends string = '',
> = Prefix extends '' > = Prefix extends ''
? keyof Resources[NS] ? keyof Resources[NS]
@@ -22,7 +21,7 @@ export type I18nKeysByPrefix<
: never : never
export type I18nKeysWithPrefix< export type I18nKeysWithPrefix<
NS extends NamespaceCamelCase, NS extends Namespace,
Prefix extends string = '', Prefix extends string = '',
> = Prefix extends '' > = Prefix extends ''
? keyof Resources[NS] ? keyof Resources[NS]

7
web/utils/object.ts Normal file
View File

@@ -0,0 +1,7 @@
export function ObjectFromEntries<const T extends ReadonlyArray<readonly [PropertyKey, unknown]>>(entries: T): { [K in T[number]as K[0]]: K[1] } {
return Object.fromEntries(entries) as { [K in T[number]as K[0]]: K[1] }
}
export function ObjectKeys<const T extends Record<string, unknown>>(obj: T): (keyof T)[] {
return Object.keys(obj) as (keyof T)[]
}