Compare commits

..

22 Commits

Author SHA1 Message Date
Stephen Zhou
56265a5217 fix: Instrument_Serif 2026-02-25 17:44:14 +08:00
Stephen Zhou
084eeac776 fix: constants shim 2026-02-25 17:38:27 +08:00
Stephen Zhou
3e92f85beb fix: image size plugin 2026-02-25 17:25:52 +08:00
Stephen Zhou
7bd987c6f1 use pkg new pr 2026-02-25 13:01:12 +08:00
Stephen Zhou
af80c10ed3 fix load module 2026-02-25 12:46:41 +08:00
Stephen Zhou
af6218d4b5 update 2026-02-25 11:33:27 +08:00
Stephen Zhou
84fda207a6 update 2026-02-25 11:32:20 +08:00
Stephen Zhou
eb81f0563d jiti v1 2026-02-25 11:31:23 +08:00
Stephen Zhou
cd22550454 load env 2026-02-25 11:28:18 +08:00
Stephen Zhou
1bbc2f147d env 2026-02-25 11:24:32 +08:00
Stephen Zhou
b787c0af7f import ts in next.config 2026-02-25 11:12:35 +08:00
Stephen Zhou
56c8bef073 chore: add vinext as dev server 2026-02-25 11:00:29 +08:00
akashseth-ifp
4e142f72e8 test(base): add test coverage for more base/form components (#32437)
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
2026-02-25 10:47:25 +08:00
木之本澪
a6456da393 test: migrate delete_archived_workflow_run SQL tests to testcontainers (#32549)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:18:52 +09:00
木之本澪
b863f8edbd test: migrate test_document_service_display_status SQL tests to testcontainers (#32545)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:13:22 +09:00
木之本澪
64296da7e7 test: migrate remove_app_and_related_data_task SQL tests to testcontainers (#32547)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:12:23 +09:00
木之本澪
02fef84d7f test: migrate node execution repository sql tests to testcontainers (#32524)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:01:26 +09:00
木之本澪
28f2098b00 test: migrate workflow trigger log repository sql tests to testcontainers (#32525)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:53:16 +09:00
木之本澪
59681ce760 test: migrate message extra contents tests to testcontainers (#32532)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:51:14 +09:00
木之本澪
4997b82a63 test: migrate end user service SQL tests to testcontainers (#32530)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:49:49 +09:00
木之本澪
3abfbc0246 test: migrate remaining DocumentSegment navigation SQL tests to testcontainers (#32523)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 02:51:38 +09:00
木之本澪
beea1acd92 test: migrate workflow run repository SQL tests to testcontainers (#32519)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 01:36:39 +09:00
49 changed files with 3886 additions and 1669 deletions

View File

@@ -269,3 +269,221 @@ class TestDatasetDocumentProperties:
db_session_with_containers.flush()
assert doc.hit_count == 25
class TestDocumentSegmentNavigationProperties:
"""Integration tests for DocumentSegment navigation properties."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def test_document_segment_dataset_property(self, db_session_with_containers: Session) -> None:
"""Test segment can access its parent dataset."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add(segment)
db_session_with_containers.flush()
# Act
related_dataset = segment.dataset
# Assert
assert related_dataset is not None
assert related_dataset.id == dataset.id
def test_document_segment_document_property(self, db_session_with_containers: Session) -> None:
"""Test segment can access its parent document."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add(segment)
db_session_with_containers.flush()
# Act
related_document = segment.document
# Assert
assert related_document is not None
assert related_document.id == document.id
def test_document_segment_previous_segment(self, db_session_with_containers: Session) -> None:
"""Test segment can access previous segment."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
previous_segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Previous",
word_count=1,
tokens=2,
created_by=created_by,
)
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=2,
content="Current",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add_all([previous_segment, segment])
db_session_with_containers.flush()
# Act
prev_seg = segment.previous_segment
# Assert
assert prev_seg is not None
assert prev_seg.position == 1
def test_document_segment_next_segment(self, db_session_with_containers: Session) -> None:
"""Test segment can access next segment."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Current",
word_count=1,
tokens=2,
created_by=created_by,
)
next_segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=2,
content="Next",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add_all([segment, next_segment])
db_session_with_containers.flush()
# Act
next_seg = segment.next_segment
# Assert
assert next_seg is not None
assert next_seg.position == 2

View File

@@ -0,0 +1,143 @@
"""Integration tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository using testcontainers."""
from __future__ import annotations
from datetime import timedelta
from uuid import uuid4
from sqlalchemy import Engine, delete
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
def _create_node_execution(
session: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
status: WorkflowNodeExecutionStatus,
index: int,
created_by: str,
created_at_offset_seconds: int,
) -> WorkflowNodeExecutionModel:
now = naive_utc_now()
node_execution = WorkflowNodeExecutionModel(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run_id,
index=index,
predecessor_node_id=None,
node_execution_id=None,
node_id=f"node-{index}",
node_type="llm",
title=f"Node {index}",
inputs="{}",
process_data="{}",
outputs="{}",
status=status,
error=None,
elapsed_time=0.0,
execution_metadata="{}",
created_at=now + timedelta(seconds=created_at_offset_seconds),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
finished_at=None,
)
session.add(node_execution)
session.flush()
return node_execution
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
def test_get_executions_by_workflow_run_keeps_paused_records(self, db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
created_by = str(uuid4())
other_tenant_id = str(uuid4())
other_app_id = str(uuid4())
included_paused = _create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.PAUSED,
index=1,
created_by=created_by,
created_at_offset_seconds=0,
)
included_succeeded = _create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_by=created_by,
created_at_offset_seconds=1,
)
_create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=str(uuid4()),
status=WorkflowNodeExecutionStatus.PAUSED,
index=3,
created_by=created_by,
created_at_offset_seconds=2,
)
_create_node_execution(
db_session_with_containers,
tenant_id=other_tenant_id,
app_id=other_app_id,
workflow_id=str(uuid4()),
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.PAUSED,
index=4,
created_by=str(uuid4()),
created_at_offset_seconds=3,
)
db_session_with_containers.commit()
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(sessionmaker(bind=engine, expire_on_commit=False))
try:
results = repository.get_executions_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_run_id=workflow_run_id,
)
assert len(results) == 2
assert [result.id for result in results] == [included_paused.id, included_succeeded.id]
assert any(result.status == WorkflowNodeExecutionStatus.PAUSED for result in results)
assert all(result.tenant_id == tenant_id for result in results)
assert all(result.app_id == app_id for result in results)
assert all(result.workflow_run_id == workflow_run_id for result in results)
finally:
db_session_with_containers.execute(
delete(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.tenant_id.in_([tenant_id, other_tenant_id])
)
)
db_session_with_containers.commit()

View File

@@ -0,0 +1,506 @@
"""Integration tests for DifyAPISQLAlchemyWorkflowRunRepository using testcontainers."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.entities.pause_reason import PauseReasonType
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_WorkflowRunError,
)
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Concrete repository for tests where save() is not under test."""
def save(self, execution: WorkflowExecution) -> None:
return None
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows and storage keys."""
tenant_id: str = field(default_factory=lambda: str(uuid4()))
app_id: str = field(default_factory=lambda: str(uuid4()))
workflow_id: str = field(default_factory=lambda: str(uuid4()))
user_id: str = field(default_factory=lambda: str(uuid4()))
state_keys: set[str] = field(default_factory=set)
def _create_workflow_run(
session: Session,
scope: _TestScope,
*,
status: WorkflowExecutionStatus,
created_at: datetime | None = None,
) -> WorkflowRun:
"""Create and persist a workflow run bound to the current test scope."""
workflow_run = WorkflowRun(
id=str(uuid4()),
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_id=scope.workflow_id,
type="workflow",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="draft",
graph="{}",
inputs="{}",
status=status,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=scope.user_id,
created_at=created_at or naive_utc_now(),
)
session.add(workflow_run)
session.commit()
return workflow_run
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows and storage objects for a test scope."""
pause_ids_subquery = select(WorkflowPause.id).where(WorkflowPause.workflow_id == scope.workflow_id)
session.execute(delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids_subquery)))
session.execute(delete(WorkflowPause).where(WorkflowPause.workflow_id == scope.workflow_id))
session.execute(
delete(WorkflowAppLog).where(
WorkflowAppLog.tenant_id == scope.tenant_id,
WorkflowAppLog.app_id == scope.app_id,
)
)
session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == scope.tenant_id,
WorkflowRun.app_id == scope.app_id,
)
)
session.commit()
for state_key in scope.state_keys:
try:
storage.delete(state_key)
except FileNotFoundError:
continue
@pytest.fixture
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> _TestScope:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetRunsBatchByTimeRange:
"""Integration tests for get_runs_batch_by_time_range."""
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return only terminal workflow runs, excluding RUNNING and PAUSED."""
now = naive_utc_now()
ended_statuses = [
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.STOPPED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
]
ended_run_ids = {
_create_workflow_run(
db_session_with_containers,
test_scope,
status=status,
created_at=now - timedelta(minutes=3),
).id
for status in ended_statuses
}
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
created_at=now - timedelta(minutes=2),
)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.PAUSED,
created_at=now - timedelta(minutes=1),
)
runs = repository.get_runs_batch_by_time_range(
start_from=now - timedelta(days=1),
end_before=now + timedelta(days=1),
last_seen=None,
batch_size=50,
tenant_ids=[test_scope.tenant_id],
)
returned_ids = {run.id for run in runs}
returned_statuses = {run.status for run in runs}
assert returned_ids == ended_run_ids
assert returned_statuses == set(ended_statuses)
class TestDeleteRunsWithRelated:
"""Integration tests for delete_runs_with_related."""
def test_uses_trigger_log_repository(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Delete run-related records and invoke injected trigger-log deleter."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
app_log = WorkflowAppLog(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
pause_reason = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.SCHEDULED_PAUSE,
message="scheduled pause",
)
db_session_with_containers.add_all([app_log, pause, pause_reason])
db_session_with_containers.commit()
fake_trigger_repo = Mock()
fake_trigger_repo.delete_by_run_ids.return_value = 3
counts = repository.delete_runs_with_related(
[workflow_run],
delete_node_executions=lambda session, runs: (2, 1),
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
)
fake_trigger_repo.delete_by_run_ids.assert_called_once_with([workflow_run.id])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 1
assert counts["pauses"] == 1
assert counts["pause_reasons"] == 1
assert counts["runs"] == 1
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
assert verification_session.get(WorkflowRun, workflow_run.id) is None
class TestCountRunsWithRelated:
"""Integration tests for count_runs_with_related."""
def test_uses_trigger_log_repository(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count run-related records and invoke injected trigger-log counter."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
app_log = WorkflowAppLog(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
pause_reason = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.SCHEDULED_PAUSE,
message="scheduled pause",
)
db_session_with_containers.add_all([app_log, pause, pause_reason])
db_session_with_containers.commit()
fake_trigger_repo = Mock()
fake_trigger_repo.count_by_run_ids.return_value = 3
counts = repository.count_runs_with_related(
[workflow_run],
count_node_executions=lambda session, runs: (2, 1),
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
)
fake_trigger_repo.count_by_run_ids.assert_called_once_with([workflow_run.id])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 1
assert counts["pauses"] == 1
assert counts["pause_reasons"] == 1
assert counts["runs"] == 1
class TestCreateWorkflowPause:
"""Integration tests for create_workflow_pause."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Create pause successfully, persist pause record, and set run status to PAUSED."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state = '{"test": "state"}'
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state=state,
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
db_session_with_containers.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
assert pause_entity.id == pause_model.id
assert pause_entity.workflow_execution_id == workflow_run.id
assert pause_entity.get_pause_reasons() == []
assert pause_entity.get_state() == state.encode()
def test_create_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
test_scope: _TestScope,
) -> None:
"""Raise ValueError when the workflow run does not exist."""
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.create_workflow_pause(
workflow_run_id=str(uuid4()),
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
def test_create_workflow_pause_invalid_status(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when pausing a run in non-pausable status."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
class TestResumeWorkflowPause:
"""Integration tests for resume_workflow_pause."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Resume pause successfully and switch workflow run status back to RUNNING."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
db_session_with_containers.refresh(workflow_run)
db_session_with_containers.refresh(pause_model)
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
assert pause_model.resumed_at is not None
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when workflow run is not in PAUSED status."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when pause entity ID mismatches persisted pause ID."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
mismatched_pause_entity = Mock(spec=WorkflowPauseEntity)
mismatched_pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=mismatched_pause_entity,
)
class TestDeleteWorkflowPause:
"""Integration tests for delete_workflow_pause."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Delete pause record and its state object from storage."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
state_key = pause_model.state_object_key
test_scope.state_keys.add(state_key)
repository.delete_workflow_pause(pause_entity=pause_entity)
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
assert verification_session.get(WorkflowPause, pause_entity.id) is None
with pytest.raises(FileNotFoundError):
storage.load(state_key)
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
) -> None:
"""Raise _WorkflowRunError when deleting a non-existent pause."""
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
repository.delete_workflow_pause(pause_entity=pause_entity)

View File

@@ -0,0 +1,134 @@
"""Integration tests for SQLAlchemyWorkflowTriggerLogRepository using testcontainers."""
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from models.enums import AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
from models.trigger import WorkflowTriggerLog
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
def _create_trigger_log(
session: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
created_by: str,
) -> WorkflowTriggerLog:
trigger_log = WorkflowTriggerLog(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
root_node_id=None,
trigger_metadata="{}",
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
trigger_data="{}",
inputs="{}",
outputs=None,
status=WorkflowTriggerStatus.SUCCEEDED,
error=None,
queue_name="default",
celery_task_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
retry_count=0,
)
session.add(trigger_log)
session.flush()
return trigger_log
def test_delete_by_run_ids_executes_delete(db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
created_by = str(uuid4())
run_id_1 = str(uuid4())
run_id_2 = str(uuid4())
untouched_run_id = str(uuid4())
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id_1,
created_by=created_by,
)
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id_2,
created_by=created_by,
)
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=untouched_run_id,
created_by=created_by,
)
db_session_with_containers.commit()
repository = SQLAlchemyWorkflowTriggerLogRepository(db_session_with_containers)
try:
deleted = repository.delete_by_run_ids([run_id_1, run_id_2])
db_session_with_containers.commit()
assert deleted == 2
remaining_logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id)
).all()
assert len(remaining_logs) == 1
assert remaining_logs[0].workflow_run_id == untouched_run_id
finally:
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id))
db_session_with_containers.commit()
def test_delete_by_run_ids_empty_short_circuits(db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
created_by = str(uuid4())
run_id = str(uuid4())
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id,
created_by=created_by,
)
db_session_with_containers.commit()
repository = SQLAlchemyWorkflowTriggerLogRepository(db_session_with_containers)
try:
deleted = repository.delete_by_run_ids([])
db_session_with_containers.commit()
assert deleted == 0
remaining_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowTriggerLog)
.where(WorkflowTriggerLog.tenant_id == tenant_id)
.where(WorkflowTriggerLog.workflow_run_id == run_id)
)
assert remaining_count == 1
finally:
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id))
db_session_with_containers.commit()

View File

@@ -0,0 +1,143 @@
"""
Testcontainers integration tests for archived workflow run deletion service.
"""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
from sqlalchemy import select
from core.workflow.enums import WorkflowExecutionStatus
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowArchiveLog, WorkflowRun
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
class TestArchivedWorkflowRunDeletion:
def _create_workflow_run(
self,
db_session_with_containers,
*,
tenant_id: str,
created_at: datetime,
) -> WorkflowRun:
run = WorkflowRun(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type="workflow",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="1.0.0",
graph="{}",
inputs="{}",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs="{}",
elapsed_time=0.1,
total_tokens=1,
total_steps=1,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=created_at,
finished_at=created_at,
exceptions_count=0,
)
db_session_with_containers.add(run)
db_session_with_containers.commit()
return run
def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None:
archive_log = WorkflowArchiveLog(
tenant_id=run.tenant_id,
app_id=run.app_id,
workflow_id=run.workflow_id,
workflow_run_id=run.id,
created_by_role=run.created_by_role,
created_by=run.created_by,
log_id=None,
log_created_at=None,
log_created_from=None,
run_version=run.version,
run_status=run.status,
run_triggered_from=run.triggered_from,
run_error=run.error,
run_elapsed_time=run.elapsed_time,
run_total_tokens=run.total_tokens,
run_total_steps=run.total_steps,
run_created_at=run.created_at,
run_finished_at=run.finished_at,
run_exceptions_count=run.exceptions_count,
trigger_metadata=None,
)
db_session_with_containers.add(archive_log)
db_session_with_containers.commit()
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers):
deleter = ArchivedWorkflowRunDeletion()
missing_run_id = str(uuid4())
result = deleter.delete_by_run_id(missing_run_id)
assert result.success is False
assert result.error == f"Workflow run {missing_run_id} not found"
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
deleter = ArchivedWorkflowRunDeletion()
result = deleter.delete_by_run_id(run.id)
assert result.success is False
assert result.error == f"Workflow run {run.id} is not archived"
def test_delete_batch_uses_repo(self, db_session_with_containers):
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time)
run2 = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=base_time + timedelta(seconds=1),
)
self._create_archive_log(db_session_with_containers, run=run1)
self._create_archive_log(db_session_with_containers, run=run2)
run_ids = [run1.id, run2.id]
deleter = ArchivedWorkflowRunDeletion()
results = deleter.delete_batch(
tenant_ids=[tenant_id],
start_date=base_time - timedelta(minutes=1),
end_date=base_time + timedelta(minutes=1),
limit=2,
)
assert len(results) == 2
assert all(result.success for result in results)
remaining_runs = db_session_with_containers.scalars(
select(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
).all()
assert remaining_runs == []
def test_delete_run_calls_repo(self, db_session_with_containers):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion()
result = deleter._delete_run(run)
assert result.success is True
assert result.deleted_counts["runs"] == 1
db_session_with_containers.expunge_all()
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None

View File

@@ -0,0 +1,143 @@
import datetime
from uuid import uuid4
from sqlalchemy import select
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService
def _create_dataset(db_session_with_containers) -> Dataset:
dataset = Dataset(
tenant_id=str(uuid4()),
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = str(uuid4())
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_document(
db_session_with_containers,
*,
dataset_id: str,
tenant_id: str,
indexing_status: str,
enabled: bool = True,
archived: bool = False,
is_paused: bool = False,
position: int = 1,
) -> Document:
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type="upload_file",
data_source_info="{}",
batch=f"batch-{uuid4()}",
name=f"doc-{uuid4()}",
created_from="web",
created_by=str(uuid4()),
doc_form="text_model",
)
document.id = str(uuid4())
document.indexing_status = indexing_status
document.enabled = enabled
document.archived = archived
document.is_paused = is_paused
if indexing_status == "completed":
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def test_build_display_status_filters_available(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
available_doc = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=True,
archived=False,
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=False,
archived=False,
position=2,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=True,
archived=True,
position=3,
)
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
rows = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id, *filters)).all()
assert [row.id for row in rows] == [available_doc.id]
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
waiting_doc = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
position=2,
)
query = select(Document).where(Document.dataset_id == dataset.id)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
rows = db_session_with_containers.scalars(filtered).all()
assert [row.id for row in rows] == [waiting_doc.id]
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
doc1 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
position=1,
)
doc2 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
position=2,
)
query = select(Document).where(Document.dataset_id == dataset.id)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
rows = db_session_with_containers.scalars(filtered).all()
assert {row.id for row in rows} == {doc1.id, doc2.id}

View File

@@ -0,0 +1,416 @@
from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App, DefaultEndUserSessionID, EndUser
from services.end_user_service import EndUserService
class TestEndUserServiceFactory:
"""Factory class for creating test data and mock objects for end user service tests."""
@staticmethod
def create_app_and_account(db_session_with_containers):
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"end_user_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.flush()
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_end_user(
db_session_with_containers,
*,
tenant_id: str,
app_id: str,
session_id: str,
invoke_type: InvokeFrom,
is_anonymous: bool = False,
):
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=invoke_type,
external_user_id=session_id,
name=f"User-{uuid4()}",
is_anonymous=is_anonymous,
session_id=session_id,
)
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
user_id = "custom-user-123"
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result.tenant_id == app.tenant_id
assert result.app_id == app.id
assert result.session_id == user_id
assert result.type == InvokeFrom.SERVICE_API
assert result.is_anonymous is False
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
assert result.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
def test_get_existing_end_user(self, db_session_with_containers, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
user_id = "existing-user-123"
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result.id == existing_user.id
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
This test suite covers:
- Creating end users with different InvokeFrom types
- Type migration for legacy users
- Query ordering and prioritization
- Session management
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_create_end_user_service_api_type(self, db_session_with_containers, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == InvokeFrom.SERVICE_API
assert result.tenant_id == tenant_id
assert result.app_id == app_id
assert result.session_id == user_id
def test_create_end_user_web_app_type(self, db_session_with_containers, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == InvokeFrom.WEB_APP
@patch("services.end_user_service.logger")
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == existing_user.id
assert result.type == InvokeFrom.WEB_APP # Type should be updated
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
@patch("services.end_user_service.logger")
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == existing_user.id
assert result.type == InvokeFrom.SERVICE_API
mock_logger.info.assert_not_called()
def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
assert result.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
non_matching = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.WEB_APP,
)
matching = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == matching.id
assert result.id != non_matching.id
def test_external_user_id_matches_session_id(self, db_session_with_containers, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "custom-external-id"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.external_user_id == user_id
assert result.session_id == user_id
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = f"user-{uuid4()}"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory):
app = factory.create_app_and_account(db_session_with_containers)
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=app.tenant_id,
app_id=app.id,
session_id=f"session-{uuid4()}",
invoke_type=InvokeFrom.SERVICE_API,
)
result = EndUserService.get_end_user_by_id(
tenant_id=app.tenant_id,
app_id=app.id,
end_user_id=existing_user.id,
)
assert result is not None
assert result.id == existing_user.id
def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory):
app = factory.create_app_and_account(db_session_with_containers)
result = EndUserService.get_end_user_by_id(
tenant_id=app.tenant_id,
app_id=app.id,
end_user_id=str(uuid4()),
)
assert result is None

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from decimal import Decimal
import pytest
from models.model import Message
from services import message_service
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
def test_attach_message_extra_contents_assigns_serialized_payload(db_session_with_containers) -> None:
fixture = create_human_input_message_fixture(db_session_with_containers)
message_without_extra_content = Message(
app_id=fixture.app.id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=fixture.conversation.id,
inputs={},
query="Query without extra content",
message={"messages": [{"role": "user", "content": "Query without extra content"}]},
message_tokens=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
answer="Answer without extra content",
answer_tokens=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=0,
total_price=Decimal(0),
currency="USD",
status="normal",
from_source="console",
from_account_id=fixture.account.id,
)
db_session_with_containers.add(message_without_extra_content)
db_session_with_containers.commit()
messages = [fixture.message, message_without_extra_content]
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": fixture.message.workflow_run_id,
"submitted": True,
"form_submission_data": {
"node_id": fixture.form.node_id,
"node_title": fixture.node_title,
"rendered_content": fixture.form.rendered_content,
"action_id": fixture.action_id,
"action_text": fixture.action_text,
},
}
]
assert messages[1].extra_contents == []

View File

@@ -0,0 +1,224 @@
import uuid
from unittest.mock import ANY, call, patch
import pytest
from core.db.session_factory import session_factory
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import (
_delete_draft_variable_offload_data,
delete_draft_variables_batch,
)
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(WorkflowDraftVariable).delete()
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.commit()
def _create_tenant_and_app(db_session_with_containers):
tenant = Tenant(name=f"test_tenant_{uuid.uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
app = App(
tenant_id=tenant.id,
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return tenant, app
def _create_draft_variables(
db_session_with_containers,
*,
app_id: str,
count: int,
file_id_by_index: dict[int, str] | None = None,
) -> list[WorkflowDraftVariable]:
variables: list[WorkflowDraftVariable] = []
file_id_by_index = file_id_by_index or {}
for i in range(count):
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
file_id=file_id_by_index.get(i),
)
db_session_with_containers.add(variable)
variables.append(variable)
db_session_with_containers.commit()
return variables
def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: str, count: int):
upload_files: list[UploadFile] = []
variable_files: list[WorkflowDraftVariableFile] = []
for i in range(count):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=f"test/file-{uuid.uuid4()}-{i}.json",
name=f"file-{i}.json",
size=1024 + i,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.flush()
upload_files.append(upload_file)
variable_file = WorkflowDraftVariableFile(
tenant_id=tenant_id,
app_id=app_id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file.id,
size=1024 + i,
length=10 + i,
value_type=SegmentType.STRING,
)
db_session_with_containers.add(variable_file)
db_session_with_containers.flush()
variable_files.append(variable_file)
db_session_with_containers.commit()
return {
"upload_files": upload_files,
"variable_files": variable_files,
}
class TestDeleteDraftVariablesBatch:
def test_delete_draft_variables_batch_success(self, db_session_with_containers):
"""Test successful deletion of draft variables in batches."""
_, app1 = _create_tenant_and_app(db_session_with_containers)
_, app2 = _create_tenant_and_app(db_session_with_containers)
_create_draft_variables(db_session_with_containers, app_id=app1.id, count=150)
_create_draft_variables(db_session_with_containers, app_id=app2.id, count=100)
result = delete_draft_variables_batch(app1.id, batch_size=100)
assert result == 150
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app1.id
)
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app2.id
)
assert app1_remaining.count() == 0
assert app2_remaining.count() == 100
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
"""Test deletion when no draft variables exist for the app."""
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
assert result == 0
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(
self, mock_logger, mock_offload_cleanup, db_session_with_containers
):
"""Test that batch deletion logs progress correctly."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=10)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
file_id_by_index: dict[int, str] = {}
for i in range(30):
if i % 3 == 0:
file_id_by_index[i] = file_ids[i // 3]
_create_draft_variables(db_session_with_containers, app_id=app.id, count=30, file_id_by_index=file_id_by_index)
mock_offload_cleanup.return_value = len(file_id_by_index)
result = delete_draft_variables_batch(app.id, 50)
assert result == 30
mock_offload_cleanup.assert_called_once()
_, called_file_ids = mock_offload_cleanup.call_args.args
assert {str(file_id) for file_id in called_file_ids} == {str(file_id) for file_id in file_id_by_index.values()}
assert mock_logger.info.call_count == 2
mock_logger.info.assert_any_call(ANY)
class TestDeleteDraftVariableOffloadData:
"""Test the Offload data cleanup functionality."""
@patch("extensions.ext_storage.storage")
def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers):
"""Test successful deletion of offload data."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
upload_file_keys = [upload_file.key for upload_file in offload_data["upload_files"]]
upload_file_ids = [upload_file.id for upload_file in offload_data["upload_files"]]
with session_factory.create_session() as session, session.begin():
result = _delete_draft_variable_offload_data(session, file_ids)
assert result == 3
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_storage_failure(
self, mock_logging, mock_storage, db_session_with_containers
):
"""Test handling of storage deletion failures."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=2)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
storage_keys = [upload_file.key for upload_file in offload_data["upload_files"]]
upload_file_ids = [upload_file.id for upload_file in offload_data["upload_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
with session_factory.create_session() as session, session.begin():
result = _delete_draft_variable_offload_data(session, file_ids)
assert result == 1
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0

View File

@@ -954,148 +954,6 @@ class TestChildChunk:
assert child_chunk.index_node_hash == index_node_hash
class TestDocumentSegmentNavigation:
"""Test suite for DocumentSegment navigation properties."""
def test_document_segment_dataset_property(self):
"""Test segment can access its parent dataset."""
# Arrange
dataset_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=dataset_id,
document_id=str(uuid4()),
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
mock_dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
mock_dataset.id = dataset_id
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=mock_dataset):
# Act
dataset = segment.dataset
# Assert
assert dataset is not None
assert dataset.id == dataset_id
def test_document_segment_document_property(self):
"""Test segment can access its parent document."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
mock_document = Document(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=str(uuid4()),
)
mock_document.id = document_id
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=mock_document):
# Act
document = segment.document
# Assert
assert document is not None
assert document.id == document_id
def test_document_segment_previous_segment(self):
"""Test segment can access previous segment."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=2,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
previous_segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Previous",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=previous_segment):
# Act
prev_seg = segment.previous_segment
# Assert
assert prev_seg is not None
assert prev_seg.position == 1
def test_document_segment_next_segment(self):
"""Test segment can access next segment."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
next_segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=2,
content="Next",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=next_segment):
# Act
next_seg = segment.next_segment
# Assert
assert next_seg is not None
assert next_seg.position == 2
class TestModelIntegration:
"""Test suite for model integration scenarios."""

View File

@@ -1,40 +0,0 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation."""
from unittest.mock import Mock
from sqlalchemy.orm import Session, sessionmaker
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
def test_get_executions_by_workflow_run_keeps_paused_records(self):
mock_session = Mock(spec=Session)
execute_result = Mock()
execute_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = execute_result
session_maker = Mock(spec=sessionmaker)
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker)
repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-123",
workflow_run_id="workflow-run-123",
)
stmt = mock_session.execute.call_args[0][0]
where_clauses = list(getattr(stmt, "_where_criteria", []) or [])
where_strs = [str(clause).lower() for clause in where_clauses]
assert any("tenant_id" in clause for clause in where_strs)
assert any("app_id" in clause for clause in where_strs)
assert any("workflow_run_id" in clause for clause in where_strs)
assert not any("paused" in clause for clause in where_strs)

View File

@@ -1,435 +1,50 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
"""Unit tests for non-SQL helper logic in workflow run repository."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from models.workflow import WorkflowPauseReason
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=Session)
@pytest.fixture
def mock_session_maker(self, mock_session):
"""Create a mock sessionmaker."""
session_maker = Mock(spec=sessionmaker)
# Create a context manager mock
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
# Mock session.begin() context manager
begin_context_manager = Mock()
begin_context_manager.__enter__ = Mock(return_value=None)
begin_context_manager.__exit__ = Mock(return_value=None)
mock_session.begin = Mock(return_value=begin_context_manager)
# Add missing session methods
mock_session.commit = Mock()
mock_session.rollback = Mock()
mock_session.add = Mock()
mock_session.delete = Mock()
mock_session.get = Mock()
mock_session.scalar = Mock()
mock_session.scalars = Mock()
# Also support expire_on_commit parameter
def make_session(expire_on_commit=None):
cm = Mock()
cm.__enter__ = Mock(return_value=mock_session)
cm.__exit__ = Mock(return_value=None)
return cm
session_maker.side_effect = make_session
return session_maker
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mocked dependencies."""
# Create a testable subclass that implements the save method
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
def __init__(self, session_maker):
# Initialize without calling parent __init__ to avoid any instantiation issues
self._session_maker = session_maker
def save(self, execution):
"""Mock implementation of save method."""
return None
# Create repository instance
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
return repo
@pytest.fixture
def sample_workflow_run(self):
"""Create a sample WorkflowRun model."""
workflow_run = Mock(spec=WorkflowRun)
workflow_run.id = "workflow-run-123"
workflow_run.tenant_id = "tenant-123"
workflow_run.app_id = "app-123"
workflow_run.workflow_id = "workflow-123"
workflow_run.status = WorkflowExecutionStatus.RUNNING
return workflow_run
@pytest.fixture
def sample_workflow_pause(self):
"""Create a sample WorkflowPauseModel."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
scalar_result = Mock()
scalar_result.all.return_value = []
mock_session.scalars.return_value = scalar_result
repository.get_runs_batch_by_time_range(
start_from=None,
end_before=datetime(2024, 1, 1),
last_seen=None,
batch_size=50,
)
stmt = mock_session.scalars.call_args[0][0]
compiled_sql = str(
stmt.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
assert "workflow_runs.status" in compiled_sql
for status in (
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.STOPPED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
):
assert f"'{status.value}'" in compiled_sql
assert "'running'" not in compiled_sql
assert "'paused'" not in compiled_sql
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test create_workflow_pause method."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test successful workflow pause creation."""
# Arrange
workflow_run_id = "workflow-run-123"
state_owner_user_id = "user-123"
state = '{"test": "state"}'
mock_session.get.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
mock_uuidv7.side_effect = ["pause-123"]
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
result = repository.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
pause_reasons=[],
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
assert result.get_pause_reasons() == []
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
mock_storage.save.assert_called_once()
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_create_workflow_pause_not_found(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
"""Test workflow pause creation when workflow run not found."""
# Arrange
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
def test_create_workflow_pause_invalid_status(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
node_ids_result = Mock()
node_ids_result.all.return_value = []
pause_ids_result = Mock()
pause_ids_result.all.return_value = []
mock_session.scalars.side_effect = [node_ids_result, pause_ids_result]
# app_logs delete, runs delete
mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)]
fake_trigger_repo = Mock()
fake_trigger_repo.delete_by_run_ids.return_value = 3
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
counts = repository.delete_runs_with_related(
[run],
delete_node_executions=lambda session, runs: (2, 1),
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
)
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["runs"] == 1
class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
pause_ids_result = Mock()
pause_ids_result.all.return_value = ["pause-1", "pause-2"]
mock_session.scalars.return_value = pause_ids_result
mock_session.scalar.side_effect = [5, 2]
fake_trigger_repo = Mock()
fake_trigger_repo.count_by_run_ids.return_value = 3
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
counts = repository.count_runs_with_related(
[run],
count_node_executions=lambda session, runs: (2, 1),
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
)
fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 5
assert counts["pauses"] == 2
assert counts["pause_reasons"] == 2
assert counts["runs"] == 1
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test resume_workflow_pause method."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause resume."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
# Setup workflow run and pause
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.pause = sample_workflow_pause
sample_workflow_pause.resumed_at = None
mock_session.scalar.return_value = sample_workflow_run
mock_session.scalars.return_value.all.return_value = []
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
mock_now.return_value = datetime.now(UTC)
# Act
result = repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
# Verify state transitions
assert sample_workflow_pause.resumed_at is not None
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
# Verify database interactions
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test resume when workflow is not paused."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test resume when pause ID doesn't match."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-456" # Different ID
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_pause.id = "pause-123"
sample_workflow_run.pause = sample_workflow_pause
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test delete_workflow_pause method."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause deletion."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = sample_workflow_pause
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
repository.delete_workflow_pause(pause_entity=pause_entity)
# Assert
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
mock_session.delete.assert_called_once_with(sample_workflow_pause)
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
):
"""Test delete when pause not found."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
@pytest.fixture
def sample_workflow_pause() -> Mock:
"""Create a sample WorkflowPause model."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity class."""
def test_properties(self, sample_workflow_pause: Mock):
def test_properties(self, sample_workflow_pause: Mock) -> None:
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Act & Assert
# Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock):
def test_get_state(self, sample_workflow_pause: Mock) -> None:
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
@@ -445,7 +60,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock):
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
@@ -456,16 +71,20 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
# Act
result1 = entity.get_state()
result2 = entity.get_state() # Should use cache
result2 = entity.get_state()
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once() # Only called once due to caching
mock_storage.load.assert_called_once()
class TestBuildHumanInputRequiredReason:
def test_prefers_backstage_token_when_available(self):
"""Test helper that builds HumanInputRequired pause reasons."""
def test_prefers_backstage_token_when_available(self) -> None:
"""Use backstage token when multiple recipient types may exist."""
# Arrange
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
@@ -504,8 +123,10 @@ class TestBuildHumanInputRequiredReason:
access_token=access_token,
)
# Act
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
# Assert
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"

View File

@@ -1,31 +0,0 @@
from unittest.mock import Mock
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
def test_delete_by_run_ids_executes_delete():
session = Mock(spec=Session)
session.execute.return_value = Mock(rowcount=2)
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
deleted = repo.delete_by_run_ids(["run-1", "run-2"])
stmt = session.execute.call_args[0][0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
assert "workflow_trigger_logs" in compiled_sql
assert "'run-1'" in compiled_sql
assert "'run-2'" in compiled_sql
assert deleted == 2
def test_delete_by_run_ids_empty_short_circuits():
session = Mock(spec=Session)
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
deleted = repo.delete_by_run_ids([])
session.execute.assert_not_called()
assert deleted == 0

View File

@@ -6,66 +6,6 @@ from unittest.mock import MagicMock, patch
class TestArchivedWorkflowRunDeletion:
def test_delete_by_run_id_returns_error_when_run_missing(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
session = MagicMock()
session.get.return_value = None
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 not found"
repo.get_archived_run_ids.assert_not_called()
def test_delete_by_run_id_returns_error_when_not_archived(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
repo.get_archived_run_ids.return_value = set()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
session = MagicMock()
session.get.return_value = run
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(deleter, "_delete_run") as mock_delete_run,
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 is not archived"
mock_delete_run.assert_not_called()
def test_delete_by_run_id_calls_delete_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@@ -98,55 +38,6 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is True
mock_delete_run.assert_called_once_with(run)
def test_delete_batch_uses_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
run1 = MagicMock()
run1.id = "run-1"
run1.tenant_id = "tenant-1"
run2 = MagicMock()
run2.id = "run-2"
run2.tenant_id = "tenant-1"
repo.get_archived_runs_by_time_range.return_value = [run1, run2]
session = MagicMock()
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
start_date = MagicMock()
end_date = MagicMock()
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(
deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)]
) as mock_delete_run,
):
results = deleter.delete_batch(
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert len(results) == 2
repo.get_archived_runs_by_time_range.assert_called_once_with(
session=session,
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert mock_delete_run.call_count == 2
def test_delete_run_dry_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@@ -160,21 +51,3 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is True
mock_get_repo.assert_not_called()
def test_delete_run_calls_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
repo = MagicMock()
repo.delete_runs_with_related.return_value = {"runs": 1}
with patch.object(deleter, "_get_workflow_run_repo", return_value=repo):
result = deleter._delete_run(run)
assert result.success is True
assert result.deleted_counts == {"runs": 1}
repo.delete_runs_with_related.assert_called_once()

View File

@@ -1,6 +1,3 @@
import sqlalchemy as sa
from models.dataset import Document
from services.dataset_service import DocumentService
@@ -9,25 +6,3 @@ def test_normalize_display_status_alias_mapping():
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None
def test_build_display_status_filters_available():
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
def test_apply_display_status_filter_applies_when_status_present():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" in compiled
assert "documents.indexing_status = 'waiting'" in compiled
def test_apply_display_status_filter_returns_same_when_invalid():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" not in compiled

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import App, DefaultEndUserSessionID, EndUser
from models.model import App, EndUser
from services.end_user_service import EndUserService
@@ -44,113 +44,6 @@ class TestEndUserServiceFactory:
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 01: Get or create with custom user_id
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_mock()
user_id = "custom-user-123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify the created user has correct attributes
added_user = mock_session.add.call_args[0][0]
assert added_user.tenant_id == app.tenant_id
assert added_user.app_id == app.id
assert added_user.session_id == user_id
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.is_anonymous is False
# Test 02: Get or create without user_id (default session)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_mock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
# Test 03: Get existing end user
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_mock()
user_id = "existing-user-123"
existing_user = factory.create_end_user_mock(
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result == existing_user
mock_session.add.assert_not_called() # Should not create new user
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
@@ -167,226 +60,6 @@ class TestEndUserServiceGetOrCreateEndUserByType:
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 04: Create new end user with SERVICE_API type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.tenant_id == tenant_id
assert added_user.app_id == app_id
assert added_user.session_id == user_id
# Test 05: Create new end user with WEB_APP type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.WEB_APP
# Test 06: Upgrade legacy end user type
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
mock_session.commit.assert_called_once()
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
# Test 07: Get existing end user with matching type (no upgrade needed)
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.SERVICE_API
# No commit should be called (no type update needed)
mock_session.commit.assert_not_called()
mock_logger.info.assert_not_called()
# Test 08: Create anonymous user with default session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Test 09: Query ordering prioritizes matching type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
# Verify order_by was called (for type prioritization)
mock_query.order_by.assert_called_once()
# Test 10: Session context manager properly closes
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
@@ -420,117 +93,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
# Verify context manager was entered and exited
mock_context.__enter__.assert_called_once()
mock_context.__exit__.assert_called_once()
# Test 11: External user ID matches session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "custom-external-id"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.external_user_id == user_id
assert added_user.session_id == user_id
# Test 12: Different InvokeFrom types
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
tenant_id = "tenant-123"
app_id = "app-456"
end_user_id = "end-user-789"
existing_user = MagicMock(spec=EndUser)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_user
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
assert result == existing_user
mock_session.query.assert_called_once_with(EndUser)
mock_query.where.assert_called_once()
assert len(mock_query.where.call_args[0]) == 3
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
assert result is None

View File

@@ -1,61 +0,0 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": "workflow-run-1",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []

View File

@@ -1,4 +1,4 @@
from unittest.mock import ANY, MagicMock, call, patch
from unittest.mock import MagicMock, call, patch
import pytest
@@ -14,124 +14,6 @@ from tasks.remove_app_and_related_data_task import (
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
batch2_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(100, 150)]
batch1_ids = [row[0] for row in batch1_data]
batch1_file_ids = [row[1] for row in batch1_data if row[1] is not None]
batch2_ids = [row[0] for row in batch2_data]
batch2_file_ids = [row[1] for row in batch2_data if row[1] is not None]
# Setup side effects for execute calls in the correct order:
# 1. SELECT (returns batch1_data with id, file_id)
# 2. DELETE (returns result with rowcount=100)
# 3. SELECT (returns batch2_data)
# 4. DELETE (returns result with rowcount=50)
# 5. SELECT (returns empty, ends loop)
# Create mock results with actual integer rowcount attributes
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
# First SELECT result
select_result1 = MagicMock()
select_result1.__iter__.return_value = iter(batch1_data)
# First DELETE result
delete_result1 = MockResult(rowcount=100)
# Second SELECT result
select_result2 = MagicMock()
select_result2.__iter__.return_value = iter(batch2_data)
# Second DELETE result
delete_result2 = MockResult(rowcount=50)
# Third SELECT result (empty, ends loop)
select_result3 = MagicMock()
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_session.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
delete_result2, # Second DELETE
select_result3, # Third SELECT (empty)
]
# Mock offload data cleanup
mock_offload_cleanup.side_effect = [len(batch1_file_ids), len(batch2_file_ids)]
# Execute the function
result = delete_draft_variables_batch(app_id, batch_size)
# Verify the result
assert result == 150
# Verify database calls
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
# Verify offload cleanup was called for both batches with file_ids
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
# Simplified verification - check that the right number of calls were made
# and that the SQL queries contain the expected patterns
actual_calls = mock_session.execute.call_args_list
for i, actual_call in enumerate(actual_calls):
sql_text = str(actual_call[0][0])
normalized = " ".join(sql_text.split())
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
assert "WHERE app_id = :app_id" in normalized
assert "LIMIT :batch_size" in normalized
else: # DELETE calls (odd indices: 1, 3)
assert "DELETE FROM workflow_draft_variables" in normalized
assert "WHERE id IN :ids" in normalized
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_session.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_session.execute.call_count == 1 # Only one select query
mock_offload_cleanup.assert_not_called() # No files to clean up
def test_delete_draft_variables_batch_invalid_batch_size(self):
"""Test that invalid batch size raises ValueError."""
app_id = "test-app-id"
@@ -142,66 +24,6 @@ class TestDeleteDraftVariablesBatch:
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock one batch then empty
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
batch_ids = [row[0] for row in batch_data]
batch_file_ids = [row[1] for row in batch_data if row[1] is not None]
# Create properly configured mocks
select_result = MagicMock()
select_result.__iter__.return_value = iter(batch_data)
# Create simple object with rowcount attribute
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
delete_result = MockResult(rowcount=30)
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_session.execute.side_effect = [
# Select query result
select_result,
# Delete query result
delete_result,
# Empty select result (end condition)
empty_result,
]
# Mock offload cleanup
mock_offload_cleanup.return_value = len(batch_file_ids)
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 30
# Verify offload cleanup was called with file_ids
if batch_file_ids:
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
# Verify logging calls
assert mock_logging.info.call_count == 2
mock_logging.info.assert_any_call(
ANY # click.style call
)
@patch("tasks.remove_app_and_related_data_task.delete_draft_variables_batch")
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
"""Test that _delete_draft_variables calls the batch function correctly."""
@@ -218,58 +40,6 @@ class TestDeleteDraftVariablesBatch:
class TestDeleteDraftVariableOffloadData:
"""Test the Offload data cleanup functionality."""
@patch("extensions.ext_storage.storage")
def test_delete_draft_variable_offload_data_success(self, mock_storage):
"""Test successful deletion of offload data."""
# Mock connection
mock_conn = MagicMock()
file_ids = ["file-1", "file-2", "file-3"]
# Mock query results: (variable_file_id, storage_key, upload_file_id)
query_results = [
("file-1", "storage/key/1", "upload-1"),
("file-2", "storage/key/2", "upload-2"),
("file-3", "storage/key/3", "upload-3"),
]
mock_result = MagicMock()
mock_result.__iter__.return_value = iter(query_results)
mock_conn.execute.return_value = mock_result
# Execute function
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
# Verify return value
assert result == 3
# Verify storage deletion calls
expected_storage_calls = [call("storage/key/1"), call("storage/key/2"), call("storage/key/3")]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
# Verify database calls - should be 3 calls total
assert mock_conn.execute.call_count == 3
# Verify the queries were called
actual_calls = mock_conn.execute.call_args_list
# First call should be the SELECT query
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
# Second call should be DELETE upload_files
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
assert "DELETE FROM upload_files" in delete_upload_call_sql
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
# Third call should be DELETE workflow_draft_variable_files
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
def test_delete_draft_variable_offload_data_empty_file_ids(self):
"""Test handling of empty file_ids list."""
mock_conn = MagicMock()
@@ -279,38 +49,6 @@ class TestDeleteDraftVariableOffloadData:
assert result == 0
mock_conn.execute.assert_not_called()
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_storage_failure(self, mock_logging, mock_storage):
"""Test handling of storage deletion failures."""
mock_conn = MagicMock()
file_ids = ["file-1", "file-2"]
# Mock query results
query_results = [
("file-1", "storage/key/1", "upload-1"),
("file-2", "storage/key/2", "upload-2"),
]
mock_result = MagicMock()
mock_result.__iter__.return_value = iter(query_results)
mock_conn.execute.return_value = mock_result
# Make storage.delete fail for the first file
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Execute function
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
# Should still return 2 (both files processed, even if one storage delete failed)
assert result == 1 # Only one storage deletion succeeded
# Verify warning was logged
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", "storage/key/1")
# Verify both database cleanup calls still happened
assert mock_conn.execute.call_count == 3
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_database_failure(self, mock_logging):
"""Test handling of database operation failures."""

View File

@@ -0,0 +1,48 @@
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { FormTypeEnum } from '@/app/components/base/form/types'
import AuthForm from './index'
const formSchemas = [{
type: FormTypeEnum.textInput,
name: 'apiKey',
label: 'API Key',
required: true,
}] as const
const renderWithQueryClient = (ui: Parameters<typeof render>[0]) => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
describe('AuthForm', () => {
it('should render configured fields', () => {
renderWithQueryClient(<AuthForm formSchemas={[...formSchemas]} />)
expect(screen.getByText('API Key')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
it('should use provided default values', () => {
renderWithQueryClient(<AuthForm formSchemas={[...formSchemas]} defaultValues={{ apiKey: 'value-123' }} />)
expect(screen.getByDisplayValue('value-123')).toBeInTheDocument()
})
it('should render nothing when no schema is provided', () => {
const { container } = renderWithQueryClient(<AuthForm formSchemas={[]} />)
expect(container).toBeEmptyDOMElement()
})
})

View File

@@ -0,0 +1,137 @@
import type { BaseConfiguration } from './types'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { TransferMethod } from '@/types/app'
import { useAppForm } from '../..'
import BaseField from './field'
import { BaseFieldType } from './types'
vi.mock('next/navigation', () => ({
useParams: () => ({}),
}))
const createConfig = (overrides: Partial<BaseConfiguration> = {}): BaseConfiguration => ({
type: BaseFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: BaseConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => BaseField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
describe('BaseField', () => {
it('should render a text input field when configured as text input', () => {
render(<FieldHarness config={createConfig({ label: 'Username' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Username')).toBeInTheDocument()
})
it('should render a number input when configured as number input', () => {
render(<FieldHarness config={createConfig({ type: BaseFieldType.numberInput, label: 'Age' })} initialData={{ fieldA: 20 }} />)
expect(screen.getByRole('spinbutton')).toBeInTheDocument()
expect(screen.getByText('Age')).toBeInTheDocument()
})
it('should render a checkbox when configured as checkbox', () => {
render(<FieldHarness config={createConfig({ type: BaseFieldType.checkbox, label: 'Agree' })} initialData={{ fieldA: false }} />)
expect(screen.getByText('Agree')).toBeInTheDocument()
})
it('should render paragraph and select fields based on configuration', () => {
const scenarios: Array<{ config: BaseConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({
type: BaseFieldType.paragraph,
label: 'Description',
}),
initialData: { fieldA: 'hello' },
},
{
config: createConfig({
type: BaseFieldType.select,
label: 'Mode',
options: [{ value: 'safe', label: 'Safe' }],
}),
initialData: { fieldA: 'safe' },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
})
it('should render file uploader when configured as file', () => {
const scenarios: Array<{ config: BaseConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({
type: BaseFieldType.file,
label: 'Attachment',
allowedFileExtensions: ['txt'],
allowedFileTypes: ['document'],
allowedFileUploadMethods: [TransferMethod.local_file],
}),
initialData: { fieldA: [] },
},
{
config: createConfig({
type: BaseFieldType.fileList,
label: 'Attachments',
maxLength: 2,
allowedFileExtensions: ['txt'],
allowedFileTypes: ['document'],
allowedFileUploadMethods: [TransferMethod.local_file],
}),
initialData: { fieldA: [] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as BaseFieldType, label: 'Unsupported' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported')).not.toBeInTheDocument()
})
it('should not render when show conditions are not met', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Field',
showConditions: [{ variable: 'toggle', value: true }],
})}
initialData={{ fieldA: '', toggle: false }}
/>,
)
expect(screen.queryByText('Hidden Field')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,94 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import BaseForm from './index'
import { BaseFieldType } from './types'
const baseConfigurations = [{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: false,
showConditions: [],
}]
describe('BaseForm', () => {
it('should render configured fields', () => {
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={() => {}}
/>,
)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByDisplayValue('Alice')).toBeInTheDocument()
})
it('should submit current form values when submit button is clicked', async () => {
const onSubmit = vi.fn()
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={onSubmit}
CustomActions={({ form }) => (
<button type="button" onClick={() => form.handleSubmit()}>
Submit
</button>
)}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ name: 'Alice' })
})
})
it('should render custom actions when provided', () => {
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={() => {}}
CustomActions={() => <button type="button">Save Form</button>}
/>,
)
expect(screen.getByRole('button', { name: /save form/i })).toBeInTheDocument()
expect(screen.queryByRole('button', { name: /common.operation.submit/i })).not.toBeInTheDocument()
})
it('should handle native form submit and block invalid submission', async () => {
const onSubmit = vi.fn()
const requiredConfig = [{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: true,
showConditions: [],
maxLength: 2,
}]
const { container } = render(
<BaseForm
initialData={{ name: 'ok' }}
configurations={requiredConfig}
onSubmit={onSubmit}
/>,
)
const form = container.querySelector('form')
const input = screen.getByRole('textbox')
expect(form).not.toBeNull()
fireEvent.submit(form!)
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ name: 'ok' })
})
fireEvent.change(input, { target: { value: 'long' } })
fireEvent.submit(form!)
expect(onSubmit).toHaveBeenCalledTimes(1)
})
})

View File

@@ -0,0 +1,15 @@
import { BaseFieldType } from './types'
describe('base scenario types', () => {
it('should include all supported base field types', () => {
expect(Object.values(BaseFieldType)).toEqual([
'text-input',
'paragraph',
'number-input',
'checkbox',
'select',
'file',
'file-list',
])
})
})

View File

@@ -0,0 +1,49 @@
import { BaseFieldType } from './types'
import { generateZodSchema } from './utils'
describe('base scenario schema generator', () => {
it('should validate required text fields with max length', () => {
const schema = generateZodSchema([{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: true,
maxLength: 3,
showConditions: [],
}])
expect(schema.safeParse({ name: 'abc' }).success).toBe(true)
expect(schema.safeParse({ name: '' }).success).toBe(false)
expect(schema.safeParse({ name: 'abcd' }).success).toBe(false)
})
it('should validate number bounds', () => {
const schema = generateZodSchema([{
type: BaseFieldType.numberInput,
variable: 'age',
label: 'Age',
required: true,
min: 18,
max: 30,
showConditions: [],
}])
expect(schema.safeParse({ age: 20 }).success).toBe(true)
expect(schema.safeParse({ age: 17 }).success).toBe(false)
expect(schema.safeParse({ age: 31 }).success).toBe(false)
})
it('should allow optional fields to be undefined or null', () => {
const schema = generateZodSchema([{
type: BaseFieldType.select,
variable: 'mode',
label: 'Mode',
required: false,
showConditions: [],
options: [{ value: 'safe', label: 'Safe' }],
}])
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ mode: null }).success).toBe(true)
})
})

View File

@@ -0,0 +1,24 @@
import { render, screen } from '@testing-library/react'
import { useAppForm } from '../..'
import ContactFields from './contact-fields'
import { demoFormOpts } from './shared-options'
const ContactFieldsHarness = () => {
const form = useAppForm({
...demoFormOpts,
onSubmit: () => {},
})
return <ContactFields form={form} />
}
describe('ContactFields', () => {
it('should render contact section fields', () => {
render(<ContactFieldsHarness />)
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /email/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /phone/i })).toBeInTheDocument()
expect(screen.getByText(/preferred contact method/i)).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,69 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import DemoForm from './index'
describe('DemoForm', () => {
const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
beforeEach(() => {
vi.clearAllMocks()
})
it('should render the primary fields', () => {
render(<DemoForm />)
expect(screen.getByRole('textbox', { name: /^name$/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /^surname$/i })).toBeInTheDocument()
expect(screen.getByText(/i accept the terms and conditions/i)).toBeInTheDocument()
})
it('should show contact fields after a name is entered', () => {
render(<DemoForm />)
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
fireEvent.change(screen.getByRole('textbox', { name: /^name$/i }), { target: { value: 'Alice' } })
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
})
it('should hide contact fields when name is cleared', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i })
fireEvent.change(nameInput, { target: { value: 'Alice' } })
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
fireEvent.change(nameInput, { target: { value: '' } })
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
})
it('should log validation errors on invalid submit', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i }) as HTMLInputElement
fireEvent.submit(nameInput.form!)
return waitFor(() => {
expect(consoleLogSpy).toHaveBeenCalledWith('Validation errors:', expect.any(Array))
})
})
it('should log submitted values on valid submit', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i }) as HTMLInputElement
fireEvent.change(nameInput, { target: { value: 'Alice' } })
fireEvent.change(screen.getByRole('textbox', { name: /^surname$/i }), { target: { value: 'Smith' } })
fireEvent.click(screen.getByText(/i accept the terms and conditions/i))
fireEvent.change(screen.getByRole('textbox', { name: /email/i }), { target: { value: 'alice@example.com' } })
fireEvent.submit(nameInput.form!)
return waitFor(() => {
expect(consoleLogSpy).toHaveBeenCalledWith('Form submitted:', expect.objectContaining({
name: 'Alice',
surname: 'Smith',
isAcceptingTerms: true,
}))
})
})
})

