feat(api): Introduce workflow pause state management (#27298)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
QuantumGhost
2025-10-30 14:41:09 +08:00
committed by GitHub
parent fd7c4e8a6d
commit a1c0bd7a1c
43 changed files with 3834 additions and 44 deletions

View File

@@ -0,0 +1 @@
# Core integration tests package

View File

@@ -0,0 +1 @@
# App integration tests package

View File

@@ -0,0 +1 @@
# Layers integration tests package

View File

@@ -0,0 +1,520 @@
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class.
This test suite covers complete integration scenarios including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using test storage backend
- Complete workflow: event -> state serialization -> database save -> storage save
- Testing with actual WorkflowRunService (not mocked)
- Real Workflow and WorkflowRun instances in database
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in database
- Error handling with real database constraints
- Multiple pause events in sequence
- Integration with real ReadOnlyGraphRuntimeState implementations
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from time import time
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.runtime.variable_pool import SystemVariable, VariablePool
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from services.file_service import FileService
from services.workflow_run_service import WorkflowRunService
class _TestCommandChannelImpl:
"""Real implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
"""Fetch pending commands for this GraphEngine instance."""
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
"""Send a command to be processed by this GraphEngine instance."""
self._commands.append(command)
class TestPauseStatePersistenceLayerTestContainers:
"""Comprehensive TestContainers-based integration tests for PauseStatePersistenceLayer class."""
@pytest.fixture
def engine(self, db_session_with_containers: Session):
"""Get database engine from TestContainers session."""
bind = db_session_with_containers.get_bind()
assert isinstance(bind, Engine)
return bind
@pytest.fixture
def file_service(self, engine: Engine):
"""Create FileService instance with TestContainers engine."""
return FileService(engine)
@pytest.fixture
def workflow_run_service(self, engine: Engine, file_service: FileService):
"""Create WorkflowRunService instance with TestContainers engine and FileService."""
return WorkflowRunService(engine)
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
self.test_workflow_run_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Create test workflow run
self.test_workflow_run = WorkflowRun(
id=self.test_workflow_run_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
# Store session and service instances
self.session = db_session_with_containers
self.file_service = file_service
self.workflow_run_service = workflow_run_service
# Save test data to database
self.session.add(self.test_workflow)
self.session.add(self.test_workflow_run)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
try:
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
except Exception as e:
self.session.rollback()
raise e
def _create_graph_runtime_state(
self,
outputs: dict[str, object] | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
variables: dict[tuple[str, str], object] | None = None,
workflow_run_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
"""Create a real GraphRuntimeState for testing."""
start_at = time()
execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4())
# Create variable pool
variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id))
if variables:
for (node_id, var_key), value in variables.items():
variable_pool.add([node_id, var_key], value)
# Create LLM usage
llm_usage = LLMUsage.empty_usage()
# Create graph runtime state
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=start_at,
total_tokens=total_tokens,
llm_usage=llm_usage,
outputs=outputs or {},
node_run_steps=node_run_steps,
)
return ReadOnlyGraphRuntimeStateWrapper(graph_runtime_state)
def _create_pause_state_persistence_layer(
self,
workflow_run: WorkflowRun | None = None,
workflow: Workflow | None = None,
state_owner_user_id: str | None = None,
) -> PauseStatePersistenceLayer:
"""Create PauseStatePersistenceLayer with real dependencies."""
owner_id = state_owner_user_id
if owner_id is None:
if workflow is not None and workflow.created_by:
owner_id = workflow.created_by
elif workflow_run is not None and workflow_run.created_by:
owner_id = workflow_run.created_by
else:
owner_id = getattr(self, "test_user_id", None)
assert owner_id is not None
owner_id = str(owner_id)
return PauseStatePersistenceLayer(
session_factory=self.session.get_bind(),
state_owner_user_id=owner_id,
)
def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers):
"""Test complete pause flow: event -> state serialization -> database save -> storage save."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create real graph runtime state with test data
test_outputs = {"result": "test_output", "step": "intermediate"}
test_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=test_outputs,
total_tokens=100,
node_run_steps=5,
variables=test_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Create pause event
event = GraphRunPausedEvent(
reason=SchedulingPause(message="test pause"),
outputs={"intermediate": "result"},
)
# Act
layer.on_event(event)
# Assert - Verify pause state was saved to database
self.session.refresh(self.test_workflow_run)
workflow_run = self.session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Verify pause state exists in database
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_id == self.test_workflow_id
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.state_object_key != ""
assert pause_model.resumed_at is None
storage_content = storage.load(pause_model.state_object_key).decode()
expected_state = json.loads(graph_runtime_state.dumps())
actual_state = json.loads(storage_content)
assert actual_state == expected_state
def test_state_persistence_and_retrieval(self, db_session_with_containers):
"""Test that pause state can be persisted and retrieved correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create complex test data
complex_outputs = {
"nested": {"key": "value", "number": 42},
"list": [1, 2, 3, {"nested": "item"}],
"boolean": True,
"null_value": None,
}
complex_variables = {
("node1", "var1"): "string_value",
("node2", "var2"): {"complex": "object"},
("node3", "var3"): [1, 2, 3],
}
graph_runtime_state = self._create_graph_runtime_state(
outputs=complex_outputs,
total_tokens=250,
node_run_steps=10,
variables=complex_variables,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act - Save pause state
layer.on_event(event)
# Assert - Retrieve and verify
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(self.test_workflow_run_id)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == self.test_workflow_run_id
state_bytes = pause_entity.get_state()
retrieved_state = json.loads(state_bytes.decode())
expected_state = json.loads(graph_runtime_state.dumps())
assert retrieved_state == expected_state
assert retrieved_state["outputs"] == complex_outputs
assert retrieved_state["total_tokens"] == 250
assert retrieved_state["node_run_steps"] == 10
def test_database_transaction_handling(self, db_session_with_containers):
"""Test that database transactions are handled correctly."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state(
outputs={"test": "transaction"},
total_tokens=50,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify data is committed and accessible in new session
with Session(bind=self.session.get_bind(), expire_on_commit=False) as new_session:
workflow_run = new_session.get(WorkflowRun, self.test_workflow_run_id)
assert workflow_run is not None
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_model = new_session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.workflow_run_id == self.test_workflow_run_id
assert pause_model.resumed_at is None
assert pause_model.state_object_key != ""
def test_file_storage_integration(self, db_session_with_containers):
"""Test integration with file storage system."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Create large state data to test storage
large_outputs = {"data": "x" * 10000} # 10KB of data
graph_runtime_state = self._create_graph_runtime_state(
outputs=large_outputs,
total_tokens=1000,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Verify file was uploaded to storage
self.session.refresh(self.test_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run.id)
).first()
assert pause_model is not None
assert pause_model.state_object_key != ""
# Verify content in storage
storage_content = storage.load(pause_model.state_object_key).decode()
assert storage_content == graph_runtime_state.dumps()
def test_workflow_with_different_creators(self, db_session_with_containers):
"""Test pause state with workflows created by different users."""
# Arrange - Create workflow with different creator
different_user_id = str(uuid.uuid4())
different_workflow = Workflow(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=different_user_id,
created_at=naive_utc_now(),
)
different_workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=different_workflow.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=self.test_user_id, # Run created by different user
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(different_workflow)
self.session.add(different_workflow_run)
self.session.commit()
layer = self._create_pause_state_persistence_layer(
workflow_run=different_workflow_run,
workflow=different_workflow,
)
graph_runtime_state = self._create_graph_runtime_state(
outputs={"creator_test": "different_creator"},
workflow_run_id=different_workflow_run.id,
)
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act
layer.on_event(event)
# Assert - Should use workflow creator (not run creator)
self.session.refresh(different_workflow_run)
pause_model = self.session.scalars(
select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == different_workflow_run.id)
).first()
assert pause_model is not None
# Verify the state owner is the workflow creator
pause_entity = self.workflow_run_service._workflow_run_repo.get_workflow_pause(different_workflow_run.id)
assert pause_entity is not None
def test_layer_ignores_non_pause_events(self, db_session_with_containers):
"""Test that layer ignores non-pause events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
graph_runtime_state = self._create_graph_runtime_state()
command_channel = _TestCommandChannelImpl()
layer.initialize(graph_runtime_state, command_channel)
# Import other event types
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
# Act - Send non-pause events
layer.on_event(GraphRunStartedEvent())
layer.on_event(GraphRunSucceededEvent(outputs={"result": "success"}))
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
# Assert - No pause state should be created
self.session.refresh(self.test_workflow_run)
assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_states = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id)
.all()
)
assert len(pause_states) == 0
def test_layer_requires_initialization(self, db_session_with_containers):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
event = GraphRunPausedEvent(reason=SchedulingPause(message="test pause"))
# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
layer.on_event(event)

