mirror of
https://github.com/langgenius/dify.git
synced 2026-02-24 18:05:11 +00:00
test: migrate workflow run repository SQL tests to testcontainers (#32519)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,506 @@
|
||||
"""Integration tests for DifyAPISQLAlchemyWorkflowRunRepository using testcontainers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine, delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities import WorkflowExecution
|
||||
from core.workflow.entities.pause_reason import PauseReasonType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Concrete repository for tests where save() is not under test."""
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _TestScope:
|
||||
"""Per-test data scope used to isolate DB rows and storage keys."""
|
||||
|
||||
tenant_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
app_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
workflow_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
user_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
state_keys: set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
def _create_workflow_run(
|
||||
session: Session,
|
||||
scope: _TestScope,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
created_at: datetime | None = None,
|
||||
) -> WorkflowRun:
|
||||
"""Create and persist a workflow run bound to the current test scope."""
|
||||
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(uuid4()),
|
||||
tenant_id=scope.tenant_id,
|
||||
app_id=scope.app_id,
|
||||
workflow_id=scope.workflow_id,
|
||||
type="workflow",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
version="draft",
|
||||
graph="{}",
|
||||
inputs="{}",
|
||||
status=status,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=scope.user_id,
|
||||
created_at=created_at or naive_utc_now(),
|
||||
)
|
||||
session.add(workflow_run)
|
||||
session.commit()
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
|
||||
"""Remove test-created DB rows and storage objects for a test scope."""
|
||||
|
||||
pause_ids_subquery = select(WorkflowPause.id).where(WorkflowPause.workflow_id == scope.workflow_id)
|
||||
session.execute(delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids_subquery)))
|
||||
session.execute(delete(WorkflowPause).where(WorkflowPause.workflow_id == scope.workflow_id))
|
||||
session.execute(
|
||||
delete(WorkflowAppLog).where(
|
||||
WorkflowAppLog.tenant_id == scope.tenant_id,
|
||||
WorkflowAppLog.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
delete(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == scope.tenant_id,
|
||||
WorkflowRun.app_id == scope.app_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
for state_key in scope.state_keys:
|
||||
try:
|
||||
storage.delete(state_key)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Build a repository backed by the testcontainers database engine."""
|
||||
|
||||
engine = db_session_with_containers.get_bind()
|
||||
assert isinstance(engine, Engine)
|
||||
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_scope(db_session_with_containers: Session) -> _TestScope:
|
||||
"""Provide an isolated scope and clean related data after each test."""
|
||||
|
||||
scope = _TestScope()
|
||||
yield scope
|
||||
_cleanup_scope_data(db_session_with_containers, scope)
|
||||
|
||||
|
||||
class TestGetRunsBatchByTimeRange:
|
||||
"""Integration tests for get_runs_batch_by_time_range."""
|
||||
|
||||
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Return only terminal workflow runs, excluding RUNNING and PAUSED."""
|
||||
|
||||
now = naive_utc_now()
|
||||
ended_statuses = [
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
]
|
||||
ended_run_ids = {
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=status,
|
||||
created_at=now - timedelta(minutes=3),
|
||||
).id
|
||||
for status in ended_statuses
|
||||
}
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
created_at=now - timedelta(minutes=2),
|
||||
)
|
||||
_create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
created_at=now - timedelta(minutes=1),
|
||||
)
|
||||
|
||||
runs = repository.get_runs_batch_by_time_range(
|
||||
start_from=now - timedelta(days=1),
|
||||
end_before=now + timedelta(days=1),
|
||||
last_seen=None,
|
||||
batch_size=50,
|
||||
tenant_ids=[test_scope.tenant_id],
|
||||
)
|
||||
|
||||
returned_ids = {run.id for run in runs}
|
||||
returned_statuses = {run.status for run in runs}
|
||||
|
||||
assert returned_ids == ended_run_ids
|
||||
assert returned_statuses == set(ended_statuses)
|
||||
|
||||
|
||||
class TestDeleteRunsWithRelated:
|
||||
"""Integration tests for delete_runs_with_related."""
|
||||
|
||||
def test_uses_trigger_log_repository(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Delete run-related records and invoke injected trigger-log deleter."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
app_log = WorkflowAppLog(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
pause_reason = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.SCHEDULED_PAUSE,
|
||||
message="scheduled pause",
|
||||
)
|
||||
db_session_with_containers.add_all([app_log, pause, pause_reason])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.delete_by_run_ids.return_value = 3
|
||||
|
||||
counts = repository.delete_runs_with_related(
|
||||
[workflow_run],
|
||||
delete_node_executions=lambda session, runs: (2, 1),
|
||||
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.delete_by_run_ids.assert_called_once_with([workflow_run.id])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["app_logs"] == 1
|
||||
assert counts["pauses"] == 1
|
||||
assert counts["pause_reasons"] == 1
|
||||
assert counts["runs"] == 1
|
||||
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
|
||||
assert verification_session.get(WorkflowRun, workflow_run.id) is None
|
||||
|
||||
|
||||
class TestCountRunsWithRelated:
|
||||
"""Integration tests for count_runs_with_related."""
|
||||
|
||||
def test_uses_trigger_log_repository(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Count run-related records and invoke injected trigger-log counter."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
app_log = WorkflowAppLog(
|
||||
tenant_id=test_scope.tenant_id,
|
||||
app_id=test_scope.app_id,
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=test_scope.user_id,
|
||||
)
|
||||
pause = WorkflowPause(
|
||||
id=str(uuid4()),
|
||||
workflow_id=test_scope.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_object_key=f"workflow-state-{uuid4()}.json",
|
||||
)
|
||||
pause_reason = WorkflowPauseReason(
|
||||
pause_id=pause.id,
|
||||
type_=PauseReasonType.SCHEDULED_PAUSE,
|
||||
message="scheduled pause",
|
||||
)
|
||||
db_session_with_containers.add_all([app_log, pause, pause_reason])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.count_by_run_ids.return_value = 3
|
||||
|
||||
counts = repository.count_runs_with_related(
|
||||
[workflow_run],
|
||||
count_node_executions=lambda session, runs: (2, 1),
|
||||
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.count_by_run_ids.assert_called_once_with([workflow_run.id])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["app_logs"] == 1
|
||||
assert counts["pauses"] == 1
|
||||
assert counts["pause_reasons"] == 1
|
||||
assert counts["runs"] == 1
|
||||
|
||||
|
||||
class TestCreateWorkflowPause:
|
||||
"""Integration tests for create_workflow_pause."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Create pause successfully, persist pause record, and set run status to PAUSED."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
state = '{"test": "state"}'
|
||||
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state=state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
|
||||
assert pause_model is not None
|
||||
test_scope.state_keys.add(pause_model.state_object_key)
|
||||
|
||||
db_session_with_containers.refresh(workflow_run)
|
||||
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
assert pause_entity.id == pause_model.id
|
||||
assert pause_entity.workflow_execution_id == workflow_run.id
|
||||
assert pause_entity.get_pause_reasons() == []
|
||||
assert pause_entity.get_state() == state.encode()
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise ValueError when the workflow run does not exist."""
|
||||
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=str(uuid4()),
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise _WorkflowRunError when pausing a run in non-pausable status."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
|
||||
class TestResumeWorkflowPause:
|
||||
"""Integration tests for resume_workflow_pause."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Resume pause successfully and switch workflow run status back to RUNNING."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
|
||||
assert pause_model is not None
|
||||
test_scope.state_keys.add(pause_model.state_object_key)
|
||||
|
||||
resumed_entity = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
db_session_with_containers.refresh(workflow_run)
|
||||
db_session_with_containers.refresh(pause_model)
|
||||
assert resumed_entity.id == pause_entity.id
|
||||
assert resumed_entity.resumed_at is not None
|
||||
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
assert pause_model.resumed_at is not None
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise _WorkflowRunError when workflow run is not in PAUSED status."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = str(uuid4())
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Raise _WorkflowRunError when pause entity ID mismatches persisted pause ID."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
|
||||
assert pause_model is not None
|
||||
test_scope.state_keys.add(pause_model.state_object_key)
|
||||
|
||||
mismatched_pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
mismatched_pause_entity.id = str(uuid4())
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pause_entity=mismatched_pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause:
|
||||
"""Integration tests for delete_workflow_pause."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
db_session_with_containers: Session,
|
||||
test_scope: _TestScope,
|
||||
) -> None:
|
||||
"""Delete pause record and its state object from storage."""
|
||||
|
||||
workflow_run = _create_workflow_run(
|
||||
db_session_with_containers,
|
||||
test_scope,
|
||||
status=WorkflowExecutionStatus.RUNNING,
|
||||
)
|
||||
pause_entity = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run.id,
|
||||
state_owner_user_id=test_scope.user_id,
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
|
||||
assert pause_model is not None
|
||||
state_key = pause_model.state_object_key
|
||||
test_scope.state_keys.add(state_key)
|
||||
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
|
||||
assert verification_session.get(WorkflowPause, pause_entity.id) is None
|
||||
with pytest.raises(FileNotFoundError):
|
||||
storage.load(state_key)
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
) -> None:
|
||||
"""Raise _WorkflowRunError when deleting a non-existent pause."""
|
||||
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = str(uuid4())
|
||||
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
@@ -1,435 +1,50 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
"""Unit tests for non-SQL helper logic in workflow run repository."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from models.workflow import WorkflowPauseReason
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock session."""
|
||||
return Mock(spec=Session)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(self, mock_session):
|
||||
"""Create a mock sessionmaker."""
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
|
||||
# Create a context manager mock
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_context_manager = Mock()
|
||||
begin_context_manager.__enter__ = Mock(return_value=None)
|
||||
begin_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_session.begin = Mock(return_value=begin_context_manager)
|
||||
|
||||
# Add missing session methods
|
||||
mock_session.commit = Mock()
|
||||
mock_session.rollback = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.delete = Mock()
|
||||
mock_session.get = Mock()
|
||||
mock_session.scalar = Mock()
|
||||
mock_session.scalars = Mock()
|
||||
|
||||
# Also support expire_on_commit parameter
|
||||
def make_session(expire_on_commit=None):
|
||||
cm = Mock()
|
||||
cm.__enter__ = Mock(return_value=mock_session)
|
||||
cm.__exit__ = Mock(return_value=None)
|
||||
return cm
|
||||
|
||||
session_maker.side_effect = make_session
|
||||
return session_maker
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_session_maker):
|
||||
"""Create repository instance with mocked dependencies."""
|
||||
|
||||
# Create a testable subclass that implements the save method
|
||||
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def __init__(self, session_maker):
|
||||
# Initialize without calling parent __init__ to avoid any instantiation issues
|
||||
self._session_maker = session_maker
|
||||
|
||||
def save(self, execution):
|
||||
"""Mock implementation of save method."""
|
||||
return None
|
||||
|
||||
# Create repository instance
|
||||
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
|
||||
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_run(self):
|
||||
"""Create a sample WorkflowRun model."""
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.id = "workflow-run-123"
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.app_id = "app-123"
|
||||
workflow_run.workflow_id = "workflow-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
return workflow_run
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause(self):
|
||||
"""Create a sample WorkflowPauseModel."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
scalar_result = Mock()
|
||||
scalar_result.all.return_value = []
|
||||
mock_session.scalars.return_value = scalar_result
|
||||
|
||||
repository.get_runs_batch_by_time_range(
|
||||
start_from=None,
|
||||
end_before=datetime(2024, 1, 1),
|
||||
last_seen=None,
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
stmt = mock_session.scalars.call_args[0][0]
|
||||
compiled_sql = str(
|
||||
stmt.compile(
|
||||
dialect=postgresql.dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "workflow_runs.status" in compiled_sql
|
||||
for status in (
|
||||
WorkflowExecutionStatus.SUCCEEDED,
|
||||
WorkflowExecutionStatus.FAILED,
|
||||
WorkflowExecutionStatus.STOPPED,
|
||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
):
|
||||
assert f"'{status.value}'" in compiled_sql
|
||||
|
||||
assert "'running'" not in compiled_sql
|
||||
assert "'paused'" not in compiled_sql
|
||||
|
||||
|
||||
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test create_workflow_pause method."""
|
||||
|
||||
def test_create_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test successful workflow pause creation."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
state_owner_user_id = "user-123"
|
||||
state = '{"test": "state"}'
|
||||
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
|
||||
mock_uuidv7.side_effect = ["pause-123"]
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
result = repository.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=state_owner_user_id,
|
||||
state=state,
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
assert result.workflow_execution_id == workflow_run_id
|
||||
assert result.get_pause_reasons() == []
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
|
||||
mock_storage.save.assert_called_once()
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_create_workflow_pause_not_found(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow run not found."""
|
||||
# Arrange
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
|
||||
|
||||
def test_create_workflow_pause_invalid_status(
|
||||
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
state='{"test": "state"}',
|
||||
pause_reasons=[],
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
|
||||
node_ids_result = Mock()
|
||||
node_ids_result.all.return_value = []
|
||||
pause_ids_result = Mock()
|
||||
pause_ids_result.all.return_value = []
|
||||
mock_session.scalars.side_effect = [node_ids_result, pause_ids_result]
|
||||
|
||||
# app_logs delete, runs delete
|
||||
mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)]
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.delete_by_run_ids.return_value = 3
|
||||
|
||||
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
|
||||
counts = repository.delete_runs_with_related(
|
||||
[run],
|
||||
delete_node_executions=lambda session, runs: (2, 1),
|
||||
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["runs"] == 1
|
||||
|
||||
|
||||
class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
|
||||
pause_ids_result = Mock()
|
||||
pause_ids_result.all.return_value = ["pause-1", "pause-2"]
|
||||
mock_session.scalars.return_value = pause_ids_result
|
||||
mock_session.scalar.side_effect = [5, 2]
|
||||
|
||||
fake_trigger_repo = Mock()
|
||||
fake_trigger_repo.count_by_run_ids.return_value = 3
|
||||
|
||||
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
|
||||
counts = repository.count_runs_with_related(
|
||||
[run],
|
||||
count_node_executions=lambda session, runs: (2, 1),
|
||||
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
|
||||
)
|
||||
|
||||
fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"])
|
||||
assert counts["node_executions"] == 2
|
||||
assert counts["offloads"] == 1
|
||||
assert counts["trigger_logs"] == 3
|
||||
assert counts["app_logs"] == 5
|
||||
assert counts["pauses"] == 2
|
||||
assert counts["pause_reasons"] == 2
|
||||
assert counts["runs"] == 1
|
||||
|
||||
|
||||
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test resume_workflow_pause method."""
|
||||
|
||||
def test_resume_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause resume."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
# Setup workflow run and pause
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
|
||||
# Act
|
||||
result = repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, _PrivateWorkflowPauseEntity)
|
||||
assert result.id == "pause-123"
|
||||
|
||||
# Verify state transitions
|
||||
assert sample_workflow_pause.resumed_at is not None
|
||||
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
|
||||
|
||||
# Verify database interactions
|
||||
mock_session.add.assert_called()
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_resume_workflow_pause_not_paused(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
):
|
||||
"""Test resume when workflow is not paused."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
def test_resume_workflow_pause_id_mismatch(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_run: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test resume when pause ID doesn't match."""
|
||||
# Arrange
|
||||
workflow_run_id = "workflow-run-123"
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-456" # Different ID
|
||||
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_pause.id = "pause-123"
|
||||
sample_workflow_run.pause = sample_workflow_pause
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
|
||||
repository.resume_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
pause_entity=pause_entity,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
"""Test delete_workflow_pause method."""
|
||||
|
||||
def test_delete_workflow_pause_success(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
sample_workflow_pause: Mock,
|
||||
):
|
||||
"""Test successful workflow pause deletion."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = sample_workflow_pause
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
# Act
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
# Assert
|
||||
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
mock_session.delete.assert_called_once_with(sample_workflow_pause)
|
||||
# When using session.begin() context manager, commit is handled automatically
|
||||
# No explicit commit call is expected
|
||||
|
||||
def test_delete_workflow_pause_not_found(
|
||||
self,
|
||||
repository: DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
mock_session: Mock,
|
||||
):
|
||||
"""Test delete when pause not found."""
|
||||
# Arrange
|
||||
pause_entity = Mock(spec=WorkflowPauseEntity)
|
||||
pause_entity.id = "pause-123"
|
||||
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
|
||||
repository.delete_workflow_pause(pause_entity=pause_entity)
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
@pytest.fixture
|
||||
def sample_workflow_pause() -> Mock:
|
||||
"""Create a sample WorkflowPause model."""
|
||||
pause = Mock(spec=WorkflowPauseModel)
|
||||
pause.id = "pause-123"
|
||||
pause.workflow_id = "workflow-123"
|
||||
pause.workflow_run_id = "workflow-run-123"
|
||||
pause.state_object_key = "workflow-state-123.json"
|
||||
pause.resumed_at = None
|
||||
pause.created_at = datetime.now(UTC)
|
||||
return pause
|
||||
|
||||
|
||||
class TestPrivateWorkflowPauseEntity:
|
||||
"""Test _PrivateWorkflowPauseEntity class."""
|
||||
|
||||
def test_properties(self, sample_workflow_pause: Mock):
|
||||
def test_properties(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test entity properties."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
|
||||
# Act & Assert
|
||||
# Assert
|
||||
assert entity.id == sample_workflow_pause.id
|
||||
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
|
||||
assert entity.resumed_at == sample_workflow_pause.resumed_at
|
||||
|
||||
def test_get_state(self, sample_workflow_pause: Mock):
|
||||
def test_get_state(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test getting state from storage."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
@@ -445,7 +60,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
assert result == expected_state
|
||||
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
|
||||
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock):
|
||||
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
|
||||
"""Test state caching in get_state method."""
|
||||
# Arrange
|
||||
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
|
||||
@@ -456,16 +71,20 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
|
||||
# Act
|
||||
result1 = entity.get_state()
|
||||
result2 = entity.get_state() # Should use cache
|
||||
result2 = entity.get_state()
|
||||
|
||||
# Assert
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
mock_storage.load.assert_called_once()
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
def test_prefers_backstage_token_when_available(self):
|
||||
"""Test helper that builds HumanInputRequired pause reasons."""
|
||||
|
||||
def test_prefers_backstage_token_when_available(self) -> None:
|
||||
"""Use backstage token when multiple recipient types may exist."""
|
||||
# Arrange
|
||||
expiration_time = datetime.now(UTC)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
@@ -504,8 +123,10 @@ class TestBuildHumanInputRequiredReason:
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Act
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
|
||||
|
||||
# Assert
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
|
||||
Reference in New Issue
Block a user