View File

@@ -0,0 +1,16 @@
import { demoFormOpts } from './shared-options'
describe('demoFormOpts', () => {
it('should provide expected default values', () => {
expect(demoFormOpts.defaultValues).toEqual({
name: '',
surname: '',
isAcceptingTerms: false,
contact: {
email: '',
phone: '',
preferredContactMethod: 'email',
},
})
})
})

View File

@@ -0,0 +1,39 @@
import { ContactMethods, UserSchema } from './types'
describe('demo scenario types', () => {
it('should expose contact methods with capitalized labels', () => {
expect(ContactMethods).toEqual([
{ value: 'email', label: 'Email' },
{ value: 'phone', label: 'Phone' },
{ value: 'whatsapp', label: 'Whatsapp' },
{ value: 'sms', label: 'Sms' },
])
})
it('should validate a complete user payload', () => {
expect(UserSchema.safeParse({
name: 'Alice',
surname: 'Smith',
isAcceptingTerms: true,
contact: {
email: 'alice@example.com',
phone: '',
preferredContactMethod: 'email',
},
}).success).toBe(true)
})
it('should reject invalid user payload', () => {
const result = UserSchema.safeParse({
name: 'alice',
surname: 's',
isAcceptingTerms: false,
contact: {
email: 'invalid',
preferredContactMethod: 'email',
},
})
expect(result.success).toBe(false)
})
})

