Compare commits

..

6 Commits

Author SHA1 Message Date
CodingOnStar
2e0a7857f0 feat: enhance candidate node interactions with iteration constraints and add block functionality 2025-10-24 14:47:29 +08:00
-LAN-
dc7ce125ad chore: disable postgres timeouts for docker workflows (#27397)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2025-10-24 13:46:36 +08:00
Novice
eabdb09f8e fix: support webapp passport token with end_user_id in web API auth (#27396) 2025-10-24 13:29:47 +08:00
Yunlu Wen
fa6d03c979 Fix/refresh token (#27381) 2025-10-24 13:09:34 +08:00
Novice
634fb192ef fix: remove unnecessary Flask context preservation to avoid circular import in audio service (#27380) 2025-10-24 10:41:14 +08:00
crazywoola
a4b38e7521 Revert "Sync log detail drawer with conversation_id query parameter, so that we can share a specific conversation" (#27382) 2025-10-24 10:40:41 +08:00
36 changed files with 398 additions and 3796 deletions

View File

@@ -1107,13 +1107,6 @@ class SwaggerUIConfig(BaseSettings):
)
class TenantSelfTaskQueueConfig(BaseSettings):
TENANT_SELF_TASK_QUEUE_PULL_SIZE: int = Field(
description="Default batch size for tenant self task queue pull operations",
default=1,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@@ -1138,7 +1131,6 @@ class FeatureConfig(
RagEtlConfig,
RepositoryConfig,
SecurityConfig,
TenantSelfTaskQueueConfig,
ToolConfig,
UpdateConfig,
WorkflowConfig,

View File

@@ -29,6 +29,7 @@ from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_refresh_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
@@ -270,7 +271,7 @@ class EmailCodeLoginApi(Resource):
class RefreshTokenApi(Resource):
def post(self):
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
refresh_token = extract_refresh_token(request)
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401

View File

@@ -41,14 +41,18 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.datasource_provider_service import DatasourceProviderService
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
@@ -244,7 +248,34 @@ class PipelineGenerator(BaseAppGenerator):
)
if rag_pipeline_invoke_entities:
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
# store the rag_pipeline_invoke_entities to object storage
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
features = FeatureService.get_features(dataset.tenant_id)
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
if redis_client.get(tenant_pipeline_task_key):
# Add to waiting queue using List operations (lpush)
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
else:
# Set flag and execute task
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
else:
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file.id,
tenant_id=dataset.tenant_id,
)
# return batch, dataset, documents
return {
"batch": batch,

View File

@@ -1,13 +0,0 @@
from dataclasses import dataclass
@dataclass
class DocumentTask:
"""Document task entity for document indexing operations.
This class represents a document indexing task that can be queued
and processed by the document indexing system.
"""
tenant_id: str
dataset_id: str
document_ids: list[str]

View File

@@ -1,89 +0,0 @@
import json
from dataclasses import dataclass
from typing import Any
from extensions.ext_redis import redis_client
TASK_WRAPPER_PREFIX = "__WRAPPER__:"
@dataclass
class TaskWrapper:
data: Any
def serialize(self) -> str:
return json.dumps(self.data, ensure_ascii=False)
@classmethod
def deserialize(cls, serialized_data: str) -> 'TaskWrapper':
data = json.loads(serialized_data)
return cls(data)
class TenantSelfTaskQueue:
"""
Simple queue for tenant self tasks, used for tenant self task isolation.
It uses Redis list to store tasks, and Redis key to store task waiting flag.
Support tasks that can be serialized by json.
"""
DEFAULT_TASK_TTL = 60 * 60
def __init__(self, tenant_id: str, unique_key: str):
self.tenant_id = tenant_id
self.unique_key = unique_key
self.queue = f"tenant_self_{unique_key}_task_queue:{tenant_id}"
self.task_key = f"tenant_{unique_key}_task:{tenant_id}"
def get_task_key(self):
return redis_client.get(self.task_key)
def set_task_waiting_time(self, ttl: int | None = None):
ttl = ttl or self.DEFAULT_TASK_TTL
redis_client.setex(self.task_key, ttl, 1)
def delete_task_key(self):
redis_client.delete(self.task_key)
def push_tasks(self, tasks: list):
serialized_tasks = []
for task in tasks:
# Store str list directly, maintaining full compatibility for pipeline scenarios
if isinstance(task, str):
serialized_tasks.append(task)
else:
# Use TaskWrapper to do JSON serialization, add prefix for identification
wrapper = TaskWrapper(task)
serialized_data = wrapper.serialize()
serialized_tasks.append(f"{TASK_WRAPPER_PREFIX}{serialized_data}")
redis_client.lpush(self.queue, *serialized_tasks)
def pull_tasks(self, count: int = 1) -> list:
if count <= 0:
return []
tasks = []
for _ in range(count):
serialized_task = redis_client.rpop(self.queue)
if not serialized_task:
break
if isinstance(serialized_task, bytes):
serialized_task = serialized_task.decode('utf-8')
# Check if use TaskWrapper or not
if serialized_task.startswith(TASK_WRAPPER_PREFIX):
try:
wrapper_data = serialized_task[len(TASK_WRAPPER_PREFIX):]
wrapper = TaskWrapper.deserialize(wrapper_data)
tasks.append(wrapper.data)
except (json.JSONDecodeError, TypeError, ValueError):
tasks.append(serialized_task)
else:
tasks.append(serialized_task)
return tasks
def get_next_task(self):
tasks = self.pull_tasks(1)
return tasks[0] if tasks else None

View File

@@ -6,10 +6,11 @@ from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from libs.token import extract_access_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
@@ -61,14 +62,30 @@ def load_user_from_request(request_from_flask_login):
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
if webapp_token:
decoded = PassportService().verify(webapp_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
raise Unauthorized("Invalid Authorization token for web API.")
elif request.blueprint == "mcp":
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:

View File

@@ -38,9 +38,6 @@ def _real_cookie_name(cookie_name: str) -> str:
def _try_extract_from_header(request: Request) -> str | None:
"""
Try to extract access token from header
"""
auth_header = request.headers.get("Authorization")
if auth_header:
if " " not in auth_header:
@@ -55,27 +52,19 @@ def _try_extract_from_header(request: Request) -> str | None:
return None
def extract_refresh_token(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
def extract_csrf_token(request: Request) -> str | None:
"""
Try to extract CSRF token from header or cookie.
"""
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
def extract_csrf_token_from_cookie(request: Request) -> str | None:
"""
Try to extract CSRF token from cookie.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
"""
Try to extract access token from cookie, header or params.
Access token is either for console session or webapp passport exchange.
"""
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
@@ -83,20 +72,10 @@ def extract_access_token(request: Request) -> str | None:
def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
Webapp access token (part of passport) is only used for webapp session.
"""
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))

View File

@@ -82,54 +82,51 @@ class AudioService:
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:
workflow = app_model.workflow
if (
workflow is None
or "text_to_speech" not in workflow.features_dict
or not workflow.features_dict["text_to_speech"].get("enabled")
):
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if is_draft:
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
else:
workflow = app_model.workflow
if (
workflow is None
or "text_to_speech" not in workflow.features_dict
or not workflow.features_dict["text_to_speech"].get("enabled")
):
raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
voice = text_to_speech_dict.get("voice")
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled")
voice = text_to_speech_dict.get("voice")
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
)
try:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
)
except Exception as e:
raise e
if message_id:
try:

View File

@@ -49,7 +49,6 @@ from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@@ -79,6 +78,7 @@ from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.disable_segments_from_index_task import disable_segments_from_index_task
from tasks.document_indexing_task import document_indexing_task
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
@@ -1687,7 +1687,7 @@ class DocumentService:
# trigger async task
if document_ids:
DocumentIndexingTaskProxy(dataset.tenant_id, dataset.id, document_ids).delay()
document_indexing_task.delay(dataset.id, document_ids)
if duplicate_document_ids:
duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)

View File

@@ -1,82 +0,0 @@
import logging
from collections.abc import Callable
from dataclasses import asdict
from functools import cached_property
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.feature_service import FeatureService
from tasks.document_indexing_task import normal_document_indexing_task, priority_document_indexing_task
logger = logging.getLogger(__name__)
class DocumentIndexingTaskProxy:
def __init__(self, tenant_id: str, dataset_id: str, document_ids: list[str]):
self.tenant_id = tenant_id
self.dataset_id = dataset_id
self.document_ids = document_ids
self.tenant_self_task_queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
@cached_property
def features(self):
return FeatureService.get_features(self.tenant_id)
def _send_to_direct_queue(self, task_func: Callable):
logger.info("send dataset %s to direct queue", self.dataset_id)
task_func.delay( # type: ignore
tenant_id=self.tenant_id,
dataset_id=self.dataset_id,
document_ids=self.document_ids
)
def _send_to_tenant_queue(self, task_func: Callable):
logger.info("send dataset %s to tenant queue", self.dataset_id)
if self.tenant_self_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self.tenant_self_task_queue.push_tasks([
asdict(
DocumentTask(tenant_id=self.tenant_id, dataset_id=self.dataset_id, document_ids=self.document_ids)
)
])
logger.info("push tasks: %s - %s", self.dataset_id, self.document_ids)
else:
# Set flag and execute task
self.tenant_self_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=self.tenant_id,
dataset_id=self.dataset_id,
document_ids=self.document_ids
)
logger.info("init tasks: %s - %s", self.dataset_id, self.document_ids)
def _send_to_default_tenant_queue(self):
self._send_to_tenant_queue(normal_document_indexing_task)
def _send_to_priority_tenant_queue(self):
self._send_to_tenant_queue(priority_document_indexing_task)
def _send_to_priority_direct_queue(self):
self._send_to_direct_queue(priority_document_indexing_task)
def _dispatch(self):
logger.info(
"dispatch args: %s - %s - %s",
self.tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan
)
# dispatch to different indexing queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == "sandbox":
# dispatch to normal pipeline queue with tenant self sub queue for sandbox plan
self._send_to_default_tenant_queue()
else:
# dispatch to priority pipeline queue with tenant self sub queue for other plans
self._send_to_priority_tenant_queue()
else:
# dispatch to priority queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue()
def delay(self):
self._dispatch()

View File

@@ -1,101 +0,0 @@
import json
import logging
from collections.abc import Callable
from functools import cached_property
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from extensions.ext_database import db
from services.feature_service import FeatureService
from services.file_service import FileService
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
logger = logging.getLogger(__name__)
class RagPipelineTaskProxy:
def __init__(
self, dataset_tenant_id: str, user_id: str, rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity]
):
self.dataset_tenant_id = dataset_tenant_id
self.user_id = user_id
self.rag_pipeline_invoke_entities = rag_pipeline_invoke_entities
self.tenant_self_pipeline_task_queue = TenantSelfTaskQueue(dataset_tenant_id, "pipeline")
@cached_property
def features(self):
return FeatureService.get_features(self.dataset_tenant_id)
def _upload_invoke_entities(self) -> str:
text = [item.model_dump() for item in self.rag_pipeline_invoke_entities]
name = "rag_pipeline_invoke_entities.json"
# Convert list to proper JSON string
json_text = json.dumps(text)
upload_file = FileService(db.engine).upload_text(json_text, name, self.user_id, self.dataset_tenant_id)
return upload_file.id
def _send_to_direct_queue(self, upload_file_id: str, task_func: Callable):
logger.info("send file %s to direct queue", upload_file_id)
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self.dataset_tenant_id,
)
def _send_to_tenant_queue(self, upload_file_id: str, task_func: Callable):
logger.info("send file %s to tenant queue", upload_file_id)
if self.tenant_self_pipeline_task_queue.get_task_key():
# Add to waiting queue using List operations (lpush)
self.tenant_self_pipeline_task_queue.push_tasks([upload_file_id])
logger.info("push tasks: %s", upload_file_id)
else:
# Set flag and execute task
self.tenant_self_pipeline_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id=self.dataset_tenant_id,
)
logger.info("init tasks: %s", upload_file_id)
def _send_to_default_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, rag_pipeline_run_task)
def _send_to_priority_tenant_queue(self, upload_file_id: str):
self._send_to_tenant_queue(upload_file_id, priority_rag_pipeline_run_task)
def _send_to_priority_direct_queue(self, upload_file_id: str):
self._send_to_direct_queue(upload_file_id, priority_rag_pipeline_run_task)
def _dispatch(self):
upload_file_id = self._upload_invoke_entities()
if not upload_file_id:
raise ValueError("upload_file_id is empty")
logger.info(
"dispatch args: %s - %s - %s",
self.dataset_tenant_id,
self.features.billing.enabled,
self.features.billing.subscription.plan
)
# dispatch to different pipeline queue with tenant isolation when billing enabled
if self.features.billing.enabled:
if self.features.billing.subscription.plan == "sandbox":
# dispatch to normal pipeline queue with tenant isolation for sandbox plan
self._send_to_default_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue with tenant isolation for other plans
self._send_to_priority_tenant_queue(upload_file_id)
else:
# dispatch to priority pipeline queue without tenant isolation for others, e.g.: self-hosted or enterprise
self._send_to_priority_direct_queue(upload_file_id)
def delay(self):
if not self.rag_pipeline_invoke_entities:
logger.warning(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
self.dataset_tenant_id,
self.user_id
)
return
self._dispatch()

