Compare commits

...

1 Commits

Author SHA1 Message Date
木之本澪
5d927b413f test: migrate workflow_node_execution_service_repository SQL tests to testcontainers (#32591)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-26 03:42:08 +09:00
2 changed files with 436 additions and 240 deletions

View File

@@ -0,0 +1,436 @@
from datetime import datetime, timedelta
from uuid import uuid4
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@staticmethod
def _create_repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowNodeExecutionRepository:
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(
session_maker=sessionmaker(bind=engine, expire_on_commit=False)
)
@staticmethod
def _create_execution(
db_session_with_containers: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
node_id: str,
status: WorkflowNodeExecutionStatus,
index: int,
created_at: datetime,
) -> WorkflowNodeExecutionModel:
execution = WorkflowNodeExecutionModel(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run_id,
index=index,
predecessor_node_id=None,
node_execution_id=None,
node_id=node_id,
node_type="llm",
title=f"Node {index}",
inputs="{}",
process_data="{}",
outputs="{}",
status=status,
error=None,
elapsed_time=0.0,
execution_metadata="{}",
created_at=created_at,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
finished_at=None,
)
db_session_with_containers.add(execution)
db_session_with_containers.commit()
return execution
def test_get_node_last_execution_found(self, db_session_with_containers):
"""Test getting the last execution for a node when it exists."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
node_id = "node-202"
workflow_run_id = str(uuid4())
now = naive_utc_now()
self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=node_id,
status=WorkflowNodeExecutionStatus.PAUSED,
index=1,
created_at=now - timedelta(minutes=2),
)
expected = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=node_id,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_at=now - timedelta(minutes=1),
)
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.get_node_last_execution(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
node_id=node_id,
)
# Assert
assert result is not None
assert result.id == expected.id
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_get_node_last_execution_not_found(self, db_session_with_containers):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.get_node_last_execution(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
node_id="node-202",
)
# Assert
assert result is None
def test_get_executions_by_workflow_run_empty(self, db_session_with_containers):
"""Test getting executions for a workflow run when none exist."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_run_id = str(uuid4())
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.get_executions_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_run_id=workflow_run_id,
)
# Assert
assert result == []
def test_get_execution_by_id_found(self, db_session_with_containers):
"""Test getting execution by ID when it exists."""
# Arrange
execution = self._create_execution(
db_session_with_containers,
tenant_id=str(uuid4()),
app_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_run_id=str(uuid4()),
node_id="node-202",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=1,
created_at=naive_utc_now(),
)
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.get_execution_by_id(execution.id)
# Assert
assert result is not None
assert result.id == execution.id
def test_get_execution_by_id_not_found(self, db_session_with_containers):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
repository = self._create_repository(db_session_with_containers)
missing_execution_id = str(uuid4())
# Act
result = repository.get_execution_by_id(missing_execution_id)
# Assert
assert result is None
def test_delete_expired_executions(self, db_session_with_containers):
"""Test deleting expired executions."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
now = naive_utc_now()
before_date = now - timedelta(days=1)
old_execution_1 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-1",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=1,
created_at=now - timedelta(days=3),
)
old_execution_2 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-2",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_at=now - timedelta(days=2),
)
kept_execution = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-3",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=3,
created_at=now,
)
old_execution_1_id = old_execution_1.id
old_execution_2_id = old_execution_2.id
kept_execution_id = kept_execution.id
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.delete_expired_executions(
tenant_id=tenant_id,
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
remaining_ids = {
execution.id
for execution in db_session_with_containers.scalars(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
).all()
}
assert old_execution_1_id not in remaining_ids
assert old_execution_2_id not in remaining_ids
assert kept_execution_id in remaining_ids
def test_delete_executions_by_app(self, db_session_with_containers):
"""Test deleting executions by app."""
# Arrange
tenant_id = str(uuid4())
target_app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
created_at = naive_utc_now()
deleted_1 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=target_app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-1",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=1,
created_at=created_at,
)
deleted_2 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=target_app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-2",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_at=created_at,
)
kept = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=str(uuid4()),
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-3",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=3,
created_at=created_at,
)
deleted_1_id = deleted_1.id
deleted_2_id = deleted_2.id
kept_id = kept.id
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.delete_executions_by_app(
tenant_id=tenant_id,
app_id=target_app_id,
batch_size=1000,
)
# Assert
assert result == 2
remaining_ids = {
execution.id
for execution in db_session_with_containers.scalars(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
).all()
}
assert deleted_1_id not in remaining_ids
assert deleted_2_id not in remaining_ids
assert kept_id in remaining_ids
def test_get_expired_executions_batch(self, db_session_with_containers):
"""Test getting expired executions batch for backup."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
now = naive_utc_now()
before_date = now - timedelta(days=1)
old_execution_1 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-1",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=1,
created_at=now - timedelta(days=3),
)
old_execution_2 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-2",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_at=now - timedelta(days=2),
)
self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-3",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=3,
created_at=now,
)
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.get_expired_executions_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
result_ids = {execution.id for execution in result}
assert old_execution_1.id in result_ids
assert old_execution_2.id in result_ids
def test_delete_executions_by_ids(self, db_session_with_containers):
"""Test deleting executions by IDs."""
# Arrange
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
created_at = naive_utc_now()
execution_1 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-1",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=1,
created_at=created_at,
)
execution_2 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-2",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_at=created_at,
)
execution_3 = self._create_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id="node-3",
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=3,
created_at=created_at,
)
repository = self._create_repository(db_session_with_containers)
execution_ids = [execution_1.id, execution_2.id, execution_3.id]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
remaining = db_session_with_containers.scalars(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
).all()
assert remaining == []
def test_delete_executions_by_ids_empty_list(self, db_session_with_containers):
"""Test deleting executions with empty ID list."""
# Arrange
repository = self._create_repository(db_session_with_containers)
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0

View File

@@ -1,12 +1,7 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
@@ -18,109 +13,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
@@ -136,135 +28,3 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()