View File

@@ -0,0 +1,139 @@
import type { InputFieldConfiguration } from './types'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { useAppForm } from '../..'
import InputField from './field'
import { InputFieldType } from './types'
const createConfig = (overrides: Partial<InputFieldConfiguration> = {}): InputFieldConfiguration => ({
type: InputFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: InputFieldConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => InputField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
describe('InputField', () => {
it('should render text input field by default', () => {
render(<FieldHarness config={createConfig({ label: 'Prompt' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Prompt')).toBeInTheDocument()
})
it('should render number slider field when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.numberSlider,
label: 'Temperature',
description: 'Control randomness',
min: 0,
max: 1,
})}
initialData={{ fieldA: 0.5 }}
/>,
)
expect(screen.getByText('Temperature')).toBeInTheDocument()
expect(screen.getByText('Control randomness')).toBeInTheDocument()
})
it('should render select field with options when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.select,
label: 'Mode',
options: [{ value: 'safe', label: 'Safe' }],
})}
initialData={{ fieldA: 'safe' }}
/>,
)
expect(screen.getByText('Mode')).toBeInTheDocument()
})
it('should render upload method field when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.uploadMethod,
label: 'Upload Method',
})}
initialData={{ fieldA: 'local_file' }}
/>,
)
expect(screen.getByText('Upload Method')).toBeInTheDocument()
})
it('should hide the field when show conditions are not met', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Input',
showConditions: [{ variable: 'enabled', value: true }],
})}
initialData={{ enabled: false, fieldA: '' }}
/>,
)
expect(screen.queryByText('Hidden Input')).not.toBeInTheDocument()
})
it('should render remaining field types and fallback for unsupported type', () => {
const scenarios: Array<{ config: InputFieldConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({ type: InputFieldType.numberInput, label: 'Count', min: 1, max: 5 }),
initialData: { fieldA: 2 },
},
{
config: createConfig({ type: InputFieldType.checkbox, label: 'Enable' }),
initialData: { fieldA: false },
},
{
config: createConfig({ type: InputFieldType.inputTypeSelect, label: 'Input Type', supportFile: true }),
initialData: { fieldA: 'text' },
},
{
config: createConfig({ type: InputFieldType.fileTypes, label: 'File Types' }),
initialData: { fieldA: { allowedFileTypes: ['document'] } },
},
{
config: createConfig({ type: InputFieldType.options, label: 'Choices' }),
initialData: { fieldA: ['one'] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as InputFieldType, label: 'Unsupported' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,17 @@
import { InputFieldType } from './types'
describe('input-field scenario types', () => {
it('should include expected input field types', () => {
expect(Object.values(InputFieldType)).toEqual([
'textInput',
'numberInput',
'numberSlider',
'checkbox',
'options',
'select',
'inputTypeSelect',
'uploadMethod',
'fileTypes',
])
})
})

View File

@@ -0,0 +1,150 @@
import { InputFieldType } from './types'
import { generateZodSchema } from './utils'
describe('input-field scenario schema generator', () => {
it('should validate required text input with max length', () => {
const schema = generateZodSchema([{
type: InputFieldType.textInput,
variable: 'prompt',
label: 'Prompt',
required: true,
maxLength: 5,
showConditions: [],
}])
expect(schema.safeParse({ prompt: 'hello' }).success).toBe(true)
expect(schema.safeParse({ prompt: '' }).success).toBe(false)
expect(schema.safeParse({ prompt: 'longer than five' }).success).toBe(false)
})
it('should validate file types payload shape', () => {
const schema = generateZodSchema([{
type: InputFieldType.fileTypes,
variable: 'files',
label: 'Files',
required: true,
showConditions: [],
}])
expect(schema.safeParse({
files: {
allowedFileExtensions: 'txt,pdf',
allowedFileTypes: ['document'],
},
}).success).toBe(true)
expect(schema.safeParse({
files: {
allowedFileTypes: ['invalid-type'],
},
}).success).toBe(false)
})
it('should allow optional upload method fields to be omitted', () => {
const schema = generateZodSchema([{
type: InputFieldType.uploadMethod,
variable: 'methods',
label: 'Methods',
required: false,
showConditions: [],
}])
expect(schema.safeParse({}).success).toBe(true)
})
it('should validate numeric bounds and other field type shapes', () => {
const schema = generateZodSchema([
{
type: InputFieldType.numberInput,
variable: 'count',
label: 'Count',
required: true,
min: 1,
max: 3,
showConditions: [],
},
{
type: InputFieldType.numberSlider,
variable: 'temperature',
label: 'Temperature',
required: true,
showConditions: [],
},
{
type: InputFieldType.checkbox,
variable: 'enabled',
label: 'Enabled',
required: true,
showConditions: [],
},
{
type: InputFieldType.options,
variable: 'choices',
label: 'Choices',
required: true,
showConditions: [],
},
{
type: InputFieldType.select,
variable: 'mode',
label: 'Mode',
required: true,
showConditions: [],
},
{
type: InputFieldType.inputTypeSelect,
variable: 'inputType',
label: 'Input Type',
required: true,
showConditions: [],
},
{
type: InputFieldType.uploadMethod,
variable: 'methods',
label: 'Methods',
required: true,
showConditions: [],
},
{
type: 'unsupported' as InputFieldType,
variable: 'other',
label: 'Other',
required: true,
showConditions: [],
},
])
expect(schema.safeParse({
count: 2,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(true)
expect(schema.safeParse({
count: 0,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(false)
expect(schema.safeParse({
count: 4,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(false)
})
})

View File

@@ -0,0 +1,145 @@
import type { ReactNode } from 'react'
import type { InputFieldConfiguration } from './types'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { ReactFlowProvider } from 'reactflow'
import { useAppForm } from '../..'
import NodePanelField from './field'
import { InputFieldType } from './types'
vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({
default: () => <div>Variable Picker</div>,
}))
const createConfig = (overrides: Partial<InputFieldConfiguration> = {}): InputFieldConfiguration => ({
type: InputFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: InputFieldConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => NodePanelField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
const NodePanelWrapper = ({ children }: { children: ReactNode }) => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return (
<QueryClientProvider client={queryClient}>
<ReactFlowProvider>
{children}
</ReactFlowProvider>
</QueryClientProvider>
)
}
describe('NodePanelField', () => {
it('should render text input field', () => {
render(<FieldHarness config={createConfig({ label: 'Node Name' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Node Name')).toBeInTheDocument()
})
it('should render variable-or-constant field when configured', () => {
render(
<NodePanelWrapper>
<FieldHarness
config={createConfig({
type: InputFieldType.variableOrConstant,
label: 'Mode',
})}
initialData={{ fieldA: '' }}
/>
</NodePanelWrapper>,
)
expect(screen.getByText('Mode')).toBeInTheDocument()
})
it('should hide field when show conditions are not satisfied', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Node Field',
showConditions: [{ variable: 'enabled', value: true }],
})}
initialData={{ enabled: false, fieldA: '' }}
/>,
)
expect(screen.queryByText('Hidden Node Field')).not.toBeInTheDocument()
})
it('should render other configured field types and hide unsupported type', () => {
const scenarios: Array<{ config: InputFieldConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({ type: InputFieldType.numberInput, label: 'Count', min: 1, max: 3 }),
initialData: { fieldA: 2 },
},
{
config: createConfig({ type: InputFieldType.numberSlider, label: 'Temperature', description: 'Adjust' }),
initialData: { fieldA: 0.4 },
},
{
config: createConfig({ type: InputFieldType.checkbox, label: 'Enabled' }),
initialData: { fieldA: true },
},
{
config: createConfig({ type: InputFieldType.select, label: 'Mode', options: [{ value: 'safe', label: 'Safe' }] }),
initialData: { fieldA: 'safe' },
},
{
config: createConfig({ type: InputFieldType.inputTypeSelect, label: 'Input Type', supportFile: true }),
initialData: { fieldA: 'text' },
},
{
config: createConfig({ type: InputFieldType.uploadMethod, label: 'Upload Method' }),
initialData: { fieldA: ['local_file'] },
},
{
config: createConfig({ type: InputFieldType.fileTypes, label: 'File Types' }),
initialData: { fieldA: { allowedFileTypes: ['document'] } },
},
{
config: createConfig({ type: InputFieldType.options, label: 'Options' }),
initialData: { fieldA: ['a'] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as InputFieldType, label: 'Unsupported Node' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported Node')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,7 @@
import { InputFieldType } from './types'
describe('node-panel scenario types', () => {
it('should include variableOrConstant field type', () => {
expect(Object.values(InputFieldType)).toContain('variableOrConstant')
})
})

View File

@@ -0,0 +1,12 @@
import * as hookExports from './index'
import { useCheckValidated } from './use-check-validated'
import { useGetFormValues } from './use-get-form-values'
import { useGetValidators } from './use-get-validators'
describe('hooks index exports', () => {
it('should re-export all hook modules', () => {
expect(hookExports.useCheckValidated).toBe(useCheckValidated)
expect(hookExports.useGetFormValues).toBe(useGetFormValues)
expect(hookExports.useGetValidators).toBe(useGetValidators)
})
})

View File

@@ -0,0 +1,105 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { renderHook } from '@testing-library/react'
import { FormTypeEnum } from '../types'
import { useCheckValidated } from './use-check-validated'
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
}))
describe('useCheckValidated', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return true when form has no errors', () => {
const form = {
getAllErrors: () => undefined,
state: { values: {} },
}
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, []))
expect(result.current.checkValidated()).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify and return false when visible field has errors', () => {
const form = {
getAllErrors: () => ({
fields: {
name: { errors: ['Name is required'] },
},
}),
state: { values: {} },
}
const schemas = [{
name: 'name',
label: 'Name',
required: true,
type: FormTypeEnum.textInput,
show_on: [],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(false)
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Name is required',
})
})
it('should ignore hidden field errors and return true', () => {
const form = {
getAllErrors: () => ({
fields: {
secret: { errors: ['Secret is required'] },
},
}),
state: { values: { enabled: 'false' } },
}
const schemas = [{
name: 'secret',
label: 'Secret',
required: true,
type: FormTypeEnum.textInput,
show_on: [{ variable: 'enabled', value: 'true' }],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify when field is shown and has errors', () => {
const form = {
getAllErrors: () => ({
fields: {
secret: { errors: ['Secret is required'] },
},
}),
state: { values: { enabled: 'true' } },
}
const schemas = [{
name: 'secret',
label: 'Secret',
required: true,
type: FormTypeEnum.textInput,
show_on: [{ variable: 'enabled', value: 'true' }],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(false)
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Secret is required',
})
})
})

View File

@@ -0,0 +1,74 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { renderHook } from '@testing-library/react'
import { FormTypeEnum } from '../types'
import { useGetFormValues } from './use-get-form-values'
const mockCheckValidated = vi.fn()
const mockTransform = vi.fn()
vi.mock('./use-check-validated', () => ({
useCheckValidated: () => ({
checkValidated: mockCheckValidated,
}),
}))
vi.mock('../utils/secret-input', () => ({
getTransformedValuesWhenSecretInputPristine: (...args: unknown[]) => mockTransform(...args),
}))
describe('useGetFormValues', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return raw values when validation check is disabled', () => {
const form = {
store: { state: { values: { name: 'Alice' } } },
}
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, []))
expect(result.current.getFormValues({ needCheckValidatedValues: false })).toEqual({
values: { name: 'Alice' },
isCheckValidated: true,
})
})
it('should return transformed values when validation passes and transform is requested', () => {
const form = {
store: { state: { values: { password: 'abc123' } } },
}
const schemas = [{
name: 'password',
label: 'Password',
required: true,
type: FormTypeEnum.secretInput,
}]
mockCheckValidated.mockReturnValue(true)
mockTransform.mockReturnValue({ password: '[__HIDDEN__]' })
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, schemas))
expect(result.current.getFormValues({
needCheckValidatedValues: true,
needTransformWhenSecretFieldIsPristine: true,
})).toEqual({
values: { password: '[__HIDDEN__]' },
isCheckValidated: true,
})
})
it('should return empty values when validation fails', () => {
const form = {
store: { state: { values: { name: '' } } },
}
mockCheckValidated.mockReturnValue(false)
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, []))
expect(result.current.getFormValues({ needCheckValidatedValues: true })).toEqual({
values: {},
isCheckValidated: false,
})
})
})

View File

@@ -0,0 +1,78 @@
import { renderHook } from '@testing-library/react'
import { createElement } from 'react'
import { FormTypeEnum } from '../types'
import { useGetValidators } from './use-get-validators'
vi.mock('@/hooks/use-i18n', () => ({
useRenderI18nObject: () => (obj: Record<string, string>) => obj.en_US,
}))
describe('useGetValidators', () => {
it('should create required validators when field is required without custom validators', () => {
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'username',
label: 'Username',
required: true,
type: FormTypeEnum.textInput,
})
const mountMessage = validators?.onMount?.({ value: '' })
const blurMessage = validators?.onBlur?.({ value: '' })
expect(mountMessage).toContain('common.errorMsg.fieldRequired')
expect(mountMessage).toContain('"field":"Username"')
expect(blurMessage).toContain('common.errorMsg.fieldRequired')
})
it('should keep existing validators when custom validators are provided', () => {
const customValidators = {
onChange: vi.fn(() => 'custom error'),
}
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'username',
label: 'Username',
required: true,
type: FormTypeEnum.textInput,
validators: customValidators,
})
expect(validators).toBe(customValidators)
})
it('should fallback to field name when label is a react element', () => {
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'apiKey',
label: createElement('span', undefined, 'API Key'),
required: true,
type: FormTypeEnum.textInput,
})
const mountMessage = validators?.onMount?.({ value: '' })
expect(mountMessage).toContain('"field":"apiKey"')
})
it('should translate object labels and skip validators for non-required fields', () => {
const { result } = renderHook(() => useGetValidators())
const requiredValidators = result.current.getValidators({
name: 'workspace',
label: { en_US: 'Workspace', zh_Hans: '工作区' },
required: true,
type: FormTypeEnum.textInput,
})
const nonRequiredValidators = result.current.getValidators({
name: 'optionalField',
label: 'Optional',
required: false,
type: FormTypeEnum.textInput,
})
const changeMessage = requiredValidators?.onChange?.({ value: '' })
expect(changeMessage).toContain('"field":"Workspace"')
expect(nonRequiredValidators).toBeUndefined()
})
})

