Compare commits

...

12 Commits

Author SHA1 Message Date
wangxiaolei
b62965034e refactor: document_indexing_sync_task split db session (#32129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 17:16:17 +08:00
wangxiaolei
016d72a8c6 fix: fix trigger output schema miss (#32116) 2026-02-09 17:16:08 +08:00
wangxiaolei
125f7e3ab4 refactor: document_indexing_update_task split database session (#32105)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 10:51:45 +08:00
wangxiaolei
400ed2fd72 refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-08 21:05:03 +08:00
QuantumGhost
840a8f3fc2 perf: use batch delete method instead of single delete (#32036)
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: FFXN <lizy@dify.ai>
2026-02-06 15:13:17 +08:00
wangxiaolei
b4a5296fd1 fix: fix tool type is miss (#32042) 2026-02-06 14:38:54 +08:00
wangxiaolei
fcb53383df fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables.

The root cause for incorrect `type` field is still not clear.
2026-02-06 11:25:29 +08:00
QuantumGhost
540e1db83c perf(api): Optimize the response time of AppListApi endpoint (#31999) 2026-02-06 10:46:25 +08:00
wangxiaolei
2f75e38c08 fix: fix miss use db.session (#31971) 2026-02-05 15:59:37 +08:00
wangxiaolei
cd03e0a9ef fix: fix delete_draft_variables_batch cycle forever (#31934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 19:42:50 +08:00
zxhlyh
df2421d187 fix: auto summary env (#31930) 2026-02-04 19:42:26 +08:00
QuantumGhost
0ba321d840 chore: bump version in docker-compose and package manager to 1.12.1 (#31947) 2026-02-04 19:41:50 +08:00
38 changed files with 1477 additions and 813 deletions

View File

@@ -1,3 +1,4 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@@ -499,6 +502,7 @@ class AppListApi(Resource):
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
)
)
.scalars()
@@ -510,12 +514,14 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:
for _, node_data in workflow.walk_nodes():
for node_id, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue
for app in app_pagination.items:

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -81,7 +81,7 @@ dependencies = [
"starlette==0.49.1",
"tiktoken~=0.9.0",
"transformers~=4.56.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -5,7 +5,6 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -14,6 +14,9 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form:
raise ValueError("doc_form is required")
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
storage_keys_to_delete: list[str] = []
index_node_ids: list[str] = []
segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
total_image_upload_file_ids.extend(image_upload_file_ids)
# Query storage keys for image files
if total_image_upload_file_ids:
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
).all()
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
# Query storage keys for document files
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
)
)
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)

View File

@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
dataset_config = {
"id": dataset.id,
"indexing_technique": dataset.indexing_technique,
"tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_config = {
"id": dataset_document.id,
"doc_form": dataset_document.doc_form,
"word_count": dataset_document.word_count or 0,
}
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
upload_file_key = upload_file.key
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
return
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
tokens_list = [0] * len(content)
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],
provider=dataset_config["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=dataset_config["embedding_model"],
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
with session_factory.create_session() as session:
dataset = session.get(Dataset, dataset_id)
if dataset:
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
"""
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session:
try:
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
total_image_files = []
with session_factory.create_session() as session, session.begin():
for segment_content in segment_contents:
image_upload_file_ids = get_image_upload_file_ids(segment_content)
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
total_image_files.extend([image_file.key for image_file in image_files])
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
for image_file_key in total_image_files:
try:
storage.delete(image_file_key)
except Exception:
logger.exception("Cleaned document when document deleted failed")
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
"""
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -67,8 +68,14 @@ def delete_segment_from_index_task(
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()

View File

@@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
"""
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
if document.data_source_type != "notion_import":
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
return
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
tenant_id = document.tenant_id
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
)
for document in documents:
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
# Transaction committed and closed
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
has_error = False
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
)
for document in documents:
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True
):
try:
generate_summary_index_task.delay(dataset.id, document_id, None)
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document_id,
document.id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document_id,
document.id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
else:
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
logger.warning("Document %s not found after indexing", document.id)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@@ -8,7 +8,6 @@ from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
@@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
clean_success = False
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logger.info(
click.style(
@@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
clean_success = True
except Exception:
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
if clean_success:
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(workflow_archive_log_id: str):
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
@@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables

View File

@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
"""
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
from core.db.session_factory import session_factory
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with Session(bind=db.engine) as session, session.begin():
with session_factory.create_session() as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -10,7 +10,10 @@ from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
from tasks.remove_app_and_related_data_task import (
_delete_draft_variables,
delete_draft_variables_batch,
)
@pytest.fixture
@@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.return_value = None
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
@@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
if app2_obj:
session.delete(app2_obj)
session.commit()
class TestDeleteDraftVariablesSessionCommit:
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with offload files for session commit tests."""
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
tenant, app = app_and_tenant
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
yield data
with session_factory.create_session() as session:
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@pytest.fixture
def setup_commit_test_data(self, app_and_tenant):
"""Create test data for session commit tests."""
tenant, app = app_and_tenant
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(10):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
yield {
"app": app,
"tenant": tenant,
"variable_ids": variable_ids,
}
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
session.execute(cleanup_query)
session.commit()
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
"""Test that session.begin() is used for automatic transaction management."""
data = setup_commit_test_data
app_id = data["app"].id
# Since session.begin() is used, the transaction is automatically committed
# when the with block exits successfully. We verify this by checking that
# data is actually persisted.
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert deleted_count == 10
assert remaining_count == 0
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
"""Test that data is actually persisted to database after batch deletion with commits."""
data = setup_commit_test_data
app_id = data["app"].id
variable_ids = data["variable_ids"]
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
assert deleted_count == 10
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
)
assert remaining_vars == 0
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
"""Test session behavior when deleting from an empty dataset."""
nonexistent_app_id = str(uuid.uuid4())
# Should not raise any errors and should return 0
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
assert deleted_count == 0
def test_session_commit_with_single_batch(self, setup_commit_test_data):
"""Test that commit happens correctly when all data fits in a single batch."""
data = setup_commit_test_data
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Delete all in a single batch
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
assert deleted_count == 10
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
"""Test that invalid batch size raises ValueError."""
data = setup_commit_test_data
app_id = data["app"].id
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=-1)
@patch("extensions.ext_storage.storage")
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that session commits correctly when cleaning up offload data."""
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
mock_storage.delete.return_value = None
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete variables with offload data
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count == 3
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage cleanup was called
assert mock_storage.delete.call_count == 2

View File

@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download
# Execute the task
# Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
with pytest.raises(ValueError, match="The CSV file is empty"):
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
# Verify error handling
# Check Redis cache was set to error status
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
# Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()

View File

@@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task(document_ids, dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
@@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
== 0
)
# Verify index processor was called for each document
# Verify index processor was called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == len(document_ids)
mock_processor.clean.assert_called_once()
# This test successfully verifies:
# 1. Document records are properly deleted from the database
@@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Execute cleanup task with non-existent dataset - expect exception
with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor was not called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
# Verify that the index processor factory was not used
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id)
# Verify that the index processor was not called
# Verify that the index processor was called once with empty node list
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations.
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
@@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0
@@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
# Verify only specified documents' segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
@@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - it should handle the exception gracefully
clean_notion_document_task([document.id], dataset.id)
# Execute cleanup task - current implementation propagates the exception
with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully.
@@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
# Verify only documents' segments from target dataset are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
@@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0

View File

@@ -0,0 +1,182 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_update_task import document_indexing_update_task
class TestDocumentIndexingUpdateTask:
@pytest.fixture
def mock_external_dependencies(self):
"""Patch external collaborators used by the update task.
- IndexProcessorFactory.init_index_processor().clean(...)
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
yield {
"factory": mock_factory,
"processor": processor_instance,
"runner": mock_runner,
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Dataset and document
dataset = Dataset(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
document = Document(
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Segments
node_ids = []
for i in range(segment_count):
node_id = f"node-{i + 1}"
seg = DocumentSegment(
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=32),
answer=None,
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
created_by=account.id,
)
db_session_with_containers.add(seg)
node_ids.append(node_id)
db_session_with_containers.commit()
# Refresh to ensure ORM state
db_session_with_containers.refresh(dataset)
db_session_with_containers.refresh(document)
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining == 0
# Assert index processor clean was called with expected args
clean_call = mock_external_dependencies["processor"].clean.call_args
assert clean_call is not None
args, kwargs = clean_call
# args[0] is a Dataset instance (from another session) — validate by id
assert getattr(args[0], "id", None) == dataset.id
# args[1] should contain our node_ids
assert set(args[1]) == set(node_ids)
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
# Assert indexing runner invoked with the updated document
run_call = mock_external_dependencies["runner_instance"].run.call_args
assert run_call is not None
run_docs = run_call[0][0]
assert len(run_docs) == 1
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Indexing should still be triggered
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
# Neither processor nor runner should be called
mock_external_dependencies["processor"].clean.assert_not_called()
mock_external_dependencies["runner_instance"].run.assert_not_called()

View File

@@ -4,7 +4,7 @@ from typing import Any
from uuid import uuid4
import pytest
from hypothesis import given, settings
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType
@@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
)
@settings(max_examples=50)
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value)
@@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
assert seg.value == value
@settings(max_examples=50)
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(values=st.lists(_scalar_value(), max_size=20))
def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values)

View File

@@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
sessions = [] # Track all created sessions
# Shared mock data that all sessions will access
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
def _exit_side_effect(*args, **kwargs):
session.close()
def create_session_side_effect():
session = MagicMock()
session.close = MagicMock()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
# Track commit calls
commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
@pytest.fixture
@@ -252,18 +356,9 @@ class TestTaskEnqueuing:
use the deprecated function.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -304,21 +399,9 @@ class TestBatchProcessing:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create an iterator for documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -357,19 +440,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@@ -407,19 +480,9 @@ class TestBatchProcessing:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@@ -444,7 +507,10 @@ class TestBatchProcessing:
"""
# Arrange
document_ids = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -482,19 +548,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -528,19 +584,9 @@ class TestProgressTracking:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -635,19 +681,9 @@ class TestErrorHandling:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -674,17 +710,9 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@@ -708,17 +736,9 @@ class TestErrorHandling:
but not treated as a failure.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@@ -853,17 +873,9 @@ class TestTaskCancellation:
Session cleanup should happen in finally block.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -883,17 +895,9 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur.
"""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error")
@@ -962,6 +966,7 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document)
@@ -971,21 +976,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Create iterator that returns None for missing document
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
doc.stopped_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False
@@ -1273,19 +1246,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1321,19 +1284,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1415,17 +1368,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1465,17 +1410,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = [mock_document]
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1555,19 +1492,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1612,19 +1539,9 @@ class TestEdgeCases:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1826,19 +1733,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
# Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@@ -1866,7 +1763,7 @@ class TestRobustness:
- No exceptions occur
Expected behavior:
- Database session is closed
- All database sessions are closed
- No connection leaks
"""
# Arrange
@@ -1879,19 +1776,9 @@ class TestRobustness:
doc.processing_started_at = None
mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = mock_documents
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False
@@ -1899,10 +1786,11 @@ class TestRobustness:
# Act
_document_indexing(dataset_id, document_ids)
# Assert
assert mock_db_session.close.called
# Verify close is called exactly once
assert mock_db_session.close.call_count == 1
# Assert - All created sessions should be closed
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
assert len(mock_db_session.all_sessions) >= 1
for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
"""

View File

@@ -109,25 +109,87 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
"""Mock database session via session_factory.create_session().
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
sessions = []
def _exit_side_effect(*args, **kwargs):
session.close()
# Shared query mocks - all sessions use these
shared_query = MagicMock()
shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
# Create custom first mock that auto-cycles side_effect
class CyclicMock(MagicMock):
def __setattr__(self, name, value):
if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
shared_query.where.return_value.first = CyclicMock()
shared_filter_by.first = CyclicMock()
def _create_session():
"""Create a new mock session for each create_session() call."""
session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(exc_type, exc, tb):
# commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
@pytest.fixture
@@ -186,8 +248,8 @@ class TestDocumentIndexingSyncTask:
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_db_session.close.assert_called_once()
# Assert - at least one session should have been closed
assert mock_db_session.any_close_called()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
@@ -230,6 +292,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
@@ -239,8 +302,8 @@ class TestDocumentIndexingSyncTask:
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_page_not_updated(
self,
@@ -254,6 +317,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
@@ -263,8 +327,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
# At least one session should have been closed via context manager teardown
assert mock_db_session.any_close_called()
def test_successful_sync_when_page_updated(
self,
@@ -281,7 +345,20 @@ class TestDocumentIndexingSyncTask:
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
# Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
@@ -299,28 +376,40 @@ class TestDocumentIndexingSyncTask:
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
# Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
# Verify session operations (across any created session)
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_indexing_runner,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@@ -329,8 +418,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# Session should be closed after error
mock_db_session.close.assert_called_once()
# At least one session should be closed after error
assert mock_db_session.any_close_called()
def test_cleaning_error_continues_to_indexing(
self,
@@ -346,8 +435,14 @@ class TestDocumentIndexingSyncTask:
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
@@ -356,7 +451,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_indexing_runner_document_paused_error(
self,
@@ -373,7 +468,10 @@ class TestDocumentIndexingSyncTask:
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
@@ -383,7 +481,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after handling error
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_indexing_runner_general_error(
self,
@@ -400,7 +498,10 @@ class TestDocumentIndexingSyncTask:
):
"""Test that general exceptions during indexing are handled."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
@@ -410,7 +511,7 @@ class TestDocumentIndexingSyncTask:
# Assert
# Session should be closed after error
mock_db_session.close.assert_called_once()
assert mock_db_session.any_close_called()
def test_notion_extractor_initialized_with_correct_params(
self,
@@ -517,7 +618,14 @@ class TestDocumentIndexingSyncTask:
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"

View File

@@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query
delete_func("log-1")
delete_func(mock_db.session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once()

13
api/uv.lock generated
View File

@@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@@ -1653,7 +1653,7 @@ requires-dist = [
{ name = "starlette", specifier = "==0.49.1" },
{ name = "tiktoken", specifier = "~=0.9.0" },
{ name = "transformers", specifier = "~=4.56.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
{ name = "weave", specifier = ">=0.52.16" },
{ name = "weaviate-client", specifier = "==4.17.0" },
{ name = "webvtt-py", specifier = "~=0.5.1" },
@@ -6814,12 +6814,12 @@ wheels = [
[[package]]
name = "unstructured"
version = "0.16.25"
version = "0.18.31"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "backoff" },
{ name = "beautifulsoup4" },
{ name = "chardet" },
{ name = "charset-normalizer" },
{ name = "dataclasses-json" },
{ name = "emoji" },
{ name = "filetype" },
@@ -6827,6 +6827,7 @@ dependencies = [
{ name = "langdetect" },
{ name = "lxml" },
{ name = "nltk" },
{ name = "numba" },
{ name = "numpy" },
{ name = "psutil" },
{ name = "python-iso639" },
@@ -6839,9 +6840,9 @@ dependencies = [
{ name = "unstructured-client" },
{ name = "wrapt" },
]
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
]
[package.optional-dependencies]

View File

@@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -707,7 +707,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -749,7 +749,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -788,7 +788,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -818,7 +818,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -109,6 +109,7 @@ const AgentTools: FC = () => {
tool_parameters: paramsWithDefaultValue,
notAuthor: !tool.is_team_authorization,
enabled: true,
type: tool.provider_type as CollectionType,
}
}
const handleSelectTool = (tool: ToolDefaultValue) => {

View File

@@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
</div>
))}
{
showSummaryIndexSetting && (
showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3">
<SummaryIndexSetting
entry="create-document"

View File

@@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card'
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg'
@@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
</div>
))}
{
showSummaryIndexSetting && (
showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3">
<SummaryIndexSetting
entry="create-document"

View File

@@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
@@ -263,10 +264,14 @@ const Operations = ({
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
</div>
)}
<div className={s.actionItem} onClick={() => onOperate('summary')}>
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
</div>
{
IS_CE_EDITION && (
<div className={s.actionItem} onClick={() => onOperate('summary')}>
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
</div>
)
}
<Divider className="my-1" />
</>
)}

View File

@@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import { IS_CE_EDITION } from '@/config'
import { cn } from '@/utils/classnames'
const i18nPrefix = 'batchAction'
@@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
</Button>
)}
{onBatchSummary && (
{onBatchSummary && IS_CE_EDITION && (
<Button
variant="ghost"
className="gap-x-0.5 px-3"

View File

@@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { IS_CE_EDITION } from '@/config'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
@@ -359,7 +360,7 @@ const Form = () => {
{
indexMethod === IndexingType.QUALIFIED
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
&& (
&& IS_CE_EDITION && (
<>
<Divider
type="horizontal"

View File

@@ -129,6 +129,7 @@ export const useToolSelectorState = ({
extra: {
description: tool.tool_description,
},
type: tool.provider_type,
}
}, [])

View File

@@ -87,6 +87,7 @@ export type ToolValue = {
enabled?: boolean
extra?: { description?: string } & Record<string, unknown>
credential_id?: string
type?: string
}
export type DataSourceItem = {

View File

@@ -18,6 +18,7 @@ import {
Group,
} from '@/app/components/workflow/nodes/_base/components/layout'
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
import { IS_CE_EDITION } from '@/config'
import Split from '../_base/components/split'
import ChunkStructure from './components/chunk-structure'
import EmbeddingModel from './components/embedding-model'
@@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
{
data.indexing_technique === IndexMethodEnum.QUALIFIED
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
&& (
&& IS_CE_EDITION && (
<>
<SummaryIndexSetting
summaryIndexSetting={data.summary_index_setting}

View File

@@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.12.0",
"version": "1.12.1",
"private": true,
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": {

View File

@@ -9,7 +9,6 @@ import type {
} from '@/types/workflow'
import { get, post } from './base'
import { getFlowPrefix } from './utils'
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
export const fetchWorkflowDraft = (url: string) => {
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
@@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
url: string
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
}) => {
const sanitized = sanitizeWorkflowDraftPayload(params)
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
}
export const fetchNodesDefaultConfigs = (url: string) => {