mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
feat(api): Introduce workflow pause state management (#27298)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
# Core integration tests package
|
||||
@@ -0,0 +1 @@
|
||||
# App integration tests package
|
||||
@@ -0,0 +1 @@
|
||||
# Layers integration tests package
|
||||
@@ -0,0 +1,520 @@
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
|
||||
|
||||
This test suite covers complete integration scenarios including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using test storage backend
|
||||
- Complete workflow: event -> state serialization -> database save -> storage save
|
||||
- Testing with actual WorkflowRunService (not mocked)
|
||||
- Real Workflow and WorkflowRun instances in database
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in database
|
||||
- Error handling with real database constraints
|
||||
- Multiple pause events in sequence
|
||||
- Integration with real ReadOnlyGraphRuntimeState implementations
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
|
||||
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from services.file_service import FileService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
class _TestCommandChannelImpl:
|
||||
"""Real implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
"""Fetch pending commands for this GraphEngine instance."""
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to be processed by this GraphEngine instance."""
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayerTestContainers:
|
||||
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, db_session_with_containers: Session):
|
||||
"""Get database engine from TestContainers session."""
|
||||
bind = db_session_with_containers.get_bind()
|
||||
assert isinstance(bind, Engine)
|
||||
return bind
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(self, engine: Engine):
|
||||
"""Create FileService instance with TestContainers engine."""
|
||||
return FileService(engine)
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, engine: Engine, file_service: FileService):
|
||||
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
|
||||
return WorkflowRunService(engine)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
self.test_workflow_run_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Create test workflow run
|
||||
self.test_workflow_run = WorkflowRun(
|
||||
id=self.test_workflow_run_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session and service instances
|
||||
self.session = db_session_with_containers
|
||||
self.file_service = file_service
|
||||
self.workflow_run_service = workflow_run_service
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.add(self.test_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
try:
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
raise e
|
||||
|
||||
def _create_graph_runtime_state(
|
||||
self,
|
||||
outputs: dict[str, object] | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
"""Create a real GraphRuntimeState for testing."""
|
||||
start_at = time()
|
||||
|
||||
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
|
||||
if variables:
|
||||
for (node_id, var_key), value in variables.items():
|
||||
variable_pool.add([node_id, var_key], value)
|
||||
|
||||
# Create LLM usage
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Create graph runtime state
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=start_at,
|
||||
total_tokens=total_tokens,
|
||||
llm_usage=llm_usage,
|
||||
outputs=outputs or {},
|
||||
node_run_steps=node_run_steps,
|
||||
)
|
||||
|
||||
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
|
||||
|
||||
def _create_pause_state_persistence_layer(
|
||||
self,
|
||||
workflow_run: WorkflowRun | None = None,
|
||||
workflow: Workflow | None = None,
|
||||
state_owner_user_id: str | None = None,
|
||||
) -> PauseStatePersistenceLayer:
|
||||
"""Create PauseStatePersistenceLayer with real dependencies."""
|
||||
owner_id = state_owner_user_id
|
||||
if owner_id is None:
|
||||
if workflow is not None and workflow.created_by:
|
||||
owner_id = workflow.created_by
|
||||
elif workflow_run is not None and workflow_run.created_by:
|
||||
owner_id = workflow_run.created_by
|
||||
else:
|
||||
owner_id = getattr(self, "test_user_id", None)
|
||||
|
||||
assert owner_id is not None
|
||||
owner_id = str(owner_id)
|
||||
|
||||
return PauseStatePersistenceLayer(
|
||||
session_factory=self.session.get_bind(),
|
||||
state_owner_user_id=owner_id,
|
||||
)
|
||||
|
||||
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
|
||||
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create real graph runtime state with test data
|
||||
test_outputs = {"result": "test_output", "step": "intermediate"}
|
||||
test_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
}
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=test_outputs,
|
||||
total_tokens=100,
|
||||
node_run_steps=5,
|
||||
variables=test_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Create pause event
|
||||
event = GraphRunPausedEvent(
|
||||
reason=SchedulingPause(message="test pause"),
|
||||
outputs={"intermediate": "result"},
|
||||
)
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify pause state was saved to database
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Verify pause state exists in database
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.state_object_key != ""
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
actual_state = json.loads(storage_content)
|
||||
|
||||
assert actual_state == expected_state
|
||||
|
||||
def test_state_persistence_and_retrieval(self, db_session_with_containers):
|
||||
"""Test that pause state can be persisted and retrieved correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create complex test data
|
||||
complex_outputs = {
|
||||
"nested": {"key": "value", "number": 42},
|
||||
"list": [1, 2, 3, {"nested": "item"}],
|
||||
"boolean": True,
|
||||
"null_value": None,
|
||||
}
|
||||
complex_variables = {
|
||||
("node1", "var1"): "string_value",
|
||||
("node2", "var2"): {"complex": "object"},
|
||||
("node3", "var3"): [1, 2, 3],
|
||||
}
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=complex_outputs,
|
||||
total_tokens=250,
|
||||
node_run_steps=10,
|
||||
variables=complex_variables,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act - Save pause state
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Retrieve and verify
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
|
||||
|
||||
state_bytes = pause_entity.get_state()
|
||||
retrieved_state = json.loads(state_bytes.decode())
|
||||
expected_state = json.loads(graph_runtime_state.dumps())
|
||||
|
||||
assert retrieved_state == expected_state
|
||||
assert retrieved_state["outputs"] == complex_outputs
|
||||
assert retrieved_state["total_tokens"] == 250
|
||||
assert retrieved_state["node_run_steps"] == 10
|
||||
|
||||
def test_database_transaction_handling(self, db_session_with_containers):
|
||||
"""Test that database transactions are handled correctly."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"test": "transaction"},
|
||||
total_tokens=50,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify data is committed and accessible in new session
|
||||
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
|
||||
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
|
||||
assert workflow_run is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
pause_model = new_session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.workflow_run_id == self.test_workflow_run_id
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
def test_file_storage_integration(self, db_session_with_containers):
|
||||
"""Test integration with file storage system."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
|
||||
# Create large state data to test storage
|
||||
large_outputs = {"data": "x" * 10000} # 10KB of data
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs=large_outputs,
|
||||
total_tokens=1000,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify content in storage
|
||||
storage_content = storage.load(pause_model.state_object_key).decode()
|
||||
assert storage_content == graph_runtime_state.dumps()
|
||||
|
||||
def test_workflow_with_different_creators(self, db_session_with_containers):
|
||||
"""Test pause state with workflows created by different users."""
|
||||
# Arrange - Create workflow with different creator
|
||||
different_user_id = str(uuid.uuid4())
|
||||
different_workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=different_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
different_workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=different_workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=self.test_user_id, # Run created by different user
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
self.session.add(different_workflow)
|
||||
self.session.add(different_workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
layer = self._create_pause_state_persistence_layer(
|
||||
workflow_run=different_workflow_run,
|
||||
workflow=different_workflow,
|
||||
)
|
||||
|
||||
graph_runtime_state = self._create_graph_runtime_state(
|
||||
outputs={"creator_test": "different_creator"},
|
||||
workflow_run_id=different_workflow_run.id,
|
||||
)
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act
|
||||
layer.on_event(event)
|
||||
|
||||
# Assert - Should use workflow creator (not run creator)
|
||||
self.session.refresh(different_workflow_run)
|
||||
pause_model = self.session.scalars(
|
||||
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
|
||||
).first()
|
||||
assert pause_model is not None
|
||||
|
||||
# Verify the state owner is the workflow creator
|
||||
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
|
||||
assert pause_entity is not None
|
||||
|
||||
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
|
||||
"""Test that layer ignores non-pause events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
graph_runtime_state = self._create_graph_runtime_state()
|
||||
|
||||
command_channel = _TestCommandChannelImpl()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
# Import other event types
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
|
||||
# Act - Send non-pause events
|
||||
layer.on_event(GraphRunStartedEvent())
|
||||
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
|
||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||
|
||||
# Assert - No pause state should be created
|
||||
self.session.refresh(self.test_workflow_run)
|
||||
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
pause_states = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
assert len(pause_states) == 0
|
||||
|
||||
def test_layer_requires_initialization(self, db_session_with_containers):
|
||||
"""Test that layer requires proper initialization before handling events."""
|
||||
# Arrange
|
||||
layer = self._create_pause_state_persistence_layer()
|
||||
# Don't initialize - graph_runtime_state should not be set
|
||||
|
||||
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
|
||||
|
||||
# Act & Assert - Should raise AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
@@ -0,0 +1,948 @@
|
||||
"""Comprehensive integration tests for workflow pause functionality.
|
||||
|
||||
This test suite covers complete workflow pause functionality including:
|
||||
- Real database interactions using containerized PostgreSQL
|
||||
- Real storage operations using the test storage backend
|
||||
- Complete workflow: create -> pause -> resume -> delete
|
||||
- Testing with actual FileService (not mocked)
|
||||
- Database transactions and rollback behavior
|
||||
- Actual file upload and retrieval through storage
|
||||
- Workflow status transitions in the database
|
||||
- Error handling with real database constraints
|
||||
- Concurrent access scenarios
|
||||
- Multi-tenant isolation
|
||||
- Prune functionality
|
||||
- File storage integration
|
||||
|
||||
These tests use TestContainers to spin up real services for integration testing,
|
||||
providing more reliable and realistic test scenarios than mocks.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models import WorkflowPause as WorkflowPauseModel
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowSuccessCase:
|
||||
"""Test case for successful pause workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PauseWorkflowFailureCase:
|
||||
"""Test case for pause workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowSuccessCase:
|
||||
"""Test case for successful resume workflow operations."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResumeWorkflowFailureCase:
|
||||
"""Test case for resume workflow failure scenarios."""
|
||||
|
||||
name: str
|
||||
initial_status: WorkflowExecutionStatus
|
||||
pause_resumed: bool
|
||||
set_running_status: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrunePausesTestCase:
|
||||
"""Test case for prune pauses operations."""
|
||||
|
||||
name: str
|
||||
pause_age: timedelta
|
||||
resume_age: timedelta | None
|
||||
expected_pruned_count: int
|
||||
description: str = ""
|
||||
|
||||
|
||||
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
|
||||
"""Create test cases for pause workflow failure scenarios."""
|
||||
return [
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_already_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should fail to pause an already paused workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_completed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
description="Should fail to pause a completed workflow",
|
||||
),
|
||||
PauseWorkflowFailureCase(
|
||||
name="pause_failed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.FAILED,
|
||||
description="Should fail to pause a failed workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
|
||||
"""Create test cases for successful resume workflow operations."""
|
||||
return [
|
||||
ResumeWorkflowSuccessCase(
|
||||
name="resume_paused_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
description="Should successfully resume a paused workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
|
||||
"""Create test cases for resume workflow failure scenarios."""
|
||||
return [
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_already_resumed_workflow",
|
||||
initial_status=WorkflowExecutionStatus.PAUSED,
|
||||
pause_resumed=True,
|
||||
description="Should fail to resume an already resumed workflow",
|
||||
),
|
||||
ResumeWorkflowFailureCase(
|
||||
name="resume_running_workflow",
|
||||
initial_status=WorkflowExecutionStatus.RUNNING,
|
||||
pause_resumed=False,
|
||||
set_running_status=True,
|
||||
description="Should fail to resume a running workflow",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
|
||||
"""Create test cases for prune pauses operations."""
|
||||
return [
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_active_pauses",
|
||||
pause_age=timedelta(days=7),
|
||||
resume_age=None,
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="prune_old_resumed_pauses",
|
||||
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
|
||||
resume_age=timedelta(days=7),
|
||||
expected_pruned_count=1,
|
||||
description="Should prune old resumed pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_active_pauses",
|
||||
pause_age=timedelta(hours=1),
|
||||
resume_age=None,
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent active pauses",
|
||||
),
|
||||
PrunePausesTestCase(
|
||||
name="keep_recent_resumed_pauses",
|
||||
pause_age=timedelta(days=1),
|
||||
resume_age=timedelta(hours=1),
|
||||
expected_pruned_count=0,
|
||||
description="Should keep recent resumed pauses",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestWorkflowPauseIntegration:
|
||||
"""Comprehensive integration tests for workflow pause functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_data(self, db_session_with_containers):
|
||||
"""Set up test data for each test method using TestContainers."""
|
||||
# Create test tenant and account
|
||||
|
||||
tenant = Tenant(
|
||||
name="Test Tenant",
|
||||
status="normal",
|
||||
)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account = Account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
tenant_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(tenant_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set test data
|
||||
self.test_tenant_id = tenant.id
|
||||
self.test_user_id = account.id
|
||||
self.test_app_id = str(uuid.uuid4())
|
||||
self.test_workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create test workflow
|
||||
self.test_workflow = Workflow(
|
||||
id=self.test_workflow_id,
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=self.test_user_id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store session instance
|
||||
self.session = db_session_with_containers
|
||||
|
||||
# Save test data to database
|
||||
self.session.add(self.test_workflow)
|
||||
self.session.commit()
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
self._cleanup_test_data()
|
||||
|
||||
def _cleanup_test_data(self):
|
||||
"""Clean up test data after each test method."""
|
||||
# Clean up workflow pauses
|
||||
self.session.execute(delete(WorkflowPauseModel))
|
||||
# Clean up upload files
|
||||
self.session.execute(
|
||||
delete(UploadFile).where(
|
||||
UploadFile.tenant_id == self.test_tenant_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflow runs
|
||||
self.session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == self.test_tenant_id,
|
||||
WorkflowRun.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
# Clean up workflows
|
||||
self.session.execute(
|
||||
delete(Workflow).where(
|
||||
Workflow.tenant_id == self.test_tenant_id,
|
||||
Workflow.app_id == self.test_app_id,
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
|
||||
) -> WorkflowRun:
|
||||
"""Create a test workflow run with specified status."""
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=self.test_tenant_id,
|
||||
app_id=self.test_app_id,
|
||||
workflow_id=self.test_workflow_id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=status,
|
||||
created_by=self.test_user_id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
return workflow_run
|
||||
|
||||
def _create_test_state(self) -> str:
|
||||
"""Create a test state string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"node_id": "test-node",
|
||||
"node_type": "llm",
|
||||
"status": "paused",
|
||||
"data": {"key": "value"},
|
||||
"timestamp": naive_utc_now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_workflow_run_repository(self):
|
||||
"""Get workflow run repository instance for testing."""
|
||||
# Create session factory from the test session
|
||||
engine = self.session.get_bind()
|
||||
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
# Create a test-specific repository that implements the missing save method
|
||||
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test-specific repository that implements the missing save method."""
|
||||
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""Implement the missing save method for testing."""
|
||||
# For testing purposes, we don't need to implement this method
|
||||
# as it's not used in the pause functionality tests
|
||||
pass
|
||||
|
||||
# Create and return repository instance
|
||||
repository = TestWorkflowRunRepository(session_maker=session_factory)
|
||||
return repository
|
||||
|
||||
# ==================== Complete Pause Workflow Tests ====================
|
||||
|
||||
def test_complete_pause_resume_workflow(self):
|
||||
"""Test complete workflow: create -> pause -> resume -> delete."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Pause state created
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.id is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
# Convert both to strings for comparison
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Verify database state
|
||||
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.resumed_at is None
|
||||
assert pause_model.id == pause_entity.id
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
# Act - Get pause state
|
||||
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
|
||||
|
||||
# Assert - Pause state retrieved
|
||||
assert retrieved_entity is not None
|
||||
assert retrieved_entity.id == pause_entity.id
|
||||
retrieved_state = retrieved_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
# Act - Resume workflow
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert - Workflow resumed
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify database state
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
self.session.refresh(pause_model)
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause state deleted
|
||||
with Session(bind=self.session.get_bind()) as session:
|
||||
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
def test_pause_workflow_success(self):
|
||||
"""Test successful pause workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
assert pause_entity is not None
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is None
|
||||
|
||||
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
|
||||
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
|
||||
"""Test pause workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
|
||||
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
|
||||
"""Test successful resume workflow scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
|
||||
pause_model = self.session.scalars(pause_query).first()
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
def test_resume_running_workflow(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
self.session.refresh(workflow_run)
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.add(workflow_run)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_resumed_pause(self):
|
||||
"""Test resume workflow failure scenarios."""
|
||||
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
self.session.add(pause_model)
|
||||
self.session.commit()
|
||||
|
||||
with pytest.raises(_WorkflowRunError):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Error Scenario Tests ====================
|
||||
|
||||
def test_pause_nonexistent_workflow_run(self):
|
||||
"""Test pausing a non-existent workflow run."""
|
||||
# Arrange
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
def test_resume_nonexistent_workflow_run(self):
|
||||
"""Test resuming a non-existent workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
nonexistent_id = str(uuid.uuid4())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=nonexistent_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# ==================== Prune Functionality Tests ====================
|
||||
|
||||
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
|
||||
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
|
||||
"""Test various prune pauses scenarios."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create pause state
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Manually adjust timestamps for testing
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - test_case.pause_age
|
||||
|
||||
if test_case.resume_age is not None:
|
||||
# Resume pause and adjust resume time
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
# Need to refresh to get the updated model
|
||||
self.session.refresh(pause_model)
|
||||
# Manually set the resumed_at to an older time for testing
|
||||
pause_model.resumed_at = now - test_case.resume_age
|
||||
self.session.commit() # Commit the resumed_at change
|
||||
# Refresh again to ensure the change is persisted
|
||||
self.session.refresh(pause_model)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune pauses
|
||||
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
|
||||
resumption_time = now - timedelta(
|
||||
days=7, seconds=1
|
||||
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
|
||||
|
||||
# Debug: Check pause state before pruning
|
||||
self.session.refresh(pause_model)
|
||||
print(f"Pause created_at: {pause_model.created_at}")
|
||||
print(f"Pause resumed_at: {pause_model.resumed_at}")
|
||||
print(f"Expiration time: {expiration_time}")
|
||||
print(f"Resumption time: {resumption_time}")
|
||||
|
||||
# Force commit to ensure timestamps are saved
|
||||
self.session.commit()
|
||||
|
||||
# Determine if the pause should be pruned based on timestamps
|
||||
should_be_pruned = False
|
||||
if test_case.resume_age is not None:
|
||||
# If resumed, check if resumed_at is older than resumption_time
|
||||
should_be_pruned = pause_model.resumed_at < resumption_time
|
||||
else:
|
||||
# If not resumed, check if created_at is older than expiration_time
|
||||
should_be_pruned = pause_model.created_at < expiration_time
|
||||
|
||||
# Act - Prune pauses
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
)
|
||||
|
||||
# Assert - Check pruning results
|
||||
if should_be_pruned:
|
||||
assert len(pruned_ids) == test_case.expected_pruned_count
|
||||
# Verify pause was actually deleted
|
||||
# The pause should be in the pruned_ids list if it was pruned
|
||||
assert pause_entity.id in pruned_ids
|
||||
else:
|
||||
assert len(pruned_ids) == 0
|
||||
|
||||
def test_prune_pauses_with_limit(self):
|
||||
"""Test prune pauses with limit parameter."""
|
||||
now = naive_utc_now()
|
||||
|
||||
# Create multiple pause states
|
||||
pause_entities = []
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
for i in range(5):
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
pause_entities.append(pause_entity)
|
||||
|
||||
# Make all pauses old enough to be pruned
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
pause_model.created_at = now - timedelta(days=7)
|
||||
|
||||
self.session.commit()
|
||||
|
||||
# Act - Prune with limit
|
||||
expiration_time = now - timedelta(days=1)
|
||||
resumption_time = now - timedelta(days=7)
|
||||
|
||||
pruned_ids = repository.prune_pauses(
|
||||
expiration=expiration_time,
|
||||
resumption_expiration=resumption_time,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(pruned_ids) == 3
|
||||
|
||||
# Verify only 3 were deleted
|
||||
remaining_count = (
|
||||
self.session.query(WorkflowPauseModel)
|
||||
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
|
||||
.count()
|
||||
)
|
||||
assert remaining_count == 2
|
||||
|
||||
# ==================== Multi-tenant Isolation Tests ====================
|
||||
|
||||
def test_multi_tenant_pause_isolation(self):
|
||||
"""Test that pause states are properly isolated by tenant."""
|
||||
# Arrange - Create second tenant
|
||||
|
||||
tenant2 = Tenant(
|
||||
name="Test Tenant 2",
|
||||
status="normal",
|
||||
)
|
||||
self.session.add(tenant2)
|
||||
self.session.commit()
|
||||
|
||||
account2 = Account(
|
||||
email="test2@example.com",
|
||||
name="Test User 2",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
self.session.add(account2)
|
||||
self.session.commit()
|
||||
|
||||
tenant2_join = TenantAccountJoin(
|
||||
tenant_id=tenant2.id,
|
||||
account_id=account2.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
self.session.add(tenant2_join)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow for tenant 2
|
||||
workflow2 = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph='{"nodes": [], "edges": []}',
|
||||
features='{"file_upload": {"enabled": false}}',
|
||||
created_by=account2.id,
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow2)
|
||||
self.session.commit()
|
||||
|
||||
# Create workflow runs for both tenants
|
||||
workflow_run1 = self._create_test_workflow_run()
|
||||
workflow_run2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
app_id=workflow2.app_id,
|
||||
workflow_id=workflow2.id,
|
||||
type="workflow",
|
||||
triggered_from="debugging",
|
||||
version="draft",
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_by=account2.id,
|
||||
created_by_role="account",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
self.session.add(workflow_run2)
|
||||
self.session.commit()
|
||||
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause for tenant 1
|
||||
pause_entity1 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run1.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Try to access pause from tenant 2 using tenant 1's repository
|
||||
# This should work because we're using the same repository
|
||||
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
|
||||
assert pause_entity2 is None # No pause for tenant 2 yet
|
||||
|
||||
# Create pause for tenant 2
|
||||
pause_entity2 = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run2.id,
|
||||
state_owner_user_id=account2.id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Both pauses should exist and be separate
|
||||
assert pause_entity1 is not None
|
||||
assert pause_entity2 is not None
|
||||
assert pause_entity1.id != pause_entity2.id
|
||||
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
|
||||
|
||||
def test_cross_tenant_access_restriction(self):
|
||||
"""Test that cross-tenant access is properly restricted."""
|
||||
# This test would require tenant-specific repositories
|
||||
# For now, we test that pause entities are properly scoped by tenant_id
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Verify pause is properly scoped
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.workflow_id == self.test_workflow_id
|
||||
|
||||
# ==================== File Storage Integration Tests ====================
|
||||
|
||||
def test_file_storage_integration(self):
|
||||
"""Test that state files are properly stored and retrieved."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act - Create pause state
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Assert - Verify file was uploaded to storage
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Verify file content in storage
|
||||
|
||||
file_key = pause_model.state_object_key
|
||||
storage_content = storage.load(file_key).decode()
|
||||
assert storage_content == test_state
|
||||
|
||||
# Verify retrieval through entity
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == test_state
|
||||
|
||||
def test_file_cleanup_on_pause_deletion(self):
|
||||
"""Test that files are properly handled on pause deletion."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
test_state = self._create_test_state()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=test_state,
|
||||
)
|
||||
|
||||
# Get file info before deletion
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
file_key = pause_model.state_object_key
|
||||
|
||||
# Act - Delete pause state
|
||||
repository.delete_workflow_pause(pause_entity)
|
||||
|
||||
# Assert - Pause record should be deleted
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert deleted_pause is None
|
||||
|
||||
try:
|
||||
content = storage.load(file_key).decode()
|
||||
pytest.fail("File should be deleted from storage after pause deletion")
|
||||
except FileNotFoundError:
|
||||
# This is expected - file should be deleted from storage
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected error when checking file deletion: {e}")
|
||||
|
||||
def test_large_state_file_handling(self):
|
||||
"""Test handling of large state files."""
|
||||
# Arrange - Create a large state (1MB)
|
||||
large_state = "x" * (1024 * 1024) # 1MB of data
|
||||
large_state_json = json.dumps({"large_data": large_state})
|
||||
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=large_state_json,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert pause_entity is not None
|
||||
retrieved_state = pause_entity.get_state()
|
||||
if isinstance(retrieved_state, bytes):
|
||||
retrieved_state = retrieved_state.decode()
|
||||
assert retrieved_state == large_state_json
|
||||
|
||||
# Verify file size in database
|
||||
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
|
||||
assert pause_model.state_object_key != ""
|
||||
loaded_state = storage.load(pause_model.state_object_key)
|
||||
assert loaded_state.decode() == large_state_json
|
||||
|
||||
def test_multiple_pause_resume_cycles(self):
|
||||
"""Test multiple pause/resume cycles on the same workflow run."""
|
||||
# Arrange
|
||||
workflow_run = self._create_test_workflow_run()
|
||||
repository = self._get_workflow_run_repository()
|
||||
|
||||
# Act & Assert - Multiple cycles
|
||||
for i in range(3):
|
||||
state = json.dumps({"cycle": i, "data": f"state_{i}"})
|
||||
|
||||
# Reset workflow run status to RUNNING before each pause (after first cycle)
|
||||
if i > 0:
|
||||
self.session.refresh(workflow_run) # Refresh to get latest state from session
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
self.session.commit()
|
||||
self.session.refresh(workflow_run) # Refresh again after commit
|
||||
|
||||
# Pause
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=self.test_user_id,
|
||||
state=state,
|
||||
)
|
||||
assert pause_entity is not None
|
||||
|
||||
# Verify pause
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
self.session.refresh(workflow_run)
|
||||
|
||||
# Use the test session directly to verify the pause
|
||||
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
|
||||
workflow_run_with_pause = self.session.scalar(stmt)
|
||||
pause_model = workflow_run_with_pause.pause
|
||||
|
||||
# Verify pause using test session directly
|
||||
assert pause_model is not None
|
||||
assert pause_model.id == pause_entity.id
|
||||
assert pause_model.state_object_key != ""
|
||||
|
||||
# Load file content using storage directly
|
||||
file_content = storage.load(pause_model.state_object_key)
|
||||
if isinstance(file_content, bytes):
|
||||
file_content = file_content.decode()
|
||||
assert file_content == state
|
||||
|
||||
# Resume
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
assert resumed_entity is not None
|
||||
assert resumed_entity.resumed_at is not None
|
||||
|
||||
# Verify resume - check that pause is marked as resumed
|
||||
self.session.expire_all() # Clear session to ensure fresh query
|
||||
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
|
||||
resumed_pause_model = self.session.scalar(stmt)
|
||||
assert resumed_pause_model is not None
|
||||
assert resumed_pause_model.resumed_at is not None
|
||||
|
||||
# Verify workflow run status
|
||||
self.session.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
@@ -0,0 +1,278 @@
|
||||
import json
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
|
||||
from core.workflow.graph_events.graph import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory helpers for constructing graph events used in tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
|
||||
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_started_event() -> GraphRunStartedEvent:
|
||||
return GraphRunStartedEvent()
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
|
||||
return GraphRunSucceededEvent(outputs=outputs or {})
|
||||
|
||||
@staticmethod
|
||||
def create_graph_run_failed_event(
|
||||
error: str = "Test error",
|
||||
exceptions_count: int = 1,
|
||||
) -> GraphRunFailedEvent:
|
||||
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
|
||||
|
||||
|
||||
class MockSystemVariableReadOnlyView:
|
||||
"""Minimal read-only system variable view for testing."""
|
||||
|
||||
def __init__(self, workflow_execution_id: str | None = None) -> None:
|
||||
self._workflow_execution_id = workflow_execution_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str | None:
|
||||
return self._workflow_execution_id
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
"""Mock implementation of ReadOnlyVariablePool for testing."""
|
||||
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
mock_segment.value = value
|
||||
return mock_segment
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
|
||||
|
||||
|
||||
class MockReadOnlyGraphRuntimeState:
|
||||
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_at: float | None = None,
|
||||
total_tokens: int = 0,
|
||||
node_run_steps: int = 0,
|
||||
ready_queue_size: int = 0,
|
||||
exceptions_count: int = 0,
|
||||
outputs: dict[str, object] | None = None,
|
||||
variables: dict[tuple[str, str], object] | None = None,
|
||||
workflow_execution_id: str | None = None,
|
||||
):
|
||||
self._start_at = start_at or time()
|
||||
self._total_tokens = total_tokens
|
||||
self._node_run_steps = node_run_steps
|
||||
self._ready_queue_size = ready_queue_size
|
||||
self._exceptions_count = exceptions_count
|
||||
self._outputs = outputs or {}
|
||||
self._variable_pool = MockReadOnlyVariablePool(variables)
|
||||
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
|
||||
|
||||
@property
|
||||
def system_variable(self) -> MockSystemVariableReadOnlyView:
|
||||
return self._system_variable
|
||||
|
||||
@property
|
||||
def variable_pool(self) -> ReadOnlyVariablePool:
|
||||
return self._variable_pool
|
||||
|
||||
@property
|
||||
def start_at(self) -> float:
|
||||
return self._start_at
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def node_run_steps(self) -> int:
|
||||
return self._node_run_steps
|
||||
|
||||
@property
|
||||
def ready_queue_size(self) -> int:
|
||||
return self._ready_queue_size
|
||||
|
||||
@property
|
||||
def exceptions_count(self) -> int:
|
||||
return self._exceptions_count
|
||||
|
||||
@property
|
||||
def outputs(self) -> dict[str, object]:
|
||||
return self._outputs.copy()
|
||||
|
||||
@property
|
||||
def llm_usage(self):
|
||||
mock_usage = Mock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 20
|
||||
mock_usage.total_tokens = 30
|
||||
return mock_usage
|
||||
|
||||
def get_output(self, key: str, default: object = None) -> object:
|
||||
return self._outputs.get(key, default)
|
||||
|
||||
def dumps(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"start_at": self._start_at,
|
||||
"total_tokens": self._total_tokens,
|
||||
"node_run_steps": self._node_run_steps,
|
||||
"ready_queue_size": self._ready_queue_size,
|
||||
"exceptions_count": self._exceptions_count,
|
||||
"outputs": self._outputs,
|
||||
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
|
||||
"workflow_execution_id": self._system_variable.workflow_execution_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class MockCommandChannel:
|
||||
"""Mock implementation of CommandChannel for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self._commands: list[GraphEngineCommand] = []
|
||||
|
||||
def fetch_commands(self) -> list[GraphEngineCommand]:
|
||||
return self._commands.copy()
|
||||
|
||||
def send_command(self, command: GraphEngineCommand) -> None:
|
||||
self._commands.append(command)
|
||||
|
||||
|
||||
class TestPauseStatePersistenceLayer:
|
||||
"""Unit tests for PauseStatePersistenceLayer."""
|
||||
|
||||
def test_init_with_dependency_injection(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
state_owner_user_id = "user-123"
|
||||
|
||||
layer = PauseStatePersistenceLayer(
|
||||
session_factory=session_factory,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
)
|
||||
|
||||
assert layer._session_maker is session_factory
|
||||
assert layer._state_owner_user_id == state_owner_user_id
|
||||
assert not hasattr(layer, "graph_runtime_state")
|
||||
assert not hasattr(layer, "command_channel")
|
||||
|
||||
def test_initialize_sets_dependencies(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
assert layer.graph_runtime_state is graph_runtime_state
|
||||
assert layer.command_channel is command_channel
|
||||
|
||||
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(
|
||||
outputs={"result": "test_output"},
|
||||
total_tokens=100,
|
||||
workflow_execution_id="run-123",
|
||||
)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
|
||||
expected_state = graph_runtime_state.dumps()
|
||||
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_called_once_with(session_factory)
|
||||
mock_repo.create_workflow_pause.assert_called_once_with(
|
||||
workflow_run_id="run-123",
|
||||
state_owner_user_id="owner-123",
|
||||
state=expected_state,
|
||||
)
|
||||
|
||||
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState()
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
events = [
|
||||
TestDataFactory.create_graph_run_started_event(),
|
||||
TestDataFactory.create_graph_run_succeeded_event(),
|
||||
TestDataFactory.create_graph_run_failed_event(),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
|
||||
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
layer.on_event(event)
|
||||
|
||||
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
|
||||
session_factory = Mock(name="session_factory")
|
||||
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
|
||||
|
||||
mock_repo = Mock()
|
||||
mock_factory = Mock(return_value=mock_repo)
|
||||
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
|
||||
|
||||
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
|
||||
command_channel = MockCommandChannel()
|
||||
layer.initialize(graph_runtime_state, command_channel)
|
||||
|
||||
event = TestDataFactory.create_graph_run_paused_event()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer.on_event(event)
|
||||
|
||||
mock_factory.assert_not_called()
|
||||
mock_repo.create_workflow_pause.assert_not_called()
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Tests for _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity implementation."""
|
||||
|
||||
def test_entity_initialization(self):
|
||||
"""Test entity initialization with required parameters."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
# Create entity
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify initialization
|
||||
assert entity._pause_model is mock_pause_model
|
||||
assert entity._cached_state is None
|
||||
|
||||
def test_from_models_classmethod(self):
|
||||
"""Test from_models class method."""
|
||||
# Create mock models
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
# Create entity using from_models
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(
|
||||
workflow_pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Verify entity creation
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model is mock_pause_model
|
||||
|
||||
def test_id_property(self):
|
||||
"""Test id property returns pause model ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.id = "pause-123"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.id == "pause-123"
|
||||
|
||||
def test_workflow_execution_id_property(self):
|
||||
"""Test workflow_execution_id property returns workflow run ID."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.workflow_run_id = "execution-456"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.workflow_execution_id == "execution-456"
|
||||
|
||||
def test_resumed_at_property(self):
|
||||
"""Test resumed_at property returns pause model resumed_at."""
|
||||
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = resumed_at
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at == resumed_at
|
||||
|
||||
def test_resumed_at_property_none(self):
|
||||
"""Test resumed_at property returns None when not set."""
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.resumed_at = None
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call should load from storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
mock_storage.load.return_value = state_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
mock_pause_model.state_object_key = "test-state-key"
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# First call
|
||||
result1 = entity.get_state()
|
||||
# Second call should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
assert result1 == state_data
|
||||
assert result2 == state_data
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
# Pre-cache data
|
||||
entity._cached_state = state_data
|
||||
|
||||
# Should return cached data without calling storage
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == state_data
|
||||
mock_storage.load.assert_not_called()
|
||||
|
||||
def test_entity_with_binary_state_data(self):
|
||||
"""Test entity with binary state data."""
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
entity = _PrivateWorkflowPauseEntity(
|
||||
pause_model=mock_pause_model,
|
||||
)
|
||||
|
||||
result = entity.get_state()
|
||||
|
||||
assert result == binary_data
|
||||
@@ -3,6 +3,7 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@@ -149,8 +150,8 @@ def test_pause_command():
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
|
||||
assert len(pause_events) == 1
|
||||
assert pause_events[0].reason == "User requested pause"
|
||||
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.is_paused
|
||||
assert graph_execution.pause_reason == "User requested pause"
|
||||
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")
|
||||
|
||||
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
32
api/tests/unit_tests/core/workflow/test_enums.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Tests for workflow pause related enums and constants."""
|
||||
|
||||
from core.workflow.enums import (
|
||||
WorkflowExecutionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkflowExecutionStatus:
|
||||
"""Test WorkflowExecutionStatus enum."""
|
||||
|
||||
def test_is_ended_method(self):
|
||||
"""Test is_ended method for different statuses."""
|
||||
# Test ended statuses
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
]
|
||||
|
||||
for status in ended_statuses:
|
||||
assert status.is_ended(), f"{status} should be considered ended"
|
||||
|
||||
# Test non-ended statuses
|
||||
non_ended_statuses = [
|
||||
WorkflowExecutionStatus.SCHEDULED,
|
||||
WorkflowExecutionStatus.RUNNING,
|
||||
WorkflowExecutionStatus.PAUSED,
|
||||
]
|
||||
|
||||
for status in non_ended_statuses:
|
||||
assert not status.is_ended(), f"{status} should not be considered ended"
|
||||
@@ -0,0 +1,202 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
|
||||
|
||||
|
||||
class TestSystemVariableReadOnlyView:
|
||||
"""Test cases for SystemVariableReadOnlyView class."""
|
||||
|
||||
def test_read_only_property_access(self):
|
||||
"""Test that all properties return correct values from wrapped instance."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "value", "nested": {"data": 42}}
|
||||
|
||||
# Create SystemVariable with all fields
|
||||
system_var = SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app-123",
|
||||
workflow_id="workflow-123",
|
||||
files=[test_file],
|
||||
workflow_execution_id="exec-123",
|
||||
query="test query",
|
||||
conversation_id="conv-123",
|
||||
dialogue_count=5,
|
||||
document_id="doc-123",
|
||||
original_document_id="orig-doc-123",
|
||||
dataset_id="dataset-123",
|
||||
batch="batch-123",
|
||||
datasource_type="type-123",
|
||||
datasource_info=datasource_info,
|
||||
invoke_from="invoke-123",
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all properties
|
||||
assert read_only_view.user_id == "user-123"
|
||||
assert read_only_view.app_id == "app-123"
|
||||
assert read_only_view.workflow_id == "workflow-123"
|
||||
assert read_only_view.workflow_execution_id == "exec-123"
|
||||
assert read_only_view.query == "test query"
|
||||
assert read_only_view.conversation_id == "conv-123"
|
||||
assert read_only_view.dialogue_count == 5
|
||||
assert read_only_view.document_id == "doc-123"
|
||||
assert read_only_view.original_document_id == "orig-doc-123"
|
||||
assert read_only_view.dataset_id == "dataset-123"
|
||||
assert read_only_view.batch == "batch-123"
|
||||
assert read_only_view.datasource_type == "type-123"
|
||||
assert read_only_view.invoke_from == "invoke-123"
|
||||
|
||||
def test_defensive_copying_of_mutable_objects(self):
|
||||
"""Test that mutable objects are defensively copied."""
|
||||
# Create test data
|
||||
test_file = File(
|
||||
id="file-123",
|
||||
tenant_id="tenant-123",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="related-123",
|
||||
)
|
||||
|
||||
datasource_info = {"key": "original_value"}
|
||||
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(
|
||||
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files defensive copying
|
||||
files_copy = read_only_view.files
|
||||
assert isinstance(files_copy, tuple) # Should be immutable tuple
|
||||
assert len(files_copy) == 1
|
||||
assert files_copy[0].id == "file-123"
|
||||
|
||||
# Verify it's a copy (can't modify original through view)
|
||||
assert isinstance(files_copy, tuple)
|
||||
# tuples don't have append method, so they're immutable
|
||||
|
||||
# Test datasource_info defensive copying
|
||||
datasource_copy = read_only_view.datasource_info
|
||||
assert datasource_copy is not None
|
||||
assert datasource_copy["key"] == "original_value"
|
||||
|
||||
datasource_copy = cast(dict, datasource_copy)
|
||||
with pytest.raises(TypeError):
|
||||
datasource_copy["key"] = "modified value"
|
||||
|
||||
# Verify original is unchanged
|
||||
assert system_var.datasource_info is not None
|
||||
assert system_var.datasource_info["key"] == "original_value"
|
||||
assert read_only_view.datasource_info is not None
|
||||
assert read_only_view.datasource_info["key"] == "original_value"
|
||||
|
||||
def test_always_accesses_latest_data(self):
|
||||
"""Test that properties always return the latest data from wrapped instance."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Verify initial value
|
||||
assert read_only_view.user_id == "original-user"
|
||||
|
||||
# Modify the wrapped instance
|
||||
system_var.user_id = "modified-user"
|
||||
|
||||
# Verify view returns the new value
|
||||
assert read_only_view.user_id == "modified-user"
|
||||
|
||||
def test_repr_method(self):
|
||||
"""Test the __repr__ method."""
|
||||
# Create SystemVariable
|
||||
system_var = SystemVariable(workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test repr
|
||||
repr_str = repr(read_only_view)
|
||||
assert "SystemVariableReadOnlyView" in repr_str
|
||||
assert "system_variable=" in repr_str
|
||||
|
||||
def test_none_value_handling(self):
|
||||
"""Test that None values are properly handled."""
|
||||
# Create SystemVariable with all None values except workflow_execution_id
|
||||
system_var = SystemVariable(
|
||||
user_id=None,
|
||||
app_id=None,
|
||||
workflow_id=None,
|
||||
workflow_execution_id="exec-123",
|
||||
query=None,
|
||||
conversation_id=None,
|
||||
dialogue_count=None,
|
||||
document_id=None,
|
||||
original_document_id=None,
|
||||
dataset_id=None,
|
||||
batch=None,
|
||||
datasource_type=None,
|
||||
datasource_info=None,
|
||||
invoke_from=None,
|
||||
)
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test all None values
|
||||
assert read_only_view.user_id is None
|
||||
assert read_only_view.app_id is None
|
||||
assert read_only_view.workflow_id is None
|
||||
assert read_only_view.query is None
|
||||
assert read_only_view.conversation_id is None
|
||||
assert read_only_view.dialogue_count is None
|
||||
assert read_only_view.document_id is None
|
||||
assert read_only_view.original_document_id is None
|
||||
assert read_only_view.dataset_id is None
|
||||
assert read_only_view.batch is None
|
||||
assert read_only_view.datasource_type is None
|
||||
assert read_only_view.datasource_info is None
|
||||
assert read_only_view.invoke_from is None
|
||||
|
||||
# files should be empty tuple even when default list is empty
|
||||
assert read_only_view.files == ()
|
||||
|
||||
def test_empty_files_handling(self):
|
||||
"""Test that empty files list is handled correctly."""
|
||||
# Create SystemVariable with empty files
|
||||
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test files handling
|
||||
assert read_only_view.files == ()
|
||||
assert isinstance(read_only_view.files, tuple)
|
||||
|
||||
def test_empty_datasource_info_handling(self):
|
||||
"""Test that empty datasource_info is handled correctly."""
|
||||
# Create SystemVariable with empty datasource_info
|
||||
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
|
||||
|
||||
# Create read-only view
|
||||
read_only_view = SystemVariableReadOnlyView(system_var)
|
||||
|
||||
# Test datasource_info handling
|
||||
assert read_only_view.datasource_info == {}
|
||||
# Should be a copy, not the same object
|
||||
assert read_only_view.datasource_info is not system_var.datasource_info
|
||||
11
api/tests/unit_tests/models/test_base.py
Normal file
11
api/tests/unit_tests/models/test_base.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from models.base import DefaultFieldsMixin
|
||||
|
||||
|
||||
class FooModel(DefaultFieldsMixin):
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
|
||||
def test_repr():
|
||||
foo_model = FooModel(id="test-id")
|
||||
assert repr(foo_model) == "<FooModel(id=test-id)>"
|
||||
@@ -0,0 +1,370 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_session):
|
||||
"""Create a mock sessionmaker."""
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
|
||||
# Create a context manager mock
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_context_manager = Mock()
|
||||
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||
|
||||
# Add missing session methods
|
||||
mock_session.commit = Mock()
|
||||
mock_session.rollback = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.delete = Mock()
|
||||
mock_session.get = Mock()
|
||||
mock_session.scalar = Mock()
|
||||
mock_session.scalars = Mock()
|
||||
|
||||
# Also support expire_on_commit parameter
|
||||
def make_session(expire_on_commit=None):
|
||||
cm = Mock()
|
||||
cm.__enter__ = Mock(return_value=mock_session)
|
||||
cm.__exit__ = Mock(return_value=None)
|
||||
return cm
|
||||
|
||||
session_maker.side_effect = make_session
|
||||
return session_maker
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mocked dependencies."""
|
||||
|
||||
# Create a testable subclass that implements the save method
|
||||
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def __init__(self, session_maker):
|
||||
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||
self._session_maker = session_maker
|
||||
|
||||
def save(self, execution):
|
||||
"""Mock implementation of save method."""
|
||||
return None
|
||||
|
||||
# Create repository instance
|
||||
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_run(self):
|
||||
"""Create a sample WorkflowRun model."""
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.id = "workflow-run-123"
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.app_id = "app-123"
|
||||
workflow_run.workflow_id = "workflow-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
return workflow_run
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause(self):
|
||||
"""Create a sample WorkflowPauseModel."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test successful workflow pause creation."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
state_owner_user_id = "user-123"
|
||||
state = '{"test": "state"}'
|
||||
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||
mock_uuidv7.side_effect = ["pause-123"]
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
result = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow run not found."""
|
||||
# Arrange
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
)
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause resume."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
# Setup workflow run and pause
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
|
||||
# Act
|
||||
result = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
|
||||
# Verify state transitions
|
||||
assert sample_workflow_pause.resumed_at is not None
|
||||
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test resume when workflow is not paused."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test resume when pause ID doesn't match."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-456" # Different ID
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_pause.id = "pause-123"
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test delete_workflow_pause method."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause deletion."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = sample_workflow_pause
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
):
|
||||
"""Test delete when pause not found."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_from_models(self, sample_workflow_pause: Mock):
|
||||
"""Test creating _PrivateWorkflowPauseEntity from models."""
|
||||
# Act
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Assert
|
||||
assert isinstance(entity, _PrivateWorkflowPauseEntity)
|
||||
assert entity._pause_model == sample_workflow_pause
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
|
||||
# Act & Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
|
||||
expected_state = b'{"test": "state"}'
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
mock_storage.load.return_value = expected_state
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state() # Should use cache
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
200
api/tests/unit_tests/services/test_workflow_run_service_pause.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Comprehensive unit tests for WorkflowRunService class.
|
||||
|
||||
This test suite covers all pause state management operations including:
|
||||
- Retrieving pause state for workflow runs
|
||||
- Saving pause state with file uploads
|
||||
- Marking paused workflows as resumed
|
||||
- Error handling and edge cases
|
||||
- Database transaction management
|
||||
- Repository-based approach testing
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
|
||||
from services.workflow_run_service import (
|
||||
WorkflowRunService,
|
||||
)
|
||||
|
||||
|
||||
class TestDataFactory:
|
||||
"""Factory class for creating test data objects."""
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_run_mock(
|
||||
id: str = "workflow-run-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
mock_run = MagicMock()
|
||||
mock_run.id = id
|
||||
mock_run.tenant_id = tenant_id
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
return mock_run
|
||||
|
||||
@staticmethod
|
||||
def create_workflow_pause_mock(
|
||||
id: str = "pause-123",
|
||||
tenant_id: str = "tenant-456",
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
workflow_execution_id: str = "workflow-execution-123",
|
||||
state_file_id: str = "file-456",
|
||||
resumed_at: datetime | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowPauseModel object."""
|
||||
mock_pause = MagicMock()
|
||||
mock_pause.id = id
|
||||
mock_pause.tenant_id = tenant_id
|
||||
mock_pause.app_id = app_id
|
||||
mock_pause.workflow_id = workflow_id
|
||||
mock_pause.workflow_execution_id = workflow_execution_id
|
||||
mock_pause.state_file_id = state_file_id
|
||||
mock_pause.resumed_at = resumed_at
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_pause, key, value)
|
||||
|
||||
return mock_pause
|
||||
|
||||
@staticmethod
|
||||
def create_upload_file_mock(
|
||||
id: str = "file-456",
|
||||
key: str = "upload_files/test/state.json",
|
||||
name: str = "state.json",
|
||||
tenant_id: str = "tenant-456",
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock UploadFile object."""
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = id
|
||||
mock_file.key = key
|
||||
mock_file.name = name
|
||||
mock_file.tenant_id = tenant_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_file, key, value)
|
||||
|
||||
return mock_file
|
||||
|
||||
@staticmethod
|
||||
def create_pause_entity_mock(
|
||||
pause_model: MagicMock | None = None,
|
||||
upload_file: MagicMock | None = None,
|
||||
) -> _PrivateWorkflowPauseEntity:
|
||||
"""Create a mock _PrivateWorkflowPauseEntity object."""
|
||||
if pause_model is None:
|
||||
pause_model = TestDataFactory.create_workflow_pause_mock()
|
||||
if upload_file is None:
|
||||
upload_file = TestDataFactory.create_upload_file_mock()
|
||||
|
||||
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
|
||||
|
||||
|
||||
class TestWorkflowRunService:
|
||||
"""Comprehensive unit tests for WorkflowRunService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(self):
|
||||
"""Create a mock session factory with proper session management."""
|
||||
mock_session = create_autospec(Session)
|
||||
|
||||
# Create a mock context manager for the session
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Create a mock context manager for the transaction
|
||||
mock_transaction_cm = MagicMock()
|
||||
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
|
||||
|
||||
# Create mock factory that returns the context manager
|
||||
mock_factory = MagicMock(spec=sessionmaker)
|
||||
mock_factory.return_value = mock_session_cm
|
||||
|
||||
return mock_factory, mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_run_repository(self):
|
||||
"""Create a mock APIWorkflowRunRepository."""
|
||||
mock_repo = create_autospec(APIWorkflowRunRepository)
|
||||
return mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Create WorkflowRunService instance with Engine input."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
|
||||
# ==================== Initialization Tests ====================
|
||||
|
||||
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
|
||||
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
assert service._session_factory == session_factory
|
||||
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
|
||||
|
||||
def test_init_with_default_dependencies(self, mock_session_factory):
|
||||
"""Test WorkflowRunService initialization with default dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
assert service._session_factory == session_factory
|
||||
Reference in New Issue
Block a user