mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 09:44:00 +00:00
Compare commits
24 Commits
deploy/dev
...
release/e-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
08b8eff933 | ||
|
|
579cdea820 | ||
|
|
125f7e3ab4 | ||
|
|
400ed2fd72 | ||
|
|
840a8f3fc2 | ||
|
|
b4a5296fd1 | ||
|
|
d7c3ae50dc | ||
|
|
b921711e9e | ||
|
|
fb38ad84e1 | ||
|
|
91c854b5be | ||
|
|
d35b231941 | ||
|
|
849b4b8c40 | ||
|
|
990e8feee8 | ||
|
|
53641019b1 | ||
|
|
d1f10ff301 | ||
|
|
c8027e168b | ||
|
|
aae3f76999 | ||
|
|
2860c72b03 | ||
|
|
fcb53383df | ||
|
|
540e1db83c | ||
|
|
2f75e38c08 | ||
|
|
cd03e0a9ef | ||
|
|
df2421d187 | ||
|
|
0ba321d840 |
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
@@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
@@ -499,6 +502,7 @@ class AppListApi(Resource):
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
Workflow.tenant_id == current_tenant_id,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
@@ -510,12 +514,14 @@ class AppListApi(Resource):
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
node_id = None
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
for node_id, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
|
||||
continue
|
||||
|
||||
for app in app_pagination.items:
|
||||
|
||||
@@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
id=args["id"],
|
||||
account=current_user,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.12.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@@ -81,7 +81,7 @@ dependencies = [
|
||||
"starlette==0.49.1",
|
||||
"tiktoken~=0.9.0",
|
||||
"transformers~=4.56.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.8.0",
|
||||
|
||||
@@ -327,6 +327,12 @@ class AccountService:
|
||||
@staticmethod
|
||||
def delete_account(account: Account):
|
||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_account_deletion
|
||||
|
||||
sync_account_deletion(account_id=account.id, source="account_deleted")
|
||||
|
||||
# Now proceed with async account deletion
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
@@ -1230,6 +1236,11 @@ class TenantService:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.clean_billing_info_cache(tenant.id)
|
||||
|
||||
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
|
||||
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
|
||||
|
||||
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
"""Update member role"""
|
||||
|
||||
115
api/services/enterprise/account_deletion_sync.py
Normal file
115
api/services/enterprise/account_deletion_sync.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from redis import RedisError
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
|
||||
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
|
||||
|
||||
|
||||
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Queue an account deletion sync task to Redis.
|
||||
|
||||
Internal helper function. Do not call directly - use the public functions instead.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace/tenant ID to sync
|
||||
member_id: The member/account ID that was removed
|
||||
source: Source of the sync request (for debugging/tracking)
|
||||
|
||||
Returns:
|
||||
bool: True if task was queued successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
task = {
|
||||
"task_id": str(uuid.uuid4()),
|
||||
"workspace_id": workspace_id,
|
||||
"member_id": member_id,
|
||||
"retry_count": 0,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"source": source,
|
||||
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
|
||||
}
|
||||
|
||||
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
|
||||
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
|
||||
|
||||
logger.info(
|
||||
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
|
||||
workspace_id,
|
||||
member_id,
|
||||
task["task_id"],
|
||||
source,
|
||||
)
|
||||
return True
|
||||
|
||||
except (RedisError, TypeError) as e:
|
||||
logger.error(
|
||||
"Failed to queue account deletion sync for workspace %s, member %s: %s",
|
||||
workspace_id,
|
||||
member_id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
# Don't raise - we don't want to fail member deletion if queueing fails
|
||||
return False
|
||||
|
||||
|
||||
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Sync a single workspace member removal (enterprise only).
|
||||
|
||||
Queues a task for the enterprise backend to reassign resources from the removed member.
|
||||
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace/tenant ID
|
||||
member_id: The member/account ID that was removed
|
||||
source: Source of the sync request (e.g., "workspace_member_removed")
|
||||
|
||||
Returns:
|
||||
bool: True if task was queued (or skipped in community), False if queueing failed
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
|
||||
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||
|
||||
|
||||
def sync_account_deletion(account_id: str, *, source: str) -> bool:
|
||||
"""
|
||||
Sync full account deletion across all workspaces (enterprise only).
|
||||
|
||||
Fetches all workspace memberships for the account and queues a sync task for each.
|
||||
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||
|
||||
Args:
|
||||
account_id: The account ID being deleted
|
||||
source: Source of the sync request (e.g., "account_deleted")
|
||||
|
||||
Returns:
|
||||
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return True
|
||||
|
||||
# Fetch all workspaces the account belongs to
|
||||
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
|
||||
|
||||
# Queue sync task for each workspace
|
||||
success = True
|
||||
for join in workspace_joins:
|
||||
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
|
||||
success = False
|
||||
|
||||
return success
|
||||
@@ -2,7 +2,10 @@ import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.account import Account
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -406,20 +409,37 @@ class BuiltinToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
|
||||
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None):
|
||||
"""
|
||||
set default provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
# get provider
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
|
||||
# get provider (verify tenant ownership to prevent IDOR)
|
||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||
if target_provider is None:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# clear default provider
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
# Enterprise: verify admin permission for tenant-wide operation
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
if account is None:
|
||||
# In enterprise mode, an account context is required to perform permission checks
|
||||
raise ValueError("Account is required to set default credentials in enterprise mode")
|
||||
|
||||
if not TenantAccountRole.is_privileged_role(account.current_role):
|
||||
raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode")
|
||||
# Enterprise: clear ALL defaults for this provider in the tenant
|
||||
# (regardless of user_id, since enterprise credentials may have different user_id)
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
else:
|
||||
# Non-enterprise: only clear defaults for the current user
|
||||
session.query(BuiltinToolProvider).filter_by(
|
||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||
).update({"is_default": False})
|
||||
|
||||
# set new default provider
|
||||
target_provider.is_default = True
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -5,7 +5,6 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery import shared_task
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Build index for annotation failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -14,6 +14,9 @@ from models.model import UploadFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Batch size for database operations to keep transactions short
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
@@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
).delete(synchronize_session=False)
|
||||
storage_keys_to_delete: list[str] = []
|
||||
index_node_ids: list[str] = []
|
||||
segment_ids: list[str] = []
|
||||
total_image_upload_file_ids: list[str] = []
|
||||
|
||||
try:
|
||||
# ============ Step 1: Query segment and file data (short read-only transaction) ============
|
||||
with session_factory.create_session() as session:
|
||||
# Get segments info
|
||||
segments = session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Collect image file IDs from segment content
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
||||
for image_file in image_files:
|
||||
try:
|
||||
if image_file and image_file.key:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(stmt)
|
||||
session.delete(segment)
|
||||
total_image_upload_file_ids.extend(image_upload_file_ids)
|
||||
|
||||
# Query storage keys for image files
|
||||
if total_image_upload_file_ids:
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
|
||||
).all()
|
||||
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
|
||||
|
||||
# Query storage keys for document files
|
||||
if file_ids:
|
||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(stmt)
|
||||
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
|
||||
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
|
||||
if index_node_ids:
|
||||
try:
|
||||
# Fetch dataset in a fresh session to avoid DetachedInstanceError
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
|
||||
else:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
len(index_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# ============ Step 3: Delete metadata binding (separate short transaction) ============
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
deleted_count = (
|
||||
session.query(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
|
||||
except Exception:
|
||||
logger.exception("Cleaned documents when documents deleted failed")
|
||||
logger.exception(
|
||||
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
|
||||
if total_image_upload_file_ids:
|
||||
failed_batches = 0
|
||||
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
|
||||
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
failed_batches += 1
|
||||
logger.exception(
|
||||
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
|
||||
i,
|
||||
i + len(batch),
|
||||
dataset_id,
|
||||
)
|
||||
if failed_batches > 0:
|
||||
logger.warning(
|
||||
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
|
||||
failed_batches,
|
||||
total_batches,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
|
||||
if segment_ids:
|
||||
failed_batches = 0
|
||||
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
for i in range(0, len(segment_ids), BATCH_SIZE):
|
||||
batch = segment_ids[i : i + BATCH_SIZE]
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
|
||||
session.execute(segment_delete_stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
failed_batches += 1
|
||||
logger.exception(
|
||||
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
|
||||
i,
|
||||
i + len(batch),
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
if failed_batches > 0:
|
||||
logger.warning(
|
||||
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
|
||||
failed_batches,
|
||||
total_batches,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# ============ Step 6: Delete document-associated files (separate short transaction) ============
|
||||
if file_ids:
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
|
||||
dataset_id,
|
||||
file_ids,
|
||||
)
|
||||
|
||||
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
|
||||
storage_delete_failures = 0
|
||||
for storage_key in storage_keys_to_delete:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
except Exception:
|
||||
storage_delete_failures += 1
|
||||
logger.exception("Failed to delete file from storage, key: %s", storage_key)
|
||||
if storage_delete_failures > 0:
|
||||
logger.warning(
|
||||
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
|
||||
storage_delete_failures,
|
||||
len(storage_keys_to_delete),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
|
||||
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
|
||||
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
|
||||
f"storage_files: {len(storage_keys_to_delete)}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
|
||||
dataset_id,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
||||
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
|
||||
# Initialize variables with default values
|
||||
upload_file_key: str | None = None
|
||||
dataset_config: dict | None = None
|
||||
document_config: dict | None = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
||||
if not upload_file:
|
||||
raise ValueError("UploadFile not found.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
dataset_config = {
|
||||
"id": dataset.id,
|
||||
"indexing_technique": dataset.indexing_technique,
|
||||
"tenant_id": dataset.tenant_id,
|
||||
"embedding_model_provider": dataset.embedding_model_provider,
|
||||
"embedding_model": dataset.embedding_model,
|
||||
}
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
document_config = {
|
||||
"id": dataset_document.id,
|
||||
"doc_form": dataset_document.doc_form,
|
||||
"word_count": dataset_document.word_count or 0,
|
||||
}
|
||||
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
upload_file_key = upload_file.key
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
||||
texts=[segment["content"] for segment in content]
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
# Ensure required variables are set before proceeding
|
||||
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||
logger.error("Required configuration not set due to session error")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
suffix = Path(upload_file_key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file_key, file_path)
|
||||
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for _, row in df.iterrows():
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
data = {"content": row.iloc[0]}
|
||||
content.append(data)
|
||||
if len(content) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset_config["indexing_technique"] == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset_config["tenant_id"],
|
||||
provider=dataset_config["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset_config["embedding_model"],
|
||||
)
|
||||
|
||||
word_count_change = 0
|
||||
if embedding_model:
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||
else:
|
||||
tokens_list = [0] * len(content)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment, tokens in zip(content, tokens_list):
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content,
|
||||
word_count=len(content),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if dataset_document:
|
||||
assert dataset_document.word_count is not None
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Segments batch created index failed")
|
||||
redis_client.setex(indexing_cache_key, 600, "error")
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset:
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
"""
|
||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
total_attachment_files = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
SegmentAttachmentBinding.document_id == document_id,
|
||||
)
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
segment_contents = [segment.content for segment in segments]
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
return
|
||||
|
||||
# check segment is exist
|
||||
if index_node_ids:
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
with session_factory.create_session() as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
index_processor.clean(
|
||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||
image_files = session.scalars(
|
||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
).all()
|
||||
for image_file in image_files:
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file.id,
|
||||
)
|
||||
total_image_files = []
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
for segment_content in segment_contents:
|
||||
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||
total_image_files.extend([image_file.key for image_file in image_files])
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
|
||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||
session.execute(image_file_delete_stmt)
|
||||
session.delete(segment)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
session.commit()
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
# delete segment attachments
|
||||
if attachments_with_bindings:
|
||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
for binding, attachment_file in attachments_with_bindings:
|
||||
try:
|
||||
storage.delete(attachment_file.key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
binding.attachment_id,
|
||||
)
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
||||
)
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
for image_file_key in total_image_files:
|
||||
try:
|
||||
storage.delete(image_file_key)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
logger.exception(
|
||||
"Delete image_files failed when storage deleted, \
|
||||
image_upload_file_is: %s",
|
||||
image_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
if file_id:
|
||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
except Exception:
|
||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||
session.delete(file)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete segment attachments
|
||||
if attachment_ids:
|
||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||
session.execute(attachment_file_delete_stmt)
|
||||
|
||||
if binding_ids:
|
||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||
session.execute(binding_delete_stmt)
|
||||
|
||||
for attachment_file_key in total_attachment_files:
|
||||
try:
|
||||
storage.delete(attachment_file_key)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Delete attachment_file failed when storage deleted, \
|
||||
attachment_file_id: %s",
|
||||
attachment_file_key,
|
||||
)
|
||||
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# delete dataset metadata binding
|
||||
session.query(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -67,8 +68,14 @@ def delete_segment_from_index_task(
|
||||
if segment_attachment_bindings:
|
||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||
for binding in segment_attachment_bindings:
|
||||
session.delete(binding)
|
||||
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
|
||||
|
||||
for i in range(0, len(segment_attachment_bind_ids), 1000):
|
||||
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
|
||||
)
|
||||
session.execute(segment_attachment_bind_delete_stmt)
|
||||
|
||||
# delete upload file
|
||||
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
|
||||
@@ -28,7 +28,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
@@ -68,7 +68,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
document.indexing_status = "error"
|
||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||
document.stopped_at = naive_utc_now()
|
||||
session.commit()
|
||||
return
|
||||
|
||||
loader = NotionExtractor(
|
||||
@@ -85,7 +84,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
if last_edited_time != page_edited_time:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
|
||||
@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
# Phase 1: Update status to parsing (short transaction)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
documents = (
|
||||
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
session.commit()
|
||||
# Transaction committed and closed
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||
has_error = False
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
has_error = True
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
has_error = True
|
||||
|
||||
if not has_error:
|
||||
with session_factory.create_session() as session:
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
# expire all session to get latest document's indexing status
|
||||
session.expire_all()
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
documents = (
|
||||
session.query(Document)
|
||||
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
and document.need_summary is True
|
||||
):
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
document.id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: "
|
||||
"status=%s, doc_form=%s, need_summary=%s",
|
||||
document_id,
|
||||
document.id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
document.need_summary,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
logger.warning("Document %s not found after indexing", document.id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy import delete, select
|
||||
from core.db.session_factory import session_factory
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
@@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
|
||||
if not document:
|
||||
@@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
session.commit()
|
||||
|
||||
# delete all document segment and index
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return
|
||||
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_type = document.doc_form
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
||||
session.execute(segment_delete_stmt)
|
||||
db.session.commit()
|
||||
clean_success = False
|
||||
try:
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if index_node_ids:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
@@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
||||
clean_success = True
|
||||
except Exception:
|
||||
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
if clean_success:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
session.execute(segment_delete_stmt)
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||
|
||||
@@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
||||
|
||||
|
||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||
def del_workflow_archive_log(workflow_archive_log_id: str):
|
||||
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
@@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with session_factory.create_session() as session:
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
|
||||
@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
|
||||
@@ -10,7 +10,10 @@ from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_draft_variables,
|
||||
delete_draft_variables_batch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = session.query(UploadFile).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
@@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
@@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = session.query(UploadFile).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
if app2_obj:
|
||||
session.delete(app2_obj)
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesSessionCommit:
|
||||
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with offload files for session commit tests."""
|
||||
from core.variables.types import SegmentType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(upload_file1)
|
||||
session.add(upload_file2)
|
||||
session.flush()
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
session.add(var_file1)
|
||||
session.add(var_file2)
|
||||
session.flush()
|
||||
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(draft_var1)
|
||||
session.add(draft_var2)
|
||||
session.add(draft_var3)
|
||||
session.commit()
|
||||
|
||||
data = {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
yield data
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def setup_commit_test_data(self, app_and_tenant):
|
||||
"""Create test data for session commit tests."""
|
||||
tenant, app = app_and_tenant
|
||||
variable_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
variables = []
|
||||
for i in range(10):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="test_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(var)
|
||||
variables.append(var)
|
||||
session.commit()
|
||||
variable_ids = [v.id for v in variables]
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"variable_ids": variable_ids,
|
||||
}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
cleanup_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.execute(cleanup_query)
|
||||
session.commit()
|
||||
|
||||
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
|
||||
"""Test that session.begin() is used for automatic transaction management."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Since session.begin() is used, the transaction is automatically committed
|
||||
# when the with block exits successfully. We verify this by checking that
|
||||
# data is actually persisted.
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
# Verify all data was deleted (proves transaction was committed)
|
||||
with session_factory.create_session() as session:
|
||||
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
|
||||
assert deleted_count == 10
|
||||
assert remaining_count == 0
|
||||
|
||||
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
|
||||
"""Test that data is actually persisted to database after batch deletion with commits."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
variable_ids = data["variable_ids"]
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Perform deletion with small batch size to force multiple commits
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify all data is deleted in a new session (proves commits worked)
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
# Verify specific IDs are deleted
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = (
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
|
||||
)
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
|
||||
"""Test session behavior when deleting from an empty dataset."""
|
||||
nonexistent_app_id = str(uuid.uuid4())
|
||||
|
||||
# Should not raise any errors and should return 0
|
||||
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
|
||||
assert deleted_count == 0
|
||||
|
||||
def test_session_commit_with_single_batch(self, setup_commit_test_data):
|
||||
"""Test that commit happens correctly when all data fits in a single batch."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert initial_count == 10
|
||||
|
||||
# Delete all in a single batch
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
|
||||
assert deleted_count == 10
|
||||
|
||||
# Verify data is persisted
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert final_count == 0
|
||||
|
||||
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
|
||||
"""Test that invalid batch size raises ValueError."""
|
||||
data = setup_commit_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, batch_size=-1)
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that session commits correctly when cleaning up offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete variables with offload data
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count == 3
|
||||
|
||||
# Verify all data is persisted (deleted) in new session
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_after == 0
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage cleanup was called
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
mock_storage.download.side_effect = mock_download
|
||||
|
||||
# Execute the task
|
||||
# Execute the task - should raise ValueError for empty CSV
|
||||
job_id = str(uuid.uuid4())
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||
batch_create_segment_to_index_task(
|
||||
job_id=job_id,
|
||||
upload_file_id=upload_file.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
)
|
||||
|
||||
# Verify error handling
|
||||
# Check Redis cache was set to error status
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_value = redis_client.get(cache_key)
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
# Since exception was raised, no segments should be created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
|
||||
|
||||
class TestDocumentIndexingUpdateTask:
|
||||
@pytest.fixture
|
||||
def mock_external_dependencies(self):
|
||||
"""Patch external collaborators used by the update task.
|
||||
- IndexProcessorFactory.init_index_processor().clean(...)
|
||||
- IndexingRunner.run([...])
|
||||
"""
|
||||
with (
|
||||
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
|
||||
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
|
||||
):
|
||||
processor_instance = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = processor_instance
|
||||
|
||||
runner_instance = MagicMock()
|
||||
mock_runner.return_value = runner_instance
|
||||
|
||||
yield {
|
||||
"factory": mock_factory,
|
||||
"processor": processor_instance,
|
||||
"runner": mock_runner,
|
||||
"runner_instance": runner_instance,
|
||||
}
|
||||
|
||||
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
|
||||
fake = Faker()
|
||||
|
||||
# Account and tenant
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=fake.company(), status="normal")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Dataset and document
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=64),
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch="test_batch",
|
||||
name=fake.file_name(),
|
||||
created_from="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_status="waiting",
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Segments
|
||||
node_ids = []
|
||||
for i in range(segment_count):
|
||||
node_id = f"node-{i + 1}"
|
||||
seg = DocumentSegment(
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
position=i,
|
||||
content=fake.text(max_nb_chars=32),
|
||||
answer=None,
|
||||
word_count=10,
|
||||
tokens=5,
|
||||
index_node_id=node_id,
|
||||
status="completed",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(seg)
|
||||
node_ids.append(node_id)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh to ensure ORM state
|
||||
db_session_with_containers.refresh(dataset)
|
||||
db_session_with_containers.refresh(document)
|
||||
|
||||
return dataset, document, node_ids
|
||||
|
||||
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
|
||||
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
document_indexing_update_task(dataset.id, document.id)
|
||||
|
||||
# Ensure we see committed changes from another session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert document status updated before reindex
|
||||
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
|
||||
assert updated.indexing_status == "parsing"
|
||||
assert updated.processing_started_at is not None
|
||||
|
||||
# Segments should be deleted
|
||||
remaining = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||
)
|
||||
assert remaining == 0
|
||||
|
||||
# Assert index processor clean was called with expected args
|
||||
clean_call = mock_external_dependencies["processor"].clean.call_args
|
||||
assert clean_call is not None
|
||||
args, kwargs = clean_call
|
||||
# args[0] is a Dataset instance (from another session) — validate by id
|
||||
assert getattr(args[0], "id", None) == dataset.id
|
||||
# args[1] should contain our node_ids
|
||||
assert set(args[1]) == set(node_ids)
|
||||
assert kwargs.get("with_keywords") is True
|
||||
assert kwargs.get("delete_child_chunks") is True
|
||||
|
||||
# Assert indexing runner invoked with the updated document
|
||||
run_call = mock_external_dependencies["runner_instance"].run.call_args
|
||||
assert run_call is not None
|
||||
run_docs = run_call[0][0]
|
||||
assert len(run_docs) == 1
|
||||
first = run_docs[0]
|
||||
assert getattr(first, "id", None) == document.id
|
||||
|
||||
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
|
||||
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||
|
||||
# Force clean to raise; task should continue to indexing
|
||||
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
|
||||
|
||||
document_indexing_update_task(dataset.id, document.id)
|
||||
|
||||
# Ensure we see committed changes from another session
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Indexing should still be triggered
|
||||
mock_external_dependencies["runner_instance"].run.assert_called_once()
|
||||
|
||||
# Segments should remain (since clean failed before DB delete)
|
||||
remaining = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||
)
|
||||
assert remaining > 0
|
||||
|
||||
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
|
||||
fake = Faker()
|
||||
# Act with non-existent document id
|
||||
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
|
||||
|
||||
# Neither processor nor runner should be called
|
||||
mock_external_dependencies["processor"].clean.assert_not_called()
|
||||
mock_external_dependencies["runner_instance"].run.assert_not_called()
|
||||
@@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||
session = MagicMock()
|
||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
||||
sessions = [] # Track all created sessions
|
||||
# Shared mock data that all sessions will access
|
||||
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
def create_session_side_effect():
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
# Track commit calls
|
||||
commit_mock = MagicMock()
|
||||
session.commit = commit_mock
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# Support session.begin() for transactions
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def begin_exit_side_effect(*args, **kwargs):
|
||||
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||
session.commit()
|
||||
# Also mark wrapper's commit as called
|
||||
if sessions:
|
||||
sessions[0].commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||
session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
sessions.append(session)
|
||||
|
||||
# Setup query with side_effect to handle both Dataset and Document queries
|
||||
def query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
# Support both .first() and .all() calls with chaining
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
# Create an iterator for .first() calls if not exists
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
session.query = MagicMock(side_effect=query_side_effect)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = create_session_side_effect
|
||||
|
||||
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||
class SessionWrapper:
|
||||
def __init__(self):
|
||||
self._sessions = sessions
|
||||
self._shared_data = shared_mock_data
|
||||
# Create a default session for setup phase
|
||||
self._default_session = MagicMock()
|
||||
self._default_session.close = MagicMock()
|
||||
self._default_session.commit = MagicMock()
|
||||
|
||||
# Support session.begin() for default session too
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = self._default_session
|
||||
|
||||
def default_begin_exit_side_effect(*args, **kwargs):
|
||||
self._default_session.commit()
|
||||
|
||||
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||
|
||||
def default_query_side_effect(*args):
|
||||
query = MagicMock()
|
||||
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.first.return_value = shared_mock_data["dataset"]
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||
where_result = MagicMock()
|
||||
where_result.where = MagicMock(return_value=where_result)
|
||||
|
||||
if shared_mock_data["doc_iter"] is None:
|
||||
docs = shared_mock_data["documents"] or [None]
|
||||
shared_mock_data["doc_iter"] = iter(docs)
|
||||
|
||||
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||
docs_or_empty = shared_mock_data["documents"] or []
|
||||
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||
query.where = MagicMock(return_value=where_result)
|
||||
else:
|
||||
query.where = MagicMock(return_value=query)
|
||||
return query
|
||||
|
||||
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Forward all attribute access to the first session, or default if none created yet
|
||||
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||
return getattr(target_session, name)
|
||||
|
||||
@property
|
||||
def all_sessions(self):
|
||||
"""Access all created sessions for testing."""
|
||||
return self._sessions
|
||||
|
||||
wrapper = SessionWrapper()
|
||||
yield wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -252,18 +356,9 @@ class TestTaskEnqueuing:
|
||||
use the deprecated function.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -304,21 +399,9 @@ class TestBatchProcessing:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create an iterator for documents
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
# Return documents one by one for each call
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -357,19 +440,9 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
@@ -407,19 +480,9 @@ class TestBatchProcessing:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
@@ -444,7 +507,10 @@ class TestBatchProcessing:
|
||||
"""
|
||||
# Arrange
|
||||
document_ids = []
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Set shared mock data with empty documents list
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = []
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -482,19 +548,9 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -528,19 +584,9 @@ class TestProgressTracking:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -635,19 +681,9 @@ class TestErrorHandling:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set up to trigger vector space limit error
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -674,17 +710,9 @@ class TestErrorHandling:
|
||||
Errors during indexing should be caught and logged, but not crash the task.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||
@@ -708,17 +736,9 @@ class TestErrorHandling:
|
||||
but not treated as a failure.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise DocumentIsPausedError
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||
@@ -853,17 +873,9 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen in finally block.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -883,17 +895,9 @@ class TestTaskCancellation:
|
||||
Session cleanup should happen even when errors occur.
|
||||
"""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first.side_effect = mock_documents
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||
@@ -962,6 +966,7 @@ class TestAdvancedScenarios:
|
||||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
# Create only 2 documents (simulate one missing)
|
||||
# The new code uses .all() which will only return existing documents
|
||||
mock_documents = []
|
||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||
doc = MagicMock(spec=Document)
|
||||
@@ -971,21 +976,9 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
# Create iterator that returns None for missing document
|
||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
||||
doc_iter = iter(doc_responses)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data - .all() will only return existing documents
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
|
||||
doc.stopped_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space exactly at limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Billing disabled - limits should not be checked
|
||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||
@@ -1273,19 +1246,9 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1321,19 +1284,9 @@ class TestIntegration:
|
||||
|
||||
# Set up rpop to return None for concurrency check (no more tasks)
|
||||
mock_redis.rpop.side_effect = [None]
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1415,17 +1368,9 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1465,17 +1410,9 @@ class TestEdgeCases:
|
||||
mock_document.indexing_status = "waiting"
|
||||
mock_document.processing_started_at = None
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: mock_document
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = [mock_document]
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1555,19 +1492,9 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set vector space limit to 0 (unlimited)
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1612,19 +1539,9 @@ class TestEdgeCases:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Set negative vector space limit
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Configure billing with sufficient limits
|
||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||
@@ -1826,19 +1733,9 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
# Make IndexingRunner raise an exception
|
||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||
@@ -1866,7 +1763,7 @@ class TestRobustness:
|
||||
- No exceptions occur
|
||||
|
||||
Expected behavior:
|
||||
- Database session is closed
|
||||
- All database sessions are closed
|
||||
- No connection leaks
|
||||
"""
|
||||
# Arrange
|
||||
@@ -1879,19 +1776,9 @@ class TestRobustness:
|
||||
doc.processing_started_at = None
|
||||
mock_documents.append(doc)
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
doc_iter = iter(mock_documents)
|
||||
|
||||
def mock_query_side_effect(*args):
|
||||
mock_query = MagicMock()
|
||||
if args[0] == Dataset:
|
||||
mock_query.where.return_value.first.return_value = mock_dataset
|
||||
elif args[0] == Document:
|
||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
||||
return mock_query
|
||||
|
||||
mock_db_session.query.side_effect = mock_query_side_effect
|
||||
# Set shared mock data so all sessions can access it
|
||||
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||
mock_db_session._shared_data["documents"] = mock_documents
|
||||
|
||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value.billing.enabled = False
|
||||
@@ -1899,10 +1786,11 @@ class TestRobustness:
|
||||
# Act
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
assert mock_db_session.close.called
|
||||
# Verify close is called exactly once
|
||||
assert mock_db_session.close.call_count == 1
|
||||
# Assert - All created sessions should be closed
|
||||
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||
assert len(mock_db_session.all_sessions) >= 1
|
||||
for session in mock_db_session.all_sessions:
|
||||
assert session.close.called, "All sessions should be closed"
|
||||
|
||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||
"""
|
||||
|
||||
@@ -114,6 +114,21 @@ def mock_db_session():
|
||||
session = MagicMock()
|
||||
# Ensure tests can observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
|
||||
# Mock session.begin() context manager to auto-commit on exit
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def _begin_exit_side_effect(*args, **kwargs):
|
||||
# session.begin().__exit__() should commit if no exception
|
||||
if args[0] is None: # No exception
|
||||
session.commit()
|
||||
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
|
||||
mock_query.where.return_value = mock_delete_query
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
delete_func("log-1")
|
||||
delete_func(mock_db.session, "log-1")
|
||||
|
||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||
mock_query.where.assert_called_once()
|
||||
|
||||
19
api/uv.lock
generated
19
api/uv.lock
generated
@@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.12.0"
|
||||
version = "1.12.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
@@ -1653,7 +1653,7 @@ requires-dist = [
|
||||
{ name = "starlette", specifier = "==0.49.1" },
|
||||
{ name = "tiktoken", specifier = "~=0.9.0" },
|
||||
{ name = "transformers", specifier = "~=4.56.1" },
|
||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
|
||||
{ name = "weave", specifier = ">=0.52.16" },
|
||||
{ name = "weaviate-client", specifier = "==4.17.0" },
|
||||
{ name = "webvtt-py", specifier = "~=0.5.1" },
|
||||
@@ -4433,15 +4433,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pdfminer-six"
|
||||
version = "20251230"
|
||||
version = "20260107"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6814,12 +6814,12 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "unstructured"
|
||||
version = "0.16.25"
|
||||
version = "0.18.31"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "backoff" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "chardet" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "dataclasses-json" },
|
||||
{ name = "emoji" },
|
||||
{ name = "filetype" },
|
||||
@@ -6827,6 +6827,7 @@ dependencies = [
|
||||
{ name = "langdetect" },
|
||||
{ name = "lxml" },
|
||||
{ name = "nltk" },
|
||||
{ name = "numba" },
|
||||
{ name = "numpy" },
|
||||
{ name = "psutil" },
|
||||
{ name = "python-iso639" },
|
||||
@@ -6839,9 +6840,9 @@ dependencies = [
|
||||
{ name = "unstructured-client" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -707,7 +707,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -749,7 +749,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -788,7 +788,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.12.0
|
||||
image: langgenius/dify-api:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -818,7 +818,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.12.0
|
||||
image: langgenius/dify-web:1.12.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -109,6 +109,7 @@ const AgentTools: FC = () => {
|
||||
tool_parameters: paramsWithDefaultValue,
|
||||
notAuthor: !tool.is_team_authorization,
|
||||
enabled: true,
|
||||
type: tool.provider_type as CollectionType,
|
||||
}
|
||||
}
|
||||
const handleSelectTool = (tool: ToolDefaultValue) => {
|
||||
|
||||
@@ -19,19 +19,21 @@ import { useBoolean, useSessionStorageState } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { Generator } from '@/app/components/base/icons/src/vender/other'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug'
|
||||
import { useGenerateRuleTemplate } from '@/service/use-apps'
|
||||
import { useStore } from '../../../store'
|
||||
import IdeaOutput from './idea-output'
|
||||
import InstructionEditorInBasic from './instruction-editor'
|
||||
import InstructionEditorInWorkflow from './instruction-editor-in-workflow'
|
||||
@@ -83,6 +85,9 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
||||
onFinished,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { appDetail } = useStore(useShallow(state => ({
|
||||
appDetail: state.appDetail,
|
||||
})))
|
||||
const localModel = localStorage.getItem('auto-gen-model')
|
||||
? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model
|
||||
: null
|
||||
@@ -235,6 +240,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
||||
instruction,
|
||||
model_config: model,
|
||||
no_variable: false,
|
||||
app_id: appDetail?.id,
|
||||
})
|
||||
apiRes = {
|
||||
...res,
|
||||
@@ -256,6 +262,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
||||
instruction,
|
||||
ideal_output: ideaOutput,
|
||||
model_config: model,
|
||||
app_id: appDetail?.id,
|
||||
})
|
||||
apiRes = res
|
||||
if (error) {
|
||||
|
||||
@@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
|
||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import FileList from '../../assets/file-list-3-fill.svg'
|
||||
import Note from '../../assets/note-mod.svg'
|
||||
@@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
</div>
|
||||
))}
|
||||
{
|
||||
showSummaryIndexSetting && (
|
||||
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||
<div className="mt-3">
|
||||
<SummaryIndexSetting
|
||||
entry="create-document"
|
||||
|
||||
@@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { DataSourceType, DocumentActionType } from '@/models/datasets'
|
||||
import {
|
||||
useDocumentArchive,
|
||||
@@ -263,10 +264,14 @@ const Operations = ({
|
||||
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
)}
|
||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
{
|
||||
IS_CE_EDITION && (
|
||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Divider className="my-1" />
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
const i18nPrefix = 'batchAction'
|
||||
@@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
||||
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
{onBatchSummary && (
|
||||
{onBatchSummary && IS_CE_EDITION && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="gap-x-0.5 px-3"
|
||||
|
||||
@@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
@@ -359,7 +360,7 @@ const Form = () => {
|
||||
{
|
||||
indexMethod === IndexingType.QUALIFIED
|
||||
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
|
||||
&& (
|
||||
&& IS_CE_EDITION && (
|
||||
<>
|
||||
<Divider
|
||||
type="horizontal"
|
||||
|
||||
@@ -104,7 +104,7 @@ const MembersPage = () => {
|
||||
<UpgradeBtn className="mr-2" loc="member-invite" />
|
||||
)}
|
||||
<div className="shrink-0">
|
||||
<InviteButton disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)} />
|
||||
{isCurrentWorkspaceManager && <InviteButton disabled={isMemberFull} onClick={() => setInviteModalVisible(true)} />}
|
||||
</div>
|
||||
</div>
|
||||
<div className="overflow-visible lg:overflow-visible">
|
||||
|
||||
@@ -129,6 +129,7 @@ export const useToolSelectorState = ({
|
||||
extra: {
|
||||
description: tool.tool_description,
|
||||
},
|
||||
type: tool.provider_type,
|
||||
}
|
||||
}, [])
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ export type ToolValue = {
|
||||
enabled?: boolean
|
||||
extra?: { description?: string } & Record<string, unknown>
|
||||
credential_id?: string
|
||||
type?: string
|
||||
}
|
||||
|
||||
export type DataSourceItem = {
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
Group,
|
||||
} from '@/app/components/workflow/nodes/_base/components/layout'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import Split from '../_base/components/split'
|
||||
import ChunkStructure from './components/chunk-structure'
|
||||
import EmbeddingModel from './components/embedding-model'
|
||||
@@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
||||
{
|
||||
data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
|
||||
&& (
|
||||
&& IS_CE_EDITION && (
|
||||
<>
|
||||
<SummaryIndexSetting
|
||||
summaryIndexSetting={data.summary_index_setting}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.12.0",
|
||||
"version": "1.12.1",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
|
||||
"imports": {
|
||||
|
||||
2205
web/pnpm-lock.yaml
generated
2205
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,6 @@ import type {
|
||||
} from '@/types/workflow'
|
||||
import { get, post } from './base'
|
||||
import { getFlowPrefix } from './utils'
|
||||
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
|
||||
|
||||
export const fetchWorkflowDraft = (url: string) => {
|
||||
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
|
||||
@@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
|
||||
url: string
|
||||
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
|
||||
}) => {
|
||||
const sanitized = sanitizeWorkflowDraftPayload(params)
|
||||
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
|
||||
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
|
||||
}
|
||||
|
||||
export const fetchNodesDefaultConfigs = (url: string) => {
|
||||
|
||||
Reference in New Issue
Block a user