mirror of
https://github.com/langgenius/dify.git
synced 2026-02-26 11:25:10 +00:00
Compare commits
1 Commits
inject-red
...
inject-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0be7ac6b9b |
@@ -56,6 +56,8 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
@@ -103,11 +105,11 @@ forbidden_modules =
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
@@ -240,6 +242,7 @@ ignore_imports =
|
||||
core.workflow.variable_loader -> core.variables
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
|
||||
@@ -33,7 +33,6 @@ from core.workflow.enums import NodeType
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@@ -741,7 +740,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ from core.errors.error import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.app_fields import (
|
||||
app_detail_fields_with_site,
|
||||
deleted_tool_fields,
|
||||
@@ -226,7 +225,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
@@ -101,6 +100,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -31,7 +31,6 @@ from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
@@ -281,7 +280,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@@ -122,6 +121,6 @@ class WorkflowTaskStopApi(WebApiResource):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from typing import TYPE_CHECKING, final
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@@ -13,7 +14,8 @@ from core.workflow.enums import NodeType
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
@@ -27,6 +29,24 @@ if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
language: CodeLanguage,
|
||||
code: str,
|
||||
inputs: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
return CodeExecutor.execute_workflow_code_template(
|
||||
language=language,
|
||||
code=code,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool:
|
||||
return isinstance(error, CodeExecutionError)
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
@@ -43,7 +63,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
) -> None:
|
||||
self.graph_init_params = graph_init_params
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self._code_executor: type[CodeExecutor] = CodeExecutor
|
||||
self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor()
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = CodeNode.default_code_providers()
|
||||
self._code_limits = CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
|
||||
@@ -7,28 +7,12 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, final
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
|
||||
class RedisPipelineProtocol(Protocol):
|
||||
"""Minimal Redis pipeline contract used by the command channel."""
|
||||
|
||||
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
def execute(self) -> list[Any]: ...
|
||||
def rpush(self, name: str, *values: str) -> Any: ...
|
||||
def expire(self, name: str, time: int) -> Any: ...
|
||||
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
||||
def get(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
"""Redis client contract required by the command channel."""
|
||||
|
||||
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
|
||||
@final
|
||||
@@ -42,7 +26,7 @@ class RedisChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: RedisClientProtocol,
|
||||
redis_client: "RedisClientWrapper",
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
|
||||
@@ -3,14 +3,13 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Callers must provide a Redis client dependency from outside the workflow package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
@@ -18,6 +17,7 @@ from core.workflow.graph_engine.entities.commands import (
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,12 +31,8 @@ class GraphEngineManager:
|
||||
by sending commands through Redis channels, without user validation.
|
||||
"""
|
||||
|
||||
_redis_client: RedisClientProtocol
|
||||
|
||||
def __init__(self, redis_client: RedisClientProtocol) -> None:
|
||||
self._redis_client = redis_client
|
||||
|
||||
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
@@ -45,31 +41,34 @@ class GraphEngineManager:
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
self._send_command(task_id, abort_command)
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
|
||||
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
self._send_command(task_id, pause_command)
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
|
||||
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
self._send_command(task_id, update_command)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
|
||||
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(self._redis_client, channel_key)
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
@@ -11,7 +10,7 @@ from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
@@ -25,6 +24,18 @@ if TYPE_CHECKING:
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class WorkflowCodeExecutor(Protocol):
|
||||
def execute(
|
||||
self,
|
||||
*,
|
||||
language: CodeLanguage,
|
||||
code: str,
|
||||
inputs: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool: ...
|
||||
|
||||
|
||||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
|
||||
@@ -40,7 +51,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
code_executor: type[CodeExecutor] | None = None,
|
||||
code_executor: WorkflowCodeExecutor,
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits,
|
||||
) -> None:
|
||||
@@ -50,7 +61,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
|
||||
self._code_executor: WorkflowCodeExecutor = code_executor
|
||||
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
|
||||
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
|
||||
)
|
||||
@@ -98,7 +109,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
# Run code
|
||||
try:
|
||||
_ = self._select_code_provider(code_language)
|
||||
result = self._code_executor.execute_workflow_code_template(
|
||||
result = self._code_executor.execute(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables,
|
||||
@@ -106,7 +117,13 @@ class CodeNode(Node[CodeNodeData]):
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
except CodeNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
except Exception as e:
|
||||
if not self._code_executor.is_execution_error(e):
|
||||
raise
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
)
|
||||
|
||||
@@ -111,7 +111,6 @@ class RedisClientWrapper:
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if self._client is None:
|
||||
|
||||
@@ -8,7 +8,6 @@ new GraphEngine command channel mechanism.
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@@ -43,4 +42,4 @@ class AppTaskService:
|
||||
# New mechanism: Send stop command via GraphEngine for workflow-based apps
|
||||
# This ensures proper workflow status recording in the persistence layer
|
||||
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
@@ -68,6 +68,7 @@ def init_code_node(code_config: dict):
|
||||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
code_executor=node_factory._code_executor,
|
||||
code_limits=CodeNodeLimits(
|
||||
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
|
||||
max_number=dify_config.CODE_MAX_NUMBER,
|
||||
|
||||
@@ -596,8 +596,7 @@ class TestWorkflowTaskStopApiPost:
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
mock_graph_mgr.assert_called_once()
|
||||
mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1")
|
||||
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
|
||||
|
||||
def test_stop_workflow_task_wrong_app_mode(self, app):
|
||||
"""Test NotWorkflowAppError when app mode is not workflow."""
|
||||
|
||||
@@ -24,6 +24,16 @@ DEFAULT_CODE_LIMITS = CodeNodeLimits(
|
||||
)
|
||||
|
||||
|
||||
class _NoopCodeExecutor:
|
||||
def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]:
|
||||
_ = (language, code, inputs)
|
||||
return {}
|
||||
|
||||
def is_execution_error(self, error: Exception) -> bool:
|
||||
_ = error
|
||||
return False
|
||||
|
||||
|
||||
class TestMockTemplateTransformNode:
|
||||
"""Test cases for MockTemplateTransformNode."""
|
||||
|
||||
@@ -319,6 +329,7 @@ class TestMockCodeNode:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_executor=_NoopCodeExecutor(),
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
@@ -384,6 +395,7 @@ class TestMockCodeNode:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_executor=_NoopCodeExecutor(),
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
@@ -453,6 +465,7 @@ class TestMockCodeNode:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
code_executor=_NoopCodeExecutor(),
|
||||
code_limits=DEFAULT_CODE_LIMITS,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,26 +32,25 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
manager = GraphEngineManager(mock_redis)
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
# Execute
|
||||
GraphEngineManager.send_stop_command(task_id, reason="Test stop")
|
||||
|
||||
# Execute
|
||||
manager.send_stop_command(task_id, reason="Test stop")
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
|
||||
# Verify
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
# Check that rpush was called with correct arguments
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
|
||||
# Check that rpush was called with correct arguments
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
# Verify the channel key
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
# Verify the channel key
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
# Verify the command data
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
# Verify the command data
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.ABORT
|
||||
assert command_data["reason"] == "Test stop"
|
||||
|
||||
def test_graph_engine_manager_sends_pause_command(self):
|
||||
"""Test that GraphEngineManager correctly sends pause command through Redis."""
|
||||
@@ -63,18 +62,18 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
manager = GraphEngineManager(mock_redis)
|
||||
manager.send_pause_command(task_id, reason="Awaiting resources")
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
|
||||
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
calls = mock_pipeline.rpush.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0][0][0] == expected_channel_key
|
||||
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert command_data["reason"] == "Awaiting resources"
|
||||
command_json = calls[0][0][1]
|
||||
command_data = json.loads(command_json)
|
||||
assert command_data["command_type"] == CommandType.PAUSE.value
|
||||
assert command_data["reason"] == "Awaiting resources"
|
||||
|
||||
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
|
||||
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
|
||||
@@ -83,13 +82,13 @@ class TestRedisStopIntegration:
|
||||
# Mock redis client to raise exception
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed")
|
||||
manager = GraphEngineManager(mock_redis)
|
||||
|
||||
# Should not raise exception
|
||||
try:
|
||||
manager.send_stop_command(task_id)
|
||||
except Exception as e:
|
||||
pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly")
|
||||
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
|
||||
# Should not raise exception
|
||||
try:
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
except Exception as e:
|
||||
pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly")
|
||||
|
||||
def test_app_queue_manager_no_user_check(self):
|
||||
"""Test that AppQueueManager.set_stop_flag_no_user_check works without user validation."""
|
||||
@@ -252,10 +251,13 @@ class TestRedisStopIntegration:
|
||||
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
|
||||
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis):
|
||||
with (
|
||||
patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis),
|
||||
patch("core.workflow.graph_engine.manager.redis_client", mock_redis),
|
||||
):
|
||||
# Execute both stop mechanisms
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(mock_redis).send_stop_command(task_id)
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
# Verify legacy stop flag was set
|
||||
expected_stop_flag_key = f"generate_task_stopped:{task_id}"
|
||||
|
||||
@@ -44,10 +44,9 @@ class TestAppTaskService:
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
if should_call_graph_engine:
|
||||
mock_graph_engine_manager.assert_called_once()
|
||||
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
else:
|
||||
mock_graph_engine_manager.assert_not_called()
|
||||
mock_graph_engine_manager.send_stop_command.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_from",
|
||||
@@ -77,8 +76,7 @@ class TestAppTaskService:
|
||||
|
||||
# Assert
|
||||
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
|
||||
mock_graph_engine_manager.assert_called_once()
|
||||
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
|
||||
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
|
||||
|
||||
@patch("services.app_task_service.GraphEngineManager")
|
||||
@patch("services.app_task_service.AppQueueManager")
|
||||
@@ -98,7 +96,7 @@ class TestAppTaskService:
|
||||
app_mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
# Simulate GraphEngine failure
|
||||
mock_graph_engine_manager.return_value.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
|
||||
|
||||
# Act & Assert - should raise the exception since it's not caught
|
||||
with pytest.raises(Exception, match="GraphEngine error"):
|
||||
|
||||
Reference in New Issue
Block a user