View File

@@ -0,0 +1,948 @@
"""Comprehensive integration tests for workflow pause functionality.
This test suite covers complete workflow pause functionality including:
- Real database interactions using containerized PostgreSQL
- Real storage operations using the test storage backend
- Complete workflow: create -> pause -> resume -> delete
- Testing with actual FileService (not mocked)
- Database transactions and rollback behavior
- Actual file upload and retrieval through storage
- Workflow status transitions in the database
- Error handling with real database constraints
- Concurrent access scenarios
- Multi-tenant isolation
- Prune functionality
- File storage integration
These tests use TestContainers to spin up real services for integration testing,
providing more reliable and realistic test scenarios than mocks.
"""
import json
import uuid
from dataclasses import dataclass
from datetime import timedelta
import pytest
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models import Account
from models import WorkflowPause as WorkflowPauseModel
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from models.model import UploadFile
from models.workflow import Workflow, WorkflowRun
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_WorkflowRunError,
)
@dataclass
class PauseWorkflowSuccessCase:
"""Test case for successful pause workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class PauseWorkflowFailureCase:
"""Test case for pause workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowSuccessCase:
"""Test case for successful resume workflow operations."""
name: str
initial_status: WorkflowExecutionStatus
description: str = ""
@dataclass
class ResumeWorkflowFailureCase:
"""Test case for resume workflow failure scenarios."""
name: str
initial_status: WorkflowExecutionStatus
pause_resumed: bool
set_running_status: bool = False
description: str = ""
@dataclass
class PrunePausesTestCase:
"""Test case for prune pauses operations."""
name: str
pause_age: timedelta
resume_age: timedelta | None
expected_pruned_count: int
description: str = ""
def pause_workflow_failure_cases() -> list[PauseWorkflowFailureCase]:
"""Create test cases for pause workflow failure scenarios."""
return [
PauseWorkflowFailureCase(
name="pause_already_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should fail to pause an already paused workflow",
),
PauseWorkflowFailureCase(
name="pause_completed_workflow",
initial_status=WorkflowExecutionStatus.SUCCEEDED,
description="Should fail to pause a completed workflow",
),
PauseWorkflowFailureCase(
name="pause_failed_workflow",
initial_status=WorkflowExecutionStatus.FAILED,
description="Should fail to pause a failed workflow",
),
]
def resume_workflow_success_cases() -> list[ResumeWorkflowSuccessCase]:
"""Create test cases for successful resume workflow operations."""
return [
ResumeWorkflowSuccessCase(
name="resume_paused_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
description="Should successfully resume a paused workflow",
),
]
def resume_workflow_failure_cases() -> list[ResumeWorkflowFailureCase]:
"""Create test cases for resume workflow failure scenarios."""
return [
ResumeWorkflowFailureCase(
name="resume_already_resumed_workflow",
initial_status=WorkflowExecutionStatus.PAUSED,
pause_resumed=True,
description="Should fail to resume an already resumed workflow",
),
ResumeWorkflowFailureCase(
name="resume_running_workflow",
initial_status=WorkflowExecutionStatus.RUNNING,
pause_resumed=False,
set_running_status=True,
description="Should fail to resume a running workflow",
),
]
def prune_pauses_test_cases() -> list[PrunePausesTestCase]:
"""Create test cases for prune pauses operations."""
return [
PrunePausesTestCase(
name="prune_old_active_pauses",
pause_age=timedelta(days=7),
resume_age=None,
expected_pruned_count=1,
description="Should prune old active pauses",
),
PrunePausesTestCase(
name="prune_old_resumed_pauses",
pause_age=timedelta(hours=12), # Created 12 hours ago (recent)
resume_age=timedelta(days=7),
expected_pruned_count=1,
description="Should prune old resumed pauses",
),
PrunePausesTestCase(
name="keep_recent_active_pauses",
pause_age=timedelta(hours=1),
resume_age=None,
expected_pruned_count=0,
description="Should keep recent active pauses",
),
PrunePausesTestCase(
name="keep_recent_resumed_pauses",
pause_age=timedelta(days=1),
resume_age=timedelta(hours=1),
expected_pruned_count=0,
description="Should keep recent resumed pauses",
),
]
class TestWorkflowPauseIntegration:
"""Comprehensive integration tests for workflow pause functionality."""
@pytest.fixture(autouse=True)
def setup_test_data(self, db_session_with_containers):
"""Set up test data for each test method using TestContainers."""
# Create test tenant and account
tenant = Tenant(
name="Test Tenant",
status="normal",
)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
account = Account(
email="test@example.com",
name="Test User",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant-account join
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.commit()
# Set test data
self.test_tenant_id = tenant.id
self.test_user_id = account.id
self.test_app_id = str(uuid.uuid4())
self.test_workflow_id = str(uuid.uuid4())
# Create test workflow
self.test_workflow = Workflow(
id=self.test_workflow_id,
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=self.test_user_id,
created_at=naive_utc_now(),
)
# Store session instance
self.session = db_session_with_containers
# Save test data to database
self.session.add(self.test_workflow)
self.session.commit()
yield
# Cleanup
self._cleanup_test_data()
def _cleanup_test_data(self):
"""Clean up test data after each test method."""
# Clean up workflow pauses
self.session.execute(delete(WorkflowPauseModel))
# Clean up upload files
self.session.execute(
delete(UploadFile).where(
UploadFile.tenant_id == self.test_tenant_id,
)
)
# Clean up workflow runs
self.session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == self.test_tenant_id,
WorkflowRun.app_id == self.test_app_id,
)
)
# Clean up workflows
self.session.execute(
delete(Workflow).where(
Workflow.tenant_id == self.test_tenant_id,
Workflow.app_id == self.test_app_id,
)
)
self.session.commit()
def _create_test_workflow_run(
self, status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
) -> WorkflowRun:
"""Create a test workflow run with specified status."""
workflow_run = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=self.test_tenant_id,
app_id=self.test_app_id,
workflow_id=self.test_workflow_id,
type="workflow",
triggered_from="debugging",
version="draft",
status=status,
created_by=self.test_user_id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run)
self.session.commit()
return workflow_run
def _create_test_state(self) -> str:
"""Create a test state string."""
return json.dumps(
{
"node_id": "test-node",
"node_type": "llm",
"status": "paused",
"data": {"key": "value"},
"timestamp": naive_utc_now().isoformat(),
}
)
def _get_workflow_run_repository(self):
"""Get workflow run repository instance for testing."""
# Create session factory from the test session
engine = self.session.get_bind()
session_factory = sessionmaker(bind=engine, expire_on_commit=False)
# Create a test-specific repository that implements the missing save method
class TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Test-specific repository that implements the missing save method."""
def save(self, execution: WorkflowExecution):
"""Implement the missing save method for testing."""
# For testing purposes, we don't need to implement this method
# as it's not used in the pause functionality tests
pass
# Create and return repository instance
repository = TestWorkflowRunRepository(session_maker=session_factory)
return repository
# ==================== Complete Pause Workflow Tests ====================
def test_complete_pause_resume_workflow(self):
"""Test complete workflow: create -> pause -> resume -> delete."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Pause state created
assert pause_entity is not None
assert pause_entity.id is not None
assert pause_entity.workflow_execution_id == workflow_run.id
# Convert both to strings for comparison
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Verify database state
query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(query).first()
assert pause_model is not None
assert pause_model.resumed_at is None
assert pause_model.id == pause_entity.id
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
# Act - Get pause state
retrieved_entity = repository.get_workflow_pause(workflow_run.id)
# Assert - Pause state retrieved
assert retrieved_entity is not None
assert retrieved_entity.id == pause_entity.id
retrieved_state = retrieved_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
# Act - Resume workflow
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Assert - Workflow resumed
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
# Verify database state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
self.session.refresh(pause_model)
assert pause_model.resumed_at is not None
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause state deleted
with Session(bind=self.session.get_bind()) as session:
deleted_pause = session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
def test_pause_workflow_success(self):
"""Test successful pause workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
assert pause_entity is not None
assert pause_entity.workflow_execution_id == workflow_run.id
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is None
@pytest.mark.parametrize("test_case", pause_workflow_failure_cases(), ids=lambda tc: tc.name)
def test_pause_workflow_failure(self, test_case: PauseWorkflowFailureCase):
"""Test pause workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
with pytest.raises(_WorkflowRunError):
repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
@pytest.mark.parametrize("test_case", resume_workflow_success_cases(), ids=lambda tc: tc.name)
def test_resume_workflow_success(self, test_case: ResumeWorkflowSuccessCase):
"""Test successful resume workflow scenarios."""
workflow_run = self._create_test_workflow_run(status=test_case.initial_status)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
pause_query = select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == workflow_run.id)
pause_model = self.session.scalars(pause_query).first()
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.resumed_at is not None
def test_resume_running_workflow(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
self.session.refresh(workflow_run)
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.add(workflow_run)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
def test_resume_resumed_pause(self):
"""Test resume workflow failure scenarios."""
workflow_run = self._create_test_workflow_run(status=WorkflowExecutionStatus.RUNNING)
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.resumed_at = naive_utc_now()
self.session.add(pause_model)
self.session.commit()
with pytest.raises(_WorkflowRunError):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# ==================== Error Scenario Tests ====================
def test_pause_nonexistent_workflow_run(self):
"""Test pausing a non-existent workflow run."""
# Arrange
nonexistent_id = str(uuid.uuid4())
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.create_workflow_pause(
workflow_run_id=nonexistent_id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
def test_resume_nonexistent_workflow_run(self):
"""Test resuming a non-existent workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
nonexistent_id = str(uuid.uuid4())
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.resume_workflow_pause(
workflow_run_id=nonexistent_id,
pause_entity=pause_entity,
)
# ==================== Prune Functionality Tests ====================
@pytest.mark.parametrize("test_case", prune_pauses_test_cases(), ids=lambda tc: tc.name)
def test_prune_pauses_scenarios(self, test_case: PrunePausesTestCase):
"""Test various prune pauses scenarios."""
now = naive_utc_now()
# Create pause state
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Manually adjust timestamps for testing
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - test_case.pause_age
if test_case.resume_age is not None:
# Resume pause and adjust resume time
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
# Need to refresh to get the updated model
self.session.refresh(pause_model)
# Manually set the resumed_at to an older time for testing
pause_model.resumed_at = now - test_case.resume_age
self.session.commit() # Commit the resumed_at change
# Refresh again to ensure the change is persisted
self.session.refresh(pause_model)
self.session.commit()
# Act - Prune pauses
expiration_time = now - timedelta(days=1, seconds=1) # Expire pauses older than 1 day (plus 1 second)
resumption_time = now - timedelta(
days=7, seconds=1
) # Clean up pauses resumed more than 7 days ago (plus 1 second)
# Debug: Check pause state before pruning
self.session.refresh(pause_model)
print(f"Pause created_at: {pause_model.created_at}")
print(f"Pause resumed_at: {pause_model.resumed_at}")
print(f"Expiration time: {expiration_time}")
print(f"Resumption time: {resumption_time}")
# Force commit to ensure timestamps are saved
self.session.commit()
# Determine if the pause should be pruned based on timestamps
should_be_pruned = False
if test_case.resume_age is not None:
# If resumed, check if resumed_at is older than resumption_time
should_be_pruned = pause_model.resumed_at < resumption_time
else:
# If not resumed, check if created_at is older than expiration_time
should_be_pruned = pause_model.created_at < expiration_time
# Act - Prune pauses
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
)
# Assert - Check pruning results
if should_be_pruned:
assert len(pruned_ids) == test_case.expected_pruned_count
# Verify pause was actually deleted
# The pause should be in the pruned_ids list if it was pruned
assert pause_entity.id in pruned_ids
else:
assert len(pruned_ids) == 0
def test_prune_pauses_with_limit(self):
"""Test prune pauses with limit parameter."""
now = naive_utc_now()
# Create multiple pause states
pause_entities = []
repository = self._get_workflow_run_repository()
for i in range(5):
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
pause_entities.append(pause_entity)
# Make all pauses old enough to be pruned
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
pause_model.created_at = now - timedelta(days=7)
self.session.commit()
# Act - Prune with limit
expiration_time = now - timedelta(days=1)
resumption_time = now - timedelta(days=7)
pruned_ids = repository.prune_pauses(
expiration=expiration_time,
resumption_expiration=resumption_time,
limit=3,
)
# Assert
assert len(pruned_ids) == 3
# Verify only 3 were deleted
remaining_count = (
self.session.query(WorkflowPauseModel)
.filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]))
.count()
)
assert remaining_count == 2
# ==================== Multi-tenant Isolation Tests ====================
def test_multi_tenant_pause_isolation(self):
"""Test that pause states are properly isolated by tenant."""
# Arrange - Create second tenant
tenant2 = Tenant(
name="Test Tenant 2",
status="normal",
)
self.session.add(tenant2)
self.session.commit()
account2 = Account(
email="test2@example.com",
name="Test User 2",
interface_language="en-US",
status="active",
)
self.session.add(account2)
self.session.commit()
tenant2_join = TenantAccountJoin(
tenant_id=tenant2.id,
account_id=account2.id,
role=TenantAccountRole.OWNER,
current=True,
)
self.session.add(tenant2_join)
self.session.commit()
# Create workflow for tenant 2
workflow2 = Workflow(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=str(uuid.uuid4()),
type="workflow",
version="draft",
graph='{"nodes": [], "edges": []}',
features='{"file_upload": {"enabled": false}}',
created_by=account2.id,
created_at=naive_utc_now(),
)
self.session.add(workflow2)
self.session.commit()
# Create workflow runs for both tenants
workflow_run1 = self._create_test_workflow_run()
workflow_run2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
app_id=workflow2.app_id,
workflow_id=workflow2.id,
type="workflow",
triggered_from="debugging",
version="draft",
status=WorkflowExecutionStatus.RUNNING,
created_by=account2.id,
created_by_role="account",
created_at=naive_utc_now(),
)
self.session.add(workflow_run2)
self.session.commit()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause for tenant 1
pause_entity1 = repository.create_workflow_pause(
workflow_run_id=workflow_run1.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Try to access pause from tenant 2 using tenant 1's repository
# This should work because we're using the same repository
pause_entity2 = repository.get_workflow_pause(workflow_run2.id)
assert pause_entity2 is None # No pause for tenant 2 yet
# Create pause for tenant 2
pause_entity2 = repository.create_workflow_pause(
workflow_run_id=workflow_run2.id,
state_owner_user_id=account2.id,
state=test_state,
)
# Assert - Both pauses should exist and be separate
assert pause_entity1 is not None
assert pause_entity2 is not None
assert pause_entity1.id != pause_entity2.id
assert pause_entity1.workflow_execution_id != pause_entity2.workflow_execution_id
def test_cross_tenant_access_restriction(self):
"""Test that cross-tenant access is properly restricted."""
# This test would require tenant-specific repositories
# For now, we test that pause entities are properly scoped by tenant_id
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Verify pause is properly scoped
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.workflow_id == self.test_workflow_id
# ==================== File Storage Integration Tests ====================
def test_file_storage_integration(self):
"""Test that state files are properly stored and retrieved."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
# Act - Create pause state
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Assert - Verify file was uploaded to storage
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
# Verify file content in storage
file_key = pause_model.state_object_key
storage_content = storage.load(file_key).decode()
assert storage_content == test_state
# Verify retrieval through entity
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == test_state
def test_file_cleanup_on_pause_deletion(self):
"""Test that files are properly handled on pause deletion."""
# Arrange
workflow_run = self._create_test_workflow_run()
test_state = self._create_test_state()
repository = self._get_workflow_run_repository()
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=test_state,
)
# Get file info before deletion
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
file_key = pause_model.state_object_key
# Act - Delete pause state
repository.delete_workflow_pause(pause_entity)
# Assert - Pause record should be deleted
self.session.expire_all() # Clear session to ensure fresh query
deleted_pause = self.session.get(WorkflowPauseModel, pause_entity.id)
assert deleted_pause is None
try:
content = storage.load(file_key).decode()
pytest.fail("File should be deleted from storage after pause deletion")
except FileNotFoundError:
# This is expected - file should be deleted from storage
pass
except Exception as e:
pytest.fail(f"Unexpected error when checking file deletion: {e}")
def test_large_state_file_handling(self):
"""Test handling of large state files."""
# Arrange - Create a large state (1MB)
large_state = "x" * (1024 * 1024) # 1MB of data
large_state_json = json.dumps({"large_data": large_state})
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=large_state_json,
)
# Assert
assert pause_entity is not None
retrieved_state = pause_entity.get_state()
if isinstance(retrieved_state, bytes):
retrieved_state = retrieved_state.decode()
assert retrieved_state == large_state_json
# Verify file size in database
pause_model = self.session.get(WorkflowPauseModel, pause_entity.id)
assert pause_model.state_object_key != ""
loaded_state = storage.load(pause_model.state_object_key)
assert loaded_state.decode() == large_state_json
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles on the same workflow run."""
# Arrange
workflow_run = self._create_test_workflow_run()
repository = self._get_workflow_run_repository()
# Act & Assert - Multiple cycles
for i in range(3):
state = json.dumps({"cycle": i, "data": f"state_{i}"})
# Reset workflow run status to RUNNING before each pause (after first cycle)
if i > 0:
self.session.refresh(workflow_run) # Refresh to get latest state from session
workflow_run.status = WorkflowExecutionStatus.RUNNING
self.session.commit()
self.session.refresh(workflow_run) # Refresh again after commit
# Pause
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=self.test_user_id,
state=state,
)
assert pause_entity is not None
# Verify pause
self.session.expire_all() # Clear session to ensure fresh query
self.session.refresh(workflow_run)
# Use the test session directly to verify the pause
stmt = select(WorkflowRun).options(selectinload(WorkflowRun.pause)).where(WorkflowRun.id == workflow_run.id)
workflow_run_with_pause = self.session.scalar(stmt)
pause_model = workflow_run_with_pause.pause
# Verify pause using test session directly
assert pause_model is not None
assert pause_model.id == pause_entity.id
assert pause_model.state_object_key != ""
# Load file content using storage directly
file_content = storage.load(pause_model.state_object_key)
if isinstance(file_content, bytes):
file_content = file_content.decode()
assert file_content == state
# Resume
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
assert resumed_entity is not None
assert resumed_entity.resumed_at is not None
# Verify resume - check that pause is marked as resumed
self.session.expire_all() # Clear session to ensure fresh query
stmt = select(WorkflowPauseModel).where(WorkflowPauseModel.id == pause_entity.id)
resumed_pause_model = self.session.scalar(stmt)
assert resumed_pause_model is not None
assert resumed_pause_model.resumed_at is not None
# Verify workflow run status
self.session.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.RUNNING

View File

@@ -0,0 +1,278 @@
import json
from time import time
from unittest.mock import Mock
import pytest
from core.app.layers.pause_state_persist_layer import PauseStatePersistenceLayer
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
from repositories.factory import DifyAPIRepositoryFactory
class TestDataFactory:
"""Factory helpers for constructing graph events used in tests."""
@staticmethod
def create_graph_run_paused_event(outputs: dict[str, object] | None = None) -> GraphRunPausedEvent:
return GraphRunPausedEvent(reason=SchedulingPause(message="test pause"), outputs=outputs or {})
@staticmethod
def create_graph_run_started_event() -> GraphRunStartedEvent:
return GraphRunStartedEvent()
@staticmethod
def create_graph_run_succeeded_event(outputs: dict[str, object] | None = None) -> GraphRunSucceededEvent:
return GraphRunSucceededEvent(outputs=outputs or {})
@staticmethod
def create_graph_run_failed_event(
error: str = "Test error",
exceptions_count: int = 1,
) -> GraphRunFailedEvent:
return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count)
class MockSystemVariableReadOnlyView:
"""Minimal read-only system variable view for testing."""
def __init__(self, workflow_execution_id: str | None = None) -> None:
self._workflow_execution_id = workflow_execution_id
@property
def workflow_execution_id(self) -> str | None:
return self._workflow_execution_id
class MockReadOnlyVariablePool:
"""Mock implementation of ReadOnlyVariablePool for testing."""
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
if value is None:
return None
mock_segment = Mock(spec=Segment)
mock_segment.value = value
return mock_segment
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)}
class MockReadOnlyGraphRuntimeState:
"""Mock implementation of ReadOnlyGraphRuntimeState for testing."""
def __init__(
self,
start_at: float | None = None,
total_tokens: int = 0,
node_run_steps: int = 0,
ready_queue_size: int = 0,
exceptions_count: int = 0,
outputs: dict[str, object] | None = None,
variables: dict[tuple[str, str], object] | None = None,
workflow_execution_id: str | None = None,
):
self._start_at = start_at or time()
self._total_tokens = total_tokens
self._node_run_steps = node_run_steps
self._ready_queue_size = ready_queue_size
self._exceptions_count = exceptions_count
self._outputs = outputs or {}
self._variable_pool = MockReadOnlyVariablePool(variables)
self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id)
@property
def system_variable(self) -> MockSystemVariableReadOnlyView:
return self._system_variable
@property
def variable_pool(self) -> ReadOnlyVariablePool:
return self._variable_pool
@property
def start_at(self) -> float:
return self._start_at
@property
def total_tokens(self) -> int:
return self._total_tokens
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@property
def ready_queue_size(self) -> int:
return self._ready_queue_size
@property
def exceptions_count(self) -> int:
return self._exceptions_count
@property
def outputs(self) -> dict[str, object]:
return self._outputs.copy()
@property
def llm_usage(self):
mock_usage = Mock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 20
mock_usage.total_tokens = 30
return mock_usage
def get_output(self, key: str, default: object = None) -> object:
return self._outputs.get(key, default)
def dumps(self) -> str:
return json.dumps(
{
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"ready_queue_size": self._ready_queue_size,
"exceptions_count": self._exceptions_count,
"outputs": self._outputs,
"variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()},
"workflow_execution_id": self._system_variable.workflow_execution_id,
}
)
class MockCommandChannel:
"""Mock implementation of CommandChannel for testing."""
def __init__(self):
self._commands: list[GraphEngineCommand] = []
def fetch_commands(self) -> list[GraphEngineCommand]:
return self._commands.copy()
def send_command(self, command: GraphEngineCommand) -> None:
self._commands.append(command)
class TestPauseStatePersistenceLayer:
"""Unit tests for PauseStatePersistenceLayer."""
def test_init_with_dependency_injection(self):
session_factory = Mock(name="session_factory")
state_owner_user_id = "user-123"
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
state_owner_user_id=state_owner_user_id,
)
assert layer._session_maker is session_factory
assert layer._state_owner_user_id == state_owner_user_id
assert not hasattr(layer, "graph_runtime_state")
assert not hasattr(layer, "command_channel")
def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner")
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
assert layer.graph_runtime_state is graph_runtime_state
assert layer.command_channel is command_channel
def test_on_event_with_graph_run_paused_event(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(
outputs={"result": "test_output"},
total_tokens=100,
workflow_execution_id="run-123",
)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event(outputs={"intermediate": "result"})
expected_state = graph_runtime_state.dumps()
layer.on_event(event)
mock_factory.assert_called_once_with(session_factory)
mock_repo.create_workflow_pause.assert_called_once_with(
workflow_run_id="run-123",
state_owner_user_id="owner-123",
state=expected_state,
)
def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState()
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
events = [
TestDataFactory.create_graph_run_started_event(),
TestDataFactory.create_graph_run_succeeded_event(),
TestDataFactory.create_graph_run_failed_event(),
]
for event in events:
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()
def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AttributeError):
layer.on_event(event)
def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(session_factory=session_factory, state_owner_user_id="owner-123")
mock_repo = Mock()
mock_factory = Mock(return_value=mock_repo)
monkeypatch.setattr(DifyAPIRepositoryFactory, "create_api_workflow_run_repository", mock_factory)
graph_runtime_state = MockReadOnlyGraphRuntimeState(workflow_execution_id=None)
command_channel = MockCommandChannel()
layer.initialize(graph_runtime_state, command_channel)
event = TestDataFactory.create_graph_run_paused_event()
with pytest.raises(AssertionError):
layer.on_event(event)
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()

View File

@@ -0,0 +1,171 @@
"""Tests for _PrivateWorkflowPauseEntity implementation."""
from datetime import datetime
from unittest.mock import MagicMock, patch
from models.workflow import WorkflowPause as WorkflowPauseModel
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity implementation."""
def test_entity_initialization(self):
"""Test entity initialization with required parameters."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
mock_pause_model.resumed_at = None
# Create entity
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Verify initialization
assert entity._pause_model is mock_pause_model
assert entity._cached_state is None
def test_from_models_classmethod(self):
"""Test from_models class method."""
# Create mock models
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
mock_pause_model.workflow_run_id = "execution-456"
# Create entity using from_models
entity = _PrivateWorkflowPauseEntity.from_models(
workflow_pause_model=mock_pause_model,
)
# Verify entity creation
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model is mock_pause_model
def test_id_property(self):
"""Test id property returns pause model ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.id = "pause-123"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.id == "pause-123"
def test_workflow_execution_id_property(self):
"""Test workflow_execution_id property returns workflow run ID."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.workflow_run_id = "execution-456"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.workflow_execution_id == "execution-456"
def test_resumed_at_property(self):
"""Test resumed_at property returns pause model resumed_at."""
resumed_at = datetime(2023, 12, 25, 15, 30, 45)
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = resumed_at
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at == resumed_at
def test_resumed_at_property_none(self):
"""Test resumed_at property returns None when not set."""
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.resumed_at = None
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
assert entity.resumed_at is None
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_first_call(self, mock_storage):
"""Test get_state loads from storage on first call."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call should load from storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_called_once_with("test-state-key")
assert entity._cached_state == state_data
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_cached_call(self, mock_storage):
"""Test get_state returns cached data on subsequent calls."""
state_data = b'{"test": "data", "step": 5}'
mock_storage.load.return_value = state_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
mock_pause_model.state_object_key = "test-state-key"
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# First call
result1 = entity.get_state()
# Second call should use cache
result2 = entity.get_state()
assert result1 == state_data
assert result2 == state_data
# Storage should only be called once
mock_storage.load.assert_called_once_with("test-state-key")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
def test_get_state_with_pre_cached_data(self, mock_storage):
"""Test get_state returns pre-cached data."""
state_data = b'{"test": "data", "step": 5}'
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
# Pre-cache data
entity._cached_state = state_data
# Should return cached data without calling storage
result = entity.get_state()
assert result == state_data
mock_storage.load.assert_not_called()
def test_entity_with_binary_state_data(self):
"""Test entity with binary state data."""
# Test with binary data that's not valid JSON
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = binary_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
entity = _PrivateWorkflowPauseEntity(
pause_model=mock_pause_model,
)
result = entity.get_state()
assert result == binary_data

