mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 01:33:59 +00:00
Compare commits
5 Commits
refactor/s
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a2093e40e | ||
|
|
aa800d838d | ||
|
|
4bd80683a4 | ||
|
|
c185a51bad | ||
|
|
4430a1b3da |
@@ -23,7 +23,7 @@ dependencies = [
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-python-client==2.90.0",
|
||||
"google-api-python-client==2.189.0",
|
||||
"google-auth==2.29.0",
|
||||
"google-auth-httplib2==0.2.0",
|
||||
"google-cloud-aiplatform==1.49.0",
|
||||
|
||||
@@ -155,11 +155,11 @@ class AsyncWorkflowService:
|
||||
|
||||
task: AsyncResult[Any] | None = None
|
||||
if queue_name == QueuePriority.PROFESSIONAL:
|
||||
task = execute_workflow_professional.delay(task_data_dict) # type: ignore
|
||||
task = execute_workflow_professional.delay(task_data_dict)
|
||||
elif queue_name == QueuePriority.TEAM:
|
||||
task = execute_workflow_team.delay(task_data_dict) # type: ignore
|
||||
task = execute_workflow_team.delay(task_data_dict)
|
||||
else: # SANDBOX
|
||||
task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore
|
||||
task = execute_workflow_sandbox.delay(task_data_dict)
|
||||
|
||||
# 10. Update trigger log with task info
|
||||
trigger_log.status = WorkflowTriggerStatus.QUEUED
|
||||
@@ -170,7 +170,7 @@ class AsyncWorkflowService:
|
||||
|
||||
return AsyncTriggerResponse(
|
||||
workflow_trigger_log_id=trigger_log.id,
|
||||
task_id=task.id, # type: ignore
|
||||
task_id=task.id,
|
||||
status="queued",
|
||||
queue=queue_name,
|
||||
)
|
||||
|
||||
@@ -1696,13 +1696,18 @@ class DocumentService:
|
||||
for document in documents
|
||||
if document.data_source_type == "upload_file" and document.data_source_info_dict
|
||||
]
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
# Delete documents first, then dispatch cleanup task after commit
|
||||
# to avoid deadlock between main transaction and async task
|
||||
for document in documents:
|
||||
db.session.delete(document)
|
||||
db.session.commit()
|
||||
|
||||
# Dispatch cleanup task after commit to avoid lock contention
|
||||
# Task cleans up segments, files, and vector indexes
|
||||
if dataset.doc_form is not None:
|
||||
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)
|
||||
|
||||
@staticmethod
|
||||
def rename_document(dataset_id: str, document_id: str, name: str) -> Document:
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -5,7 +5,6 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
dataset_config = {
|
||||
"id": dataset.id,
|
||||
"indexing_technique": dataset.indexing_technique,
|
||||
"tenant_id": dataset.tenant_id,
|
||||
"embedding_model_provider": dataset.embedding_model_provider,
|
||||
"embedding_model": dataset.embedding_model,
|
||||
}
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
document_config = {
|
||||
"id": dataset_document.id,
|
||||
"doc_form": dataset_document.doc_form,
|
||||
"word_count": dataset_document.word_count or 0,
|
||||
}
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
upload_file_key = upload_file.key
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
# Ensure required variables are set before proceeding
|
||||
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||
logger.error("Required configuration not set due to session error")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file_key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file_key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
provider=dataset_config["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset_config["embedding_model"],
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if dataset_document:
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset:
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_attachment_files = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segment_contents = [segment.content for segment in segments]
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
return
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment_content in segment_contents:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||
total_image_files.extend([image_file.key for image_file in image_files])
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
session.delete(segment)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
session.commit()
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
for image_file_key in total_image_files:
|
||||
try:
|
||||
storage.delete(image_file_key)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete segment attachments
|
||||
if attachment_ids:
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
if binding_ids:
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
for attachment_file_key in total_attachment_files:
|
||||
try:
|
||||
storage.delete(attachment_file_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
attachment_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
# Transaction committed and closed
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||
has_error = False
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
has_error = True
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
has_error = True
|
||||
|
||||
if not has_error:
|
||||
with session_factory.create_session() as session:
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
document.id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
logger.warning("Document %s not found after indexing", document.id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
|
||||
@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
|
||||
@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
# Execute the task
|
||||
# Execute the task - should raise ValueError for empty CSV
|
||||
job_id = str(uuid.uuid4())
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
# Verify error handling
|
||||
# Check Redis cache was set to error status
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_value = redis_client.get(cache_key)
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
# Since exception was raised, no segments should be created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
|
||||
@@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
sessions = [] # Track all created sessions
|
||||
# Shared mock data that all sessions will access
|
||||
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
def create_session_side_effect():
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
# Track commit calls
|
||||
commit_mock = MagicMock()
|
||||
session.commit = commit_mock
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# Support session.begin() for transactions
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def begin_exit_side_effect(*args, **kwargs):
|
||||
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||
session.commit()
|
||||
# Also mark wrapper's commit as called
|
||||
if sessions:
|
||||
sessions[0].commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||
session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
sessions.append(session)
|
||||
|
||||
# Setup query with side_effect to handle both Dataset and Document queries
|
||||
def query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
# Support both .first() and .all() calls with chaining
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
# Create an iterator for .first() calls if not exists
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
session.query = MagicMock(side_effect=query_side_effect)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = create_session_side_effect
|
||||
|
||||
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||
class SessionWrapper:
|
||||
def __init__(self):
|
||||
self._sessions = sessions
|
||||
self._shared_data = shared_mock_data
|
||||
# Create a default session for setup phase
|
||||
self._default_session = MagicMock()
|
||||
self._default_session.close = MagicMock()
|
||||
self._default_session.commit = MagicMock()
|
||||
|
||||
# Support session.begin() for default session too
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = self._default_session
|
||||
|
||||
def default_begin_exit_side_effect(*args, **kwargs):
|
||||
self._default_session.commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
def default_query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Forward all attribute access to the first session, or default if none created yet
|
||||
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||
return getattr(target_session, name)
|
||||
|
||||
@property
|
||||
def all_sessions(self):
|
||||
"""Access all created sessions for testing."""
|
||||
return self._sessions
|
||||
|
||||
wrapper = SessionWrapper()
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -252,18 +356,9 @@ class TestTaskEnqueuing:
|
||||
use the deprecated function.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -304,21 +399,9 @@ class TestBatchProcessing:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create an iterator for documents
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -357,19 +440,9 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
@@ -407,19 +480,9 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
@@ -444,7 +507,10 @@ class TestBatchProcessing:
|
||||
"""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Set shared mock data with empty documents list
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = []
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -482,19 +548,9 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -528,19 +584,9 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -635,19 +681,9 @@ class TestErrorHandling:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set up to trigger vector space limit error
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -674,17 +710,9 @@ class TestErrorHandling:
|
||||
Errors during indexing should be caught and logged, but not crash the task.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||
@@ -708,17 +736,9 @@ class TestErrorHandling:
|
||||
but not treated as a failure.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise DocumentIsPausedError
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||
@@ -853,17 +873,9 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen in finally block.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -883,17 +895,9 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen even when errors occur.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||
@@ -962,6 +966,7 @@ class TestAdvancedScenarios:
|
||||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
# Create only 2 documents (simulate one missing)
|
||||
# The new code uses .all() which will only return existing documents
|
||||
mock_documents = []
|
||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||
doc = MagicMock(spec=Document)
|
||||
@@ -971,21 +976,9 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create iterator that returns None for missing document
|
||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
||||
doc_iter = iter(doc_responses)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data - .all() will only return existing documents
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space exactly at limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Billing disabled - limits should not be checked
|
||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||
@@ -1273,19 +1246,9 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1321,19 +1284,9 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1415,17 +1368,9 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1465,17 +1410,9 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1555,19 +1492,9 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space limit to 0 (unlimited)
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1612,19 +1539,9 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set negative vector space limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Configure billing with sufficient limits
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1826,19 +1733,9 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||
@@ -1866,7 +1763,7 @@ class TestRobustness:
|
||||
- No exceptions occur
|
||||
|
||||
Expected behavior:
|
||||
- Database session is closed
|
||||
- All database sessions are closed
|
||||
- No connection leaks
|
||||
"""
|
||||
# Arrange
|
||||
@@ -1879,19 +1776,9 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1899,10 +1786,11 @@ class TestRobustness:
|
||||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert mock_db_session.close.called
|
||||
# Verify close is called exactly once
|
||||
assert mock_db_session.close.call_count == 1
|
||||
# Assert - All created sessions should be closed
|
||||
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||
assert len(mock_db_session.all_sessions) >= 1
|
||||
for session in mock_db_session.all_sessions:
|
||||
assert session.close.called, "All sessions should be closed"
|
||||
|
||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||
"""
|
||||
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@@ -1594,7 +1594,7 @@ requires-dist = [
|
||||
{ name = "gevent", specifier = "~=25.9.1" },
|
||||
{ name = "gmpy2", specifier = "~=2.2.1" },
|
||||
{ name = "google-api-core", specifier = "==2.18.0" },
|
||||
{ name = "google-api-python-client", specifier = "==2.90.0" },
|
||||
{ name = "google-api-python-client", specifier = "==2.189.0" },
|
||||
{ name = "google-auth", specifier = "==2.29.0" },
|
||||
{ name = "google-auth-httplib2", specifier = "==0.2.0" },
|
||||
{ name = "google-cloud-aiplatform", specifier = "==1.49.0" },
|
||||
@@ -2306,7 +2306,7 @@ grpc = [
|
||||
|
||||
[[package]]
|
||||
name = "google-api-python-client"
|
||||
version = "2.90.0"
|
||||
version = "2.189.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google-api-core" },
|
||||
@@ -2315,9 +2315,9 @@ dependencies = [
|
||||
{ name = "httplib2" },
|
||||
{ name = "uritemplate" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/8b/d990f947c261304a5c1599d45717d02c27d46af5f23e1fee5dc19c8fa79d/google-api-python-client-2.90.0.tar.gz", hash = "sha256:cbcb3ba8be37c6806676a49df16ac412077e5e5dc7fa967941eff977b31fba03", size = 10891311, upload-time = "2023-06-20T16:29:25.008Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6f/f8/0783aeca3410ee053d4dd1fccafd85197847b8f84dd038e036634605d083/google_api_python_client-2.189.0.tar.gz", hash = "sha256:45f2d8559b5c895dde6ad3fb33de025f5cb2c197fa5862f18df7f5295a172741", size = 13979470, upload-time = "2026-02-03T19:24:55.432Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/39/03/209b5c36a621ae644dc7d4743746cd3b38b18e133f8779ecaf6b95cc01ce/google_api_python_client-2.90.0-py2.py3-none-any.whl", hash = "sha256:4a41ffb7797d4f28e44635fb1e7076240b741c6493e7c3233c0e4421cec7c913", size = 11379891, upload-time = "2023-06-20T16:29:19.532Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/44/3677ff27998214f2fa7957359da48da378a0ffff1bd0bdaba42e752bc13e/google_api_python_client-2.189.0-py3-none-any.whl", hash = "sha256:a258c09660a49c6159173f8bbece171278e917e104a11f0640b34751b79c8a1a", size = 14547633, upload-time = "2026-02-03T19:24:52.845Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -194,11 +194,11 @@ const ConfigContent: FC<Props> = ({
|
||||
</div>
|
||||
{type === RETRIEVE_TYPE.multiWay && (
|
||||
<>
|
||||
<div className="my-2 flex h-6 items-center py-1">
|
||||
<div className="system-xs-semibold-uppercase mr-2 shrink-0 text-text-secondary">
|
||||
<div className="my-2 flex flex-col items-center py-1">
|
||||
<div className="system-xs-semibold-uppercase mb-2 mr-2 shrink-0 text-text-secondary">
|
||||
{t('rerankSettings', { ns: 'dataset' })}
|
||||
</div>
|
||||
<Divider bgStyle="gradient" className="mx-0 !h-px" />
|
||||
<Divider bgStyle="gradient" className="m-0 !h-px" />
|
||||
</div>
|
||||
{
|
||||
selectedDatasetsMode.inconsistentEmbeddingModel
|
||||
|
||||
Reference in New Issue
Block a user