View File

@@ -1,14 +1,11 @@
import logging
import time
from collections.abc import Callable
import click
from celery import shared_task
from configs import dify_config
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantSelfTaskQueue
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
@@ -24,24 +21,8 @@ def document_indexing_task(dataset_id: str, document_ids: list):
:param dataset_id:
:param document_ids:
.. warning:: TO BE DEPRECATED
This function will be deprecated and removed in a future version.
Use normal_document_indexing_task or priority_document_indexing_task instead.
Usage: document_indexing_task.delay(dataset_id, document_ids)
"""
logger.warning("document indexing legacy mode received: %s - %s", dataset_id, document_ids)
_document_indexing(dataset_id, document_ids)
def _document_indexing(dataset_id: str, document_ids: list):
"""
Process document for tasks
:param dataset_id:
:param document_ids:
Usage: _document_indexing(dataset_id, document_ids)
"""
documents = []
start_at = time.perf_counter()
@@ -105,61 +86,3 @@ def _document_indexing(dataset_id: str, document_ids: list):
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
def _document_indexing_with_tenant_queue(tenant_id: str, dataset_id: str, document_ids: list, task_func: Callable):
try:
_document_indexing(dataset_id, document_ids)
except Exception:
logger.exception("Error processing document indexing %s for tenant %s: %s", dataset_id, tenant_id)
finally:
tenant_self_task_queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_tasks = tenant_self_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("document indexing tenant isolation queue next tasks: %s", next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
else:
# No more waiting tasks, clear the flag
tenant_self_task_queue.delete_task_key()
@shared_task(queue="dataset")
def normal_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: list):
"""
Async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: normal_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("normal document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, normal_document_indexing_task)
@shared_task(queue="priority_dataset")
def priority_document_indexing_task(tenant_id: str, dataset_id: str, document_ids: list):
"""
Priority async process document
:param tenant_id:
:param dataset_id:
:param document_ids:
Usage: priority_document_indexing_task.delay(tenant_id, dataset_id, document_ids)
"""
logger.info("priority document indexing task received: %s - %s - %s", tenant_id, dataset_id, document_ids)
_document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, priority_document_indexing_task)

View File