View File

@@ -3,6 +3,7 @@
import time
from unittest.mock import MagicMock
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
@@ -149,8 +150,8 @@ def test_pause_command():
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)]
assert len(pause_events) == 1
assert pause_events[0].reason == "User requested pause"
assert pause_events[0].reason == SchedulingPause(message="User requested pause")
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.is_paused
assert graph_execution.pause_reason == "User requested pause"
assert graph_execution.pause_reason == SchedulingPause(message="User requested pause")

View File

@@ -0,0 +1,32 @@
"""Tests for workflow pause related enums and constants."""
from core.workflow.enums import (
WorkflowExecutionStatus,
)
class TestWorkflowExecutionStatus:
"""Test WorkflowExecutionStatus enum."""
def test_is_ended_method(self):
"""Test is_ended method for different statuses."""
# Test ended statuses
ended_statuses = [
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
for status in ended_statuses:
assert status.is_ended(), f"{status} should be considered ended"
# Test non-ended statuses
non_ended_statuses = [
WorkflowExecutionStatus.SCHEDULED,
WorkflowExecutionStatus.RUNNING,
WorkflowExecutionStatus.PAUSED,
]
for status in non_ended_statuses:
assert not status.is_ended(), f"{status} should not be considered ended"

View File

@@ -0,0 +1,202 @@
from typing import cast
import pytest
from core.file.models import File, FileTransferMethod, FileType
from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView
class TestSystemVariableReadOnlyView:
"""Test cases for SystemVariableReadOnlyView class."""
def test_read_only_property_access(self):
"""Test that all properties return correct values from wrapped instance."""
# Create test data
test_file = File(
id="file-123",
tenant_id="tenant-123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related-123",
)
datasource_info = {"key": "value", "nested": {"data": 42}}
# Create SystemVariable with all fields
system_var = SystemVariable(
user_id="user-123",
app_id="app-123",
workflow_id="workflow-123",
files=[test_file],
workflow_execution_id="exec-123",
query="test query",
conversation_id="conv-123",
dialogue_count=5,
document_id="doc-123",
original_document_id="orig-doc-123",
dataset_id="dataset-123",
batch="batch-123",
datasource_type="type-123",
datasource_info=datasource_info,
invoke_from="invoke-123",
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test all properties
assert read_only_view.user_id == "user-123"
assert read_only_view.app_id == "app-123"
assert read_only_view.workflow_id == "workflow-123"
assert read_only_view.workflow_execution_id == "exec-123"
assert read_only_view.query == "test query"
assert read_only_view.conversation_id == "conv-123"
assert read_only_view.dialogue_count == 5
assert read_only_view.document_id == "doc-123"
assert read_only_view.original_document_id == "orig-doc-123"
assert read_only_view.dataset_id == "dataset-123"
assert read_only_view.batch == "batch-123"
assert read_only_view.datasource_type == "type-123"
assert read_only_view.invoke_from == "invoke-123"
def test_defensive_copying_of_mutable_objects(self):
"""Test that mutable objects are defensively copied."""
# Create test data
test_file = File(
id="file-123",
tenant_id="tenant-123",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related-123",
)
datasource_info = {"key": "original_value"}
# Create SystemVariable
system_var = SystemVariable(
files=[test_file], datasource_info=datasource_info, workflow_execution_id="exec-123"
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test files defensive copying
files_copy = read_only_view.files
assert isinstance(files_copy, tuple) # Should be immutable tuple
assert len(files_copy) == 1
assert files_copy[0].id == "file-123"
# Verify it's a copy (can't modify original through view)
assert isinstance(files_copy, tuple)
# tuples don't have append method, so they're immutable
# Test datasource_info defensive copying
datasource_copy = read_only_view.datasource_info
assert datasource_copy is not None
assert datasource_copy["key"] == "original_value"
datasource_copy = cast(dict, datasource_copy)
with pytest.raises(TypeError):
datasource_copy["key"] = "modified value"
# Verify original is unchanged
assert system_var.datasource_info is not None
assert system_var.datasource_info["key"] == "original_value"
assert read_only_view.datasource_info is not None
assert read_only_view.datasource_info["key"] == "original_value"
def test_always_accesses_latest_data(self):
"""Test that properties always return the latest data from wrapped instance."""
# Create SystemVariable
system_var = SystemVariable(user_id="original-user", workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Verify initial value
assert read_only_view.user_id == "original-user"
# Modify the wrapped instance
system_var.user_id = "modified-user"
# Verify view returns the new value
assert read_only_view.user_id == "modified-user"
def test_repr_method(self):
"""Test the __repr__ method."""
# Create SystemVariable
system_var = SystemVariable(workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test repr
repr_str = repr(read_only_view)
assert "SystemVariableReadOnlyView" in repr_str
assert "system_variable=" in repr_str
def test_none_value_handling(self):
"""Test that None values are properly handled."""
# Create SystemVariable with all None values except workflow_execution_id
system_var = SystemVariable(
user_id=None,
app_id=None,
workflow_id=None,
workflow_execution_id="exec-123",
query=None,
conversation_id=None,
dialogue_count=None,
document_id=None,
original_document_id=None,
dataset_id=None,
batch=None,
datasource_type=None,
datasource_info=None,
invoke_from=None,
)
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test all None values
assert read_only_view.user_id is None
assert read_only_view.app_id is None
assert read_only_view.workflow_id is None
assert read_only_view.query is None
assert read_only_view.conversation_id is None
assert read_only_view.dialogue_count is None
assert read_only_view.document_id is None
assert read_only_view.original_document_id is None
assert read_only_view.dataset_id is None
assert read_only_view.batch is None
assert read_only_view.datasource_type is None
assert read_only_view.datasource_info is None
assert read_only_view.invoke_from is None
# files should be empty tuple even when default list is empty
assert read_only_view.files == ()
def test_empty_files_handling(self):
"""Test that empty files list is handled correctly."""
# Create SystemVariable with empty files
system_var = SystemVariable(files=[], workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test files handling
assert read_only_view.files == ()
assert isinstance(read_only_view.files, tuple)
def test_empty_datasource_info_handling(self):
"""Test that empty datasource_info is handled correctly."""
# Create SystemVariable with empty datasource_info
system_var = SystemVariable(datasource_info={}, workflow_execution_id="exec-123")
# Create read-only view
read_only_view = SystemVariableReadOnlyView(system_var)
# Test datasource_info handling
assert read_only_view.datasource_info == {}
# Should be a copy, not the same object
assert read_only_view.datasource_info is not system_var.datasource_info

View File

@@ -0,0 +1,11 @@
from models.base import DefaultFieldsMixin
class FooModel(DefaultFieldsMixin):
def __init__(self, id: str):
self.id = id
def test_repr():
foo_model = FooModel(id="test-id")
assert repr(foo_model) == "<FooModel(id=test-id)>"

View File

@@ -0,0 +1,370 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.workflow_pause import WorkflowPauseEntity
from core.workflow.enums import WorkflowExecutionStatus
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=Session)
@pytest.fixture
def mock_session_maker(self, mock_session):
"""Create a mock sessionmaker."""
session_maker = Mock(spec=sessionmaker)
# Create a context manager mock
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
# Mock session.begin() context manager
begin_context_manager = Mock()
begin_context_manager.__enter__ = Mock(return_value=None)
begin_context_manager.__exit__ = Mock(return_value=None)
mock_session.begin = Mock(return_value=begin_context_manager)
# Add missing session methods
mock_session.commit = Mock()
mock_session.rollback = Mock()
mock_session.add = Mock()
mock_session.delete = Mock()
mock_session.get = Mock()
mock_session.scalar = Mock()
mock_session.scalars = Mock()
# Also support expire_on_commit parameter
def make_session(expire_on_commit=None):
cm = Mock()
cm.__enter__ = Mock(return_value=mock_session)
cm.__exit__ = Mock(return_value=None)
return cm
session_maker.side_effect = make_session
return session_maker
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mocked dependencies."""
# Create a testable subclass that implements the save method
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
def __init__(self, session_maker):
# Initialize without calling parent __init__ to avoid any instantiation issues
self._session_maker = session_maker
def save(self, execution):
"""Mock implementation of save method."""
return None
# Create repository instance
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
return repo
@pytest.fixture
def sample_workflow_run(self):
"""Create a sample WorkflowRun model."""
workflow_run = Mock(spec=WorkflowRun)
workflow_run.id = "workflow-run-123"
workflow_run.tenant_id = "tenant-123"
workflow_run.app_id = "app-123"
workflow_run.workflow_id = "workflow-123"
workflow_run.status = WorkflowExecutionStatus.RUNNING
return workflow_run
@pytest.fixture
def sample_workflow_pause(self):
"""Create a sample WorkflowPauseModel."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test create_workflow_pause method."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test successful workflow pause creation."""
# Arrange
workflow_run_id = "workflow-run-123"
state_owner_user_id = "user-123"
state = '{"test": "state"}'
mock_session.get.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
mock_uuidv7.side_effect = ["pause-123"]
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
result = repository.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
mock_storage.save.assert_called_once()
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_create_workflow_pause_not_found(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
"""Test workflow pause creation when workflow run not found."""
# Arrange
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
def test_create_workflow_pause_invalid_status(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
)
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test resume_workflow_pause method."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause resume."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
# Setup workflow run and pause
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.pause = sample_workflow_pause
sample_workflow_pause.resumed_at = None
mock_session.scalar.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
mock_now.return_value = datetime.now(UTC)
# Act
result = repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
# Verify state transitions
assert sample_workflow_pause.resumed_at is not None
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
# Verify database interactions
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test resume when workflow is not paused."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test resume when pause ID doesn't match."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-456" # Different ID
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_pause.id = "pause-123"
sample_workflow_run.pause = sample_workflow_pause
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test delete_workflow_pause method."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause deletion."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = sample_workflow_pause
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
repository.delete_workflow_pause(pause_entity=pause_entity)
# Assert
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
mock_session.delete.assert_called_once_with(sample_workflow_pause)
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
):
"""Test delete when pause not found."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test _PrivateWorkflowPauseEntity class."""
def test_from_models(self, sample_workflow_pause: Mock):
"""Test creating _PrivateWorkflowPauseEntity from models."""
# Act
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Assert
assert isinstance(entity, _PrivateWorkflowPauseEntity)
assert entity._pause_model == sample_workflow_pause
def test_properties(self, sample_workflow_pause: Mock):
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
# Act & Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock):
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result = entity.get_state()
# Assert
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock):
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity.from_models(sample_workflow_pause)
expected_state = b'{"test": "state"}'
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
mock_storage.load.return_value = expected_state
# Act
result1 = entity.get_state()
result2 = entity.get_state() # Should use cache
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once() # Only called once due to caching

View File

@@ -0,0 +1,200 @@
"""Comprehensive unit tests for WorkflowRunService class.
This test suite covers all pause state management operations including:
- Retrieving pause state for workflow runs
- Saving pause state with file uploads
- Marking paused workflows as resumed
- Error handling and edge cases
- Database transaction management
- Repository-based approach testing
"""
from datetime import datetime
from unittest.mock import MagicMock, create_autospec, patch
import pytest
from sqlalchemy import Engine
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowExecutionStatus
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity
from services.workflow_run_service import (
WorkflowRunService,
)
class TestDataFactory:
"""Factory class for creating test data objects."""
@staticmethod
def create_workflow_run_mock(
id: str = "workflow-run-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
mock_run = MagicMock()
mock_run.id = id
mock_run.tenant_id = tenant_id
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)
return mock_run
@staticmethod
def create_workflow_pause_mock(
id: str = "pause-123",
tenant_id: str = "tenant-456",
app_id: str = "app-789",
workflow_id: str = "workflow-101",
workflow_execution_id: str = "workflow-execution-123",
state_file_id: str = "file-456",
resumed_at: datetime | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowPauseModel object."""
mock_pause = MagicMock()
mock_pause.id = id
mock_pause.tenant_id = tenant_id
mock_pause.app_id = app_id
mock_pause.workflow_id = workflow_id
mock_pause.workflow_execution_id = workflow_execution_id
mock_pause.state_file_id = state_file_id
mock_pause.resumed_at = resumed_at
for key, value in kwargs.items():
setattr(mock_pause, key, value)
return mock_pause
@staticmethod
def create_upload_file_mock(
id: str = "file-456",
key: str = "upload_files/test/state.json",
name: str = "state.json",
tenant_id: str = "tenant-456",
**kwargs,
) -> MagicMock:
"""Create a mock UploadFile object."""
mock_file = MagicMock()
mock_file.id = id
mock_file.key = key
mock_file.name = name
mock_file.tenant_id = tenant_id
for key, value in kwargs.items():
setattr(mock_file, key, value)
return mock_file
@staticmethod
def create_pause_entity_mock(
pause_model: MagicMock | None = None,
upload_file: MagicMock | None = None,
) -> _PrivateWorkflowPauseEntity:
"""Create a mock _PrivateWorkflowPauseEntity object."""
if pause_model is None:
pause_model = TestDataFactory.create_workflow_pause_mock()
if upload_file is None:
upload_file = TestDataFactory.create_upload_file_mock()
return _PrivateWorkflowPauseEntity.from_models(pause_model, upload_file)
class TestWorkflowRunService:
"""Comprehensive unit tests for WorkflowRunService class."""
@pytest.fixture
def mock_session_factory(self):
"""Create a mock session factory with proper session management."""
mock_session = create_autospec(Session)
# Create a mock context manager for the session
mock_session_cm = MagicMock()
mock_session_cm.__enter__ = MagicMock(return_value=mock_session)
mock_session_cm.__exit__ = MagicMock(return_value=None)
# Create a mock context manager for the transaction
mock_transaction_cm = MagicMock()
mock_transaction_cm.__enter__ = MagicMock(return_value=mock_session)
mock_transaction_cm.__exit__ = MagicMock(return_value=None)
mock_session.begin = MagicMock(return_value=mock_transaction_cm)
# Create mock factory that returns the context manager
mock_factory = MagicMock(spec=sessionmaker)
mock_factory.return_value = mock_session_cm
return mock_factory, mock_session
@pytest.fixture
def mock_workflow_run_repository(self):
"""Create a mock APIWorkflowRunRepository."""
mock_repo = create_autospec(APIWorkflowRunRepository)
return mock_repo
@pytest.fixture
def workflow_run_service(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with mocked dependencies."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
return service
@pytest.fixture
def workflow_run_service_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Create WorkflowRunService instance with Engine input."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(mock_engine)
return service
# ==================== Initialization Tests ====================
def test_init_with_session_factory(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with session_factory."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_engine(self, mock_session_factory, mock_workflow_run_repository):
"""Test WorkflowRunService initialization with Engine (should convert to sessionmaker)."""
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
service = WorkflowRunService(mock_engine)
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
assert service._session_factory == session_factory
mock_factory.create_api_workflow_run_repository.assert_called_once_with(session_factory)
def test_init_with_default_dependencies(self, mock_session_factory):
"""Test WorkflowRunService initialization with default dependencies."""
session_factory, _ = mock_session_factory
service = WorkflowRunService(session_factory)
assert service._session_factory == session_factory