View File

@@ -0,0 +1,64 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useAppForm, withForm } from './index'
const FormHarness = ({ onSubmit }: { onSubmit: (value: Record<string, unknown>) => void }) => {
const form = useAppForm({
defaultValues: { title: 'Initial title' },
onSubmit: ({ value }) => onSubmit(value),
})
return (
<form>
<form.AppField
name="title"
children={field => <field.TextField label="Title" />}
/>
<form.AppForm>
<button type="button" onClick={() => form.handleSubmit()}>
Submit
</button>
</form.AppForm>
</form>
)
}
const InlinePreview = withForm({
defaultValues: { title: '' },
render: ({ form }) => {
return (
<form.AppField
name="title"
children={field => <field.TextField label="Preview Title" />}
/>
)
},
})
const WithFormHarness = () => {
const form = useAppForm({
defaultValues: { title: 'Preview value' },
onSubmit: () => {},
})
return <InlinePreview form={form} />
}
describe('form index exports', () => {
it('should submit values through the generated app form', async () => {
const onSubmit = vi.fn()
render(<FormHarness onSubmit={onSubmit} />)
fireEvent.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ title: 'Initial title' })
})
})
it('should render components created with withForm', () => {
render(<WithFormHarness />)
expect(screen.getByRole('textbox')).toHaveValue('Preview value')
expect(screen.getByText('Preview Title')).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,18 @@
import { FormItemValidateStatusEnum, FormTypeEnum } from './types'
describe('form types', () => {
it('should expose expected form type values', () => {
expect(Object.values(FormTypeEnum)).toContain('text-input')
expect(Object.values(FormTypeEnum)).toContain('dynamic-select')
expect(Object.values(FormTypeEnum)).toContain('boolean')
})
it('should expose expected validation status values', () => {
expect(Object.values(FormItemValidateStatusEnum)).toEqual([
'success',
'warning',
'error',
'validating',
])
})
})

View File

@@ -0,0 +1,54 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { FormTypeEnum } from '../../types'
import { getTransformedValuesWhenSecretInputPristine, transformFormSchemasSecretInput } from './index'
describe('secret input utilities', () => {
it('should mask only selected truthy values in transformFormSchemasSecretInput', () => {
expect(transformFormSchemasSecretInput(['apiKey'], {
apiKey: 'secret',
token: 'token-value',
emptyValue: '',
})).toEqual({
apiKey: '[__HIDDEN__]',
token: 'token-value',
emptyValue: '',
})
})
it('should mask pristine secret input fields from form state', () => {
const formSchemas = [
{ name: 'apiKey', type: FormTypeEnum.secretInput, label: 'API Key', required: true },
{ name: 'name', type: FormTypeEnum.textInput, label: 'Name', required: true },
]
const form = {
store: {
state: {
values: {
apiKey: 'secret',
name: 'Alice',
},
},
},
getFieldMeta: (name: string) => ({ isPristine: name === 'apiKey' }),
}
expect(getTransformedValuesWhenSecretInputPristine(formSchemas, form as unknown as AnyFormApi)).toEqual({
apiKey: '[__HIDDEN__]',
name: 'Alice',
})
})
it('should keep value unchanged when secret input is not pristine', () => {
const formSchemas = [
{ name: 'apiKey', type: FormTypeEnum.secretInput, label: 'API Key', required: true },
]
const form = {
store: { state: { values: { apiKey: 'secret' } } },
getFieldMeta: () => ({ isPristine: false }),
}
expect(getTransformedValuesWhenSecretInputPristine(formSchemas, form as unknown as AnyFormApi)).toEqual({
apiKey: 'secret',
})
})
})

View File

@@ -0,0 +1,39 @@
import * as z from 'zod'
import { zodSubmitValidator } from './zod-submit-validator'
describe('zodSubmitValidator', () => {
it('should return undefined for valid values', () => {
const validator = zodSubmitValidator(z.object({
name: z.string().min(2),
}))
expect(validator({ value: { name: 'Alice' } })).toBeUndefined()
})
it('should return first error message per field for invalid values', () => {
const validator = zodSubmitValidator(z.object({
name: z.string().min(3, 'Name too short'),
age: z.number().min(18, 'Must be adult'),
}))
expect(validator({ value: { name: 'Al', age: 15 } })).toEqual({
fields: {
name: 'Name too short',
age: 'Must be adult',
},
})
})
it('should ignore root-level issues without a field path', () => {
const schema = z.object({ value: z.number() }).superRefine((_value, ctx) => {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: 'Root error',
path: [],
})
})
const validator = zodSubmitValidator(schema)
expect(validator({ value: { value: 1 } })).toEqual({ fields: {} })
})
})

