Compare commits

...

1 Commits

Author SHA1 Message Date
-LAN-
6a853d75ea refactor(workflow): inject redis into graph engine manager 2026-02-26 15:30:25 +08:00
13 changed files with 90 additions and 69 deletions

View File

@@ -56,8 +56,6 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@@ -105,7 +103,6 @@ forbidden_modules =
core.variables
ignore_imports =
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager
@@ -243,7 +240,6 @@ ignore_imports =
core.workflow.variable_loader -> core.variables
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database

View File

@@ -33,6 +33,7 @@ from core.workflow.enums import NodeType
from core.workflow.file.models import File
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@@ -740,7 +741,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -44,6 +44,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.app_fields import (
app_detail_fields_with_site,
deleted_tool_fields,
@@ -225,7 +226,7 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -23,6 +23,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
@@ -100,6 +101,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -31,6 +31,7 @@ from core.model_runtime.errors.invoke import InvokeError
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import OptionalTimestampField, TimestampField
@@ -280,7 +281,7 @@ class WorkflowTaskStopApi(Resource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -24,6 +24,7 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
@@ -121,6 +122,6 @@ class WorkflowTaskStopApi(WebApiResource):
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@@ -7,12 +7,28 @@ Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Any, final
from contextlib import AbstractContextManager
from typing import Any, Protocol, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
class RedisPipelineProtocol(Protocol):
"""Minimal Redis pipeline contract used by the command channel."""
def lrange(self, name: str, start: int, end: int) -> Any: ...
def delete(self, *names: str) -> Any: ...
def execute(self) -> list[Any]: ...
def rpush(self, name: str, *values: str) -> Any: ...
def expire(self, name: str, time: int) -> Any: ...
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
def get(self, name: str) -> Any: ...
class RedisClientProtocol(Protocol):
"""Redis client contract required by the command channel."""
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
@final
@@ -26,7 +42,7 @@ class RedisChannel:
def __init__(
self,
redis_client: "RedisClientWrapper",
redis_client: RedisClientProtocol,
channel_key: str,
command_ttl: int = 3600,
) -> None:

View File

@@ -3,13 +3,14 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Callers must provide a Redis client dependency from outside the workflow package.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
@@ -17,7 +18,6 @@ from core.workflow.graph_engine.entities.commands import (
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -31,8 +31,12 @@ class GraphEngineManager:
by sending commands through Redis channels, without user validation.
"""
@staticmethod
def send_stop_command(task_id: str, reason: str | None = None) -> None:
_redis_client: RedisClientProtocol
def __init__(self, redis_client: RedisClientProtocol) -> None:
self._redis_client = redis_client
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
@@ -41,34 +45,31 @@ class GraphEngineManager:
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
GraphEngineManager._send_command(task_id, abort_command)
self._send_command(task_id, abort_command)
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
self._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
self._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
channel = RedisChannel(self._redis_client, channel_key)
try:
channel.send_command(command)

View File

@@ -111,6 +111,7 @@ class RedisClientWrapper:
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ...
def __getattr__(self, item: str) -> Any:
if self._client is None:

View File

@@ -8,6 +8,7 @@ new GraphEngine command channel mechanism.
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_redis import redis_client
from models.model import AppMode
@@ -42,4 +43,4 @@ class AppTaskService:
# New mechanism: Send stop command via GraphEngine for workflow-based apps
# This ensures proper workflow status recording in the persistence layer
if app_mode in (AppMode.ADVANCED_CHAT, AppMode.WORKFLOW):
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)

View File

@@ -596,7 +596,8 @@ class TestWorkflowTaskStopApiPost:
assert result == {"result": "success"}
mock_queue_mgr.set_stop_flag_no_user_check.assert_called_once_with("task-1")
mock_graph_mgr.send_stop_command.assert_called_once_with("task-1")
mock_graph_mgr.assert_called_once()
mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
"""Test NotWorkflowAppError when app mode is not workflow."""

View File

@@ -32,25 +32,26 @@ class TestRedisStopIntegration:
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
# Execute
GraphEngineManager.send_stop_command(task_id, reason="Test stop")
manager = GraphEngineManager(mock_redis)
# Verify
mock_redis.pipeline.assert_called_once()
# Execute
manager.send_stop_command(task_id, reason="Test stop")
# Check that rpush was called with correct arguments
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
# Verify
mock_redis.pipeline.assert_called_once()
# Verify the channel key
assert calls[0][0][0] == expected_channel_key
# Check that rpush was called with correct arguments
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
# Verify the command data
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
# Verify the channel key
assert calls[0][0][0] == expected_channel_key
# Verify the command data
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.ABORT
assert command_data["reason"] == "Test stop"
def test_graph_engine_manager_sends_pause_command(self):
"""Test that GraphEngineManager correctly sends pause command through Redis."""
@@ -62,18 +63,18 @@ class TestRedisStopIntegration:
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
GraphEngineManager.send_pause_command(task_id, reason="Awaiting resources")
manager = GraphEngineManager(mock_redis)
manager.send_pause_command(task_id, reason="Awaiting resources")
mock_redis.pipeline.assert_called_once()
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert calls[0][0][0] == expected_channel_key
mock_redis.pipeline.assert_called_once()
calls = mock_pipeline.rpush.call_args_list
assert len(calls) == 1
assert calls[0][0][0] == expected_channel_key
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.PAUSE.value
assert command_data["reason"] == "Awaiting resources"
command_json = calls[0][0][1]
command_data = json.loads(command_json)
assert command_data["command_type"] == CommandType.PAUSE.value
assert command_data["reason"] == "Awaiting resources"
def test_graph_engine_manager_handles_redis_failure_gracefully(self):
"""Test that GraphEngineManager handles Redis failures without raising exceptions."""
@@ -82,13 +83,13 @@ class TestRedisStopIntegration:
# Mock redis client to raise exception
mock_redis = MagicMock()
mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed")
manager = GraphEngineManager(mock_redis)
with patch("core.workflow.graph_engine.manager.redis_client", mock_redis):
# Should not raise exception
try:
GraphEngineManager.send_stop_command(task_id)
except Exception as e:
pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly")
# Should not raise exception
try:
manager.send_stop_command(task_id)
except Exception as e:
pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly")
def test_app_queue_manager_no_user_check(self):
"""Test that AppQueueManager.set_stop_flag_no_user_check works without user validation."""
@@ -251,13 +252,10 @@ class TestRedisStopIntegration:
mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline)
mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None)
with (
patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis),
patch("core.workflow.graph_engine.manager.redis_client", mock_redis),
):
with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis):
# Execute both stop mechanisms
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager.send_stop_command(task_id)
GraphEngineManager(mock_redis).send_stop_command(task_id)
# Verify legacy stop flag was set
expected_stop_flag_key = f"generate_task_stopped:{task_id}"

View File

@@ -44,9 +44,10 @@ class TestAppTaskService:
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
if should_call_graph_engine:
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
mock_graph_engine_manager.assert_called_once()
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
else:
mock_graph_engine_manager.send_stop_command.assert_not_called()
mock_graph_engine_manager.assert_not_called()
@pytest.mark.parametrize(
"invoke_from",
@@ -76,7 +77,8 @@ class TestAppTaskService:
# Assert
mock_app_queue_manager.set_stop_flag.assert_called_once_with(task_id, invoke_from, user_id)
mock_graph_engine_manager.send_stop_command.assert_called_once_with(task_id)
mock_graph_engine_manager.assert_called_once()
mock_graph_engine_manager.return_value.send_stop_command.assert_called_once_with(task_id)
@patch("services.app_task_service.GraphEngineManager")
@patch("services.app_task_service.AppQueueManager")
@@ -96,7 +98,7 @@ class TestAppTaskService:
app_mode = AppMode.ADVANCED_CHAT
# Simulate GraphEngine failure
mock_graph_engine_manager.send_stop_command.side_effect = Exception("GraphEngine error")
mock_graph_engine_manager.return_value.send_stop_command.side_effect = Exception("GraphEngine error")
# Act & Assert - should raise the exception since it's not caught
with pytest.raises(Exception, match="GraphEngine error"):