mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
feat(api): Introduce WorkflowResumptionContext for pause state management (#28122)
Certain metadata (including but not limited to `InvokeFrom`, `call_depth`, and `streaming`) is required when resuming a paused workflow. However, these fields are not part of `GraphRuntimeState` and were not saved in the previous implementation of `PauseStatePersistenceLayer`. This commit addresses this limitation by introducing a `WorkflowResumptionContext` model that wraps both the `*GenerateEntity` and `GraphRuntimeState`. This approach provides: - A structured container for all necessary resumption data - Better separation of concerns between execution state and persistence - Enhanced extensibility for future metadata additions - Clearer naming that distinguishes from `GraphRuntimeState` The `WorkflowResumptionContext` model makes extending the pause state easier while maintaining backward compatibility and proper version management for the entire execution state ecosystem. Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,12 @@ import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
PauseStatePersistenceLayer,
|
||||
WorkflowResumptionContext,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
@@ -39,7 +44,7 @@ from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.model import UploadFile
|
||||
from models.model import AppMode, UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.file_service import FileService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
@@ -226,11 +231,39 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
|
||||
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||
|
||||
def _create_generate_entity(
|
||||
self,
|
||||
workflow_execution_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
) -> WorkflowAppGenerateEntity:
|
||||
execution_id = workflow_execution_id or getattr(self, "test_workflow_run_id", str(uuid.uuid4()))
|
||||
wf_id = workflow_id or getattr(self, "test_workflow_id", str(uuid.uuid4()))
|
||||
tenant_id = getattr(self, "test_tenant_id", "tenant-123")
|
||||
app_id = getattr(self, "test_app_id", "app-123")
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id=str(tenant_id),
|
||||
app_id=str(app_id),
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id=str(wf_id),
|
||||
)
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user_id or getattr(self, "test_user_id", str(uuid.uuid4())),
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id=execution_id,
|
||||
)
|
||||
|
||||
def _create_pause_state_persistence_layer(
|
||||
self,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
workflow: Workflow | None = None,
|
||||
state_owner_user_id: str | None = None,
|
||||
generate_entity: WorkflowAppGenerateEntity | None = None,
|
||||
) -> PauseStatePersistenceLayer:
|
||||
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||
owner_id = state_owner_user_id
|
||||
@@ -244,10 +277,23 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
|
||||
assert owner_id is not None
|
||||
owner_id = str(owner_id)
|
||||
workflow_execution_id = (
|
||||
workflow_run.id if workflow_run is not None else getattr(self, "test_workflow_run_id", None)
|
||||
)
|
||||
assert workflow_execution_id is not None
|
||||
workflow_id = workflow.id if workflow is not None else getattr(self, "test_workflow_id", None)
|
||||
assert workflow_id is not None
|
||||
entity_user_id = getattr(self, "test_user_id", owner_id)
|
||||
entity = generate_entity or self._create_generate_entity(
|
||||
workflow_execution_id=str(workflow_execution_id),
|
||||
user_id=entity_user_id,
|
||||
workflow_id=str(workflow_id),
|
||||
)
|
||||
|
||||
return PauseStatePersistenceLayer(
|
||||
session_factory=self.session.get_bind(),
|
||||
state_owner_user_id=owner_id,
|
||||
generate_entity=entity,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
@@ -297,10 +343,15 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
resumption_context = WorkflowResumptionContext.loads(storage_content)
|
||||
assert resumption_context.version == "1"
|
||||
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
actual_state = json.loads(storage_content)
|
||||
|
||||
actual_state = json.loads(resumption_context.serialized_graph_runtime_state)
|
||||
assert actual_state == expected_state
|
||||
persisted_entity = resumption_context.get_generate_entity()
|
||||
assert isinstance(persisted_entity, WorkflowAppGenerateEntity)
|
||||
assert persisted_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
@@ -341,13 +392,15 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
retrieved_state = json.loads(state_bytes.decode())
|
||||
resumption_context = WorkflowResumptionContext.loads(state_bytes.decode())
|
||||
retrieved_state = json.loads(resumption_context.serialized_graph_runtime_state)
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
|
||||
assert retrieved_state == expected_state
|
||||
assert retrieved_state["outputs"] == complex_outputs
|
||||
assert retrieved_state["total_tokens"] == 250
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
@@ -410,7 +463,9 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
|
||||
# Verify content in storage
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
assert storage_content == graph_runtime_state.dumps()
|
||||
resumption_context = WorkflowResumptionContext.loads(storage_content)
|
||||
assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps()
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
@@ -474,6 +529,8 @@ class TestPauseStatePersistenceLayerTestContainers:
|
||||
# Verify the state owner is the workflow creator
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||
assert pause_entity is not None
|
||||
resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode())
|
||||
assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
|
||||
@@ -4,7 +4,14 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
PauseStatePersistenceLayer,
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
@@ -15,6 +22,7 @@ from core.workflow.graph_events.graph import (
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@@ -170,6 +178,25 @@ class MockCommandChannel:
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
@staticmethod
|
||||
def _create_generate_entity(workflow_execution_id: str = "run-123") -> WorkflowAppGenerateEntity:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-123",
|
||||
)
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="task-123",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-123",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id=workflow_execution_id,
|
||||
)
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
@@ -177,6 +204,7 @@ class TestPauseStatePersistenceLayer:
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
@@ -186,7 +214,11 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
@@ -198,7 +230,12 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
generate_entity = self._create_generate_entity(workflow_execution_id="run-123")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=generate_entity,
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
@@ -221,12 +258,20 @@ class TestPauseStatePersistenceLayer:
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
state=mock_repo.create_workflow_pause.call_args.kwargs["state"],
|
||||
)
|
||||
serialized_state = mock_repo.create_workflow_pause.call_args.kwargs["state"]
|
||||
resumption_context = WorkflowResumptionContext.loads(serialized_state)
|
||||
assert resumption_context.serialized_graph_runtime_state == expected_state
|
||||
assert resumption_context.get_generate_entity().model_dump() == generate_entity.model_dump()
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
@@ -250,7 +295,11 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
@@ -259,7 +308,11 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id="owner-123",
|
||||
generate_entity=self._create_generate_entity(),
|
||||
)
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
@@ -276,3 +329,82 @@ class TestPauseStatePersistenceLayer:
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
|
||||
def _build_workflow_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
|
||||
"""Create a WorkflowAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-roundtrip",
|
||||
app_id="app-roundtrip",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-roundtrip",
|
||||
)
|
||||
serialized_state = json.dumps({"state": "workflow"})
|
||||
|
||||
return WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=serialized_state,
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(
|
||||
entity=WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=app_config,
|
||||
inputs={"input_key": "input_value"},
|
||||
files=[],
|
||||
user_id="user-roundtrip",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_id="workflow-exec-roundtrip",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_advanced_chat_generate_entity_for_roundtrip() -> WorkflowResumptionContext:
|
||||
"""Create an AdvancedChatAppGenerateEntity with realistic data for WorkflowResumptionContext tests."""
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-advanced",
|
||||
app_id="app-advanced",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-advanced",
|
||||
)
|
||||
serialized_state = json.dumps({"state": "workflow"})
|
||||
|
||||
return WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=serialized_state,
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(
|
||||
entity=AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=app_config,
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="advanced-user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_run_id="advanced-run-id",
|
||||
query="Explain serialization behavior",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state",
|
||||
[
|
||||
pytest.param(
|
||||
_build_advanced_chat_generate_entity_for_roundtrip(),
|
||||
id="advanced_chat",
|
||||
),
|
||||
pytest.param(
|
||||
_build_workflow_generate_entity_for_roundtrip(),
|
||||
id="workflow",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_dumps_loads_roundtrip(state: WorkflowResumptionContext):
|
||||
"""WorkflowResumptionContext roundtrip preserves workflow generate entity metadata."""
|
||||
dumped = state.dumps()
|
||||
loaded = WorkflowResumptionContext.loads(dumped)
|
||||
|
||||
assert loaded == state
|
||||
assert loaded.serialized_graph_runtime_state == state.serialized_graph_runtime_state
|
||||
restored_entity = loaded.get_generate_entity()
|
||||
assert isinstance(restored_entity, type(state.generate_entity.entity))
|
||||
|
||||
Reference in New Issue
Block a user