View File

@@ -1,10 +1,9 @@
import type { AppRouterInstance } from 'next/dist/shared/lib/app-router-context.shared-runtime'
import type { AppContextValue } from '@/context/app-context'
import type { ModalContextState } from '@/context/modal-context'
import type { ProviderContextState } from '@/context/provider-context'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { AppRouterContext } from 'next/dist/shared/lib/app-router-context.shared-runtime'
import { useRouter } from 'next/navigation'
import { Plan } from '@/app/components/billing/type'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
@@ -50,6 +49,14 @@ vi.mock('@/service/use-common', () => ({
useLogout: vi.fn(),
}))
vi.mock('next/navigation', async (importOriginal) => {
const actual = await importOriginal<typeof import('next/navigation')>()
return {
...actual,
useRouter: vi.fn(),
}
})
vi.mock('@/context/i18n', () => ({
useDocLink: () => (path: string) => `https://docs.dify.ai${path}`,
}))
@@ -119,15 +126,6 @@ describe('AccountDropdown', () => {
const mockSetShowAccountSettingModal = vi.fn()
const renderWithRouter = (ui: React.ReactElement) => {
const mockRouter = {
push: mockPush,
replace: vi.fn(),
prefetch: vi.fn(),
back: vi.fn(),
forward: vi.fn(),
refresh: vi.fn(),
} as unknown as AppRouterInstance
const queryClient = new QueryClient({
defaultOptions: {
queries: {
@@ -138,9 +136,7 @@ describe('AccountDropdown', () => {
return render(
<QueryClientProvider client={queryClient}>
<AppRouterContext.Provider value={mockRouter}>
{ui}
</AppRouterContext.Provider>
{ui}
</QueryClientProvider>,
)
}
@@ -166,6 +162,14 @@ describe('AccountDropdown', () => {
vi.mocked(useLogout).mockReturnValue({
mutateAsync: mockLogout,
} as unknown as ReturnType<typeof useLogout>)
vi.mocked(useRouter).mockReturnValue({
push: mockPush,
replace: vi.fn(),
prefetch: vi.fn(),
back: vi.fn(),
forward: vi.fn(),
refresh: vi.fn(),
})
})
afterEach(() => {

View File

@@ -1,4 +1,4 @@
import type { SearchParams } from 'nuqs/server'
import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'

View File

@@ -1,7 +1,7 @@
import { QueryClient } from '@tanstack/react-query'
import { cache } from 'react'
const STALE_TIME = 1000 * 60 * 5 // 5 minutes
const STALE_TIME = 1000 * 60 * 30 // 30 minutes
export function makeQueryClient() {
return new QueryClient({

View File

@@ -1,7 +1,9 @@
'use client'
import type { QueryClient } from '@tanstack/react-query'
import type { FC, PropsWithChildren } from 'react'
import { QueryClientProvider } from '@tanstack/react-query'
import { useState } from 'react'
import { TanStackDevtoolsLoader } from '@/app/components/devtools/tanstack/loader'
import { isServer } from '@/utils/client'
import { makeQueryClient } from './query-client-server'
@@ -17,8 +19,8 @@ function getQueryClient() {
return browserQueryClient
}
export const TanstackQueryInitializer = ({ children }: { children: React.ReactNode }) => {
const queryClient = getQueryClient()
export const TanstackQueryInitializer: FC<PropsWithChildren> = ({ children }) => {
const [queryClient] = useState(getQueryClient)
return (
<QueryClientProvider client={queryClient}>
{children}

View File

@@ -28,9 +28,12 @@
"scripts": {
"dev": "next dev",
"dev:inspect": "next dev --inspect",
"dev:vinext": "vinext dev",
"build": "next build",
"build:docker": "next build && node scripts/optimize-standalone.js",
"build:vinext": "vinext build",
"start": "node ./scripts/copy-and-start.mjs",
"start:vinext": "vinext start",
"lint": "eslint --cache --concurrency=auto",
"lint:ci": "eslint --cache --concurrency 2",
"lint:fix": "pnpm lint --fix",
@@ -173,6 +176,7 @@
"@iconify-json/ri": "1.2.9",
"@mdx-js/loader": "3.1.1",
"@mdx-js/react": "3.1.1",
"@mdx-js/rollup": "3.1.1",
"@next/eslint-plugin-next": "16.1.6",
"@next/mdx": "16.1.5",
"@rgrove/parse-xml": "4.2.0",
@@ -211,6 +215,7 @@
"@typescript-eslint/parser": "8.54.0",
"@typescript/native-preview": "7.0.0-dev.20251209.1",
"@vitejs/plugin-react": "5.1.2",
"@vitejs/plugin-rsc": "0.5.20",
"@vitest/coverage-v8": "4.0.17",
"autoprefixer": "10.4.21",
"code-inspector-plugin": "1.3.6",
@@ -233,6 +238,7 @@
"postcss": "8.5.6",
"postcss-js": "5.0.3",
"react-scan": "0.4.3",
"react-server-dom-webpack": "19.2.4",
"sass": "1.93.2",
"serwist": "9.5.4",
"storybook": "10.2.0",
@@ -240,6 +246,7 @@
"tsx": "4.21.0",
"typescript": "5.9.3",
"uglify-js": "3.19.3",
"vinext": "https://pkg.pr.new/hyoban/vinext@f07e125",
"vite": "7.3.1",
"vite-tsconfig-paths": "6.0.4",
"vitest": "4.0.17",

462
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -653,7 +653,7 @@ export const useMutationClearAllTaskPlugin = () => {
export const usePluginManifestInfo = (pluginUID: string) => {
return useQuery({
enabled: !!pluginUID,
queryKey: [NAME_SPACE, 'manifest', pluginUID],
queryKey: [[NAME_SPACE, 'manifest', pluginUID]],
queryFn: () => getMarketplace<{ data: { plugin: PluginInfoFromMarketPlace, version: { version: string } } }>(`/plugins/${pluginUID}`),
retry: 0,
})

View File

@@ -1,16 +1,58 @@
import path from 'node:path'
import { fileURLToPath } from 'node:url'
import type { Plugin } from 'vite'
import mdx from '@mdx-js/rollup'
import react from '@vitejs/plugin-react'
import { defineConfig } from 'vite'
import vinext from 'vinext'
import { defineConfig, loadEnv } from 'vite'
import tsconfigPaths from 'vite-tsconfig-paths'
const __dirname = path.dirname(fileURLToPath(import.meta.url))
const isCI = !!process.env.CI
export default defineConfig({
plugins: [tsconfigPaths(), react()],
resolve: {
alias: {
'~@': __dirname,
export default defineConfig(({ mode }) => {
const env = loadEnv(mode, process.cwd(), '')
return {
plugins: [
...(mode === 'test'
? [
react(),
{
// Stub .mdx files so components importing them can be unit-tested
name: 'mdx-stub',
enforce: 'pre',
transform(_, id) {
if (id.endsWith('.mdx'))
return { code: 'export default () => null', map: null }
},
} as Plugin,
]
: [
mdx(),
vinext(),
]),
tsconfigPaths(),
],
resolve: {
alias: {
'~@': __dirname,
},
},
},
optimizeDeps: {
exclude: ['nuqs'],
},
server: {
port: 3000,
},
envPrefix: 'NEXT_PUBLIC_',
test: {
environment: 'jsdom',
globals: true,
setupFiles: ['./vitest.setup.ts'],
coverage: {
provider: 'v8',
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
},
},
define: {
'process.env': env,
},
}
})

View File

@@ -1,27 +0,0 @@
import { defineConfig, mergeConfig } from 'vitest/config'
import viteConfig from './vite.config'
const isCI = !!process.env.CI
export default mergeConfig(viteConfig, defineConfig({
plugins: [
{
// Stub .mdx files so components importing them can be unit-tested
name: 'mdx-stub',
enforce: 'pre',
transform(_, id) {
if (id.endsWith('.mdx'))
return { code: 'export default () => null', map: null }
},
},
],
test: {
environment: 'jsdom',
globals: true,
setupFiles: ['./vitest.setup.ts'],
coverage: {
provider: 'v8',
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
},
},
}))