@@ -12,10 +12,8 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models import Account, Tenant
@@ -24,8 +22,6 @@ from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
@@ -73,27 +69,6 @@ def priority_rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = TenantSelfTaskQueue(tenant_id, "pipeline")
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_self_pipeline_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("priority rag pipeline tenant isolation queue next files: %s", next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_pipeline_task_queue.set_task_waiting_time()
priority_rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_self_pipeline_task_queue.delete_task_key()
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@@ -12,20 +12,17 @@ from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
logger = logging.getLogger(__name__)
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
@@ -73,27 +70,26 @@ def rag_pipeline_run_task(
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = TenantSelfTaskQueue(tenant_id, "pipeline")
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_ids = tenant_self_pipeline_task_queue.pull_tasks(count=dify_config.TENANT_SELF_TASK_QUEUE_PULL_SIZE)
logger.info("rag pipeline tenant isolation queue next files: %s", next_file_ids)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_self_pipeline_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
tenant_self_pipeline_task_queue.delete_task_key()
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()

View File

@@ -1 +0,0 @@
# Test containers integration tests for core RAG pipeline components

View File

@@ -1 +0,0 @@
# Test containers integration tests for RAG pipeline

View File

@@ -1 +0,0 @@
# Test containers integration tests for RAG pipeline queue

View File

@@ -1,663 +0,0 @@
"""
Integration tests for TenantSelfTaskQueue using testcontainers.
These tests verify the Redis-based task queue functionality with real Redis instances,
testing tenant isolation, task serialization, and queue operations in a realistic environment.
Includes compatibility tests for migrating from legacy string-only queues.
All tests use generic naming to avoid coupling to specific business implementations.
"""
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
import pytest
from faker import Faker
from core.rag.pipeline.queue import TASK_WRAPPER_PREFIX, TaskWrapper, TenantSelfTaskQueue
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@dataclass
class TestTask:
"""Test task data structure for testing complex object serialization."""
task_id: str
tenant_id: str
data: dict[str, Any]
metadata: dict[str, Any]
class TestTenantSelfTaskQueueIntegration:
"""Integration tests for TenantSelfTaskQueue using testcontainers."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantSelfTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantSelfTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "test-key")
assert queue.tenant_id == tenant.id
assert queue.unique_key == "test-key"
assert queue.queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue.task_key == f"tenant_test-key_task:{tenant.id}"
assert queue.DEFAULT_TASK_TTL == 60 * 60
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
# Create second tenant
tenant2 = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant2)
db_session_with_containers.commit()
queue1 = TenantSelfTaskQueue(tenant1.id, "same-key")
queue2 = TenantSelfTaskQueue(tenant2.id, "same-key")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2.queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantSelfTaskQueue(tenant.id, "key1")
queue2 = TenantSelfTaskQueue(tenant.id, "key2")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == f"tenant_self_key1_task_queue:{tenant.id}"
assert queue2.queue == f"tenant_self_key2_task_queue:{tenant.id}"
def test_task_key_operations(self, test_queue):
"""Test task key operations (get, set, delete)."""
# Initially no task key should exist
assert test_queue.get_task_key() is None
# Set task waiting time with default TTL
test_queue.set_task_waiting_time()
task_key = test_queue.get_task_key()
# Redis returns bytes, convert to string for comparison
assert task_key in (b"1", "1")
# Set task waiting time with custom TTL
custom_ttl = 30
test_queue.set_task_waiting_time(custom_ttl)
task_key = test_queue.get_task_key()
assert task_key in (b"1", "1")
# Delete task key
test_queue.delete_task_key()
assert test_queue.get_task_key() is None
def test_push_and_pull_string_tasks(self, test_queue):
"""Test pushing and pulling string tasks."""
tasks = ["task1", "task2", "task3"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull tasks one by one (FIFO order)
pulled_tasks = []
for _ in range(3):
task = test_queue.get_next_task()
if task:
pulled_tasks.append(task)
# Should get tasks in FIFO order (lpush + rpop = FIFO)
assert pulled_tasks == ["task1", "task2", "task3"]
def test_push_and_pull_multiple_tasks(self, test_queue):
"""Test pushing and pulling multiple tasks at once."""
tasks = ["task1", "task2", "task3", "task4", "task5"]
# Push tasks
test_queue.push_tasks(tasks)
# Pull multiple tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
assert pulled_tasks == ["task1", "task2", "task3"]
# Pull remaining tasks
remaining_tasks = test_queue.pull_tasks(5)
assert len(remaining_tasks) == 2
assert remaining_tasks == ["task4", "task5"]
def test_push_and_pull_complex_objects(self, test_queue, fake):
"""Test pushing and pulling complex object tasks."""
# Create complex task objects as dictionaries (not dataclass instances)
tasks = [
{
"task_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"data": {
"file_id": str(uuid4()),
"content": fake.text(),
"metadata": {"size": fake.random_int(1000, 10000)}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": fake.words(3)
}
},
{
"task_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"data": {
"file_id": str(uuid4()),
"content": "测试中文内容",
"metadata": {"size": fake.random_int(1000, 10000)}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": ["中文", "测试", "emoji🚀"]
}
}
]
# Push complex tasks
test_queue.push_tasks(tasks)
# Pull tasks
pulled_tasks = test_queue.pull_tasks(2)
assert len(pulled_tasks) == 2
# Verify deserialized tasks match original (FIFO order)
for i, pulled_task in enumerate(pulled_tasks):
original_task = tasks[i] # FIFO order
assert isinstance(pulled_task, dict)
assert pulled_task["task_id"] == original_task["task_id"]
assert pulled_task["tenant_id"] == original_task["tenant_id"]
assert pulled_task["data"] == original_task["data"]
assert pulled_task["metadata"] == original_task["metadata"]
def test_mixed_task_types(self, test_queue, fake):
"""Test pushing and pulling mixed string and object tasks."""
string_task = "simple_string_task"
object_task = {
"task_id": str(uuid4()),
"dataset_id": str(uuid4()),
"document_ids": [str(uuid4()) for _ in range(3)]
}
tasks = [string_task, object_task, "another_string"]
# Push mixed tasks
test_queue.push_tasks(tasks)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(3)
assert len(pulled_tasks) == 3
# Verify types and content
assert pulled_tasks[0] == string_task
assert isinstance(pulled_tasks[1], dict)
assert pulled_tasks[1] == object_task
assert pulled_tasks[2] == "another_string"
def test_empty_queue_operations(self, test_queue):
"""Test operations on empty queue."""
# Pull from empty queue
tasks = test_queue.pull_tasks(5)
assert tasks == []
# Get next task from empty queue
task = test_queue.get_next_task()
assert task is None
# Pull zero or negative count
assert test_queue.pull_tasks(0) == []
assert test_queue.pull_tasks(-1) == []
def test_task_ttl_expiration(self, test_queue):
"""Test task key TTL expiration."""
# Set task with short TTL
short_ttl = 2
test_queue.set_task_waiting_time(short_ttl)
# Verify task key exists
assert test_queue.get_task_key() == b"1" or test_queue.get_task_key() == "1"
# Wait for TTL to expire
time.sleep(short_ttl + 1)
# Verify task key has expired
assert test_queue.get_task_key() is None
def test_large_task_batch(self, test_queue, fake):
"""Test handling large batches of tasks."""
# Create large batch of tasks
large_batch = []
for i in range(100):
task = {
"task_id": str(uuid4()),
"index": i,
"data": fake.text(max_nb_chars=100),
"metadata": {"batch_id": str(uuid4())}
}
large_batch.append(task)
# Push large batch
test_queue.push_tasks(large_batch)
# Pull all tasks
pulled_tasks = test_queue.pull_tasks(100)
assert len(pulled_tasks) == 100
# Verify all tasks were retrieved correctly (FIFO order)
for i, task in enumerate(pulled_tasks):
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
# Create multiple queues for the same tenant
queue1 = TenantSelfTaskQueue(tenant.id, "queue1")
queue2 = TenantSelfTaskQueue(tenant.id, "queue2")
# Push tasks to different queues
queue1.push_tasks(["task1_queue1", "task2_queue1"])
queue2.push_tasks(["task1_queue2", "task2_queue2"])
# Verify queues are isolated
tasks1 = queue1.pull_tasks(2)
tasks2 = queue2.pull_tasks(2)
assert tasks1 == ["task1_queue1", "task2_queue1"]
assert tasks2 == ["task1_queue2", "task2_queue2"]
assert tasks1 != tasks2
def test_task_wrapper_serialization_roundtrip(self, test_queue, fake):
"""Test TaskWrapper serialization and deserialization roundtrip."""
# Create complex nested data
complex_data = {
"id": str(uuid4()),
"nested": {
"deep": {
"value": "test",
"numbers": [1, 2, 3, 4, 5],
"unicode": "测试中文",
"emoji": "🚀"
}
},
"metadata": {
"created_at": fake.iso8601(),
"tags": ["tag1", "tag2", "tag3"]
}
}
# Create wrapper and serialize
wrapper = TaskWrapper(complex_data)
serialized = wrapper.serialize()
# Verify serialization
assert isinstance(serialized, str)
assert "测试中文" in serialized
assert "🚀" in serialized
# Deserialize and verify
deserialized_wrapper = TaskWrapper.deserialize(serialized)
assert deserialized_wrapper.data == complex_data
def test_error_handling_invalid_json(self, test_queue):
"""Test error handling for invalid JSON in wrapped tasks."""
# Manually create invalid wrapped task
invalid_wrapped_task = f"{TASK_WRAPPER_PREFIX}invalid json data"
# Push invalid task directly to Redis
redis_client.lpush(test_queue.queue, invalid_wrapped_task)
# Pull task - should fall back to string
task = test_queue.get_next_task()
assert task == invalid_wrapped_task
def test_real_world_processing_scenario(self, test_queue, fake):
"""Test realistic task processing scenario."""
# Simulate various task types
tasks = []
for i in range(5):
task = {
"tenant_id": test_queue.tenant_id,
"resource_id": str(uuid4()),
"resource_ids": [str(uuid4()) for _ in range(fake.random_int(1, 5))]
}
tasks.append(task)
# Push all tasks
test_queue.push_tasks(tasks)
# Simulate processing tasks one by one
processed_tasks = []
while True:
task = test_queue.get_next_task()
if task is None:
break
processed_tasks.append(task)
# Simulate task processing time
time.sleep(0.01)
# Verify all tasks were processed
assert len(processed_tasks) == 5
# Verify task content
for task in processed_tasks:
assert isinstance(task, dict)
assert "tenant_id" in task
assert "resource_id" in task
assert "resource_ids" in task
assert task["tenant_id"] == test_queue.tenant_id
def test_real_world_batch_processing_scenario(self, test_queue, fake):
"""Test realistic batch processing scenario."""
# Simulate batch processing tasks
batch_tasks = []
for i in range(3):
task = {
"file_id": str(uuid4()),
"tenant_id": test_queue.tenant_id,
"user_id": str(uuid4()),
"processing_config": {
"model": fake.random_element(["model_a", "model_b", "model_c"]),
"temperature": fake.random.uniform(0.1, 1.0),
"max_tokens": fake.random_int(1000, 4000)
},
"metadata": {
"source": fake.random_element(["upload", "api", "webhook"]),
"priority": fake.random_element(["low", "normal", "high"])
}
}
batch_tasks.append(task)
# Push tasks
test_queue.push_tasks(batch_tasks)
# Process tasks in batches
batch_size = 2
processed_tasks = []
while True:
batch = test_queue.pull_tasks(batch_size)
if not batch:
break
processed_tasks.extend(batch)
# Verify all tasks were processed
assert len(processed_tasks) == 3
# Verify task structure
for task in processed_tasks:
assert isinstance(task, dict)
assert "file_id" in task
assert "tenant_id" in task
assert "processing_config" in task
assert "metadata" in task
assert task["tenant_id"] == test_queue.tenant_id
class TestTenantSelfTaskQueueCompatibility:
"""Compatibility tests for migrating from legacy string-only queues."""
@pytest.fixture
def fake(self):
"""Faker instance for generating test data."""
return Faker()
@pytest.fixture
def test_tenant_and_account(self, db_session_with_containers, fake):
"""Create test tenant and account for testing."""
# Create account
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake):
"""
Test compatibility with legacy queues containing only string data.
This simulates the scenario where Redis queues already contain string data
from the old architecture, and we need to ensure the new code can read them.
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "legacy_queue")
# Simulate legacy string data in Redis queue (using old format)
legacy_strings = [
"legacy_task_1",
"legacy_task_2",
"legacy_task_3",
"legacy_task_4",
"legacy_task_5"
]
# Manually push legacy strings directly to Redis (simulating old system)
for legacy_string in legacy_strings:
redis_client.lpush(queue.queue, legacy_string)
# Verify new code can read legacy string data
pulled_tasks = queue.pull_tasks(5)
assert len(pulled_tasks) == 5
# Verify all tasks are strings (not wrapped)
for task in pulled_tasks:
assert isinstance(task, str)
assert task.startswith("legacy_task_")
# Verify order (FIFO from Redis list)
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake):
"""
Test complete migration scenario from legacy to new system.
This simulates the real-world scenario where:
1. Legacy system has string data in Redis
2. New system starts processing the same queue
3. Both legacy and new tasks coexist during migration
4. New system can handle both formats seamlessly
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "migration_queue")
# Phase 1: Legacy system has data
legacy_tasks = [f"legacy_resource_{i}" for i in range(1, 6)]
redis_client.lpush(queue.queue, *legacy_tasks)
# Phase 2: New system starts processing legacy data
processed_legacy = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_legacy.append(task)
# Verify legacy data was processed correctly
assert len(processed_legacy) == 5
for task in processed_legacy:
assert isinstance(task, str)
assert task.startswith("legacy_resource_")
# Phase 3: New system adds new tasks (mixed types)
new_string_tasks = ["new_resource_1", "new_resource_2"]
new_object_tasks = [
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]}
},
{
"resource_id": str(uuid4()),
"tenant_id": tenant.id,
"processing_type": "new_system",
"metadata": {"version": "2.0", "features": ["ai", "ml"]}
}
]
# Push new tasks using new system
queue.push_tasks(new_string_tasks)
queue.push_tasks(new_object_tasks)
# Phase 4: Process all new tasks
processed_new = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_new.append(task)
# Verify new tasks were processed correctly
assert len(processed_new) == 4
string_tasks = [task for task in processed_new if isinstance(task, str)]
object_tasks = [task for task in processed_new if isinstance(task, dict)]
assert len(string_tasks) == 2
assert len(object_tasks) == 2
# Verify string tasks
for task in string_tasks:
assert task.startswith("new_resource_")
# Verify object tasks
for task in object_tasks:
assert isinstance(task, dict)
assert "resource_id" in task
assert "tenant_id" in task
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake):
"""
Test error recovery when legacy queue contains malformed data.
This ensures the new system can gracefully handle corrupted or
malformed legacy data without crashing.
"""
tenant, _ = test_tenant_and_account
queue = TenantSelfTaskQueue(tenant.id, "error_recovery_queue")
# Create mix of valid and malformed legacy data
mixed_legacy_data = [
"valid_legacy_task_1",
"valid_legacy_task_2",
"malformed_data_without_prefix", # This should be treated as string
"valid_legacy_task_3",
f"{TASK_WRAPPER_PREFIX}invalid_json_data", # This should fall back to string
"valid_legacy_task_4"
]
# Manually push mixed data directly to Redis
redis_client.lpush(queue.queue, *mixed_legacy_data)
# Process all tasks
processed_tasks = []
while True:
task = queue.get_next_task()
if task is None:
break
processed_tasks.append(task)
# Verify all tasks were processed (no crashes)
assert len(processed_tasks) == 6
# Verify all tasks are strings (malformed data falls back to string)
for task in processed_tasks:
assert isinstance(task, str)
# Verify valid tasks are preserved
valid_tasks = [task for task in processed_tasks if task.startswith("valid_legacy_task_")]
assert len(valid_tasks) == 4
# Verify malformed data is handled gracefully
malformed_tasks = [task for task in processed_tasks if not task.startswith("valid_legacy_task_")]
assert len(malformed_tasks) == 2
assert "malformed_data_without_prefix" in malformed_tasks
assert f"{TASK_WRAPPER_PREFIX}invalid_json_data" in malformed_tasks

View File

@@ -1,32 +1,16 @@
from dataclasses import asdict
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from core.entities.document_task import DocumentTask
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
_document_indexing, # Core function
_document_indexing_with_tenant_queue, # Tenant queue wrapper function
document_indexing_task, # Deprecated old interface
normal_document_indexing_task, # New normal task
priority_document_indexing_task, # New priority task
)
from tasks.document_indexing_task import document_indexing_task
class TestDocumentIndexingTasks:
"""Integration tests for document indexing tasks using testcontainers.
This test class covers:
- Core _document_indexing function
- Deprecated document_indexing_task function
- New normal_document_indexing_task function
- New priority_document_indexing_task function
- Tenant queue wrapper _document_indexing_with_tenant_queue function
"""
class TestDocumentIndexingTask:
"""Integration tests for document_indexing_task using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
@@ -239,7 +223,7 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
@@ -247,11 +231,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@@ -277,7 +260,7 @@ class TestDocumentIndexingTasks:
document_ids = [fake.uuid4() for _ in range(3)]
# Act: Execute the task with non-existent dataset
_document_indexing(non_existent_dataset_id, document_ids)
document_indexing_task(non_existent_dataset_id, document_ids)
# Assert: Verify no processing occurred
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@@ -307,18 +290,17 @@ class TestDocumentIndexingTasks:
all_document_ids = existing_document_ids + non_existent_document_ids
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
document_indexing_task(dataset.id, all_document_ids)
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with only existing documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@@ -350,7 +332,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@@ -358,11 +340,10 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_mixed_document_states(
self, db_session_with_containers, mock_external_service_dependencies
@@ -425,18 +406,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
# Verify the run method was called with all documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
@@ -489,16 +469,15 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
assert updated_document.stopped_at is not None
for document in all_documents:
db.session.refresh(document)
assert document.indexing_status == "error"
assert document.error is not None
assert "batch upload" in document.error
assert document.stopped_at is not None
# Verify no indexing runner was called
mock_external_service_dependencies["indexing_runner"].assert_not_called()
@@ -523,18 +502,17 @@ class TestDocumentIndexingTasks:
document_ids = [doc.id for doc in documents]
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None
def test_document_indexing_task_document_is_paused_error(
self, db_session_with_containers, mock_external_service_dependencies
@@ -562,7 +540,7 @@ class TestDocumentIndexingTasks:
)
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
@@ -570,314 +548,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# ==================== NEW TESTS FOR REFACTORED FUNCTIONS ====================
def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
# Act: Execute the deprecated task (it only takes 2 parameters)
document_indexing_task(dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_normal_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test normal_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new normal task
normal_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_priority_document_indexing_task_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test priority_document_indexing_task basic functionality.
This test verifies:
- Task function calls the wrapper correctly
- Basic parameter passing works
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Act: Execute the new priority task
priority_document_indexing_task(tenant_id, dataset.id, document_ids)
# Assert: Verify processing occurred (core logic is tested in _document_indexing tests)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
def test_document_indexing_with_tenant_queue_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with no waiting tasks.
This test verifies:
- Core indexing logic execution (same as _document_indexing)
- Tenant queue cleanup when no waiting tasks
- Task function parameter passing
- Queue management after processing
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=2
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify the run method was called with correct documents
call_args = mock_external_service_dependencies["indexing_runner_instance"].run.call_args
assert call_args is not None
processed_documents = call_args[0][0]
assert len(processed_documents) == 2
# Verify task function was not called (no waiting tasks)
mock_task_func.delay.assert_not_called()
def test_document_indexing_with_tenant_queue_with_waiting_tasks(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis.
This test verifies:
- Core indexing logic execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dataset_id = dataset.id
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting tasks to the real Redis queue
waiting_tasks = [
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"]),
DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-2"])
]
# Convert DocumentTask objects to dictionaries for serialization
waiting_task_dicts = [asdict(task) for task in waiting_tasks]
queue.push_tasks(waiting_task_dicts)
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify core processing occurred
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
assert len(remaining_tasks) == 1
def test_document_indexing_with_tenant_queue_error_handling(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test error handling in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
dataset, documents = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
document_ids = [doc.id for doc in documents]
tenant_id = dataset.tenant_id
dataset_id = dataset.id
# Mock IndexingRunner to raise an exception
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = Exception("Test error")
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create real queue instance
queue = TenantSelfTaskQueue(tenant_id, "document_indexing")
# Add waiting task to the real Redis queue
waiting_task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset.id, document_ids=["waiting-doc-1"])
queue.push_tasks([asdict(waiting_task)])
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).filter(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_document_indexing_with_tenant_queue_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tenant isolation in _document_indexing_with_tenant_queue using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
dataset1, documents1 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
dataset2, documents2 = self._create_test_dataset_and_documents(
db_session_with_containers, mock_external_service_dependencies, document_count=1
)
tenant1_id = dataset1.tenant_id
tenant2_id = dataset2.tenant_id
dataset1_id = dataset1.id
dataset2_id = dataset2.id
document_ids1 = [doc.id for doc in documents1]
document_ids2 = [doc.id for doc in documents2]
# Mock the task function
from unittest.mock import MagicMock
mock_task_func = MagicMock()
# Use real Redis for TenantSelfTaskQueue
from core.rag.pipeline.queue import TenantSelfTaskQueue
# Create queue instances for both tenants
queue1 = TenantSelfTaskQueue(tenant1_id, "document_indexing")
queue2 = TenantSelfTaskQueue(tenant2_id, "document_indexing")
# Add waiting tasks to both queues
waiting_task1 = DocumentTask(tenant_id=tenant1_id, dataset_id=dataset1.id, document_ids=["tenant1-doc-1"])
waiting_task2 = DocumentTask(tenant_id=tenant2_id, dataset_id=dataset2.id, document_ids=["tenant2-doc-1"])
queue1.push_tasks([asdict(waiting_task1)])
queue2.push_tasks([asdict(waiting_task2)])
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
for document in documents:
db.session.refresh(document)
assert document.indexing_status == "parsing"
assert document.processing_started_at is not None

View File

@@ -1,939 +0,0 @@
import json
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
from tasks.rag_pipeline.priority_rag_pipeline_run_task import (
priority_rag_pipeline_run_task,
run_single_rag_pipeline_task,
)
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
class TestRagPipelineRunTasks:
"""Integration tests for RAG pipeline run tasks using testcontainers.
This test class covers:
- priority_rag_pipeline_run_task function
- rag_pipeline_run_task function
- run_single_rag_pipeline_task function
- Real Redis-based TenantSelfTaskQueue operations
- PipelineGenerator._generate method mocking and parameter validation
- File operations and cleanup
- Error handling and queue management
"""
@pytest.fixture
def mock_pipeline_generator(self):
"""Mock PipelineGenerator._generate method."""
with patch("core.app.apps.pipeline.pipeline_generator.PipelineGenerator._generate") as mock_generate:
# Mock the _generate method to return a simple response
mock_generate.return_value = {
"answer": "Test response",
"metadata": {"test": "data"}
}
yield mock_generate
@pytest.fixture
def mock_file_service(self):
"""Mock FileService for file operations."""
with (
patch("services.file_service.FileService.get_file_content") as mock_get_content,
patch("services.file_service.FileService.delete_file") as mock_delete_file,
):
yield {
"get_content": mock_get_content,
"delete_file": mock_delete_file,
}
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
"""
Helper method to create test pipeline and workflow for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
Returns:
tuple: (account, tenant, pipeline, workflow) - Created entities
"""
fake = Faker()
# Create account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
# Create tenant-account join
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph="{}",
features="{}",
marked_name=fake.company(),
marked_comment=fake.text(max_nb_chars=100),
created_by=account.id,
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
db.session.add(workflow)
db.session.commit()
# Create pipeline
pipeline = Pipeline(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
workflow_id=workflow.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
created_by=account.id,
)
db.session.add(pipeline)
db.session.commit()
# Refresh entities to ensure they're properly loaded
db.session.refresh(account)
db.session.refresh(tenant)
db.session.refresh(workflow)
db.session.refresh(pipeline)
return account, tenant, pipeline, workflow
def _create_rag_pipeline_invoke_entities(self, account, tenant, pipeline, workflow, count=2):
"""
Helper method to create RAG pipeline invoke entities for testing.
Args:
account: Account instance
tenant: Tenant instance
pipeline: Pipeline instance
workflow: Workflow instance
count: Number of entities to create
Returns:
list: List of RagPipelineInvokeEntity instances
"""
fake = Faker()
entities = []
for i in range(count):
# Create application generate entity
app_config = {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
}
application_generate_entity = {
"task_id": str(uuid.uuid4()),
"app_config": app_config,
"inputs": {"query": f"Test query {i}"},
"files": [],
"user_id": account.id,
"stream": False,
"invoke_from": "published",
"workflow_execution_id": str(uuid.uuid4()),
"pipeline_config": {
"app_id": str(uuid.uuid4()),
"app_name": fake.company(),
"mode": "workflow",
"workflow_id": workflow.id,
"tenant_id": tenant.id,
"app_mode": "workflow",
},
"datasource_type": "upload_file",
"datasource_info": {},
"dataset_id": str(uuid.uuid4()),
"batch": "test_batch",
}
entity = RagPipelineInvokeEntity(
pipeline_id=pipeline.id,
application_generate_entity=application_generate_entity,
user_id=account.id,
tenant_id=tenant.id,
workflow_id=workflow.id,
streaming=False,
workflow_execution_id=str(uuid.uuid4()),
workflow_thread_pool_id=str(uuid.uuid4()),
)
entities.append(entity)
return entities
def _create_file_content_for_entities(self, entities):
"""
Helper method to create file content for RAG pipeline invoke entities.
Args:
entities: List of RagPipelineInvokeEntity instances
Returns:
str: JSON string containing serialized entities
"""
entities_data = [entity.model_dump() for entity in entities]
return json.dumps(entities_data)
def test_priority_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful priority RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=2)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 2
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test successful regular RAG pipeline run task execution.
This test verifies:
- Task execution with multiple RAG pipeline invoke entities
- File content retrieval and parsing
- PipelineGenerator._generate method calls with correct parameters
- Thread pool execution
- File cleanup after execution
- Queue management with no waiting tasks
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=3)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify expected outcomes
# Verify file operations
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
# Verify PipelineGenerator._generate was called for each entity
assert mock_pipeline_generator.call_count == 3
# Verify call parameters for each entity
calls = mock_pipeline_generator.call_args_list
for call in calls:
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(2)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
def test_rag_pipeline_run_task_legacy_compatibility(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
This test simulates the scenario where:
- Old code writes file IDs directly to Redis list using lpush
- New worker processes these legacy queue entries
- Ensures backward compatibility during deployment transition
Legacy format: redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
New format: TenantSelfTaskQueue.push_tasks([file_id])
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Simulate legacy Redis queue format - direct file IDs in Redis list
from extensions.ext_redis import redis_client
# Legacy queue key format (old code)
legacy_queue_key = f"tenant_self_pipeline_task_queue:{tenant.id}"
legacy_task_key = f"tenant_pipeline_task:{tenant.id}"
# Add legacy format data to Redis (simulating old code behavior)
legacy_file_ids = [str(uuid.uuid4()) for _ in range(3)]
for file_id_legacy in legacy_file_ids:
redis_client.lpush(legacy_queue_key, file_id_legacy)
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == legacy_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantSelfTaskQueue should be able to read from the legacy format
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
# Cleanup: Remove legacy test data
redis_client.delete(legacy_queue_key)
redis_client.delete(legacy_task_key)
def test_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
This test verifies:
- Core task execution
- Real Redis-based tenant queue processing of waiting tasks
- Task function calls for waiting tasks
- Queue management with multiple tasks using actual Redis operations
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting tasks to the real Redis queue
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify core processing occurred
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_ids[0]
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
def test_priority_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in priority RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task (should not raise exception)
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in regular RAG pipeline run task using real Redis.
This test verifies:
- Exception handling during core processing
- Tenant queue cleanup even on errors using real Redis
- Proper error logging
- Function completes without raising exceptions
- Queue management continues despite core processing errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
file_content = self._create_file_content_for_entities(entities)
# Mock file service
file_id = str(uuid.uuid4())
mock_file_service["get_content"].return_value = file_content
# Mock PipelineGenerator to raise an exception
mock_pipeline_generator.side_effect = Exception("Pipeline generation failed")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_priority_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in priority RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantSelfTaskQueue
queue1 = TenantSelfTaskQueue(tenant1.id, "pipeline")
queue2 = TenantSelfTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act: Execute the priority task for tenant1 only
priority_rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id1
assert call_kwargs.get('tenant_id') == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
def test_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in regular RAG pipeline run task using real Redis.
This test verifies:
- Different tenants have isolated queues
- Tasks from one tenant don't affect another tenant's queue
- Queue operations are properly scoped to tenant
"""
# Arrange: Create test data for two different tenants
account1, tenant1, pipeline1, workflow1 = self._create_test_pipeline_and_workflow(db_session_with_containers)
account2, tenant2, pipeline2, workflow2 = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities1 = self._create_rag_pipeline_invoke_entities(account1, tenant1, pipeline1, workflow1, count=1)
entities2 = self._create_rag_pipeline_invoke_entities(account2, tenant2, pipeline2, workflow2, count=1)
file_content1 = self._create_file_content_for_entities(entities1)
file_content2 = self._create_file_content_for_entities(entities2)
# Mock file service
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = [file_content1, file_content2]
# Use real Redis for TenantSelfTaskQueue
queue1 = TenantSelfTaskQueue(tenant1.id, "pipeline")
queue2 = TenantSelfTaskQueue(tenant2.id, "pipeline")
# Add waiting tasks to both queues
waiting_file_id1 = str(uuid.uuid4())
waiting_file_id2 = str(uuid.uuid4())
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
# Assert: Verify core processing occurred for tenant1
assert mock_file_service["get_content"].call_count == 1
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id1
assert call_kwargs.get('tenant_id') == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
assert len(remaining_tasks1) == 0
# Verify tenant2's queue still has its task (isolation)
remaining_tasks2 = queue2.pull_tasks(count=10)
assert len(remaining_tasks2) == 1
# Verify queue keys are different
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test successful run_single_rag_pipeline_task execution.
This test verifies:
- Single RAG pipeline task execution within Flask app context
- Entity validation and database queries
- PipelineGenerator._generate method call with correct parameters
- Proper Flask context handling
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
entities = self._create_rag_pipeline_invoke_entities(account, tenant, pipeline, workflow, count=1)
entity_data = entities[0].model_dump()
# Act: Execute the single task
with flask_app_with_containers.app_context():
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Verify expected outcomes
# Verify PipelineGenerator._generate was called
assert mock_pipeline_generator.call_count == 1
# Verify call parameters
call = mock_pipeline_generator.call_args
call_kwargs = call[1] # Get keyword arguments
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with invalid entity data.
This test verifies:
- Proper error handling for invalid entity data
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create entity data with valid UUIDs but non-existent entities
fake = Faker()
invalid_entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(invalid_entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with non-existent database entities.
This test verifies:
- Proper error handling for missing database entities
- Exception logging
- Function raises ValueError for missing entities
"""
# Arrange: Create test data with non-existent IDs
fake = Faker()
entity_data = {
"pipeline_id": str(uuid.uuid4()),
"application_generate_entity": {
"app_config": {
"app_id": str(uuid.uuid4()),
"app_name": "Test App",
"mode": "workflow",
"workflow_id": str(uuid.uuid4()),
},
"inputs": {"query": "Test query"},
"query": "Test query",
"response_mode": "blocking",
"user": str(uuid.uuid4()),
"files": [],
"conversation_id": str(uuid.uuid4()),
},
"user_id": str(uuid.uuid4()),
"tenant_id": str(uuid.uuid4()),
"workflow_id": str(uuid.uuid4()),
"streaming": False,
"workflow_execution_id": str(uuid.uuid4()),
"workflow_thread_pool_id": str(uuid.uuid4()),
}
# Act & Assert: Execute the single task with non-existent entities (should raise ValueError)
with flask_app_with_containers.app_context():
with pytest.raises(ValueError, match="Account .* not found"):
run_single_rag_pipeline_task(entity_data, flask_app_with_containers)
# Assert: Pipeline generator should not be called
mock_pipeline_generator.assert_not_called()
def test_priority_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch(
"tasks.rag_pipeline.priority_rag_pipeline_run_task.priority_rag_pipeline_run_task.delay"
) as mock_delay:
# Act & Assert: Execute the priority task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
priority_rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with non-existent file.
This test verifies:
- Proper error handling for missing files
- Exception logging
- Function raises Exception for file errors
- Queue management continues despite file errors
"""
# Arrange: Create test data
account, tenant, pipeline, workflow = self._create_test_pipeline_and_workflow(db_session_with_containers)
# Mock file service to raise exception
file_id = str(uuid.uuid4())
mock_file_service["get_content"].side_effect = Exception("File not found")
# Use real Redis for TenantSelfTaskQueue
queue = TenantSelfTaskQueue(tenant.id, "pipeline")
# Add waiting task to the real Redis queue
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
# Assert: Verify error was handled gracefully
mock_file_service["get_content"].assert_called_once_with(file_id)
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get('rag_pipeline_invoke_entities_file_id') == waiting_file_id
assert call_kwargs.get('tenant_id') == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
assert len(remaining_tasks) == 0

View File

@@ -1,369 +0,0 @@
"""
Unit tests for TenantSelfTaskQueue.
These tests verify the Redis-based task queue functionality for tenant-specific
task management with proper serialization and deserialization.
"""
import json
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.rag.pipeline.queue import TASK_WRAPPER_PREFIX, TaskWrapper, TenantSelfTaskQueue
class TestTaskWrapper:
"""Test cases for TaskWrapper serialization/deserialization."""
def test_serialize_simple_data(self):
"""Test serialization of simple data types."""
data = {"key": "value", "number": 42, "list": [1, 2, 3]}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
assert isinstance(serialized, str)
# Verify it's valid JSON
parsed = json.loads(serialized)
assert parsed == data
def test_serialize_complex_data(self):
"""Test serialization of complex nested data."""
data = {
"nested": {
"deep": {
"value": "test",
"numbers": [1, 2, 3, 4, 5]
}
},
"unicode": "测试中文",
"special_chars": "!@#$%^&*()"
}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
parsed = json.loads(serialized)
assert parsed == data
def test_deserialize_valid_data(self):
"""Test deserialization of valid JSON data."""
original_data = {"key": "value", "number": 42}
serialized = json.dumps(original_data, ensure_ascii=False)
wrapper = TaskWrapper.deserialize(serialized)
assert wrapper.data == original_data
def test_deserialize_invalid_json(self):
"""Test deserialization handles invalid JSON gracefully."""
invalid_json = "{invalid json}"
with pytest.raises(json.JSONDecodeError):
TaskWrapper.deserialize(invalid_json)
def test_serialize_ensure_ascii_false(self):
"""Test that serialization preserves Unicode characters."""
data = {"chinese": "中文测试", "emoji": "🚀"}
wrapper = TaskWrapper(data)
serialized = wrapper.serialize()
assert "中文测试" in serialized
assert "🚀" in serialized
class TestTenantSelfTaskQueue:
"""Test cases for TenantSelfTaskQueue functionality."""
@pytest.fixture
def mock_redis_client(self):
"""Mock Redis client for testing."""
mock_redis = MagicMock()
return mock_redis
@pytest.fixture
def sample_queue(self, mock_redis_client):
"""Create a sample TenantSelfTaskQueue instance."""
return TenantSelfTaskQueue("tenant-123", "test-key")
def test_initialization(self, sample_queue):
"""Test queue initialization with correct key generation."""
assert sample_queue.tenant_id == "tenant-123"
assert sample_queue.unique_key == "test-key"
assert sample_queue.queue == "tenant_self_test-key_task_queue:tenant-123"
assert sample_queue.task_key == "tenant_test-key_task:tenant-123"
assert sample_queue.DEFAULT_TASK_TTL == 60 * 60
@patch('core.rag.pipeline.queue.redis_client')
def test_get_task_key_exists(self, mock_redis, sample_queue):
"""Test getting task key when it exists."""
mock_redis.get.return_value = "1"
result = sample_queue.get_task_key()
assert result == "1"
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_get_task_key_not_exists(self, mock_redis, sample_queue):
"""Test getting task key when it doesn't exist."""
mock_redis.get.return_value = None
result = sample_queue.get_task_key()
assert result is None
mock_redis.get.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_set_task_waiting_time_default_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with default TTL."""
sample_queue.set_task_waiting_time()
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
3600, # DEFAULT_TASK_TTL
1
)
@patch('core.rag.pipeline.queue.redis_client')
def test_set_task_waiting_time_custom_ttl(self, mock_redis, sample_queue):
"""Test setting task waiting flag with custom TTL."""
custom_ttl = 1800
sample_queue.set_task_waiting_time(custom_ttl)
mock_redis.setex.assert_called_once_with(
"tenant_test-key_task:tenant-123",
custom_ttl,
1
)
@patch('core.rag.pipeline.queue.redis_client')
def test_delete_task_key(self, mock_redis, sample_queue):
"""Test deleting task key."""
sample_queue.delete_task_key()
mock_redis.delete.assert_called_once_with("tenant_test-key_task:tenant-123")
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_string_list(self, mock_redis, sample_queue):
"""Test pushing string tasks directly."""
tasks = ["task1", "task2", "task3"]
sample_queue.push_tasks(tasks)
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123",
"task1", "task2", "task3"
)
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_mixed_types(self, mock_redis, sample_queue):
"""Test pushing mixed string and object tasks."""
tasks = [
"string_task",
{"object_task": "data", "id": 123},
"another_string"
]
sample_queue.push_tasks(tasks)
# Verify lpush was called
mock_redis.lpush.assert_called_once()
call_args = mock_redis.lpush.call_args
# Check queue name
assert call_args[0][0] == "tenant_self_test-key_task_queue:tenant-123"
# Check serialized tasks
serialized_tasks = call_args[0][1:]
assert len(serialized_tasks) == 3
assert serialized_tasks[0] == "string_task"
assert serialized_tasks[2] == "another_string"
# Check object task is wrapped
assert serialized_tasks[1].startswith(TASK_WRAPPER_PREFIX)
wrapper_data = serialized_tasks[1][len(TASK_WRAPPER_PREFIX):]
parsed_data = json.loads(wrapper_data)
assert parsed_data == {"object_task": "data", "id": 123}
@patch('core.rag.pipeline.queue.redis_client')
def test_push_tasks_empty_list(self, mock_redis, sample_queue):
"""Test pushing empty task list."""
sample_queue.push_tasks([])
mock_redis.lpush.assert_called_once_with(
"tenant_self_test-key_task_queue:tenant-123"
)
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_default_count(self, mock_redis, sample_queue):
"""Test pulling tasks with default count (1)."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.pull_tasks()
assert result == ["task1"]
assert mock_redis.rpop.call_count == 1
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_custom_count(self, mock_redis, sample_queue):
"""Test pulling tasks with custom count."""
# First test: pull 3 tasks
mock_redis.rpop.side_effect = ["task1", "task2", "task3", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2", "task3"]
assert mock_redis.rpop.call_count == 3
# Reset mock for second test
mock_redis.reset_mock()
mock_redis.rpop.side_effect = ["task1", "task2", None]
result = sample_queue.pull_tasks(3)
assert result == ["task1", "task2"]
assert mock_redis.rpop.call_count == 3
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_zero_count(self, mock_redis, sample_queue):
"""Test pulling tasks with zero count returns empty list."""
result = sample_queue.pull_tasks(0)
assert result == []
mock_redis.rpop.assert_not_called()
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_negative_count(self, mock_redis, sample_queue):
"""Test pulling tasks with negative count returns empty list."""
result = sample_queue.pull_tasks(-1)
assert result == []
mock_redis.rpop.assert_not_called()
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_with_wrapped_objects(self, mock_redis, sample_queue):
"""Test pulling tasks that include wrapped objects."""
# Create a wrapped task
wrapper = TaskWrapper({"task_id": 123, "data": "test"})
wrapped_task = f"{TASK_WRAPPER_PREFIX}{wrapper.serialize()}"
mock_redis.rpop.side_effect = [
"string_task",
wrapped_task.encode('utf-8'), # Simulate bytes from Redis
None
]
result = sample_queue.pull_tasks(2)
assert len(result) == 2
assert result[0] == "string_task"
assert result[1] == {"task_id": 123, "data": "test"}
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_with_invalid_wrapped_data(self, mock_redis, sample_queue):
"""Test pulling tasks with invalid wrapped data falls back to string."""
invalid_wrapped = f"{TASK_WRAPPER_PREFIX}invalid json"
mock_redis.rpop.side_effect = [invalid_wrapped, None]
result = sample_queue.pull_tasks(1)
assert result == [invalid_wrapped]
@patch('core.rag.pipeline.queue.redis_client')
def test_pull_tasks_bytes_decoding(self, mock_redis, sample_queue):
"""Test pulling tasks handles bytes from Redis correctly."""
mock_redis.rpop.side_effect = [
b"task1", # bytes
"task2", # string
None
]
result = sample_queue.pull_tasks(2)
assert result == ["task1", "task2"]
@patch('core.rag.pipeline.queue.redis_client')
def test_get_next_task_success(self, mock_redis, sample_queue):
"""Test getting next single task."""
mock_redis.rpop.side_effect = ["task1", None]
result = sample_queue.get_next_task()
assert result == "task1"
assert mock_redis.rpop.call_count == 1
@patch('core.rag.pipeline.queue.redis_client')
def test_get_next_task_empty_queue(self, mock_redis, sample_queue):
"""Test getting next task when queue is empty."""
mock_redis.rpop.return_value = None
result = sample_queue.get_next_task()
assert result is None
mock_redis.rpop.assert_called_once()
def test_tenant_isolation(self):
"""Test that different tenants have isolated queues."""
queue1 = TenantSelfTaskQueue("tenant-1", "key")
queue2 = TenantSelfTaskQueue("tenant-2", "key")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == "tenant_self_key_task_queue:tenant-1"
assert queue2.queue == "tenant_self_key_task_queue:tenant-2"
def test_key_isolation(self):
"""Test that different keys have isolated queues."""
queue1 = TenantSelfTaskQueue("tenant", "key1")
queue2 = TenantSelfTaskQueue("tenant", "key2")
assert queue1.queue != queue2.queue
assert queue1.task_key != queue2.task_key
assert queue1.queue == "tenant_self_key1_task_queue:tenant"
assert queue2.queue == "tenant_self_key2_task_queue:tenant"
@patch('core.rag.pipeline.queue.redis_client')
def test_complex_object_serialization_roundtrip(self, mock_redis, sample_queue):
"""Test complex object serialization and deserialization roundtrip."""
complex_task = {
"id": uuid4().hex,
"data": {
"nested": {
"deep": [1, 2, 3],
"unicode": "测试中文",
"special": "!@#$%^&*()"
}
},
"metadata": {
"created_at": "2024-01-01T00:00:00Z",
"tags": ["tag1", "tag2", "tag3"]
}
}
# Push the complex task
sample_queue.push_tasks([complex_task])
# Verify it was wrapped
call_args = mock_redis.lpush.call_args
wrapped_task = call_args[0][1]
assert wrapped_task.startswith(TASK_WRAPPER_PREFIX)
# Simulate pulling it back
mock_redis.rpop.return_value = wrapped_task
result = sample_queue.pull_tasks(1)
assert len(result) == 1
assert result[0] == complex_task
@patch('core.rag.pipeline.queue.redis_client')
def test_large_task_list_handling(self, mock_redis, sample_queue):
"""Test handling of large task lists."""
large_task_list = [f"task_{i}" for i in range(1000)]
sample_queue.push_tasks(large_task_list)
# Verify all tasks were pushed
call_args = mock_redis.lpush.call_args
assert len(call_args[0]) == 1001 # queue name + 1000 tasks
assert list(call_args[0][1:]) == large_task_list

View File

@@ -1,332 +0,0 @@
from unittest.mock import Mock, patch
from core.entities.document_task import DocumentTask
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.document_indexing_task_proxy import DocumentIndexingTaskProxy
class DocumentIndexingTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for DocumentIndexingTaskProxy tests."""
@staticmethod
def create_mock_features(
billing_enabled: bool = False,
plan: str = "sandbox"
) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantSelfTaskQueue."""
queue = Mock(spec=TenantSelfTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_document_task_proxy(
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-456",
document_ids: list[str] | None = None
) -> DocumentIndexingTaskProxy:
"""Create DocumentIndexingTaskProxy instance for testing."""
if document_ids is None:
document_ids = ["doc-1", "doc-2", "doc-3"]
return DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
class TestDocumentIndexingTaskProxy:
"""Test cases for DocumentIndexingTaskProxy class."""
def test_initialization(self):
"""Test DocumentIndexingTaskProxy initialization."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2", "doc-3"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dataset_id == dataset_id
assert proxy.document_ids == document_ids
assert isinstance(proxy.tenant_self_task_queue, TenantSelfTaskQueue)
assert proxy.tenant_self_task_queue.tenant_id == tenant_id
assert proxy.tenant_self_task_queue.unique_key == "document_indexing"
@patch('services.document_indexing_task_proxy.FeatureService')
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(mock_task)
# Assert
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123",
dataset_id="dataset-456",
document_ids=["doc-1", "doc-2", "doc-3"]
)
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy.tenant_self_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy.tenant_self_task_queue.push_tasks.assert_called_once()
pushed_tasks = proxy.tenant_self_task_queue.push_tasks.call_args[0][0]
assert len(pushed_tasks) == 1
assert isinstance(DocumentTask(**pushed_tasks[0]), DocumentTask)
assert pushed_tasks[0]["tenant_id"] == "tenant-123"
assert pushed_tasks[0]["dataset_id"] == "dataset-456"
assert pushed_tasks[0]["document_ids"] == ["doc-1", "doc-2", "doc-3"]
mock_task.delay.assert_not_called()
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy.tenant_self_task_queue = DocumentIndexingTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(mock_task)
# Assert
proxy.tenant_self_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
tenant_id="tenant-123",
dataset_id="dataset-456",
document_ids=["doc-1", "doc-2", "doc-3"]
)
proxy.tenant_self_task_queue.push_tasks.assert_not_called()
@patch('services.document_indexing_task_proxy.normal_document_indexing_task')
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_default_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.priority_document_indexing_task')
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_tenant_queue = Mock()
# Act
proxy._send_to_priority_tenant_queue()
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.priority_document_indexing_task')
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_direct_queue = Mock()
# Act
proxy._send_to_priority_direct_queue()
# Assert
proxy._send_to_direct_queue.assert_called_once_with(mock_task)
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_default_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_enabled_non_sandbox_plan(self, mock_feature_service):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="team"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# If billing enabled with non sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_with_billing_disabled(self, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=False
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
# Act
proxy._dispatch()
# If billing disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_delay_method(self, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
# Act
proxy.delay()
# Assert
# If billing enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once()
def test_document_task_dataclass(self):
"""Test DocumentTask dataclass."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1", "doc-2"]
# Act
task = DocumentTask(tenant_id=tenant_id, dataset_id=dataset_id, document_ids=document_ids)
# Assert
assert task.tenant_id == tenant_id
assert task.dataset_id == dataset_id
assert task.document_ids == document_ids
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_edge_case_empty_plan(self, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
@patch('services.document_indexing_task_proxy.FeatureService')
def test_dispatch_edge_case_none_plan(self, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = DocumentIndexingTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = DocumentIndexingTaskProxyTestDataFactory.create_document_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once()
def test_initialization_with_empty_document_ids(self):
"""Test initialization with empty document_ids list."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = []
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dataset_id == dataset_id
assert proxy.document_ids == document_ids
def test_initialization_with_single_document_id(self):
"""Test initialization with single document_id."""
# Arrange
tenant_id = "tenant-123"
dataset_id = "dataset-456"
document_ids = ["doc-1"]
# Act
proxy = DocumentIndexingTaskProxy(tenant_id, dataset_id, document_ids)
# Assert
assert proxy.tenant_id == tenant_id
assert proxy.dataset_id == dataset_id
assert proxy.document_ids == document_ids

View File

@@ -1,499 +0,0 @@
import json
from unittest.mock import Mock, patch
import pytest
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantSelfTaskQueue
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
class RagPipelineTaskProxyTestDataFactory:
"""Factory class for creating test data and mock objects for RagPipelineTaskProxy tests."""
@staticmethod
def create_mock_features(
billing_enabled: bool = False,
plan: str = "sandbox"
) -> Mock:
"""Create mock features with billing configuration."""
features = Mock()
features.billing = Mock()
features.billing.enabled = billing_enabled
features.billing.subscription = Mock()
features.billing.subscription.plan = plan
return features
@staticmethod
def create_mock_tenant_queue(has_task_key: bool = False) -> Mock:
"""Create mock TenantSelfTaskQueue."""
queue = Mock(spec=TenantSelfTaskQueue)
queue.get_task_key.return_value = "task_key" if has_task_key else None
queue.push_tasks = Mock()
queue.set_task_waiting_time = Mock()
return queue
@staticmethod
def create_rag_pipeline_invoke_entity(
pipeline_id: str = "pipeline-123",
user_id: str = "user-456",
tenant_id: str = "tenant-789",
workflow_id: str = "workflow-101",
streaming: bool = True,
workflow_execution_id: str | None = None,
workflow_thread_pool_id: str | None = None
) -> RagPipelineInvokeEntity:
"""Create RagPipelineInvokeEntity instance for testing."""
return RagPipelineInvokeEntity(
pipeline_id=pipeline_id,
application_generate_entity={"key": "value"},
user_id=user_id,
tenant_id=tenant_id,
workflow_id=workflow_id,
streaming=streaming,
workflow_execution_id=workflow_execution_id,
workflow_thread_pool_id=workflow_thread_pool_id
)
@staticmethod
def create_rag_pipeline_task_proxy(
dataset_tenant_id: str = "tenant-123",
user_id: str = "user-456",
rag_pipeline_invoke_entities: list[RagPipelineInvokeEntity] | None = None
) -> RagPipelineTaskProxy:
"""Create RagPipelineTaskProxy instance for testing."""
if rag_pipeline_invoke_entities is None:
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()
]
return RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
@staticmethod
def create_mock_upload_file(file_id: str = "file-123") -> Mock:
"""Create mock upload file."""
upload_file = Mock()
upload_file.id = file_id
return upload_file
class TestRagPipelineTaskProxy:
"""Test cases for RagPipelineTaskProxy class."""
def test_initialization(self):
"""Test RagPipelineTaskProxy initialization."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity()
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy.dataset_tenant_id == dataset_tenant_id
assert proxy.user_id == user_id
assert proxy.rag_pipeline_invoke_entities == rag_pipeline_invoke_entities
assert isinstance(proxy.tenant_self_pipeline_task_queue, TenantSelfTaskQueue)
assert proxy.tenant_self_pipeline_task_queue.tenant_id == dataset_tenant_id
assert proxy.tenant_self_pipeline_task_queue.unique_key == "pipeline"
def test_initialization_with_empty_entities(self):
"""Test initialization with empty rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = []
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert proxy.dataset_tenant_id == dataset_tenant_id
assert proxy.user_id == user_id
assert proxy.rag_pipeline_invoke_entities == []
def test_initialization_with_multiple_entities(self):
"""Test initialization with multiple rag_pipeline_invoke_entities."""
# Arrange
dataset_tenant_id = "tenant-123"
user_id = "user-456"
rag_pipeline_invoke_entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-3")
]
# Act
proxy = RagPipelineTaskProxy(dataset_tenant_id, user_id, rag_pipeline_invoke_entities)
# Assert
assert len(proxy.rag_pipeline_invoke_entities) == 3
assert proxy.rag_pipeline_invoke_entities[0].pipeline_id == "pipeline-1"
assert proxy.rag_pipeline_invoke_entities[1].pipeline_id == "pipeline-2"
assert proxy.rag_pipeline_invoke_entities[2].pipeline_id == "pipeline-3"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
def test_features_property(self, mock_feature_service):
"""Test cached_property features."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features()
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
# Act
features1 = proxy.features
features2 = proxy.features # Second call should use cached property
# Assert
assert features1 == mock_features
assert features2 == mock_features
assert features1 is features2 # Should be the same instance due to caching
mock_feature_service.get_features.assert_called_once_with("tenant-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_upload_invoke_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-123"
mock_file_service_class.assert_called_once_with(mock_db.engine)
# Verify upload_text was called with correct parameters
mock_file_service.upload_text.assert_called_once()
call_args = mock_file_service.upload_text.call_args
json_text, name, user_id, tenant_id = call_args[0]
assert name == "rag_pipeline_invoke_entities.json"
assert user_id == "user-456"
assert tenant_id == "tenant-123"
# Verify JSON content
parsed_json = json.loads(json_text)
assert len(parsed_json) == 1
assert parsed_json[0]["pipeline_id"] == "pipeline-123"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_upload_invoke_entities_with_multiple_entities(self, mock_db, mock_file_service_class):
"""Test _upload_invoke_entities method with multiple entities."""
# Arrange
entities = [
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-1"),
RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_invoke_entity(pipeline_id="pipeline-2")
]
proxy = RagPipelineTaskProxy("tenant-123", "user-456", entities)
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-456")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
result = proxy._upload_invoke_entities()
# Assert
assert result == "file-456"
# Verify JSON content contains both entities
call_args = mock_file_service.upload_text.call_args
json_text = call_args[0][0]
parsed_json = json.loads(json_text)
assert len(parsed_json) == 2
assert parsed_json[0]["pipeline_id"] == "pipeline-1"
assert parsed_json[1]["pipeline_id"] == "pipeline-2"
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_direct_queue(self, mock_task):
"""Test _send_to_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue()
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_direct_queue(upload_file_id, mock_task)
# If sent to direct queue, tenant_self_pipeline_task_queue should not be called
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_not_called()
# Celery should be called directly
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id="tenant-123"
)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_tenant_queue_with_existing_task_key(self, mock_task):
"""Test _send_to_tenant_queue when task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=True
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If task key exists, should push tasks to the queue
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_called_once_with([upload_file_id])
# Celery should not be called directly
mock_task.delay.assert_not_called()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_tenant_queue_without_task_key(self, mock_task):
"""Test _send_to_tenant_queue when no task key exists."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy.tenant_self_pipeline_task_queue = RagPipelineTaskProxyTestDataFactory.create_mock_tenant_queue(
has_task_key=False
)
upload_file_id = "file-123"
mock_task.delay = Mock()
# Act
proxy._send_to_tenant_queue(upload_file_id, mock_task)
# If no task key, should set task waiting time key first
proxy.tenant_self_pipeline_task_queue.set_task_waiting_time.assert_called_once()
mock_task.delay.assert_called_once_with(
rag_pipeline_invoke_entities_file_id=upload_file_id,
tenant_id="tenant-123"
)
# The first task should be sent to celery directly, so push tasks should not be called
proxy.tenant_self_pipeline_task_queue.push_tasks.assert_not_called()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.rag_pipeline_run_task')
def test_send_to_default_tenant_queue(self, mock_task):
"""Test _send_to_default_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_default_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task')
def test_send_to_priority_tenant_queue(self, mock_task):
"""Test _send_to_priority_tenant_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_tenant_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_tenant_queue(upload_file_id)
# Assert
proxy._send_to_tenant_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.priority_rag_pipeline_run_task')
def test_send_to_priority_direct_queue(self, mock_task):
"""Test _send_to_priority_direct_queue method."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_direct_queue = Mock()
upload_file_id = "file-123"
# Act
proxy._send_to_priority_direct_queue(upload_file_id)
# Assert
proxy._send_to_direct_queue.assert_called_once_with(upload_file_id, mock_task)
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_enabled_sandbox_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is enabled with sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_default_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with sandbox plan, should send to default tenant queue
proxy._send_to_default_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_enabled_non_sandbox_plan(
self, mock_db, mock_file_service_class, mock_feature_service
):
"""Test _dispatch method when billing is enabled with non-sandbox plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="team"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is enabled with non-sandbox plan, should send to priority tenant queue
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_billing_disabled(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method when billing is disabled."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=False
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_direct_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# If billing is disabled, for example: self-hosted or enterprise, should send to priority direct queue
proxy._send_to_priority_direct_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_with_empty_upload_file_id(self, mock_db, mock_file_service_class):
"""Test _dispatch method when upload_file_id is empty."""
# Arrange
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = Mock()
mock_upload_file.id = "" # Empty file ID
mock_file_service.upload_text.return_value = mock_upload_file
# Act & Assert
with pytest.raises(ValueError, match="upload_file_id is empty"):
proxy._dispatch()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_edge_case_empty_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with empty plan string."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=""
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_dispatch_edge_case_none_plan(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test _dispatch method with None plan."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan=None
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._send_to_priority_tenant_queue = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy._dispatch()
# Assert
proxy._send_to_priority_tenant_queue.assert_called_once_with("file-123")
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FeatureService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.FileService')
@patch('services.rag_pipeline.rag_pipeline_task_proxy.db')
def test_delay_method(self, mock_db, mock_file_service_class, mock_feature_service):
"""Test delay method integration."""
# Arrange
mock_features = RagPipelineTaskProxyTestDataFactory.create_mock_features(
billing_enabled=True, plan="sandbox"
)
mock_feature_service.get_features.return_value = mock_features
proxy = RagPipelineTaskProxyTestDataFactory.create_rag_pipeline_task_proxy()
proxy._dispatch = Mock()
mock_file_service = Mock()
mock_file_service_class.return_value = mock_file_service
mock_upload_file = RagPipelineTaskProxyTestDataFactory.create_mock_upload_file("file-123")
mock_file_service.upload_text.return_value = mock_upload_file
# Act
proxy.delay()
# Assert
proxy._dispatch.assert_called_once()
@patch('services.rag_pipeline.rag_pipeline_task_proxy.logger')
def test_delay_method_with_empty_entities(self, mock_logger):
"""Test delay method with empty rag_pipeline_invoke_entities."""
# Arrange
proxy = RagPipelineTaskProxy("tenant-123", "user-456", [])
# Act
proxy.delay()
# Assert
mock_logger.warning.assert_called_once_with(
"Received empty rag pipeline invoke entities, no tasks delivered: %s %s",
"tenant-123",
"user-456"
)

View File

@@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,priority_dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline

View File

@@ -260,16 +260,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
# Sets the maximum allowed duration of any statement before termination.
# Default is 60000 milliseconds.
# Default is 0 (no timeout).
#
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
POSTGRES_STATEMENT_TIMEOUT=60000
# A value of 0 prevents the server from timing out statements.
POSTGRES_STATEMENT_TIMEOUT=0
# Sets the maximum allowed duration of any idle in-transaction session before termination.
# Default is 60000 milliseconds.
# Default is 0 (no timeout).
#
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
# A value of 0 prevents the server from terminating idle sessions.
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# ------------------------------
# Redis Configuration

View File

@@ -115,8 +115,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes:
- ./volumes/db/data:/var/lib/postgresql/data
healthcheck:

View File

@@ -15,8 +15,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes:
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
ports:

View File

@@ -68,8 +68,8 @@ x-shared-env: &shared-api-worker-env
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000}
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-0}
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}
REDIS_HOST: ${REDIS_HOST:-redis}
REDIS_PORT: ${REDIS_PORT:-6379}
REDIS_USERNAME: ${REDIS_USERNAME:-}
@@ -724,8 +724,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes:
- ./volumes/db/data:/var/lib/postgresql/data
healthcheck:

View File

@@ -41,16 +41,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
# Sets the maximum allowed duration of any statement before termination.
# Default is 60000 milliseconds.
# Default is 0 (no timeout).
#
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
POSTGRES_STATEMENT_TIMEOUT=60000
# A value of 0 prevents the server from timing out statements.
POSTGRES_STATEMENT_TIMEOUT=0
# Sets the maximum allowed duration of any idle in-transaction session before termination.
# Default is 60000 milliseconds.
# Default is 0 (no timeout).
#
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
# A value of 0 prevents the server from terminating idle sessions.
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# -----------------------------
# Environment Variables for redis Service

View File

@@ -14,7 +14,6 @@ import timezone from 'dayjs/plugin/timezone'
import { createContext, useContext } from 'use-context-selector'
import { useShallow } from 'zustand/react/shallow'
import { useTranslation } from 'react-i18next'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import type { ChatItemInTree } from '../../base/chat/types'
import Indicator from '../../header/indicator'
import VarPanel from './var-panel'
@@ -43,10 +42,6 @@ import cn from '@/utils/classnames'
import { noop } from 'lodash-es'
import PromptLogModal from '../../base/prompt-log-modal'
type AppStoreState = ReturnType<typeof useAppStore.getState>
type ConversationListItem = ChatConversationGeneralDetail | CompletionConversationGeneralDetail
type ConversationSelection = ConversationListItem | { id: string; isPlaceholder?: true }
dayjs.extend(utc)
dayjs.extend(timezone)
@@ -206,7 +201,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
const { formatTime } = useTimestamp()
const { onClose, appDetail } = useContext(DrawerContext)
const { notify } = useContext(ToastContext)
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({
currentLogItem: state.currentLogItem,
setCurrentLogItem: state.setCurrentLogItem,
showMessageLogModal: state.showMessageLogModal,
@@ -898,113 +893,20 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
const { t } = useTranslation()
const { formatTime } = useTimestamp()
const router = useRouter()
const pathname = usePathname()
const searchParams = useSearchParams()
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
const [currentConversation, setCurrentConversation] = useState<ConversationSelection | undefined>() // Currently selected conversation
const closingConversationIdRef = useRef<string | null>(null)
const pendingConversationIdRef = useRef<string | null>(null)
const pendingConversationCacheRef = useRef<ConversationSelection | undefined>(undefined)
const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app
const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow((state: AppStoreState) => ({
const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow(state => ({
setShowPromptLogModal: state.setShowPromptLogModal,
setShowAgentLogModal: state.setShowAgentLogModal,
setShowMessageLogModal: state.setShowMessageLogModal,
})))
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
const buildUrlWithConversation = useCallback((conversationId?: string) => {
const params = new URLSearchParams(searchParams.toString())
if (conversationId)
params.set('conversation_id', conversationId)
else
params.delete('conversation_id')
const queryString = params.toString()
return queryString ? `${pathname}?${queryString}` : pathname
}, [pathname, searchParams])
const handleRowClick = useCallback((log: ConversationListItem) => {
if (conversationIdInUrl === log.id) {
if (!showDrawer)
setShowDrawer(true)
if (!currentConversation || currentConversation.id !== log.id)
setCurrentConversation(log)
return
}
pendingConversationIdRef.current = log.id
pendingConversationCacheRef.current = log
if (!showDrawer)
setShowDrawer(true)
if (currentConversation?.id !== log.id)
setCurrentConversation(undefined)
router.push(buildUrlWithConversation(log.id), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
const currentConversationId = currentConversation?.id
useEffect(() => {
if (!conversationIdInUrl) {
if (pendingConversationIdRef.current)
return
if (showDrawer || currentConversationId) {
setShowDrawer(false)
setCurrentConversation(undefined)
}
closingConversationIdRef.current = null
pendingConversationCacheRef.current = undefined
return
}
if (closingConversationIdRef.current === conversationIdInUrl)
return
if (pendingConversationIdRef.current === conversationIdInUrl)
pendingConversationIdRef.current = null
const matchedConversation = logs?.data?.find((item: ConversationListItem) => item.id === conversationIdInUrl)
const nextConversation: ConversationSelection = matchedConversation
?? pendingConversationCacheRef.current
?? { id: conversationIdInUrl, isPlaceholder: true }
if (!showDrawer)
setShowDrawer(true)
if (!currentConversation || currentConversation.id !== conversationIdInUrl || (matchedConversation && currentConversation !== matchedConversation))
setCurrentConversation(nextConversation)
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
pendingConversationCacheRef.current = undefined
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
const onCloseDrawer = useCallback(() => {
onRefresh()
setShowDrawer(false)
setCurrentConversation(undefined)
setShowPromptLogModal(false)
setShowAgentLogModal(false)
setShowMessageLogModal(false)
pendingConversationIdRef.current = null
pendingConversationCacheRef.current = undefined
closingConversationIdRef.current = conversationIdInUrl ?? null
if (conversationIdInUrl)
router.replace(buildUrlWithConversation(), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
// Annotated data needs to be highlighted
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
return (
@@ -1023,6 +925,15 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
)
}
const onCloseDrawer = () => {
onRefresh()
setShowDrawer(false)
setCurrentConversation(undefined)
setShowPromptLogModal(false)
setShowAgentLogModal(false)
setShowMessageLogModal(false)
}
if (!logs)
return <Loading />
@@ -1049,8 +960,11 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer')
return <tr
key={log.id}
className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', activeConversationId !== log.id ? '' : 'bg-background-default-hover')}
onClick={() => handleRowClick(log)}>
className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', currentConversation?.id !== log.id ? '' : 'bg-background-default-hover')}
onClick={() => {
setShowDrawer(true)
setCurrentConversation(log)
}}>
<td className='h-4'>
{!log.read_at && (
<div className='flex items-center p-3 pr-0.5'>

View File

@@ -13,7 +13,7 @@ import {
useWorkflowStore,
} from './store'
import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks'
import { CUSTOM_NODE } from './constants'
import { CUSTOM_NODE, ITERATION_PADDING } from './constants'
import { getIterationStartNode, getLoopStartNode } from './utils'
import CustomNode from './nodes'
import CustomNoteNode from './note-node'
@@ -41,7 +41,33 @@ const CandidateNode = () => {
} = store.getState()
const { screenToFlowPosition } = reactflow
const nodes = getNodes()
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
// Get mouse position in flow coordinates (this is where the top-left corner should be)
let { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
// If the node has a parent (e.g., inside iteration), apply constraints and convert to relative position
if (candidateNode.parentId) {
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.position) {
// Apply boundary constraints for iteration nodes
if (candidateNode.data.isInIteration) {
const nodeWidth = candidateNode.width || 0
const nodeHeight = candidateNode.height || 0
const minX = parentNode.position.x + ITERATION_PADDING.left
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
const minY = parentNode.position.y + ITERATION_PADDING.top
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
// Constrain position
x = Math.max(minX, Math.min(maxX, x))
y = Math.max(minY, Math.min(maxY, y))
}
// Convert to relative position
x = x - parentNode.position.x
y = y - parentNode.position.y
}
}
const newNodes = produce(nodes, (draft) => {
draft.push({
...candidateNode,
@@ -59,6 +85,20 @@ const CandidateNode = () => {
if (candidateNode.data.type === BlockEnum.Loop)
draft.push(getLoopStartNode(candidateNode.id))
// Update parent iteration node's _children array
if (candidateNode.parentId && candidateNode.data.isInIteration) {
const parentNode = draft.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.data.type === BlockEnum.Iteration) {
if (!parentNode.data._children)
parentNode.data._children = []
parentNode.data._children.push({
nodeId: candidateNode.id,
nodeType: candidateNode.data.type,
})
}
}
})
setNodes(newNodes)
if (candidateNode.type === CUSTOM_NOTE_NODE)
@@ -84,6 +124,34 @@ const CandidateNode = () => {
if (!candidateNode)
return null
// Apply boundary constraints if node is inside iteration
if (candidateNode.parentId && candidateNode.data.isInIteration) {
const { getNodes } = store.getState()
const nodes = getNodes()
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.position) {
const { screenToFlowPosition, flowToScreenPosition } = reactflow
// Get mouse position in flow coordinates
const flowPosition = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
// Calculate boundaries in flow coordinates
const nodeWidth = candidateNode.width || 0
const nodeHeight = candidateNode.height || 0
const minX = parentNode.position.x + ITERATION_PADDING.left
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
const minY = parentNode.position.y + ITERATION_PADDING.top
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
// Constrain position
const constrainedX = Math.max(minX, Math.min(maxX, flowPosition.x))
const constrainedY = Math.max(minY, Math.min(maxY, flowPosition.y))
// Convert back to screen coordinates
flowToScreenPosition({ x: constrainedX, y: constrainedY })
}
}
return (
<div
className='absolute z-10'

View File

@@ -737,7 +737,7 @@ export const useNodesInteractions = () => {
targetHandle = 'target',
toolDefaultValue,
},
{ prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle },
{ prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle, skipAutoConnect },
) => {
if (getNodesReadOnly()) return
@@ -830,7 +830,7 @@ export const useNodesInteractions = () => {
}
let newEdge = null
if (nodeType !== BlockEnum.DataSource) {
if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
newEdge = {
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
type: CUSTOM_EDGE,
@@ -970,6 +970,7 @@ export const useNodesInteractions = () => {
nodeType !== BlockEnum.IfElse
&& nodeType !== BlockEnum.QuestionClassifier
&& nodeType !== BlockEnum.LoopEnd
&& !skipAutoConnect
) {
newEdge = {
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
@@ -1119,7 +1120,7 @@ export const useNodesInteractions = () => {
)
let newPrevEdge = null
if (nodeType !== BlockEnum.DataSource) {
if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
newPrevEdge = {
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
type: CUSTOM_EDGE,
@@ -1159,6 +1160,7 @@ export const useNodesInteractions = () => {
nodeType !== BlockEnum.IfElse
&& nodeType !== BlockEnum.QuestionClassifier
&& nodeType !== BlockEnum.LoopEnd
&& !skipAutoConnect
) {
newNextEdge = {
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,

View File

@@ -15,7 +15,9 @@ import {
useNodesSyncDraft,
} from '@/app/components/workflow/hooks'
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import type { Node } from '@/app/components/workflow/types'
import { BlockEnum, type Node } from '@/app/components/workflow/types'
import PanelAddBlock from '@/app/components/workflow/nodes/iteration/panel-add-block'
import type { IterationNodeType } from '@/app/components/workflow/nodes/iteration/types'
type PanelOperatorPopupProps = {
id: string
@@ -51,6 +53,9 @@ const PanelOperatorPopup = ({
(showChangeBlock || canRunBySingle(data.type, isChildNode)) && (
<>
<div className='p-1'>
{data.type === BlockEnum.Iteration && (
<PanelAddBlock iterationNodeData={data as IterationNodeType} onClosePopup={onClosePopup}/>
)}
{
canRunBySingle(data.type, isChildNode) && (
<div

View File

@@ -0,0 +1,116 @@
import {
memo,
useCallback,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import type { OffsetOptions } from '@floating-ui/react'
import { useStoreApi } from 'reactflow'
import BlockSelector from '@/app/components/workflow/block-selector'
import type {
OnSelectBlock,
} from '@/app/components/workflow/types'
import {
BlockEnum,
} from '@/app/components/workflow/types'
import { useAvailableBlocks, useNodesMetaData, useNodesReadOnly, usePanelInteractions } from '../../hooks'
import type { IterationNodeType } from './types'
import { useWorkflowStore } from '../../store'
import { generateNewNode, getNodeCustomTypeByNodeDataType } from '../../utils'
import { ITERATION_CHILDREN_Z_INDEX } from '../../constants'
type AddBlockProps = {
renderTrigger?: (open: boolean) => React.ReactNode
offset?: OffsetOptions
iterationNodeData: IterationNodeType
onClosePopup: () => void
}
const AddBlock = ({
offset,
iterationNodeData,
onClosePopup,
}: AddBlockProps) => {
const { t } = useTranslation()
const store = useStoreApi()
const workflowStore = useWorkflowStore()
const { nodesReadOnly } = useNodesReadOnly()
const { handlePaneContextmenuCancel } = usePanelInteractions()
const [open, setOpen] = useState(false)
const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, false)
const { nodesMap: nodesMetaDataMap } = useNodesMetaData()
const handleOpenChange = useCallback((open: boolean) => {
setOpen(open)
if (!open)
handlePaneContextmenuCancel()
}, [handlePaneContextmenuCancel])
const handleSelect = useCallback<OnSelectBlock>((type, toolDefaultValue) => {
const { getNodes } = store.getState()
const nodes = getNodes()
const nodesWithSameType = nodes.filter(node => node.data.type === type)
const { defaultValue } = nodesMetaDataMap![type]
// Find the parent iteration node
const parentIterationNode = nodes.find(node => node.data.start_node_id === iterationNodeData.start_node_id)
const { newNode } = generateNewNode({
type: getNodeCustomTypeByNodeDataType(type),
data: {
...(defaultValue as any),
title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title,
...toolDefaultValue,
_isCandidate: true,
// Set iteration-specific properties
isInIteration: true,
iteration_id: parentIterationNode?.id,
},
position: {
x: 0,
y: 0,
},
})
// Set parent and z-index for iteration child
if (parentIterationNode) {
newNode.parentId = parentIterationNode.id
newNode.extent = 'parent' as any
newNode.zIndex = ITERATION_CHILDREN_Z_INDEX
}
workflowStore.setState({
candidateNode: newNode,
})
onClosePopup()
}, [store, workflowStore, nodesMetaDataMap, iterationNodeData.start_node_id, onClosePopup])
const renderTrigger = () => {
return (
<div
className='flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
>
{t('workflow.common.addBlock')}
</div>
)
}
return (
<BlockSelector
open={open}
onOpenChange={handleOpenChange}
disabled={nodesReadOnly}
onSelect={handleSelect}
placement='right-start'
offset={offset ?? {
mainAxis: 4,
crossAxis: -8,
}}
trigger={renderTrigger}
popupClassName='!min-w-[256px]'
availableBlocksTypes={availableNextBlocks}
/>
)
}
export default memo(AddBlock)

View File

@@ -379,6 +379,7 @@ export type OnNodeAdd = (
prevNodeSourceHandle?: string
nextNodeId?: string
nextNodeTargetHandle?: string
skipAutoConnect?: boolean
},
) => void