diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..05a868c0c2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,506 @@ +"""Integration tests for DifyAPISQLAlchemyWorkflowRunRepository using testcontainers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from unittest.mock import Mock +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.entities import WorkflowExecution +from core.workflow.entities.pause_reason import PauseReasonType +from core.workflow.enums import WorkflowExecutionStatus +from extensions.ext_storage import storage +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun +from repositories.entities.workflow_pause import WorkflowPauseEntity +from repositories.sqlalchemy_api_workflow_run_repository import ( + DifyAPISQLAlchemyWorkflowRunRepository, + _WorkflowRunError, +) + + +class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + """Concrete repository for tests where save() is not under test.""" + + def save(self, execution: WorkflowExecution) -> None: + return None + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows and storage keys.""" + + tenant_id: str = field(default_factory=lambda: str(uuid4())) + app_id: str = field(default_factory=lambda: str(uuid4())) + workflow_id: str = field(default_factory=lambda: str(uuid4())) + user_id: str = field(default_factory=lambda: str(uuid4())) + state_keys: set[str] = field(default_factory=set) + + +def _create_workflow_run( + session: Session, + scope: _TestScope, + *, + status: WorkflowExecutionStatus, + created_at: datetime | None = None, +) -> WorkflowRun: + """Create and persist a workflow run bound to the current test scope.""" + + workflow_run = WorkflowRun( + id=str(uuid4()), + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_id=scope.workflow_id, + type="workflow", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + version="draft", + graph="{}", + inputs="{}", + status=status, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=scope.user_id, + created_at=created_at or naive_utc_now(), + ) + session.add(workflow_run) + session.commit() + return workflow_run + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows and storage objects for a test scope.""" + + pause_ids_subquery = select(WorkflowPause.id).where(WorkflowPause.workflow_id == scope.workflow_id) + session.execute(delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids_subquery))) + session.execute(delete(WorkflowPause).where(WorkflowPause.workflow_id == scope.workflow_id)) + session.execute( + delete(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == scope.tenant_id, + WorkflowAppLog.app_id == scope.app_id, + ) + ) + session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == scope.tenant_id, + WorkflowRun.app_id == scope.app_id, + ) + ) + session.commit() + + for state_key in scope.state_keys: + try: + storage.delete(state_key) + except FileNotFoundError: + continue + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository: + """Build a repository backed by the testcontainers database engine.""" + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> _TestScope: + """Provide an isolated scope and clean related data after each test.""" + + scope = _TestScope() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetRunsBatchByTimeRange: + """Integration tests for get_runs_batch_by_time_range.""" + + def test_get_runs_batch_by_time_range_filters_terminal_statuses( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return only terminal workflow runs, excluding RUNNING and PAUSED.""" + + now = naive_utc_now() + ended_statuses = [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.STOPPED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + ] + ended_run_ids = { + _create_workflow_run( + db_session_with_containers, + test_scope, + status=status, + created_at=now - timedelta(minutes=3), + ).id + for status in ended_statuses + } + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + created_at=now - timedelta(minutes=2), + ) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.PAUSED, + created_at=now - timedelta(minutes=1), + ) + + runs = repository.get_runs_batch_by_time_range( + start_from=now - timedelta(days=1), + end_before=now + timedelta(days=1), + last_seen=None, + batch_size=50, + tenant_ids=[test_scope.tenant_id], + ) + + returned_ids = {run.id for run in runs} + returned_statuses = {run.status for run in runs} + + assert returned_ids == ended_run_ids + assert returned_statuses == set(ended_statuses) + + +class TestDeleteRunsWithRelated: + """Integration tests for delete_runs_with_related.""" + + def test_uses_trigger_log_repository( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Delete run-related records and invoke injected trigger-log deleter.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + ) + app_log = WorkflowAppLog( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=test_scope.user_id, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + pause_reason = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.SCHEDULED_PAUSE, + message="scheduled pause", + ) + db_session_with_containers.add_all([app_log, pause, pause_reason]) + db_session_with_containers.commit() + + fake_trigger_repo = Mock() + fake_trigger_repo.delete_by_run_ids.return_value = 3 + + counts = repository.delete_runs_with_related( + [workflow_run], + delete_node_executions=lambda session, runs: (2, 1), + delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), + ) + + fake_trigger_repo.delete_by_run_ids.assert_called_once_with([workflow_run.id]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["app_logs"] == 1 + assert counts["pauses"] == 1 + assert counts["pause_reasons"] == 1 + assert counts["runs"] == 1 + with Session(bind=db_session_with_containers.get_bind()) as verification_session: + assert verification_session.get(WorkflowRun, workflow_run.id) is None + + +class TestCountRunsWithRelated: + """Integration tests for count_runs_with_related.""" + + def test_uses_trigger_log_repository( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count run-related records and invoke injected trigger-log counter.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + ) + app_log = WorkflowAppLog( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + created_from="service-api", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=test_scope.user_id, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + pause_reason = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.SCHEDULED_PAUSE, + message="scheduled pause", + ) + db_session_with_containers.add_all([app_log, pause, pause_reason]) + db_session_with_containers.commit() + + fake_trigger_repo = Mock() + fake_trigger_repo.count_by_run_ids.return_value = 3 + + counts = repository.count_runs_with_related( + [workflow_run], + count_node_executions=lambda session, runs: (2, 1), + count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids), + ) + + fake_trigger_repo.count_by_run_ids.assert_called_once_with([workflow_run.id]) + assert counts["node_executions"] == 2 + assert counts["offloads"] == 1 + assert counts["trigger_logs"] == 3 + assert counts["app_logs"] == 1 + assert counts["pauses"] == 1 + assert counts["pause_reasons"] == 1 + assert counts["runs"] == 1 + + +class TestCreateWorkflowPause: + """Integration tests for create_workflow_pause.""" + + def test_create_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Create pause successfully, persist pause record, and set run status to PAUSED.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state = '{"test": "state"}' + + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=test_scope.user_id, + state=state, + pause_reasons=[], + ) + + pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id) + assert pause_model is not None + test_scope.state_keys.add(pause_model.state_object_key) + + db_session_with_containers.refresh(workflow_run) + assert workflow_run.status == WorkflowExecutionStatus.PAUSED + assert pause_entity.id == pause_model.id + assert pause_entity.workflow_execution_id == workflow_run.id + assert pause_entity.get_pause_reasons() == [] + assert pause_entity.get_state() == state.encode() + + def test_create_workflow_pause_not_found( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + test_scope: _TestScope, + ) -> None: + """Raise ValueError when the workflow run does not exist.""" + + with pytest.raises(ValueError, match="WorkflowRun not found"): + repository.create_workflow_pause( + workflow_run_id=str(uuid4()), + state_owner_user_id=test_scope.user_id, + state='{"test": "state"}', + pause_reasons=[], + ) + + def test_create_workflow_pause_invalid_status( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Raise _WorkflowRunError when pausing a run in non-pausable status.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + ) + + with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): + repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=test_scope.user_id, + state='{"test": "state"}', + pause_reasons=[], + ) + + +class TestResumeWorkflowPause: + """Integration tests for resume_workflow_pause.""" + + def test_resume_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Resume pause successfully and switch workflow run status back to RUNNING.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=test_scope.user_id, + state='{"test": "state"}', + pause_reasons=[], + ) + + pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id) + assert pause_model is not None + test_scope.state_keys.add(pause_model.state_object_key) + + resumed_entity = repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + + db_session_with_containers.refresh(workflow_run) + db_session_with_containers.refresh(pause_model) + assert resumed_entity.id == pause_entity.id + assert resumed_entity.resumed_at is not None + assert workflow_run.status == WorkflowExecutionStatus.RUNNING + assert pause_model.resumed_at is not None + + def test_resume_workflow_pause_not_paused( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Raise _WorkflowRunError when workflow run is not in PAUSED status.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = str(uuid4()) + + with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"): + repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=pause_entity, + ) + + def test_resume_workflow_pause_id_mismatch( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Raise _WorkflowRunError when pause entity ID mismatches persisted pause ID.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=test_scope.user_id, + state='{"test": "state"}', + pause_reasons=[], + ) + + pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id) + assert pause_model is not None + test_scope.state_keys.add(pause_model.state_object_key) + + mismatched_pause_entity = Mock(spec=WorkflowPauseEntity) + mismatched_pause_entity.id = str(uuid4()) + + with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"): + repository.resume_workflow_pause( + workflow_run_id=workflow_run.id, + pause_entity=mismatched_pause_entity, + ) + + +class TestDeleteWorkflowPause: + """Integration tests for delete_workflow_pause.""" + + def test_delete_workflow_pause_success( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Delete pause record and its state object from storage.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause_entity = repository.create_workflow_pause( + workflow_run_id=workflow_run.id, + state_owner_user_id=test_scope.user_id, + state='{"test": "state"}', + pause_reasons=[], + ) + pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id) + assert pause_model is not None + state_key = pause_model.state_object_key + test_scope.state_keys.add(state_key) + + repository.delete_workflow_pause(pause_entity=pause_entity) + + with Session(bind=db_session_with_containers.get_bind()) as verification_session: + assert verification_session.get(WorkflowPause, pause_entity.id) is None + with pytest.raises(FileNotFoundError): + storage.load(state_key) + + def test_delete_workflow_pause_not_found( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + ) -> None: + """Raise _WorkflowRunError when deleting a non-existent pause.""" + + pause_entity = Mock(spec=WorkflowPauseEntity) + pause_entity.id = str(uuid4()) + + with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"): + repository.delete_workflow_pause(pause_entity=pause_entity) diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 4caaa056ff..4b5b3b318c 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -1,435 +1,50 @@ -"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation.""" +"""Unit tests for non-SQL helper logic in workflow run repository.""" import secrets from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Session, sessionmaker from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType -from core.workflow.enums import WorkflowExecutionStatus from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason, WorkflowRun -from repositories.entities.workflow_pause import WorkflowPauseEntity +from models.workflow import WorkflowPauseReason from repositories.sqlalchemy_api_workflow_run_repository import ( - DifyAPISQLAlchemyWorkflowRunRepository, _build_human_input_required_reason, _PrivateWorkflowPauseEntity, - _WorkflowRunError, ) -class TestDifyAPISQLAlchemyWorkflowRunRepository: - """Test DifyAPISQLAlchemyWorkflowRunRepository implementation.""" - - @pytest.fixture - def mock_session(self): - """Create a mock session.""" - return Mock(spec=Session) - - @pytest.fixture - def mock_session_maker(self, mock_session): - """Create a mock sessionmaker.""" - session_maker = Mock(spec=sessionmaker) - - # Create a context manager mock - context_manager = Mock() - context_manager.__enter__ = Mock(return_value=mock_session) - context_manager.__exit__ = Mock(return_value=None) - session_maker.return_value = context_manager - - # Mock session.begin() context manager - begin_context_manager = Mock() - begin_context_manager.__enter__ = Mock(return_value=None) - begin_context_manager.__exit__ = Mock(return_value=None) - mock_session.begin = Mock(return_value=begin_context_manager) - - # Add missing session methods - mock_session.commit = Mock() - mock_session.rollback = Mock() - mock_session.add = Mock() - mock_session.delete = Mock() - mock_session.get = Mock() - mock_session.scalar = Mock() - mock_session.scalars = Mock() - - # Also support expire_on_commit parameter - def make_session(expire_on_commit=None): - cm = Mock() - cm.__enter__ = Mock(return_value=mock_session) - cm.__exit__ = Mock(return_value=None) - return cm - - session_maker.side_effect = make_session - return session_maker - - @pytest.fixture - def repository(self, mock_session_maker): - """Create repository instance with mocked dependencies.""" - - # Create a testable subclass that implements the save method - class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): - def __init__(self, session_maker): - # Initialize without calling parent __init__ to avoid any instantiation issues - self._session_maker = session_maker - - def save(self, execution): - """Mock implementation of save method.""" - return None - - # Create repository instance - repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) - - return repo - - @pytest.fixture - def sample_workflow_run(self): - """Create a sample WorkflowRun model.""" - workflow_run = Mock(spec=WorkflowRun) - workflow_run.id = "workflow-run-123" - workflow_run.tenant_id = "tenant-123" - workflow_run.app_id = "app-123" - workflow_run.workflow_id = "workflow-123" - workflow_run.status = WorkflowExecutionStatus.RUNNING - return workflow_run - - @pytest.fixture - def sample_workflow_pause(self): - """Create a sample WorkflowPauseModel.""" - pause = Mock(spec=WorkflowPauseModel) - pause.id = "pause-123" - pause.workflow_id = "workflow-123" - pause.workflow_run_id = "workflow-run-123" - pause.state_object_key = "workflow-state-123.json" - pause.resumed_at = None - pause.created_at = datetime.now(UTC) - return pause - - -class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository): - def test_get_runs_batch_by_time_range_filters_terminal_statuses( - self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock - ): - scalar_result = Mock() - scalar_result.all.return_value = [] - mock_session.scalars.return_value = scalar_result - - repository.get_runs_batch_by_time_range( - start_from=None, - end_before=datetime(2024, 1, 1), - last_seen=None, - batch_size=50, - ) - - stmt = mock_session.scalars.call_args[0][0] - compiled_sql = str( - stmt.compile( - dialect=postgresql.dialect(), - compile_kwargs={"literal_binds": True}, - ) - ) - - assert "workflow_runs.status" in compiled_sql - for status in ( - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.STOPPED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - ): - assert f"'{status.value}'" in compiled_sql - - assert "'running'" not in compiled_sql - assert "'paused'" not in compiled_sql - - -class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): - """Test create_workflow_pause method.""" - - def test_create_workflow_pause_success( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - sample_workflow_run: Mock, - ): - """Test successful workflow pause creation.""" - # Arrange - workflow_run_id = "workflow-run-123" - state_owner_user_id = "user-123" - state = '{"test": "state"}' - - mock_session.get.return_value = sample_workflow_run - - with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7: - mock_uuidv7.side_effect = ["pause-123"] - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - # Act - result = repository.create_workflow_pause( - workflow_run_id=workflow_run_id, - state_owner_user_id=state_owner_user_id, - state=state, - pause_reasons=[], - ) - - # Assert - assert isinstance(result, _PrivateWorkflowPauseEntity) - assert result.id == "pause-123" - assert result.workflow_execution_id == workflow_run_id - assert result.get_pause_reasons() == [] - - # Verify database interactions - mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id) - mock_storage.save.assert_called_once() - mock_session.add.assert_called() - # When using session.begin() context manager, commit is handled automatically - # No explicit commit call is expected - - def test_create_workflow_pause_not_found( - self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock - ): - """Test workflow pause creation when workflow run not found.""" - # Arrange - mock_session.get.return_value = None - - # Act & Assert - with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"): - repository.create_workflow_pause( - workflow_run_id="workflow-run-123", - state_owner_user_id="user-123", - state='{"test": "state"}', - pause_reasons=[], - ) - - mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123") - - def test_create_workflow_pause_invalid_status( - self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock - ): - """Test workflow pause creation when workflow not in RUNNING status.""" - # Arrange - sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED - mock_session.get.return_value = sample_workflow_run - - # Act & Assert - with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"): - repository.create_workflow_pause( - workflow_run_id="workflow-run-123", - state_owner_user_id="user-123", - state='{"test": "state"}', - pause_reasons=[], - ) - - -class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): - def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): - node_ids_result = Mock() - node_ids_result.all.return_value = [] - pause_ids_result = Mock() - pause_ids_result.all.return_value = [] - mock_session.scalars.side_effect = [node_ids_result, pause_ids_result] - - # app_logs delete, runs delete - mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)] - - fake_trigger_repo = Mock() - fake_trigger_repo.delete_by_run_ids.return_value = 3 - - run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") - counts = repository.delete_runs_with_related( - [run], - delete_node_executions=lambda session, runs: (2, 1), - delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids), - ) - - fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"]) - assert counts["node_executions"] == 2 - assert counts["offloads"] == 1 - assert counts["trigger_logs"] == 3 - assert counts["runs"] == 1 - - -class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository): - def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock): - pause_ids_result = Mock() - pause_ids_result.all.return_value = ["pause-1", "pause-2"] - mock_session.scalars.return_value = pause_ids_result - mock_session.scalar.side_effect = [5, 2] - - fake_trigger_repo = Mock() - fake_trigger_repo.count_by_run_ids.return_value = 3 - - run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf") - counts = repository.count_runs_with_related( - [run], - count_node_executions=lambda session, runs: (2, 1), - count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids), - ) - - fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"]) - assert counts["node_executions"] == 2 - assert counts["offloads"] == 1 - assert counts["trigger_logs"] == 3 - assert counts["app_logs"] == 5 - assert counts["pauses"] == 2 - assert counts["pause_reasons"] == 2 - assert counts["runs"] == 1 - - -class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): - """Test resume_workflow_pause method.""" - - def test_resume_workflow_pause_success( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - sample_workflow_run: Mock, - sample_workflow_pause: Mock, - ): - """Test successful workflow pause resume.""" - # Arrange - workflow_run_id = "workflow-run-123" - pause_entity = Mock(spec=WorkflowPauseEntity) - pause_entity.id = "pause-123" - - # Setup workflow run and pause - sample_workflow_run.status = WorkflowExecutionStatus.PAUSED - sample_workflow_run.pause = sample_workflow_pause - sample_workflow_pause.resumed_at = None - - mock_session.scalar.return_value = sample_workflow_run - mock_session.scalars.return_value.all.return_value = [] - - with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now: - mock_now.return_value = datetime.now(UTC) - - # Act - result = repository.resume_workflow_pause( - workflow_run_id=workflow_run_id, - pause_entity=pause_entity, - ) - - # Assert - assert isinstance(result, _PrivateWorkflowPauseEntity) - assert result.id == "pause-123" - - # Verify state transitions - assert sample_workflow_pause.resumed_at is not None - assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING - - # Verify database interactions - mock_session.add.assert_called() - # When using session.begin() context manager, commit is handled automatically - # No explicit commit call is expected - - def test_resume_workflow_pause_not_paused( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - sample_workflow_run: Mock, - ): - """Test resume when workflow is not paused.""" - # Arrange - workflow_run_id = "workflow-run-123" - pause_entity = Mock(spec=WorkflowPauseEntity) - pause_entity.id = "pause-123" - - sample_workflow_run.status = WorkflowExecutionStatus.RUNNING - mock_session.scalar.return_value = sample_workflow_run - - # Act & Assert - with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"): - repository.resume_workflow_pause( - workflow_run_id=workflow_run_id, - pause_entity=pause_entity, - ) - - def test_resume_workflow_pause_id_mismatch( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - sample_workflow_run: Mock, - sample_workflow_pause: Mock, - ): - """Test resume when pause ID doesn't match.""" - # Arrange - workflow_run_id = "workflow-run-123" - pause_entity = Mock(spec=WorkflowPauseEntity) - pause_entity.id = "pause-456" # Different ID - - sample_workflow_run.status = WorkflowExecutionStatus.PAUSED - sample_workflow_pause.id = "pause-123" - sample_workflow_run.pause = sample_workflow_pause - mock_session.scalar.return_value = sample_workflow_run - - # Act & Assert - with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"): - repository.resume_workflow_pause( - workflow_run_id=workflow_run_id, - pause_entity=pause_entity, - ) - - -class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository): - """Test delete_workflow_pause method.""" - - def test_delete_workflow_pause_success( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - sample_workflow_pause: Mock, - ): - """Test successful workflow pause deletion.""" - # Arrange - pause_entity = Mock(spec=WorkflowPauseEntity) - pause_entity.id = "pause-123" - - mock_session.get.return_value = sample_workflow_pause - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - # Act - repository.delete_workflow_pause(pause_entity=pause_entity) - - # Assert - mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key) - mock_session.delete.assert_called_once_with(sample_workflow_pause) - # When using session.begin() context manager, commit is handled automatically - # No explicit commit call is expected - - def test_delete_workflow_pause_not_found( - self, - repository: DifyAPISQLAlchemyWorkflowRunRepository, - mock_session: Mock, - ): - """Test delete when pause not found.""" - # Arrange - pause_entity = Mock(spec=WorkflowPauseEntity) - pause_entity.id = "pause-123" - - mock_session.get.return_value = None - - # Act & Assert - with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"): - repository.delete_workflow_pause(pause_entity=pause_entity) - - -class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository): +@pytest.fixture +def sample_workflow_pause() -> Mock: + """Create a sample WorkflowPause model.""" + pause = Mock(spec=WorkflowPauseModel) + pause.id = "pause-123" + pause.workflow_id = "workflow-123" + pause.workflow_run_id = "workflow-run-123" + pause.state_object_key = "workflow-state-123.json" + pause.resumed_at = None + pause.created_at = datetime.now(UTC) + return pause + + +class TestPrivateWorkflowPauseEntity: """Test _PrivateWorkflowPauseEntity class.""" - def test_properties(self, sample_workflow_pause: Mock): + def test_properties(self, sample_workflow_pause: Mock) -> None: """Test entity properties.""" # Arrange entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - # Act & Assert + # Assert assert entity.id == sample_workflow_pause.id assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id assert entity.resumed_at == sample_workflow_pause.resumed_at - def test_get_state(self, sample_workflow_pause: Mock): + def test_get_state(self, sample_workflow_pause: Mock) -> None: """Test getting state from storage.""" # Arrange entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) @@ -445,7 +60,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) assert result == expected_state mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key) - def test_get_state_caching(self, sample_workflow_pause: Mock): + def test_get_state_caching(self, sample_workflow_pause: Mock) -> None: """Test state caching in get_state method.""" # Arrange entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) @@ -456,16 +71,20 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository) # Act result1 = entity.get_state() - result2 = entity.get_state() # Should use cache + result2 = entity.get_state() # Assert assert result1 == expected_state assert result2 == expected_state - mock_storage.load.assert_called_once() # Only called once due to caching + mock_storage.load.assert_called_once() class TestBuildHumanInputRequiredReason: - def test_prefers_backstage_token_when_available(self): + """Test helper that builds HumanInputRequired pause reasons.""" + + def test_prefers_backstage_token_when_available(self) -> None: + """Use backstage token when multiple recipient types may exist.""" + # Arrange expiration_time = datetime.now(UTC) form_definition = FormDefinition( form_content="content", @@ -504,8 +123,10 @@ class TestBuildHumanInputRequiredReason: access_token=access_token, ) + # Act reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) + # Assert assert isinstance(reason, HumanInputRequired) assert reason.form_token == access_token assert reason.node_title == "Ask Name"