Compare commits

..

22 Commits

Author SHA1 Message Date
Stephen Zhou
56265a5217 fix: Instrument_Serif 2026-02-25 17:44:14 +08:00
Stephen Zhou
084eeac776 fix: constants shim 2026-02-25 17:38:27 +08:00
Stephen Zhou
3e92f85beb fix: image size plugin 2026-02-25 17:25:52 +08:00
Stephen Zhou
7bd987c6f1 use pkg new pr 2026-02-25 13:01:12 +08:00
Stephen Zhou
af80c10ed3 fix load module 2026-02-25 12:46:41 +08:00
Stephen Zhou
af6218d4b5 update 2026-02-25 11:33:27 +08:00
Stephen Zhou
84fda207a6 update 2026-02-25 11:32:20 +08:00
Stephen Zhou
eb81f0563d jiti v1 2026-02-25 11:31:23 +08:00
Stephen Zhou
cd22550454 load env 2026-02-25 11:28:18 +08:00
Stephen Zhou
1bbc2f147d env 2026-02-25 11:24:32 +08:00
Stephen Zhou
b787c0af7f import ts in next.config 2026-02-25 11:12:35 +08:00
Stephen Zhou
56c8bef073 chore: add vinext as dev server 2026-02-25 11:00:29 +08:00
akashseth-ifp
4e142f72e8 test(base): add test coverage for more base/form components (#32437)
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
2026-02-25 10:47:25 +08:00
木之本澪
a6456da393 test: migrate delete_archived_workflow_run SQL tests to testcontainers (#32549)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:18:52 +09:00
木之本澪
b863f8edbd test: migrate test_document_service_display_status SQL tests to testcontainers (#32545)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:13:22 +09:00
木之本澪
64296da7e7 test: migrate remove_app_and_related_data_task SQL tests to testcontainers (#32547)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:12:23 +09:00
木之本澪
02fef84d7f test: migrate node execution repository sql tests to testcontainers (#32524)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 05:01:26 +09:00
木之本澪
28f2098b00 test: migrate workflow trigger log repository sql tests to testcontainers (#32525)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:53:16 +09:00
木之本澪
59681ce760 test: migrate message extra contents tests to testcontainers (#32532)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:51:14 +09:00
木之本澪
4997b82a63 test: migrate end user service SQL tests to testcontainers (#32530)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 04:49:49 +09:00
木之本澪
3abfbc0246 test: migrate remaining DocumentSegment navigation SQL tests to testcontainers (#32523)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 02:51:38 +09:00
木之本澪
beea1acd92 test: migrate workflow run repository SQL tests to testcontainers (#32519)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
2026-02-25 01:36:39 +09:00
94 changed files with 5890 additions and 2943 deletions

View File

@@ -204,16 +204,6 @@ When assigned to test a directory/path, test **ALL content** within that path:
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
### `nuqs` Query State Testing (Required for URL State Hooks)
When a component or hook uses `useQueryState` / `useQueryStates`:
- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`)
- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`)
- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values)
- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable)
- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)

View File

@@ -80,9 +80,6 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`)
- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`)
- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values)
### Queries

View File

@@ -125,31 +125,6 @@ describe('Component', () => {
})
```
### 2.1 `nuqs` Query State (Preferred: Testing Adapter)
For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly.
```typescript
import { renderHookWithNuqs } from '@/test/nuqs-testing'
it('should sync query to URL with push history', async () => {
const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), {
searchParams: '?page=1',
})
act(() => {
result.current.setQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.options.history).toBe('push')
expect(update.searchParams.get('page')).toBe('2')
})
```
Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope.
### 3. Portal Components (with Shared State)
```typescript

View File

@@ -269,3 +269,221 @@ class TestDatasetDocumentProperties:
db_session_with_containers.flush()
assert doc.hit_count == 25
class TestDocumentSegmentNavigationProperties:
"""Integration tests for DocumentSegment navigation properties."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def test_document_segment_dataset_property(self, db_session_with_containers: Session) -> None:
"""Test segment can access its parent dataset."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add(segment)
db_session_with_containers.flush()
# Act
related_dataset = segment.dataset
# Assert
assert related_dataset is not None
assert related_dataset.id == dataset.id
def test_document_segment_document_property(self, db_session_with_containers: Session) -> None:
"""Test segment can access its parent document."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add(segment)
db_session_with_containers.flush()
# Act
related_document = segment.document
# Assert
assert related_document is not None
assert related_document.id == document.id
def test_document_segment_previous_segment(self, db_session_with_containers: Session) -> None:
"""Test segment can access previous segment."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
previous_segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Previous",
word_count=1,
tokens=2,
created_by=created_by,
)
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=2,
content="Current",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add_all([previous_segment, segment])
db_session_with_containers.flush()
# Act
prev_seg = segment.previous_segment
# Assert
assert prev_seg is not None
assert prev_seg.position == 1
def test_document_segment_next_segment(self, db_session_with_containers: Session) -> None:
"""Test segment can access next segment."""
# Arrange
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id,
name="Test Dataset",
data_source_type="upload_file",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
document = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=1,
content="Current",
word_count=1,
tokens=2,
created_by=created_by,
)
next_segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=document.id,
position=2,
content="Next",
word_count=1,
tokens=2,
created_by=created_by,
)
db_session_with_containers.add_all([segment, next_segment])
db_session_with_containers.flush()
# Act
next_seg = segment.next_segment
# Assert
assert next_seg is not None
assert next_seg.position == 2

View File

@@ -0,0 +1,143 @@
"""Integration tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository using testcontainers."""
from __future__ import annotations
from datetime import timedelta
from uuid import uuid4
from sqlalchemy import Engine, delete
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
def _create_node_execution(
session: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
status: WorkflowNodeExecutionStatus,
index: int,
created_by: str,
created_at_offset_seconds: int,
) -> WorkflowNodeExecutionModel:
now = naive_utc_now()
node_execution = WorkflowNodeExecutionModel(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
triggered_from="workflow-run",
workflow_run_id=workflow_run_id,
index=index,
predecessor_node_id=None,
node_execution_id=None,
node_id=f"node-{index}",
node_type="llm",
title=f"Node {index}",
inputs="{}",
process_data="{}",
outputs="{}",
status=status,
error=None,
elapsed_time=0.0,
execution_metadata="{}",
created_at=now + timedelta(seconds=created_at_offset_seconds),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
finished_at=None,
)
session.add(node_execution)
session.flush()
return node_execution
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
def test_get_executions_by_workflow_run_keeps_paused_records(self, db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
workflow_run_id = str(uuid4())
created_by = str(uuid4())
other_tenant_id = str(uuid4())
other_app_id = str(uuid4())
included_paused = _create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.PAUSED,
index=1,
created_by=created_by,
created_at_offset_seconds=0,
)
included_succeeded = _create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
index=2,
created_by=created_by,
created_at_offset_seconds=1,
)
_create_node_execution(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=str(uuid4()),
status=WorkflowNodeExecutionStatus.PAUSED,
index=3,
created_by=created_by,
created_at_offset_seconds=2,
)
_create_node_execution(
db_session_with_containers,
tenant_id=other_tenant_id,
app_id=other_app_id,
workflow_id=str(uuid4()),
workflow_run_id=workflow_run_id,
status=WorkflowNodeExecutionStatus.PAUSED,
index=4,
created_by=str(uuid4()),
created_at_offset_seconds=3,
)
db_session_with_containers.commit()
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(sessionmaker(bind=engine, expire_on_commit=False))
try:
results = repository.get_executions_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_run_id=workflow_run_id,
)
assert len(results) == 2
assert [result.id for result in results] == [included_paused.id, included_succeeded.id]
assert any(result.status == WorkflowNodeExecutionStatus.PAUSED for result in results)
assert all(result.tenant_id == tenant_id for result in results)
assert all(result.app_id == app_id for result in results)
assert all(result.workflow_run_id == workflow_run_id for result in results)
finally:
db_session_with_containers.execute(
delete(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.tenant_id.in_([tenant_id, other_tenant_id])
)
)
db_session_with_containers.commit()

View File

@@ -0,0 +1,506 @@
"""Integration tests for DifyAPISQLAlchemyWorkflowRunRepository using testcontainers."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from unittest.mock import Mock
from uuid import uuid4
import pytest
from sqlalchemy import Engine, delete, select
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities import WorkflowExecution
from core.workflow.entities.pause_reason import PauseReasonType
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_WorkflowRunError,
)
class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
"""Concrete repository for tests where save() is not under test."""
def save(self, execution: WorkflowExecution) -> None:
return None
@dataclass
class _TestScope:
"""Per-test data scope used to isolate DB rows and storage keys."""
tenant_id: str = field(default_factory=lambda: str(uuid4()))
app_id: str = field(default_factory=lambda: str(uuid4()))
workflow_id: str = field(default_factory=lambda: str(uuid4()))
user_id: str = field(default_factory=lambda: str(uuid4()))
state_keys: set[str] = field(default_factory=set)
def _create_workflow_run(
session: Session,
scope: _TestScope,
*,
status: WorkflowExecutionStatus,
created_at: datetime | None = None,
) -> WorkflowRun:
"""Create and persist a workflow run bound to the current test scope."""
workflow_run = WorkflowRun(
id=str(uuid4()),
tenant_id=scope.tenant_id,
app_id=scope.app_id,
workflow_id=scope.workflow_id,
type="workflow",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="draft",
graph="{}",
inputs="{}",
status=status,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=scope.user_id,
created_at=created_at or naive_utc_now(),
)
session.add(workflow_run)
session.commit()
return workflow_run
def _cleanup_scope_data(session: Session, scope: _TestScope) -> None:
"""Remove test-created DB rows and storage objects for a test scope."""
pause_ids_subquery = select(WorkflowPause.id).where(WorkflowPause.workflow_id == scope.workflow_id)
session.execute(delete(WorkflowPauseReason).where(WorkflowPauseReason.pause_id.in_(pause_ids_subquery)))
session.execute(delete(WorkflowPause).where(WorkflowPause.workflow_id == scope.workflow_id))
session.execute(
delete(WorkflowAppLog).where(
WorkflowAppLog.tenant_id == scope.tenant_id,
WorkflowAppLog.app_id == scope.app_id,
)
)
session.execute(
delete(WorkflowRun).where(
WorkflowRun.tenant_id == scope.tenant_id,
WorkflowRun.app_id == scope.app_id,
)
)
session.commit()
for state_key in scope.state_keys:
try:
storage.delete(state_key)
except FileNotFoundError:
continue
@pytest.fixture
def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository:
"""Build a repository backed by the testcontainers database engine."""
engine = db_session_with_containers.get_bind()
assert isinstance(engine, Engine)
return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False))
@pytest.fixture
def test_scope(db_session_with_containers: Session) -> _TestScope:
"""Provide an isolated scope and clean related data after each test."""
scope = _TestScope()
yield scope
_cleanup_scope_data(db_session_with_containers, scope)
class TestGetRunsBatchByTimeRange:
"""Integration tests for get_runs_batch_by_time_range."""
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Return only terminal workflow runs, excluding RUNNING and PAUSED."""
now = naive_utc_now()
ended_statuses = [
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.STOPPED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
]
ended_run_ids = {
_create_workflow_run(
db_session_with_containers,
test_scope,
status=status,
created_at=now - timedelta(minutes=3),
).id
for status in ended_statuses
}
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
created_at=now - timedelta(minutes=2),
)
_create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.PAUSED,
created_at=now - timedelta(minutes=1),
)
runs = repository.get_runs_batch_by_time_range(
start_from=now - timedelta(days=1),
end_before=now + timedelta(days=1),
last_seen=None,
batch_size=50,
tenant_ids=[test_scope.tenant_id],
)
returned_ids = {run.id for run in runs}
returned_statuses = {run.status for run in runs}
assert returned_ids == ended_run_ids
assert returned_statuses == set(ended_statuses)
class TestDeleteRunsWithRelated:
"""Integration tests for delete_runs_with_related."""
def test_uses_trigger_log_repository(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Delete run-related records and invoke injected trigger-log deleter."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
app_log = WorkflowAppLog(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
pause_reason = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.SCHEDULED_PAUSE,
message="scheduled pause",
)
db_session_with_containers.add_all([app_log, pause, pause_reason])
db_session_with_containers.commit()
fake_trigger_repo = Mock()
fake_trigger_repo.delete_by_run_ids.return_value = 3
counts = repository.delete_runs_with_related(
[workflow_run],
delete_node_executions=lambda session, runs: (2, 1),
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
)
fake_trigger_repo.delete_by_run_ids.assert_called_once_with([workflow_run.id])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 1
assert counts["pauses"] == 1
assert counts["pause_reasons"] == 1
assert counts["runs"] == 1
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
assert verification_session.get(WorkflowRun, workflow_run.id) is None
class TestCountRunsWithRelated:
"""Integration tests for count_runs_with_related."""
def test_uses_trigger_log_repository(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Count run-related records and invoke injected trigger-log counter."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
app_log = WorkflowAppLog(
tenant_id=test_scope.tenant_id,
app_id=test_scope.app_id,
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=test_scope.user_id,
)
pause = WorkflowPause(
id=str(uuid4()),
workflow_id=test_scope.workflow_id,
workflow_run_id=workflow_run.id,
state_object_key=f"workflow-state-{uuid4()}.json",
)
pause_reason = WorkflowPauseReason(
pause_id=pause.id,
type_=PauseReasonType.SCHEDULED_PAUSE,
message="scheduled pause",
)
db_session_with_containers.add_all([app_log, pause, pause_reason])
db_session_with_containers.commit()
fake_trigger_repo = Mock()
fake_trigger_repo.count_by_run_ids.return_value = 3
counts = repository.count_runs_with_related(
[workflow_run],
count_node_executions=lambda session, runs: (2, 1),
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
)
fake_trigger_repo.count_by_run_ids.assert_called_once_with([workflow_run.id])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 1
assert counts["pauses"] == 1
assert counts["pause_reasons"] == 1
assert counts["runs"] == 1
class TestCreateWorkflowPause:
"""Integration tests for create_workflow_pause."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Create pause successfully, persist pause record, and set run status to PAUSED."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
state = '{"test": "state"}'
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state=state,
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
db_session_with_containers.refresh(workflow_run)
assert workflow_run.status == WorkflowExecutionStatus.PAUSED
assert pause_entity.id == pause_model.id
assert pause_entity.workflow_execution_id == workflow_run.id
assert pause_entity.get_pause_reasons() == []
assert pause_entity.get_state() == state.encode()
def test_create_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
test_scope: _TestScope,
) -> None:
"""Raise ValueError when the workflow run does not exist."""
with pytest.raises(ValueError, match="WorkflowRun not found"):
repository.create_workflow_pause(
workflow_run_id=str(uuid4()),
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
def test_create_workflow_pause_invalid_status(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when pausing a run in non-pausable status."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.SUCCEEDED,
)
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
class TestResumeWorkflowPause:
"""Integration tests for resume_workflow_pause."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Resume pause successfully and switch workflow run status back to RUNNING."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
resumed_entity = repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
db_session_with_containers.refresh(workflow_run)
db_session_with_containers.refresh(pause_model)
assert resumed_entity.id == pause_entity.id
assert resumed_entity.resumed_at is not None
assert workflow_run.status == WorkflowExecutionStatus.RUNNING
assert pause_model.resumed_at is not None
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when workflow run is not in PAUSED status."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Raise _WorkflowRunError when pause entity ID mismatches persisted pause ID."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
test_scope.state_keys.add(pause_model.state_object_key)
mismatched_pause_entity = Mock(spec=WorkflowPauseEntity)
mismatched_pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run.id,
pause_entity=mismatched_pause_entity,
)
class TestDeleteWorkflowPause:
"""Integration tests for delete_workflow_pause."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
db_session_with_containers: Session,
test_scope: _TestScope,
) -> None:
"""Delete pause record and its state object from storage."""
workflow_run = _create_workflow_run(
db_session_with_containers,
test_scope,
status=WorkflowExecutionStatus.RUNNING,
)
pause_entity = repository.create_workflow_pause(
workflow_run_id=workflow_run.id,
state_owner_user_id=test_scope.user_id,
state='{"test": "state"}',
pause_reasons=[],
)
pause_model = db_session_with_containers.get(WorkflowPause, pause_entity.id)
assert pause_model is not None
state_key = pause_model.state_object_key
test_scope.state_keys.add(state_key)
repository.delete_workflow_pause(pause_entity=pause_entity)
with Session(bind=db_session_with_containers.get_bind()) as verification_session:
assert verification_session.get(WorkflowPause, pause_entity.id) is None
with pytest.raises(FileNotFoundError):
storage.load(state_key)
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
) -> None:
"""Raise _WorkflowRunError when deleting a non-existent pause."""
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = str(uuid4())
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"):
repository.delete_workflow_pause(pause_entity=pause_entity)

View File

@@ -0,0 +1,134 @@
"""Integration tests for SQLAlchemyWorkflowTriggerLogRepository using testcontainers."""
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from models.enums import AppTriggerType, CreatorUserRole, WorkflowTriggerStatus
from models.trigger import WorkflowTriggerLog
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
def _create_trigger_log(
session: Session,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_run_id: str,
created_by: str,
) -> WorkflowTriggerLog:
trigger_log = WorkflowTriggerLog(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
root_node_id=None,
trigger_metadata="{}",
trigger_type=AppTriggerType.TRIGGER_WEBHOOK,
trigger_data="{}",
inputs="{}",
outputs=None,
status=WorkflowTriggerStatus.SUCCEEDED,
error=None,
queue_name="default",
celery_task_id=None,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
retry_count=0,
)
session.add(trigger_log)
session.flush()
return trigger_log
def test_delete_by_run_ids_executes_delete(db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
created_by = str(uuid4())
run_id_1 = str(uuid4())
run_id_2 = str(uuid4())
untouched_run_id = str(uuid4())
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id_1,
created_by=created_by,
)
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id_2,
created_by=created_by,
)
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=untouched_run_id,
created_by=created_by,
)
db_session_with_containers.commit()
repository = SQLAlchemyWorkflowTriggerLogRepository(db_session_with_containers)
try:
deleted = repository.delete_by_run_ids([run_id_1, run_id_2])
db_session_with_containers.commit()
assert deleted == 2
remaining_logs = db_session_with_containers.scalars(
select(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id)
).all()
assert len(remaining_logs) == 1
assert remaining_logs[0].workflow_run_id == untouched_run_id
finally:
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id))
db_session_with_containers.commit()
def test_delete_by_run_ids_empty_short_circuits(db_session_with_containers: Session) -> None:
tenant_id = str(uuid4())
app_id = str(uuid4())
workflow_id = str(uuid4())
created_by = str(uuid4())
run_id = str(uuid4())
_create_trigger_log(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
workflow_run_id=run_id,
created_by=created_by,
)
db_session_with_containers.commit()
repository = SQLAlchemyWorkflowTriggerLogRepository(db_session_with_containers)
try:
deleted = repository.delete_by_run_ids([])
db_session_with_containers.commit()
assert deleted == 0
remaining_count = db_session_with_containers.scalar(
select(func.count())
.select_from(WorkflowTriggerLog)
.where(WorkflowTriggerLog.tenant_id == tenant_id)
.where(WorkflowTriggerLog.workflow_run_id == run_id)
)
assert remaining_count == 1
finally:
db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.tenant_id == tenant_id))
db_session_with_containers.commit()

View File

@@ -0,0 +1,143 @@
"""
Testcontainers integration tests for archived workflow run deletion service.
"""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
from sqlalchemy import select
from core.workflow.enums import WorkflowExecutionStatus
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom
from models.workflow import WorkflowArchiveLog, WorkflowRun
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
class TestArchivedWorkflowRunDeletion:
def _create_workflow_run(
self,
db_session_with_containers,
*,
tenant_id: str,
created_at: datetime,
) -> WorkflowRun:
run = WorkflowRun(
id=str(uuid4()),
tenant_id=tenant_id,
app_id=str(uuid4()),
workflow_id=str(uuid4()),
type="workflow",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
version="1.0.0",
graph="{}",
inputs="{}",
status=WorkflowExecutionStatus.SUCCEEDED,
outputs="{}",
elapsed_time=0.1,
total_tokens=1,
total_steps=1,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid4()),
created_at=created_at,
finished_at=created_at,
exceptions_count=0,
)
db_session_with_containers.add(run)
db_session_with_containers.commit()
return run
def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None:
archive_log = WorkflowArchiveLog(
tenant_id=run.tenant_id,
app_id=run.app_id,
workflow_id=run.workflow_id,
workflow_run_id=run.id,
created_by_role=run.created_by_role,
created_by=run.created_by,
log_id=None,
log_created_at=None,
log_created_from=None,
run_version=run.version,
run_status=run.status,
run_triggered_from=run.triggered_from,
run_error=run.error,
run_elapsed_time=run.elapsed_time,
run_total_tokens=run.total_tokens,
run_total_steps=run.total_steps,
run_created_at=run.created_at,
run_finished_at=run.finished_at,
run_exceptions_count=run.exceptions_count,
trigger_metadata=None,
)
db_session_with_containers.add(archive_log)
db_session_with_containers.commit()
def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers):
deleter = ArchivedWorkflowRunDeletion()
missing_run_id = str(uuid4())
result = deleter.delete_by_run_id(missing_run_id)
assert result.success is False
assert result.error == f"Workflow run {missing_run_id} not found"
def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
deleter = ArchivedWorkflowRunDeletion()
result = deleter.delete_by_run_id(run.id)
assert result.success is False
assert result.error == f"Workflow run {run.id} is not archived"
def test_delete_batch_uses_repo(self, db_session_with_containers):
tenant_id = str(uuid4())
base_time = datetime.now(UTC)
run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time)
run2 = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=base_time + timedelta(seconds=1),
)
self._create_archive_log(db_session_with_containers, run=run1)
self._create_archive_log(db_session_with_containers, run=run2)
run_ids = [run1.id, run2.id]
deleter = ArchivedWorkflowRunDeletion()
results = deleter.delete_batch(
tenant_ids=[tenant_id],
start_date=base_time - timedelta(minutes=1),
end_date=base_time + timedelta(minutes=1),
limit=2,
)
assert len(results) == 2
assert all(result.success for result in results)
remaining_runs = db_session_with_containers.scalars(
select(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
).all()
assert remaining_runs == []
def test_delete_run_calls_repo(self, db_session_with_containers):
tenant_id = str(uuid4())
run = self._create_workflow_run(
db_session_with_containers,
tenant_id=tenant_id,
created_at=datetime.now(UTC),
)
run_id = run.id
deleter = ArchivedWorkflowRunDeletion()
result = deleter._delete_run(run)
assert result.success is True
assert result.deleted_counts["runs"] == 1
db_session_with_containers.expunge_all()
deleted_run = db_session_with_containers.get(WorkflowRun, run_id)
assert deleted_run is None

View File

@@ -0,0 +1,143 @@
import datetime
from uuid import uuid4
from sqlalchemy import select
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService
def _create_dataset(db_session_with_containers) -> Dataset:
dataset = Dataset(
tenant_id=str(uuid4()),
name=f"dataset-{uuid4()}",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = str(uuid4())
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_document(
db_session_with_containers,
*,
dataset_id: str,
tenant_id: str,
indexing_status: str,
enabled: bool = True,
archived: bool = False,
is_paused: bool = False,
position: int = 1,
) -> Document:
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type="upload_file",
data_source_info="{}",
batch=f"batch-{uuid4()}",
name=f"doc-{uuid4()}",
created_from="web",
created_by=str(uuid4()),
doc_form="text_model",
)
document.id = str(uuid4())
document.indexing_status = indexing_status
document.enabled = enabled
document.archived = archived
document.is_paused = is_paused
if indexing_status == "completed":
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def test_build_display_status_filters_available(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
available_doc = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=True,
archived=False,
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=False,
archived=False,
position=2,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
enabled=True,
archived=True,
position=3,
)
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
rows = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id, *filters)).all()
assert [row.id for row in rows] == [available_doc.id]
def test_apply_display_status_filter_applies_when_status_present(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
waiting_doc = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
position=1,
)
_create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
position=2,
)
query = select(Document).where(Document.dataset_id == dataset.id)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
rows = db_session_with_containers.scalars(filtered).all()
assert [row.id for row in rows] == [waiting_doc.id]
def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_containers):
dataset = _create_dataset(db_session_with_containers)
doc1 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="waiting",
position=1,
)
doc2 = _create_document(
db_session_with_containers,
dataset_id=dataset.id,
tenant_id=dataset.tenant_id,
indexing_status="completed",
position=2,
)
query = select(Document).where(Document.dataset_id == dataset.id)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
rows = db_session_with_containers.scalars(filtered).all()
assert {row.id for row in rows} == {doc1.id, doc2.id}

View File

@@ -0,0 +1,416 @@
from __future__ import annotations
from unittest.mock import patch
from uuid import uuid4
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account, Tenant, TenantAccountJoin
from models.model import App, DefaultEndUserSessionID, EndUser
from services.end_user_service import EndUserService
class TestEndUserServiceFactory:
"""Factory class for creating test data and mock objects for end user service tests."""
@staticmethod
def create_app_and_account(db_session_with_containers):
tenant = Tenant(name=f"Tenant {uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
account = Account(
name=f"Account {uuid4()}",
email=f"end_user_{uuid4()}@example.com",
password="hashed-password",
password_salt="salt",
interface_language="en-US",
timezone="UTC",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
tenant_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role="owner",
current=True,
)
db_session_with_containers.add(tenant_join)
db_session_with_containers.flush()
app = App(
tenant_id=tenant.id,
name=f"App {uuid4()}",
description="",
mode="chat",
icon_type="emoji",
icon="bot",
icon_background="#FFFFFF",
enable_site=False,
enable_api=True,
api_rpm=100,
api_rph=100,
is_demo=False,
is_public=False,
is_universal=False,
created_by=account.id,
updated_by=account.id,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_end_user(
db_session_with_containers,
*,
tenant_id: str,
app_id: str,
session_id: str,
invoke_type: InvokeFrom,
is_anonymous: bool = False,
):
end_user = EndUser(
tenant_id=tenant_id,
app_id=app_id,
type=invoke_type,
external_user_id=session_id,
name=f"User-{uuid4()}",
is_anonymous=is_anonymous,
session_id=session_id,
)
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
user_id = "custom-user-123"
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result.tenant_id == app.tenant_id
assert result.app_id == app.id
assert result.session_id == user_id
assert result.type == InvokeFrom.SERVICE_API
assert result.is_anonymous is False
def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
assert result.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
def test_get_existing_end_user(self, db_session_with_containers, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
user_id = "existing-user-123"
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result.id == existing_user.id
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
This test suite covers:
- Creating end users with different InvokeFrom types
- Type migration for legacy users
- Query ordering and prioritization
- Session management
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_create_end_user_service_api_type(self, db_session_with_containers, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == InvokeFrom.SERVICE_API
assert result.tenant_id == tenant_id
assert result.app_id == app_id
assert result.session_id == user_id
def test_create_end_user_web_app_type(self, db_session_with_containers, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == InvokeFrom.WEB_APP
@patch("services.end_user_service.logger")
def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == existing_user.id
assert result.type == InvokeFrom.WEB_APP # Type should be updated
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
@patch("services.end_user_service.logger")
def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == existing_user.id
assert result.type == InvokeFrom.SERVICE_API
mock_logger.info.assert_not_called()
def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
assert result.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert result._is_anonymous is True
assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "user-789"
non_matching = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.WEB_APP,
)
matching = factory.create_end_user(
db_session_with_containers,
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
invoke_type=InvokeFrom.SERVICE_API,
)
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.id == matching.id
assert result.id != non_matching.id
def test_external_user_id_matches_session_id(self, db_session_with_containers, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = "custom-external-id"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.external_user_id == user_id
assert result.session_id == user_id
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
app = factory.create_app_and_account(db_session_with_containers)
tenant_id = app.tenant_id
app_id = app.id
user_id = f"user-{uuid4()}"
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory):
app = factory.create_app_and_account(db_session_with_containers)
existing_user = factory.create_end_user(
db_session_with_containers,
tenant_id=app.tenant_id,
app_id=app.id,
session_id=f"session-{uuid4()}",
invoke_type=InvokeFrom.SERVICE_API,
)
result = EndUserService.get_end_user_by_id(
tenant_id=app.tenant_id,
app_id=app.id,
end_user_id=existing_user.id,
)
assert result is not None
assert result.id == existing_user.id
def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory):
app = factory.create_app_and_account(db_session_with_containers)
result = EndUserService.get_end_user_by_id(
tenant_id=app.tenant_id,
app_id=app.id,
end_user_id=str(uuid4()),
)
assert result is None

View File

@@ -0,0 +1,63 @@
from __future__ import annotations
from decimal import Decimal
import pytest
from models.model import Message
from services import message_service
from tests.test_containers_integration_tests.helpers.execution_extra_content import (
create_human_input_message_fixture,
)
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
def test_attach_message_extra_contents_assigns_serialized_payload(db_session_with_containers) -> None:
fixture = create_human_input_message_fixture(db_session_with_containers)
message_without_extra_content = Message(
app_id=fixture.app.id,
model_provider=None,
model_id="",
override_model_configs=None,
conversation_id=fixture.conversation.id,
inputs={},
query="Query without extra content",
message={"messages": [{"role": "user", "content": "Query without extra content"}]},
message_tokens=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
answer="Answer without extra content",
answer_tokens=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
parent_message_id=None,
provider_response_latency=0,
total_price=Decimal(0),
currency="USD",
status="normal",
from_source="console",
from_account_id=fixture.account.id,
)
db_session_with_containers.add(message_without_extra_content)
db_session_with_containers.commit()
messages = [fixture.message, message_without_extra_content]
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": fixture.message.workflow_run_id,
"submitted": True,
"form_submission_data": {
"node_id": fixture.form.node_id,
"node_title": fixture.node_title,
"rendered_content": fixture.form.rendered_content,
"action_id": fixture.action_id,
"action_text": fixture.action_text,
},
}
]
assert messages[1].extra_contents == []

View File

@@ -0,0 +1,224 @@
import uuid
from unittest.mock import ANY, call, patch
import pytest
from core.db.session_factory import session_factory
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import (
_delete_draft_variable_offload_data,
delete_draft_variables_batch,
)
@pytest.fixture(autouse=True)
def cleanup_database(db_session_with_containers):
db_session_with_containers.query(WorkflowDraftVariable).delete()
db_session_with_containers.query(WorkflowDraftVariableFile).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.commit()
def _create_tenant_and_app(db_session_with_containers):
tenant = Tenant(name=f"test_tenant_{uuid.uuid4()}")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
app = App(
tenant_id=tenant.id,
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db_session_with_containers.add(app)
db_session_with_containers.commit()
return tenant, app
def _create_draft_variables(
db_session_with_containers,
*,
app_id: str,
count: int,
file_id_by_index: dict[int, str] | None = None,
) -> list[WorkflowDraftVariable]:
variables: list[WorkflowDraftVariable] = []
file_id_by_index = file_id_by_index or {}
for i in range(count):
variable = WorkflowDraftVariable.new_node_variable(
app_id=app_id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
file_id=file_id_by_index.get(i),
)
db_session_with_containers.add(variable)
variables.append(variable)
db_session_with_containers.commit()
return variables
def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: str, count: int):
upload_files: list[UploadFile] = []
variable_files: list[WorkflowDraftVariableFile] = []
for i in range(count):
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key=f"test/file-{uuid.uuid4()}-{i}.json",
name=f"file-{i}.json",
size=1024 + i,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db_session_with_containers.add(upload_file)
db_session_with_containers.flush()
upload_files.append(upload_file)
variable_file = WorkflowDraftVariableFile(
tenant_id=tenant_id,
app_id=app_id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file.id,
size=1024 + i,
length=10 + i,
value_type=SegmentType.STRING,
)
db_session_with_containers.add(variable_file)
db_session_with_containers.flush()
variable_files.append(variable_file)
db_session_with_containers.commit()
return {
"upload_files": upload_files,
"variable_files": variable_files,
}
class TestDeleteDraftVariablesBatch:
def test_delete_draft_variables_batch_success(self, db_session_with_containers):
"""Test successful deletion of draft variables in batches."""
_, app1 = _create_tenant_and_app(db_session_with_containers)
_, app2 = _create_tenant_and_app(db_session_with_containers)
_create_draft_variables(db_session_with_containers, app_id=app1.id, count=150)
_create_draft_variables(db_session_with_containers, app_id=app2.id, count=100)
result = delete_draft_variables_batch(app1.id, batch_size=100)
assert result == 150
app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app1.id
)
app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where(
WorkflowDraftVariable.app_id == app2.id
)
assert app1_remaining.count() == 0
assert app2_remaining.count() == 100
def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers):
"""Test deletion when no draft variables exist for the app."""
result = delete_draft_variables_batch(str(uuid.uuid4()), 1000)
assert result == 0
assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(
self, mock_logger, mock_offload_cleanup, db_session_with_containers
):
"""Test that batch deletion logs progress correctly."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=10)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
file_id_by_index: dict[int, str] = {}
for i in range(30):
if i % 3 == 0:
file_id_by_index[i] = file_ids[i // 3]
_create_draft_variables(db_session_with_containers, app_id=app.id, count=30, file_id_by_index=file_id_by_index)
mock_offload_cleanup.return_value = len(file_id_by_index)
result = delete_draft_variables_batch(app.id, 50)
assert result == 30
mock_offload_cleanup.assert_called_once()
_, called_file_ids = mock_offload_cleanup.call_args.args
assert {str(file_id) for file_id in called_file_ids} == {str(file_id) for file_id in file_id_by_index.values()}
assert mock_logger.info.call_count == 2
mock_logger.info.assert_any_call(ANY)
class TestDeleteDraftVariableOffloadData:
"""Test the Offload data cleanup functionality."""
@patch("extensions.ext_storage.storage")
def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers):
"""Test successful deletion of offload data."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
upload_file_keys = [upload_file.key for upload_file in offload_data["upload_files"]]
upload_file_ids = [upload_file.id for upload_file in offload_data["upload_files"]]
with session_factory.create_session() as session, session.begin():
result = _delete_draft_variable_offload_data(session, file_ids)
assert result == 3
expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_storage_failure(
self, mock_logging, mock_storage, db_session_with_containers
):
"""Test handling of storage deletion failures."""
tenant, app = _create_tenant_and_app(db_session_with_containers)
offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=2)
file_ids = [variable_file.id for variable_file in offload_data["variable_files"]]
storage_keys = [upload_file.key for upload_file in offload_data["upload_files"]]
upload_file_ids = [upload_file.id for upload_file in offload_data["upload_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
with session_factory.create_session() as session, session.begin():
result = _delete_draft_variable_offload_data(session, file_ids)
assert result == 1
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0])
remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where(
WorkflowDraftVariableFile.id.in_(file_ids)
)
remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids))
assert remaining_var_files.count() == 0
assert remaining_upload_files.count() == 0

View File

@@ -954,148 +954,6 @@ class TestChildChunk:
assert child_chunk.index_node_hash == index_node_hash
class TestDocumentSegmentNavigation:
"""Test suite for DocumentSegment navigation properties."""
def test_document_segment_dataset_property(self):
"""Test segment can access its parent dataset."""
# Arrange
dataset_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=dataset_id,
document_id=str(uuid4()),
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
mock_dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
mock_dataset.id = dataset_id
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=mock_dataset):
# Act
dataset = segment.dataset
# Assert
assert dataset is not None
assert dataset.id == dataset_id
def test_document_segment_document_property(self):
"""Test segment can access its parent document."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
mock_document = Document(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=str(uuid4()),
)
mock_document.id = document_id
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=mock_document):
# Act
document = segment.document
# Assert
assert document is not None
assert document.id == document_id
def test_document_segment_previous_segment(self):
"""Test segment can access previous segment."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=2,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
previous_segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Previous",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=previous_segment):
# Act
prev_seg = segment.previous_segment
# Assert
assert prev_seg is not None
assert prev_seg.position == 1
def test_document_segment_next_segment(self):
"""Test segment can access next segment."""
# Arrange
document_id = str(uuid4())
segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=1,
content="Test",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
next_segment = DocumentSegment(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
document_id=document_id,
position=2,
content="Next",
word_count=1,
tokens=2,
created_by=str(uuid4()),
)
# Mock the database session scalar
with patch("models.dataset.db.session.scalar", return_value=next_segment):
# Act
next_seg = segment.next_segment
# Assert
assert next_seg is not None
assert next_seg.position == 2
class TestModelIntegration:
"""Test suite for model integration scenarios."""

View File

@@ -1,40 +0,0 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation."""
from unittest.mock import Mock
from sqlalchemy.orm import Session, sessionmaker
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
def test_get_executions_by_workflow_run_keeps_paused_records(self):
mock_session = Mock(spec=Session)
execute_result = Mock()
execute_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = execute_result
session_maker = Mock(spec=sessionmaker)
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker)
repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-123",
workflow_run_id="workflow-run-123",
)
stmt = mock_session.execute.call_args[0][0]
where_clauses = list(getattr(stmt, "_where_criteria", []) or [])
where_strs = [str(clause).lower() for clause in where_clauses]
assert any("tenant_id" in clause for clause in where_strs)
assert any("app_id" in clause for clause in where_strs)
assert any("workflow_run_id" in clause for clause in where_strs)
assert not any("paused" in clause for clause in where_strs)

View File

@@ -1,435 +1,50 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
"""Unit tests for non-SQL helper logic in workflow run repository."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from models.workflow import WorkflowPauseReason
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
class TestDifyAPISQLAlchemyWorkflowRunRepository:
"""Test DifyAPISQLAlchemyWorkflowRunRepository implementation."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=Session)
@pytest.fixture
def mock_session_maker(self, mock_session):
"""Create a mock sessionmaker."""
session_maker = Mock(spec=sessionmaker)
# Create a context manager mock
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
# Mock session.begin() context manager
begin_context_manager = Mock()
begin_context_manager.__enter__ = Mock(return_value=None)
begin_context_manager.__exit__ = Mock(return_value=None)
mock_session.begin = Mock(return_value=begin_context_manager)
# Add missing session methods
mock_session.commit = Mock()
mock_session.rollback = Mock()
mock_session.add = Mock()
mock_session.delete = Mock()
mock_session.get = Mock()
mock_session.scalar = Mock()
mock_session.scalars = Mock()
# Also support expire_on_commit parameter
def make_session(expire_on_commit=None):
cm = Mock()
cm.__enter__ = Mock(return_value=mock_session)
cm.__exit__ = Mock(return_value=None)
return cm
session_maker.side_effect = make_session
return session_maker
@pytest.fixture
def repository(self, mock_session_maker):
"""Create repository instance with mocked dependencies."""
# Create a testable subclass that implements the save method
class TestableDifyAPISQLAlchemyWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository):
def __init__(self, session_maker):
# Initialize without calling parent __init__ to avoid any instantiation issues
self._session_maker = session_maker
def save(self, execution):
"""Mock implementation of save method."""
return None
# Create repository instance
repo = TestableDifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker)
return repo
@pytest.fixture
def sample_workflow_run(self):
"""Create a sample WorkflowRun model."""
workflow_run = Mock(spec=WorkflowRun)
workflow_run.id = "workflow-run-123"
workflow_run.tenant_id = "tenant-123"
workflow_run.app_id = "app-123"
workflow_run.workflow_id = "workflow-123"
workflow_run.status = WorkflowExecutionStatus.RUNNING
return workflow_run
@pytest.fixture
def sample_workflow_pause(self):
"""Create a sample WorkflowPauseModel."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestGetRunsBatchByTimeRange(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_get_runs_batch_by_time_range_filters_terminal_statuses(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
scalar_result = Mock()
scalar_result.all.return_value = []
mock_session.scalars.return_value = scalar_result
repository.get_runs_batch_by_time_range(
start_from=None,
end_before=datetime(2024, 1, 1),
last_seen=None,
batch_size=50,
)
stmt = mock_session.scalars.call_args[0][0]
compiled_sql = str(
stmt.compile(
dialect=postgresql.dialect(),
compile_kwargs={"literal_binds": True},
)
)
assert "workflow_runs.status" in compiled_sql
for status in (
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.STOPPED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
):
assert f"'{status.value}'" in compiled_sql
assert "'running'" not in compiled_sql
assert "'paused'" not in compiled_sql
class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test create_workflow_pause method."""
def test_create_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test successful workflow pause creation."""
# Arrange
workflow_run_id = "workflow-run-123"
state_owner_user_id = "user-123"
state = '{"test": "state"}'
mock_session.get.return_value = sample_workflow_run
with patch("repositories.sqlalchemy_api_workflow_run_repository.uuidv7") as mock_uuidv7:
mock_uuidv7.side_effect = ["pause-123"]
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
result = repository.create_workflow_pause(
workflow_run_id=workflow_run_id,
state_owner_user_id=state_owner_user_id,
state=state,
pause_reasons=[],
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
assert result.workflow_execution_id == workflow_run_id
assert result.get_pause_reasons() == []
# Verify database interactions
mock_session.get.assert_called_once_with(WorkflowRun, workflow_run_id)
mock_storage.save.assert_called_once()
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_create_workflow_pause_not_found(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock
):
"""Test workflow pause creation when workflow run not found."""
# Arrange
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(ValueError, match="WorkflowRun not found: workflow-run-123"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
mock_session.get.assert_called_once_with(WorkflowRun, "workflow-run-123")
def test_create_workflow_pause_invalid_status(
self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock, sample_workflow_run: Mock
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
state='{"test": "state"}',
pause_reasons=[],
)
class TestDeleteRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
node_ids_result = Mock()
node_ids_result.all.return_value = []
pause_ids_result = Mock()
pause_ids_result.all.return_value = []
mock_session.scalars.side_effect = [node_ids_result, pause_ids_result]
# app_logs delete, runs delete
mock_session.execute.side_effect = [Mock(rowcount=0), Mock(rowcount=1)]
fake_trigger_repo = Mock()
fake_trigger_repo.delete_by_run_ids.return_value = 3
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
counts = repository.delete_runs_with_related(
[run],
delete_node_executions=lambda session, runs: (2, 1),
delete_trigger_logs=lambda session, run_ids: fake_trigger_repo.delete_by_run_ids(run_ids),
)
fake_trigger_repo.delete_by_run_ids.assert_called_once_with(["run-1"])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["runs"] == 1
class TestCountRunsWithRelated(TestDifyAPISQLAlchemyWorkflowRunRepository):
def test_uses_trigger_log_repository(self, repository: DifyAPISQLAlchemyWorkflowRunRepository, mock_session: Mock):
pause_ids_result = Mock()
pause_ids_result.all.return_value = ["pause-1", "pause-2"]
mock_session.scalars.return_value = pause_ids_result
mock_session.scalar.side_effect = [5, 2]
fake_trigger_repo = Mock()
fake_trigger_repo.count_by_run_ids.return_value = 3
run = Mock(id="run-1", tenant_id="t1", app_id="a1", workflow_id="w1", triggered_from="tf")
counts = repository.count_runs_with_related(
[run],
count_node_executions=lambda session, runs: (2, 1),
count_trigger_logs=lambda session, run_ids: fake_trigger_repo.count_by_run_ids(run_ids),
)
fake_trigger_repo.count_by_run_ids.assert_called_once_with(["run-1"])
assert counts["node_executions"] == 2
assert counts["offloads"] == 1
assert counts["trigger_logs"] == 3
assert counts["app_logs"] == 5
assert counts["pauses"] == 2
assert counts["pause_reasons"] == 2
assert counts["runs"] == 1
class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test resume_workflow_pause method."""
def test_resume_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause resume."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
# Setup workflow run and pause
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.pause = sample_workflow_pause
sample_workflow_pause.resumed_at = None
mock_session.scalar.return_value = sample_workflow_run
mock_session.scalars.return_value.all.return_value = []
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
mock_now.return_value = datetime.now(UTC)
# Act
result = repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
# Assert
assert isinstance(result, _PrivateWorkflowPauseEntity)
assert result.id == "pause-123"
# Verify state transitions
assert sample_workflow_pause.resumed_at is not None
assert sample_workflow_run.status == WorkflowExecutionStatus.RUNNING
# Verify database interactions
mock_session.add.assert_called()
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_resume_workflow_pause_not_paused(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
):
"""Test resume when workflow is not paused."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
sample_workflow_run.status = WorkflowExecutionStatus.RUNNING
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowRun is not in PAUSED status"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
def test_resume_workflow_pause_id_mismatch(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_run: Mock,
sample_workflow_pause: Mock,
):
"""Test resume when pause ID doesn't match."""
# Arrange
workflow_run_id = "workflow-run-123"
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-456" # Different ID
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_pause.id = "pause-123"
sample_workflow_run.pause = sample_workflow_pause
mock_session.scalar.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="different id in WorkflowPause and WorkflowPauseEntity"):
repository.resume_workflow_pause(
workflow_run_id=workflow_run_id,
pause_entity=pause_entity,
)
class TestDeleteWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
"""Test delete_workflow_pause method."""
def test_delete_workflow_pause_success(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
sample_workflow_pause: Mock,
):
"""Test successful workflow pause deletion."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = sample_workflow_pause
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
# Act
repository.delete_workflow_pause(pause_entity=pause_entity)
# Assert
mock_storage.delete.assert_called_once_with(sample_workflow_pause.state_object_key)
mock_session.delete.assert_called_once_with(sample_workflow_pause)
# When using session.begin() context manager, commit is handled automatically
# No explicit commit call is expected
def test_delete_workflow_pause_not_found(
self,
repository: DifyAPISQLAlchemyWorkflowRunRepository,
mock_session: Mock,
):
"""Test delete when pause not found."""
# Arrange
pause_entity = Mock(spec=WorkflowPauseEntity)
pause_entity.id = "pause-123"
mock_session.get.return_value = None
# Act & Assert
with pytest.raises(_WorkflowRunError, match="WorkflowPause not found: pause-123"):
repository.delete_workflow_pause(pause_entity=pause_entity)
class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository):
@pytest.fixture
def sample_workflow_pause() -> Mock:
"""Create a sample WorkflowPause model."""
pause = Mock(spec=WorkflowPauseModel)
pause.id = "pause-123"
pause.workflow_id = "workflow-123"
pause.workflow_run_id = "workflow-run-123"
pause.state_object_key = "workflow-state-123.json"
pause.resumed_at = None
pause.created_at = datetime.now(UTC)
return pause
class TestPrivateWorkflowPauseEntity:
"""Test _PrivateWorkflowPauseEntity class."""
def test_properties(self, sample_workflow_pause: Mock):
def test_properties(self, sample_workflow_pause: Mock) -> None:
"""Test entity properties."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
# Act & Assert
# Assert
assert entity.id == sample_workflow_pause.id
assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id
assert entity.resumed_at == sample_workflow_pause.resumed_at
def test_get_state(self, sample_workflow_pause: Mock):
def test_get_state(self, sample_workflow_pause: Mock) -> None:
"""Test getting state from storage."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
@@ -445,7 +60,7 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
assert result == expected_state
mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key)
def test_get_state_caching(self, sample_workflow_pause: Mock):
def test_get_state_caching(self, sample_workflow_pause: Mock) -> None:
"""Test state caching in get_state method."""
# Arrange
entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[])
@@ -456,16 +71,20 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
# Act
result1 = entity.get_state()
result2 = entity.get_state() # Should use cache
result2 = entity.get_state()
# Assert
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once() # Only called once due to caching
mock_storage.load.assert_called_once()
class TestBuildHumanInputRequiredReason:
def test_prefers_backstage_token_when_available(self):
"""Test helper that builds HumanInputRequired pause reasons."""
def test_prefers_backstage_token_when_available(self) -> None:
"""Use backstage token when multiple recipient types may exist."""
# Arrange
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
@@ -504,8 +123,10 @@ class TestBuildHumanInputRequiredReason:
access_token=access_token,
)
# Act
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
# Assert
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"

View File

@@ -1,31 +0,0 @@
from unittest.mock import Mock
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
def test_delete_by_run_ids_executes_delete():
session = Mock(spec=Session)
session.execute.return_value = Mock(rowcount=2)
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
deleted = repo.delete_by_run_ids(["run-1", "run-2"])
stmt = session.execute.call_args[0][0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
assert "workflow_trigger_logs" in compiled_sql
assert "'run-1'" in compiled_sql
assert "'run-2'" in compiled_sql
assert deleted == 2
def test_delete_by_run_ids_empty_short_circuits():
session = Mock(spec=Session)
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
deleted = repo.delete_by_run_ids([])
session.execute.assert_not_called()
assert deleted == 0

View File

@@ -6,66 +6,6 @@ from unittest.mock import MagicMock, patch
class TestArchivedWorkflowRunDeletion:
def test_delete_by_run_id_returns_error_when_run_missing(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
session = MagicMock()
session.get.return_value = None
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 not found"
repo.get_archived_run_ids.assert_not_called()
def test_delete_by_run_id_returns_error_when_not_archived(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
repo.get_archived_run_ids.return_value = set()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
session = MagicMock()
session.get.return_value = run
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(deleter, "_delete_run") as mock_delete_run,
):
result = deleter.delete_by_run_id("run-1")
assert result.success is False
assert result.error == "Workflow run run-1 is not archived"
mock_delete_run.assert_not_called()
def test_delete_by_run_id_calls_delete_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@@ -98,55 +38,6 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is True
mock_delete_run.assert_called_once_with(run)
def test_delete_batch_uses_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
repo = MagicMock()
run1 = MagicMock()
run1.id = "run-1"
run1.tenant_id = "tenant-1"
run2 = MagicMock()
run2.id = "run-2"
run2.tenant_id = "tenant-1"
repo.get_archived_runs_by_time_range.return_value = [run1, run2]
session = MagicMock()
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
session_maker.return_value.__exit__.return_value = None
start_date = MagicMock()
end_date = MagicMock()
mock_db = MagicMock()
mock_db.engine = MagicMock()
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(
deleter, "_delete_run", side_effect=[MagicMock(success=True), MagicMock(success=True)]
) as mock_delete_run,
):
results = deleter.delete_batch(
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert len(results) == 2
repo.get_archived_runs_by_time_range.assert_called_once_with(
session=session,
tenant_ids=["tenant-1"],
start_date=start_date,
end_date=end_date,
limit=2,
)
assert mock_delete_run.call_count == 2
def test_delete_run_dry_run(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
@@ -160,21 +51,3 @@ class TestArchivedWorkflowRunDeletion:
assert result.success is True
mock_get_repo.assert_not_called()
def test_delete_run_calls_repo(self):
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
deleter = ArchivedWorkflowRunDeletion()
run = MagicMock()
run.id = "run-1"
run.tenant_id = "tenant-1"
repo = MagicMock()
repo.delete_runs_with_related.return_value = {"runs": 1}
with patch.object(deleter, "_get_workflow_run_repo", return_value=repo):
result = deleter._delete_run(run)
assert result.success is True
assert result.deleted_counts == {"runs": 1}
repo.delete_runs_with_related.assert_called_once()

View File

@@ -1,6 +1,3 @@
import sqlalchemy as sa
from models.dataset import Document
from services.dataset_service import DocumentService
@@ -9,25 +6,3 @@ def test_normalize_display_status_alias_mapping():
assert DocumentService.normalize_display_status("enabled") == "available"
assert DocumentService.normalize_display_status("archived") == "archived"
assert DocumentService.normalize_display_status("unknown") is None
def test_build_display_status_filters_available():
filters = DocumentService.build_display_status_filters("available")
assert len(filters) == 3
for condition in filters:
assert condition is not None
def test_apply_display_status_filter_applies_when_status_present():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "queuing")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" in compiled
assert "documents.indexing_status = 'waiting'" in compiled
def test_apply_display_status_filter_returns_same_when_invalid():
query = sa.select(Document)
filtered = DocumentService.apply_display_status_filter(query, "invalid")
compiled = str(filtered.compile(compile_kwargs={"literal_binds": True}))
assert "WHERE" not in compiled

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import App, DefaultEndUserSessionID, EndUser
from models.model import App, EndUser
from services.end_user_service import EndUserService
@@ -44,113 +44,6 @@ class TestEndUserServiceFactory:
return end_user
class TestEndUserServiceGetOrCreateEndUser:
"""
Unit tests for EndUserService.get_or_create_end_user method.
This test suite covers:
- Creating new end users
- Retrieving existing end users
- Default session ID handling
- Anonymous user creation
"""
@pytest.fixture
def factory(self):
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 01: Get or create with custom user_id
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_with_custom_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user with custom user_id."""
# Arrange
app = factory.create_app_mock()
user_id = "custom-user-123"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify the created user has correct attributes
added_user = mock_session.add.call_args[0][0]
assert added_user.tenant_id == app.tenant_id
assert added_user.app_id == app.id
assert added_user.session_id == user_id
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.is_anonymous is False
# Test 02: Get or create without user_id (default session)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_or_create_end_user_without_user_id(self, mock_db, mock_session_class, factory):
"""Test getting or creating end user without user_id uses default session."""
# Arrange
app = factory.create_app_mock()
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None # No existing user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=None)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
# Test 03: Get existing end user
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user(self, mock_db, mock_session_class, factory):
"""Test retrieving an existing end user."""
# Arrange
app = factory.create_app_mock()
user_id = "existing-user-123"
existing_user = factory.create_end_user_mock(
tenant_id=app.tenant_id,
app_id=app.id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act
result = EndUserService.get_or_create_end_user(app_model=app, user_id=user_id)
# Assert
assert result == existing_user
mock_session.add.assert_not_called() # Should not create new user
class TestEndUserServiceGetOrCreateEndUserByType:
"""
Unit tests for EndUserService.get_or_create_end_user_by_type method.
@@ -167,226 +60,6 @@ class TestEndUserServiceGetOrCreateEndUserByType:
"""Provide test data factory."""
return TestEndUserServiceFactory()
# Test 04: Create new end user with SERVICE_API type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_service_api_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with SERVICE_API type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.SERVICE_API
assert added_user.tenant_id == tenant_id
assert added_user.app_id == app_id
assert added_user.session_id == user_id
# Test 05: Create new end user with WEB_APP type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_web_app_type(self, mock_db, mock_session_class, factory):
"""Test creating new end user with WEB_APP type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.type == InvokeFrom.WEB_APP
# Test 06: Upgrade legacy end user type
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_upgrade_legacy_end_user_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test upgrading legacy end user with different type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
# Existing user with old type
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with different type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.WEB_APP,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.WEB_APP # Type should be updated
mock_session.commit.assert_called_once()
mock_logger.info.assert_called_once()
# Verify log message contains upgrade info
log_call = mock_logger.info.call_args[0][0]
assert "Upgrading legacy EndUser" in log_call
# Test 07: Get existing end user with matching type (no upgrade needed)
@patch("services.end_user_service.logger")
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_existing_end_user_matching_type(self, mock_db, mock_session_class, mock_logger, factory):
"""Test retrieving existing end user with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
existing_user = factory.create_end_user_mock(
tenant_id=tenant_id,
app_id=app_id,
session_id=user_id,
type=InvokeFrom.SERVICE_API,
)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = existing_user
# Act - Request with same type
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
assert result == existing_user
assert existing_user.type == InvokeFrom.SERVICE_API
# No commit should be called (no type update needed)
mock_session.commit.assert_not_called()
mock_logger.info.assert_not_called()
# Test 08: Create anonymous user with default session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_anonymous_user_with_default_session(self, mock_db, mock_session_class, factory):
"""Test creating anonymous user when user_id is None."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=None,
)
# Assert
mock_session.add.assert_called_once()
added_user = mock_session.add.call_args[0][0]
assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Verify _is_anonymous is set correctly (property always returns False)
assert added_user._is_anonymous is True
assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
# Test 09: Query ordering prioritizes matching type
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_query_ordering_prioritizes_matching_type(self, mock_db, mock_session_class, factory):
"""Test that query ordering prioritizes records with matching type."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
# Verify order_by was called (for type prioritization)
mock_query.order_by.assert_called_once()
# Test 10: Session context manager properly closes
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
@@ -420,117 +93,3 @@ class TestEndUserServiceGetOrCreateEndUserByType:
# Verify context manager was entered and exited
mock_context.__enter__.assert_called_once()
mock_context.__exit__.assert_called_once()
# Test 11: External user ID matches session ID
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_external_user_id_matches_session_id(self, mock_db, mock_session_class, factory):
"""Test that external_user_id is set to match session_id."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "custom-external-id"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=InvokeFrom.SERVICE_API,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.external_user_id == user_id
assert added_user.session_id == user_id
# Test 12: Different InvokeFrom types
@pytest.mark.parametrize(
"invoke_type",
[
InvokeFrom.SERVICE_API,
InvokeFrom.WEB_APP,
InvokeFrom.EXPLORE,
InvokeFrom.DEBUGGER,
],
)
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_create_end_user_with_different_invoke_types(self, mock_db, mock_session_class, invoke_type, factory):
"""Test creating end users with different InvokeFrom types."""
# Arrange
tenant_id = "tenant-123"
app_id = "app-456"
user_id = "user-789"
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.first.return_value = None
# Act
result = EndUserService.get_or_create_end_user_by_type(
type=invoke_type,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
)
# Assert
added_user = mock_session.add.call_args[0][0]
assert added_user.type == invoke_type
class TestEndUserServiceGetEndUserById:
"""Unit tests for EndUserService.get_end_user_by_id."""
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_end_user(self, mock_db, mock_session_class):
tenant_id = "tenant-123"
app_id = "app-456"
end_user_id = "end-user-789"
existing_user = MagicMock(spec=EndUser)
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = existing_user
result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id)
assert result == existing_user
mock_session.query.assert_called_once_with(EndUser)
mock_query.where.assert_called_once()
assert len(mock_query.where.call_args[0]) == 3
@patch("services.end_user_service.Session")
@patch("services.end_user_service.db")
def test_get_end_user_by_id_returns_none(self, mock_db, mock_session_class):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.first.return_value = None
result = EndUserService.get_end_user_by_id(tenant_id="tenant", app_id="app", end_user_id="end-user")
assert result is None

View File

@@ -1,61 +0,0 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": "workflow-run-1",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []

View File

@@ -1,4 +1,4 @@
from unittest.mock import ANY, MagicMock, call, patch
from unittest.mock import MagicMock, call, patch
import pytest
@@ -14,124 +14,6 @@ from tasks.remove_app_and_related_data_task import (
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
batch2_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(100, 150)]
batch1_ids = [row[0] for row in batch1_data]
batch1_file_ids = [row[1] for row in batch1_data if row[1] is not None]
batch2_ids = [row[0] for row in batch2_data]
batch2_file_ids = [row[1] for row in batch2_data if row[1] is not None]
# Setup side effects for execute calls in the correct order:
# 1. SELECT (returns batch1_data with id, file_id)
# 2. DELETE (returns result with rowcount=100)
# 3. SELECT (returns batch2_data)
# 4. DELETE (returns result with rowcount=50)
# 5. SELECT (returns empty, ends loop)
# Create mock results with actual integer rowcount attributes
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
# First SELECT result
select_result1 = MagicMock()
select_result1.__iter__.return_value = iter(batch1_data)
# First DELETE result
delete_result1 = MockResult(rowcount=100)
# Second SELECT result
select_result2 = MagicMock()
select_result2.__iter__.return_value = iter(batch2_data)
# Second DELETE result
delete_result2 = MockResult(rowcount=50)
# Third SELECT result (empty, ends loop)
select_result3 = MagicMock()
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_session.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
delete_result2, # Second DELETE
select_result3, # Third SELECT (empty)
]
# Mock offload data cleanup
mock_offload_cleanup.side_effect = [len(batch1_file_ids), len(batch2_file_ids)]
# Execute the function
result = delete_draft_variables_batch(app_id, batch_size)
# Verify the result
assert result == 150
# Verify database calls
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
# Verify offload cleanup was called for both batches with file_ids
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
# Simplified verification - check that the right number of calls were made
# and that the SQL queries contain the expected patterns
actual_calls = mock_session.execute.call_args_list
for i, actual_call in enumerate(actual_calls):
sql_text = str(actual_call[0][0])
normalized = " ".join(sql_text.split())
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
assert "WHERE app_id = :app_id" in normalized
assert "LIMIT :batch_size" in normalized
else: # DELETE calls (odd indices: 1, 3)
assert "DELETE FROM workflow_draft_variables" in normalized
assert "WHERE id IN :ids" in normalized
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_session.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_session.execute.call_count == 1 # Only one select query
mock_offload_cleanup.assert_not_called() # No files to clean up
def test_delete_draft_variables_batch_invalid_batch_size(self):
"""Test that invalid batch size raises ValueError."""
app_id = "test-app-id"
@@ -142,66 +24,6 @@ class TestDeleteDraftVariablesBatch:
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.session_factory")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_sf.create_session.return_value = mock_context_manager
# Mock one batch then empty
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
batch_ids = [row[0] for row in batch_data]
batch_file_ids = [row[1] for row in batch_data if row[1] is not None]
# Create properly configured mocks
select_result = MagicMock()
select_result.__iter__.return_value = iter(batch_data)
# Create simple object with rowcount attribute
class MockResult:
def __init__(self, rowcount):
self.rowcount = rowcount
delete_result = MockResult(rowcount=30)
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_session.execute.side_effect = [
# Select query result
select_result,
# Delete query result
delete_result,
# Empty select result (end condition)
empty_result,
]
# Mock offload cleanup
mock_offload_cleanup.return_value = len(batch_file_ids)
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 30
# Verify offload cleanup was called with file_ids
if batch_file_ids:
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
# Verify logging calls
assert mock_logging.info.call_count == 2
mock_logging.info.assert_any_call(
ANY # click.style call
)
@patch("tasks.remove_app_and_related_data_task.delete_draft_variables_batch")
def test_delete_draft_variables_calls_batch_function(self, mock_batch_delete):
"""Test that _delete_draft_variables calls the batch function correctly."""
@@ -218,58 +40,6 @@ class TestDeleteDraftVariablesBatch:
class TestDeleteDraftVariableOffloadData:
"""Test the Offload data cleanup functionality."""
@patch("extensions.ext_storage.storage")
def test_delete_draft_variable_offload_data_success(self, mock_storage):
"""Test successful deletion of offload data."""
# Mock connection
mock_conn = MagicMock()
file_ids = ["file-1", "file-2", "file-3"]
# Mock query results: (variable_file_id, storage_key, upload_file_id)
query_results = [
("file-1", "storage/key/1", "upload-1"),
("file-2", "storage/key/2", "upload-2"),
("file-3", "storage/key/3", "upload-3"),
]
mock_result = MagicMock()
mock_result.__iter__.return_value = iter(query_results)
mock_conn.execute.return_value = mock_result
# Execute function
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
# Verify return value
assert result == 3
# Verify storage deletion calls
expected_storage_calls = [call("storage/key/1"), call("storage/key/2"), call("storage/key/3")]
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
# Verify database calls - should be 3 calls total
assert mock_conn.execute.call_count == 3
# Verify the queries were called
actual_calls = mock_conn.execute.call_args_list
# First call should be the SELECT query
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
# Second call should be DELETE upload_files
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
assert "DELETE FROM upload_files" in delete_upload_call_sql
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
# Third call should be DELETE workflow_draft_variable_files
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
def test_delete_draft_variable_offload_data_empty_file_ids(self):
"""Test handling of empty file_ids list."""
mock_conn = MagicMock()
@@ -279,38 +49,6 @@ class TestDeleteDraftVariableOffloadData:
assert result == 0
mock_conn.execute.assert_not_called()
@patch("extensions.ext_storage.storage")
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_storage_failure(self, mock_logging, mock_storage):
"""Test handling of storage deletion failures."""
mock_conn = MagicMock()
file_ids = ["file-1", "file-2"]
# Mock query results
query_results = [
("file-1", "storage/key/1", "upload-1"),
("file-2", "storage/key/2", "upload-2"),
]
mock_result = MagicMock()
mock_result.__iter__.return_value = iter(query_results)
mock_conn.execute.return_value = mock_result
# Make storage.delete fail for the first file
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Execute function
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
# Should still return 2 (both files processed, even if one storage delete failed)
assert result == 1 # Only one storage deletion succeeded
# Verify warning was logged
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", "storage/key/1")
# Verify both database cleanup calls still happened
assert mock_conn.execute.call_count == 3
@patch("tasks.remove_app_and_related_data_task.logging")
def test_delete_draft_variable_offload_data_database_failure(self, mock_logging):
"""Test handling of database operation failures."""

View File

@@ -8,11 +8,11 @@
*/
import type { AppListResponse } from '@/models/app'
import type { App } from '@/types/app'
import { fireEvent, screen } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import List from '@/app/components/apps/list'
import { AccessMode } from '@/models/access-control'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { AppModeEnum } from '@/types/app'
let mockIsCurrentWorkspaceEditor = true
@@ -161,9 +161,10 @@ const createPage = (apps: App[], hasMore = false, page = 1): AppListResponse =>
})
const renderList = (searchParams?: Record<string, string>) => {
return renderWithNuqs(
<List controlRefreshList={0} />,
{ searchParams },
return render(
<NuqsTestingAdapter searchParams={searchParams}>
<List controlRefreshList={0} />
</NuqsTestingAdapter>,
)
}
@@ -208,7 +209,11 @@ describe('App List Browsing Flow', () => {
it('should transition from loading to content when data loads', () => {
mockIsLoading = true
const { rerender } = renderWithNuqs(<List controlRefreshList={0} />)
const { rerender } = render(
<NuqsTestingAdapter>
<List controlRefreshList={0} />
</NuqsTestingAdapter>,
)
const skeletonCards = document.querySelectorAll('.animate-pulse')
expect(skeletonCards.length).toBeGreaterThan(0)
@@ -219,7 +224,11 @@ describe('App List Browsing Flow', () => {
createMockApp({ id: 'app-1', name: 'Loaded App' }),
])]
rerender(<List controlRefreshList={0} />)
rerender(
<NuqsTestingAdapter>
<List controlRefreshList={0} />
</NuqsTestingAdapter>,
)
expect(screen.getByText('Loaded App')).toBeInTheDocument()
})
@@ -415,9 +424,17 @@ describe('App List Browsing Flow', () => {
it('should call refetch when controlRefreshList increments', () => {
mockPages = [createPage([createMockApp()])]
const { rerender } = renderWithNuqs(<List controlRefreshList={0} />)
const { rerender } = render(
<NuqsTestingAdapter>
<List controlRefreshList={0} />
</NuqsTestingAdapter>,
)
rerender(<List controlRefreshList={1} />)
rerender(
<NuqsTestingAdapter>
<List controlRefreshList={1} />
</NuqsTestingAdapter>,
)
expect(mockRefetch).toHaveBeenCalled()
})

View File

@@ -9,11 +9,11 @@
*/
import type { AppListResponse } from '@/models/app'
import type { App } from '@/types/app'
import { fireEvent, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import List from '@/app/components/apps/list'
import { AccessMode } from '@/models/access-control'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { AppModeEnum } from '@/types/app'
let mockIsCurrentWorkspaceEditor = true
@@ -214,7 +214,11 @@ const createPage = (apps: App[]): AppListResponse => ({
})
const renderList = () => {
return renderWithNuqs(<List controlRefreshList={0} />)
return render(
<NuqsTestingAdapter>
<List controlRefreshList={0} />
</NuqsTestingAdapter>,
)
}
describe('Create App Flow', () => {

View File

@@ -7,10 +7,9 @@
*/
import type { SimpleDocumentDetail } from '@/models/datasets'
import { act, renderHook, waitFor } from '@testing-library/react'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { DataSourceType } from '@/models/datasets'
import { renderHookWithNuqs } from '@/test/nuqs-testing'
const mockPush = vi.fn()
vi.mock('next/navigation', () => ({
@@ -29,16 +28,12 @@ const { useDocumentSort } = await import(
const { useDocumentSelection } = await import(
'@/app/components/datasets/documents/components/document-list/hooks/use-document-selection',
)
const { useDocumentListQueryState } = await import(
const { default: useDocumentListQueryState } = await import(
'@/app/components/datasets/documents/hooks/use-document-list-query-state',
)
type LocalDoc = SimpleDocumentDetail & { percent?: number }
const renderQueryStateHook = (searchParams = '') => {
return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams })
}
const createDoc = (overrides?: Partial<LocalDoc>): LocalDoc => ({
id: `doc-${Math.random().toString(36).slice(2, 8)}`,
name: 'test-doc.txt',
@@ -90,7 +85,7 @@ describe('Document Management Flow', () => {
describe('URL-based Query State', () => {
it('should parse default query from empty URL params', () => {
const { result } = renderQueryStateHook()
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query).toEqual({
page: 1,
@@ -101,85 +96,107 @@ describe('Document Management Flow', () => {
})
})
it('should update keyword query with replace history', async () => {
const { result, onUrlUpdate } = renderQueryStateHook()
it('should update query and push to router', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ keyword: 'test', page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.options.history).toBe('replace')
expect(update.searchParams.get('keyword')).toBe('test')
expect(update.searchParams.get('page')).toBe('2')
expect(mockPush).toHaveBeenCalled()
// The push call should contain the updated query params
const pushUrl = mockPush.mock.calls[0][0] as string
expect(pushUrl).toContain('keyword=test')
expect(pushUrl).toContain('page=2')
})
it('should reset query to defaults', async () => {
const { result, onUrlUpdate } = renderQueryStateHook()
it('should reset query to defaults', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.resetQuery()
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.options.history).toBe('replace')
expect(update.searchParams.toString()).toBe('')
expect(mockPush).toHaveBeenCalled()
// Default query omits default values from URL
const pushUrl = mockPush.mock.calls[0][0] as string
expect(pushUrl).toBe('/datasets/ds-1/documents')
})
})
describe('Document Sort Integration', () => {
it('should derive sort field and order from remote sort value', () => {
it('should return documents unsorted when no sort field set', () => {
const docs = [
createDoc({ id: 'doc-1', name: 'Banana.txt', word_count: 300 }),
createDoc({ id: 'doc-2', name: 'Apple.txt', word_count: 100 }),
createDoc({ id: 'doc-3', name: 'Cherry.txt', word_count: 200 }),
]
const { result } = renderHook(() => useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '-created_at',
onRemoteSortChange: vi.fn(),
}))
expect(result.current.sortField).toBe('created_at')
expect(result.current.sortField).toBeNull()
expect(result.current.sortedDocuments).toHaveLength(3)
})
it('should sort by name descending', () => {
const docs = [
createDoc({ id: 'doc-1', name: 'Banana.txt' }),
createDoc({ id: 'doc-2', name: 'Apple.txt' }),
createDoc({ id: 'doc-3', name: 'Cherry.txt' }),
]
const { result } = renderHook(() => useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '-created_at',
}))
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortField).toBe('name')
expect(result.current.sortOrder).toBe('desc')
const names = result.current.sortedDocuments.map(d => d.name)
expect(names).toEqual(['Cherry.txt', 'Banana.txt', 'Apple.txt'])
})
it('should call remote sort change with descending sort for a new field', () => {
const onRemoteSortChange = vi.fn()
it('should toggle sort order on same field click', () => {
const docs = [createDoc({ id: 'doc-1', name: 'A.txt' }), createDoc({ id: 'doc-2', name: 'B.txt' })]
const { result } = renderHook(() => useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '-created_at',
onRemoteSortChange,
}))
act(() => {
result.current.handleSort('hit_count')
})
act(() => result.current.handleSort('name'))
expect(result.current.sortOrder).toBe('desc')
expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count')
act(() => result.current.handleSort('name'))
expect(result.current.sortOrder).toBe('asc')
})
it('should toggle descending to ascending when clicking active field', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-hit_count',
onRemoteSortChange,
}))
act(() => {
result.current.handleSort('hit_count')
})
expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count')
})
it('should ignore null sort field updates', () => {
const onRemoteSortChange = vi.fn()
it('should filter by status before sorting', () => {
const docs = [
createDoc({ id: 'doc-1', name: 'A.txt', display_status: 'available' }),
createDoc({ id: 'doc-2', name: 'B.txt', display_status: 'error' }),
createDoc({ id: 'doc-3', name: 'C.txt', display_status: 'available' }),
]
const { result } = renderHook(() => useDocumentSort({
documents: docs,
statusFilterValue: 'available',
remoteSortValue: '-created_at',
onRemoteSortChange,
}))
act(() => {
result.current.handleSort(null)
})
expect(onRemoteSortChange).not.toHaveBeenCalled()
// Only 'available' documents should remain
expect(result.current.sortedDocuments).toHaveLength(2)
expect(result.current.sortedDocuments.every(d => d.display_status === 'available')).toBe(true)
})
})
@@ -292,13 +309,14 @@ describe('Document Management Flow', () => {
describe('Cross-Module: Query State → Sort → Selection Pipeline', () => {
it('should maintain consistent default state across all hooks', () => {
const docs = [createDoc({ id: 'doc-1' })]
const { result: queryResult } = renderQueryStateHook()
const { result: queryResult } = renderHook(() => useDocumentListQueryState())
const { result: sortResult } = renderHook(() => useDocumentSort({
documents: docs,
statusFilterValue: queryResult.current.query.status,
remoteSortValue: queryResult.current.query.sort,
onRemoteSortChange: vi.fn(),
}))
const { result: selResult } = renderHook(() => useDocumentSelection({
documents: docs,
documents: sortResult.current.sortedDocuments,
selectedIds: [],
onSelectedIdChange: vi.fn(),
}))
@@ -307,9 +325,8 @@ describe('Document Management Flow', () => {
expect(queryResult.current.query.sort).toBe('-created_at')
expect(queryResult.current.query.status).toBe('all')
// Sort state is derived from URL default sort.
expect(sortResult.current.sortField).toBe('created_at')
expect(sortResult.current.sortOrder).toBe('desc')
// Sort inherits 'all' status → no filtering applied
expect(sortResult.current.sortedDocuments).toHaveLength(1)
// Selection starts empty
expect(selResult.current.isAllSelected).toBe(false)

View File

@@ -28,13 +28,9 @@ vi.mock('react-i18next', () => ({
}),
}))
vi.mock('nuqs', async (importOriginal) => {
const actual = await importOriginal<typeof import('nuqs')>()
return {
...actual,
useQueryState: () => ['builtin', vi.fn()],
}
})
vi.mock('nuqs', () => ({
useQueryState: () => ['builtin', vi.fn()],
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: () => ({ enable_marketplace: false }),
@@ -216,12 +212,6 @@ vi.mock('@/app/components/tools/marketplace', () => ({
default: () => null,
}))
vi.mock('@/app/components/tools/marketplace/hooks', () => ({
useMarketplace: () => ({
handleScroll: vi.fn(),
}),
}))
vi.mock('@/app/components/tools/mcp', () => ({
default: () => <div data-testid="mcp-list">MCP List</div>,
}))

View File

@@ -1,8 +1,9 @@
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import { act, fireEvent, screen } from '@testing-library/react'
import type { ReactNode } from 'react'
import { act, fireEvent, render, screen } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import * as React from 'react'
import { useStore as useTagStore } from '@/app/components/base/tag-management/store'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { AppModeEnum } from '@/types/app'
import List from '../list'
@@ -185,13 +186,15 @@ beforeAll(() => {
} as unknown as typeof IntersectionObserver
})
// Render helper wrapping with shared nuqs testing helper.
// Render helper wrapping with NuqsTestingAdapter
const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>()
const renderList = (searchParams = '') => {
return renderWithNuqs(
<List />,
{ searchParams, onUrlUpdate },
const wrapper = ({ children }: { children: ReactNode }) => (
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={onUrlUpdate}>
{children}
</NuqsTestingAdapter>
)
return render(<List />, { wrapper })
}
describe('List', () => {
@@ -388,10 +391,18 @@ describe('List', () => {
describe('Edge Cases', () => {
it('should handle multiple renders without issues', () => {
const { rerender } = renderWithNuqs(<List />)
const { rerender } = render(
<NuqsTestingAdapter>
<List />
</NuqsTestingAdapter>,
)
expect(screen.getByText('app.types.all')).toBeInTheDocument()
rerender(<List />)
rerender(
<NuqsTestingAdapter>
<List />
</NuqsTestingAdapter>,
)
expect(screen.getByText('app.types.all')).toBeInTheDocument()
})

View File

@@ -1,9 +1,18 @@
import { act, waitFor } from '@testing-library/react'
import { renderHookWithNuqs } from '@/test/nuqs-testing'
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import type { ReactNode } from 'react'
import { act, renderHook, waitFor } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import useAppsQueryState from '../use-apps-query-state'
const renderWithAdapter = (searchParams = '') => {
return renderHookWithNuqs(() => useAppsQueryState(), { searchParams })
const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>()
const wrapper = ({ children }: { children: ReactNode }) => (
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={onUrlUpdate}>
{children}
</NuqsTestingAdapter>
)
const { result } = renderHook(() => useAppsQueryState(), { wrapper })
return { result, onUrlUpdate }
}
describe('useAppsQueryState', () => {

View File

@@ -3,7 +3,7 @@
import type { FC } from 'react'
import { useDebounceFn } from 'ahooks'
import dynamic from 'next/dynamic'
import { parseAsStringLiteral, useQueryState } from 'nuqs'
import { parseAsString, useQueryState } from 'nuqs'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
@@ -16,7 +16,7 @@ import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { CheckModal } from '@/hooks/use-pay'
import { useInfiniteAppList } from '@/service/use-apps'
import { AppModeEnum, AppModes } from '@/types/app'
import { AppModeEnum } from '@/types/app'
import { cn } from '@/utils/classnames'
import AppCard from './app-card'
import { AppCardSkeleton } from './app-card-skeleton'
@@ -33,18 +33,6 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
ssr: false,
})
const APP_LIST_CATEGORY_VALUES = ['all', ...AppModes] as const
type AppListCategory = typeof APP_LIST_CATEGORY_VALUES[number]
const appListCategorySet = new Set<string>(APP_LIST_CATEGORY_VALUES)
const isAppListCategory = (value: string): value is AppListCategory => {
return appListCategorySet.has(value)
}
const parseAsAppListCategory = parseAsStringLiteral(APP_LIST_CATEGORY_VALUES)
.withDefault('all')
.withOptions({ history: 'push' })
type Props = {
controlRefreshList?: number
}
@@ -57,7 +45,7 @@ const List: FC<Props> = ({
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [activeTab, setActiveTab] = useQueryState(
'category',
parseAsAppListCategory,
parseAsString.withDefault('all').withOptions({ history: 'push' }),
)
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
@@ -92,7 +80,7 @@ const List: FC<Props> = ({
name: searchKeywords,
tag_ids: tagIDs,
is_created_by_me: isCreatedByMe,
...(activeTab !== 'all' ? { mode: activeTab } : {}),
...(activeTab !== 'all' ? { mode: activeTab as AppModeEnum } : {}),
}
const {
@@ -198,10 +186,7 @@ const List: FC<Props> = ({
<div className="sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-5 pt-7">
<TabSliderNew
value={activeTab}
onChange={(nextValue) => {
if (isAppListCategory(nextValue))
setActiveTab(nextValue)
}}
onChange={setActiveTab}
options={options}
/>
<div className="flex items-center gap-2">

View File

@@ -0,0 +1,48 @@
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { FormTypeEnum } from '@/app/components/base/form/types'
import AuthForm from './index'
const formSchemas = [{
type: FormTypeEnum.textInput,
name: 'apiKey',
label: 'API Key',
required: true,
}] as const
const renderWithQueryClient = (ui: Parameters<typeof render>[0]) => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
describe('AuthForm', () => {
it('should render configured fields', () => {
renderWithQueryClient(<AuthForm formSchemas={[...formSchemas]} />)
expect(screen.getByText('API Key')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
it('should use provided default values', () => {
renderWithQueryClient(<AuthForm formSchemas={[...formSchemas]} defaultValues={{ apiKey: 'value-123' }} />)
expect(screen.getByDisplayValue('value-123')).toBeInTheDocument()
})
it('should render nothing when no schema is provided', () => {
const { container } = renderWithQueryClient(<AuthForm formSchemas={[]} />)
expect(container).toBeEmptyDOMElement()
})
})

View File

@@ -0,0 +1,137 @@
import type { BaseConfiguration } from './types'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { TransferMethod } from '@/types/app'
import { useAppForm } from '../..'
import BaseField from './field'
import { BaseFieldType } from './types'
vi.mock('next/navigation', () => ({
useParams: () => ({}),
}))
const createConfig = (overrides: Partial<BaseConfiguration> = {}): BaseConfiguration => ({
type: BaseFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: BaseConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => BaseField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
describe('BaseField', () => {
it('should render a text input field when configured as text input', () => {
render(<FieldHarness config={createConfig({ label: 'Username' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Username')).toBeInTheDocument()
})
it('should render a number input when configured as number input', () => {
render(<FieldHarness config={createConfig({ type: BaseFieldType.numberInput, label: 'Age' })} initialData={{ fieldA: 20 }} />)
expect(screen.getByRole('spinbutton')).toBeInTheDocument()
expect(screen.getByText('Age')).toBeInTheDocument()
})
it('should render a checkbox when configured as checkbox', () => {
render(<FieldHarness config={createConfig({ type: BaseFieldType.checkbox, label: 'Agree' })} initialData={{ fieldA: false }} />)
expect(screen.getByText('Agree')).toBeInTheDocument()
})
it('should render paragraph and select fields based on configuration', () => {
const scenarios: Array<{ config: BaseConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({
type: BaseFieldType.paragraph,
label: 'Description',
}),
initialData: { fieldA: 'hello' },
},
{
config: createConfig({
type: BaseFieldType.select,
label: 'Mode',
options: [{ value: 'safe', label: 'Safe' }],
}),
initialData: { fieldA: 'safe' },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
})
it('should render file uploader when configured as file', () => {
const scenarios: Array<{ config: BaseConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({
type: BaseFieldType.file,
label: 'Attachment',
allowedFileExtensions: ['txt'],
allowedFileTypes: ['document'],
allowedFileUploadMethods: [TransferMethod.local_file],
}),
initialData: { fieldA: [] },
},
{
config: createConfig({
type: BaseFieldType.fileList,
label: 'Attachments',
maxLength: 2,
allowedFileExtensions: ['txt'],
allowedFileTypes: ['document'],
allowedFileUploadMethods: [TransferMethod.local_file],
}),
initialData: { fieldA: [] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as BaseFieldType, label: 'Unsupported' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported')).not.toBeInTheDocument()
})
it('should not render when show conditions are not met', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Field',
showConditions: [{ variable: 'toggle', value: true }],
})}
initialData={{ fieldA: '', toggle: false }}
/>,
)
expect(screen.queryByText('Hidden Field')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,94 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import BaseForm from './index'
import { BaseFieldType } from './types'
const baseConfigurations = [{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: false,
showConditions: [],
}]
describe('BaseForm', () => {
it('should render configured fields', () => {
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={() => {}}
/>,
)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByDisplayValue('Alice')).toBeInTheDocument()
})
it('should submit current form values when submit button is clicked', async () => {
const onSubmit = vi.fn()
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={onSubmit}
CustomActions={({ form }) => (
<button type="button" onClick={() => form.handleSubmit()}>
Submit
</button>
)}
/>,
)
fireEvent.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ name: 'Alice' })
})
})
it('should render custom actions when provided', () => {
render(
<BaseForm
initialData={{ name: 'Alice' }}
configurations={[...baseConfigurations]}
onSubmit={() => {}}
CustomActions={() => <button type="button">Save Form</button>}
/>,
)
expect(screen.getByRole('button', { name: /save form/i })).toBeInTheDocument()
expect(screen.queryByRole('button', { name: /common.operation.submit/i })).not.toBeInTheDocument()
})
it('should handle native form submit and block invalid submission', async () => {
const onSubmit = vi.fn()
const requiredConfig = [{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: true,
showConditions: [],
maxLength: 2,
}]
const { container } = render(
<BaseForm
initialData={{ name: 'ok' }}
configurations={requiredConfig}
onSubmit={onSubmit}
/>,
)
const form = container.querySelector('form')
const input = screen.getByRole('textbox')
expect(form).not.toBeNull()
fireEvent.submit(form!)
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ name: 'ok' })
})
fireEvent.change(input, { target: { value: 'long' } })
fireEvent.submit(form!)
expect(onSubmit).toHaveBeenCalledTimes(1)
})
})

View File

@@ -0,0 +1,15 @@
import { BaseFieldType } from './types'
describe('base scenario types', () => {
it('should include all supported base field types', () => {
expect(Object.values(BaseFieldType)).toEqual([
'text-input',
'paragraph',
'number-input',
'checkbox',
'select',
'file',
'file-list',
])
})
})

View File

@@ -0,0 +1,49 @@
import { BaseFieldType } from './types'
import { generateZodSchema } from './utils'
describe('base scenario schema generator', () => {
it('should validate required text fields with max length', () => {
const schema = generateZodSchema([{
type: BaseFieldType.textInput,
variable: 'name',
label: 'Name',
required: true,
maxLength: 3,
showConditions: [],
}])
expect(schema.safeParse({ name: 'abc' }).success).toBe(true)
expect(schema.safeParse({ name: '' }).success).toBe(false)
expect(schema.safeParse({ name: 'abcd' }).success).toBe(false)
})
it('should validate number bounds', () => {
const schema = generateZodSchema([{
type: BaseFieldType.numberInput,
variable: 'age',
label: 'Age',
required: true,
min: 18,
max: 30,
showConditions: [],
}])
expect(schema.safeParse({ age: 20 }).success).toBe(true)
expect(schema.safeParse({ age: 17 }).success).toBe(false)
expect(schema.safeParse({ age: 31 }).success).toBe(false)
})
it('should allow optional fields to be undefined or null', () => {
const schema = generateZodSchema([{
type: BaseFieldType.select,
variable: 'mode',
label: 'Mode',
required: false,
showConditions: [],
options: [{ value: 'safe', label: 'Safe' }],
}])
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ mode: null }).success).toBe(true)
})
})

View File

@@ -0,0 +1,24 @@
import { render, screen } from '@testing-library/react'
import { useAppForm } from '../..'
import ContactFields from './contact-fields'
import { demoFormOpts } from './shared-options'
const ContactFieldsHarness = () => {
const form = useAppForm({
...demoFormOpts,
onSubmit: () => {},
})
return <ContactFields form={form} />
}
describe('ContactFields', () => {
it('should render contact section fields', () => {
render(<ContactFieldsHarness />)
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /email/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /phone/i })).toBeInTheDocument()
expect(screen.getByText(/preferred contact method/i)).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,69 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import DemoForm from './index'
describe('DemoForm', () => {
const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
beforeEach(() => {
vi.clearAllMocks()
})
it('should render the primary fields', () => {
render(<DemoForm />)
expect(screen.getByRole('textbox', { name: /^name$/i })).toBeInTheDocument()
expect(screen.getByRole('textbox', { name: /^surname$/i })).toBeInTheDocument()
expect(screen.getByText(/i accept the terms and conditions/i)).toBeInTheDocument()
})
it('should show contact fields after a name is entered', () => {
render(<DemoForm />)
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
fireEvent.change(screen.getByRole('textbox', { name: /^name$/i }), { target: { value: 'Alice' } })
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
})
it('should hide contact fields when name is cleared', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i })
fireEvent.change(nameInput, { target: { value: 'Alice' } })
expect(screen.getByRole('heading', { name: /contacts/i })).toBeInTheDocument()
fireEvent.change(nameInput, { target: { value: '' } })
expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument()
})
it('should log validation errors on invalid submit', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i }) as HTMLInputElement
fireEvent.submit(nameInput.form!)
return waitFor(() => {
expect(consoleLogSpy).toHaveBeenCalledWith('Validation errors:', expect.any(Array))
})
})
it('should log submitted values on valid submit', () => {
render(<DemoForm />)
const nameInput = screen.getByRole('textbox', { name: /^name$/i }) as HTMLInputElement
fireEvent.change(nameInput, { target: { value: 'Alice' } })
fireEvent.change(screen.getByRole('textbox', { name: /^surname$/i }), { target: { value: 'Smith' } })
fireEvent.click(screen.getByText(/i accept the terms and conditions/i))
fireEvent.change(screen.getByRole('textbox', { name: /email/i }), { target: { value: 'alice@example.com' } })
fireEvent.submit(nameInput.form!)
return waitFor(() => {
expect(consoleLogSpy).toHaveBeenCalledWith('Form submitted:', expect.objectContaining({
name: 'Alice',
surname: 'Smith',
isAcceptingTerms: true,
}))
})
})
})

View File

@@ -0,0 +1,16 @@
import { demoFormOpts } from './shared-options'
describe('demoFormOpts', () => {
it('should provide expected default values', () => {
expect(demoFormOpts.defaultValues).toEqual({
name: '',
surname: '',
isAcceptingTerms: false,
contact: {
email: '',
phone: '',
preferredContactMethod: 'email',
},
})
})
})

View File

@@ -0,0 +1,39 @@
import { ContactMethods, UserSchema } from './types'
describe('demo scenario types', () => {
it('should expose contact methods with capitalized labels', () => {
expect(ContactMethods).toEqual([
{ value: 'email', label: 'Email' },
{ value: 'phone', label: 'Phone' },
{ value: 'whatsapp', label: 'Whatsapp' },
{ value: 'sms', label: 'Sms' },
])
})
it('should validate a complete user payload', () => {
expect(UserSchema.safeParse({
name: 'Alice',
surname: 'Smith',
isAcceptingTerms: true,
contact: {
email: 'alice@example.com',
phone: '',
preferredContactMethod: 'email',
},
}).success).toBe(true)
})
it('should reject invalid user payload', () => {
const result = UserSchema.safeParse({
name: 'alice',
surname: 's',
isAcceptingTerms: false,
contact: {
email: 'invalid',
preferredContactMethod: 'email',
},
})
expect(result.success).toBe(false)
})
})

View File

@@ -0,0 +1,139 @@
import type { InputFieldConfiguration } from './types'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { useAppForm } from '../..'
import InputField from './field'
import { InputFieldType } from './types'
const createConfig = (overrides: Partial<InputFieldConfiguration> = {}): InputFieldConfiguration => ({
type: InputFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: InputFieldConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => InputField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
describe('InputField', () => {
it('should render text input field by default', () => {
render(<FieldHarness config={createConfig({ label: 'Prompt' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Prompt')).toBeInTheDocument()
})
it('should render number slider field when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.numberSlider,
label: 'Temperature',
description: 'Control randomness',
min: 0,
max: 1,
})}
initialData={{ fieldA: 0.5 }}
/>,
)
expect(screen.getByText('Temperature')).toBeInTheDocument()
expect(screen.getByText('Control randomness')).toBeInTheDocument()
})
it('should render select field with options when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.select,
label: 'Mode',
options: [{ value: 'safe', label: 'Safe' }],
})}
initialData={{ fieldA: 'safe' }}
/>,
)
expect(screen.getByText('Mode')).toBeInTheDocument()
})
it('should render upload method field when configured', () => {
render(
<FieldHarness
config={createConfig({
type: InputFieldType.uploadMethod,
label: 'Upload Method',
})}
initialData={{ fieldA: 'local_file' }}
/>,
)
expect(screen.getByText('Upload Method')).toBeInTheDocument()
})
it('should hide the field when show conditions are not met', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Input',
showConditions: [{ variable: 'enabled', value: true }],
})}
initialData={{ enabled: false, fieldA: '' }}
/>,
)
expect(screen.queryByText('Hidden Input')).not.toBeInTheDocument()
})
it('should render remaining field types and fallback for unsupported type', () => {
const scenarios: Array<{ config: InputFieldConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({ type: InputFieldType.numberInput, label: 'Count', min: 1, max: 5 }),
initialData: { fieldA: 2 },
},
{
config: createConfig({ type: InputFieldType.checkbox, label: 'Enable' }),
initialData: { fieldA: false },
},
{
config: createConfig({ type: InputFieldType.inputTypeSelect, label: 'Input Type', supportFile: true }),
initialData: { fieldA: 'text' },
},
{
config: createConfig({ type: InputFieldType.fileTypes, label: 'File Types' }),
initialData: { fieldA: { allowedFileTypes: ['document'] } },
},
{
config: createConfig({ type: InputFieldType.options, label: 'Choices' }),
initialData: { fieldA: ['one'] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as InputFieldType, label: 'Unsupported' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,17 @@
import { InputFieldType } from './types'
describe('input-field scenario types', () => {
it('should include expected input field types', () => {
expect(Object.values(InputFieldType)).toEqual([
'textInput',
'numberInput',
'numberSlider',
'checkbox',
'options',
'select',
'inputTypeSelect',
'uploadMethod',
'fileTypes',
])
})
})

View File

@@ -0,0 +1,150 @@
import { InputFieldType } from './types'
import { generateZodSchema } from './utils'
describe('input-field scenario schema generator', () => {
it('should validate required text input with max length', () => {
const schema = generateZodSchema([{
type: InputFieldType.textInput,
variable: 'prompt',
label: 'Prompt',
required: true,
maxLength: 5,
showConditions: [],
}])
expect(schema.safeParse({ prompt: 'hello' }).success).toBe(true)
expect(schema.safeParse({ prompt: '' }).success).toBe(false)
expect(schema.safeParse({ prompt: 'longer than five' }).success).toBe(false)
})
it('should validate file types payload shape', () => {
const schema = generateZodSchema([{
type: InputFieldType.fileTypes,
variable: 'files',
label: 'Files',
required: true,
showConditions: [],
}])
expect(schema.safeParse({
files: {
allowedFileExtensions: 'txt,pdf',
allowedFileTypes: ['document'],
},
}).success).toBe(true)
expect(schema.safeParse({
files: {
allowedFileTypes: ['invalid-type'],
},
}).success).toBe(false)
})
it('should allow optional upload method fields to be omitted', () => {
const schema = generateZodSchema([{
type: InputFieldType.uploadMethod,
variable: 'methods',
label: 'Methods',
required: false,
showConditions: [],
}])
expect(schema.safeParse({}).success).toBe(true)
})
it('should validate numeric bounds and other field type shapes', () => {
const schema = generateZodSchema([
{
type: InputFieldType.numberInput,
variable: 'count',
label: 'Count',
required: true,
min: 1,
max: 3,
showConditions: [],
},
{
type: InputFieldType.numberSlider,
variable: 'temperature',
label: 'Temperature',
required: true,
showConditions: [],
},
{
type: InputFieldType.checkbox,
variable: 'enabled',
label: 'Enabled',
required: true,
showConditions: [],
},
{
type: InputFieldType.options,
variable: 'choices',
label: 'Choices',
required: true,
showConditions: [],
},
{
type: InputFieldType.select,
variable: 'mode',
label: 'Mode',
required: true,
showConditions: [],
},
{
type: InputFieldType.inputTypeSelect,
variable: 'inputType',
label: 'Input Type',
required: true,
showConditions: [],
},
{
type: InputFieldType.uploadMethod,
variable: 'methods',
label: 'Methods',
required: true,
showConditions: [],
},
{
type: 'unsupported' as InputFieldType,
variable: 'other',
label: 'Other',
required: true,
showConditions: [],
},
])
expect(schema.safeParse({
count: 2,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(true)
expect(schema.safeParse({
count: 0,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(false)
expect(schema.safeParse({
count: 4,
temperature: 0.5,
enabled: true,
choices: ['a'],
mode: 'safe',
inputType: 'text',
methods: ['local_file'],
other: { key: 'value' },
}).success).toBe(false)
})
})

View File

@@ -0,0 +1,145 @@
import type { ReactNode } from 'react'
import type { InputFieldConfiguration } from './types'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { render, screen } from '@testing-library/react'
import { useMemo } from 'react'
import { ReactFlowProvider } from 'reactflow'
import { useAppForm } from '../..'
import NodePanelField from './field'
import { InputFieldType } from './types'
vi.mock('@/app/components/workflow/nodes/_base/components/variable/var-reference-picker', () => ({
default: () => <div>Variable Picker</div>,
}))
const createConfig = (overrides: Partial<InputFieldConfiguration> = {}): InputFieldConfiguration => ({
type: InputFieldType.textInput,
variable: 'fieldA',
label: 'Field A',
required: false,
showConditions: [],
...overrides,
})
type FieldHarnessProps = {
config: InputFieldConfiguration
initialData?: Record<string, unknown>
}
const FieldHarness = ({ config, initialData = {} }: FieldHarnessProps) => {
const form = useAppForm({
defaultValues: initialData,
onSubmit: () => {},
})
const Component = useMemo(() => NodePanelField({ initialData, config }), [config, initialData])
return <Component form={form} />
}
const NodePanelWrapper = ({ children }: { children: ReactNode }) => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
})
return (
<QueryClientProvider client={queryClient}>
<ReactFlowProvider>
{children}
</ReactFlowProvider>
</QueryClientProvider>
)
}
describe('NodePanelField', () => {
it('should render text input field', () => {
render(<FieldHarness config={createConfig({ label: 'Node Name' })} initialData={{ fieldA: '' }} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByText('Node Name')).toBeInTheDocument()
})
it('should render variable-or-constant field when configured', () => {
render(
<NodePanelWrapper>
<FieldHarness
config={createConfig({
type: InputFieldType.variableOrConstant,
label: 'Mode',
})}
initialData={{ fieldA: '' }}
/>
</NodePanelWrapper>,
)
expect(screen.getByText('Mode')).toBeInTheDocument()
})
it('should hide field when show conditions are not satisfied', () => {
render(
<FieldHarness
config={createConfig({
label: 'Hidden Node Field',
showConditions: [{ variable: 'enabled', value: true }],
})}
initialData={{ enabled: false, fieldA: '' }}
/>,
)
expect(screen.queryByText('Hidden Node Field')).not.toBeInTheDocument()
})
it('should render other configured field types and hide unsupported type', () => {
const scenarios: Array<{ config: InputFieldConfiguration, initialData: Record<string, unknown> }> = [
{
config: createConfig({ type: InputFieldType.numberInput, label: 'Count', min: 1, max: 3 }),
initialData: { fieldA: 2 },
},
{
config: createConfig({ type: InputFieldType.numberSlider, label: 'Temperature', description: 'Adjust' }),
initialData: { fieldA: 0.4 },
},
{
config: createConfig({ type: InputFieldType.checkbox, label: 'Enabled' }),
initialData: { fieldA: true },
},
{
config: createConfig({ type: InputFieldType.select, label: 'Mode', options: [{ value: 'safe', label: 'Safe' }] }),
initialData: { fieldA: 'safe' },
},
{
config: createConfig({ type: InputFieldType.inputTypeSelect, label: 'Input Type', supportFile: true }),
initialData: { fieldA: 'text' },
},
{
config: createConfig({ type: InputFieldType.uploadMethod, label: 'Upload Method' }),
initialData: { fieldA: ['local_file'] },
},
{
config: createConfig({ type: InputFieldType.fileTypes, label: 'File Types' }),
initialData: { fieldA: { allowedFileTypes: ['document'] } },
},
{
config: createConfig({ type: InputFieldType.options, label: 'Options' }),
initialData: { fieldA: ['a'] },
},
]
for (const scenario of scenarios) {
const { unmount } = render(<FieldHarness config={scenario.config} initialData={scenario.initialData} />)
expect(screen.getByText(scenario.config.label)).toBeInTheDocument()
unmount()
}
render(
<FieldHarness
config={createConfig({ type: 'unsupported' as InputFieldType, label: 'Unsupported Node' })}
initialData={{ fieldA: '' }}
/>,
)
expect(screen.queryByText('Unsupported Node')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,7 @@
import { InputFieldType } from './types'
describe('node-panel scenario types', () => {
it('should include variableOrConstant field type', () => {
expect(Object.values(InputFieldType)).toContain('variableOrConstant')
})
})

View File

@@ -0,0 +1,12 @@
import * as hookExports from './index'
import { useCheckValidated } from './use-check-validated'
import { useGetFormValues } from './use-get-form-values'
import { useGetValidators } from './use-get-validators'
describe('hooks index exports', () => {
it('should re-export all hook modules', () => {
expect(hookExports.useCheckValidated).toBe(useCheckValidated)
expect(hookExports.useGetFormValues).toBe(useGetFormValues)
expect(hookExports.useGetValidators).toBe(useGetValidators)
})
})

View File

@@ -0,0 +1,105 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { renderHook } from '@testing-library/react'
import { FormTypeEnum } from '../types'
import { useCheckValidated } from './use-check-validated'
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
}))
describe('useCheckValidated', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return true when form has no errors', () => {
const form = {
getAllErrors: () => undefined,
state: { values: {} },
}
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, []))
expect(result.current.checkValidated()).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify and return false when visible field has errors', () => {
const form = {
getAllErrors: () => ({
fields: {
name: { errors: ['Name is required'] },
},
}),
state: { values: {} },
}
const schemas = [{
name: 'name',
label: 'Name',
required: true,
type: FormTypeEnum.textInput,
show_on: [],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(false)
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Name is required',
})
})
it('should ignore hidden field errors and return true', () => {
const form = {
getAllErrors: () => ({
fields: {
secret: { errors: ['Secret is required'] },
},
}),
state: { values: { enabled: 'false' } },
}
const schemas = [{
name: 'secret',
label: 'Secret',
required: true,
type: FormTypeEnum.textInput,
show_on: [{ variable: 'enabled', value: 'true' }],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify when field is shown and has errors', () => {
const form = {
getAllErrors: () => ({
fields: {
secret: { errors: ['Secret is required'] },
},
}),
state: { values: { enabled: 'true' } },
}
const schemas = [{
name: 'secret',
label: 'Secret',
required: true,
type: FormTypeEnum.textInput,
show_on: [{ variable: 'enabled', value: 'true' }],
}]
const { result } = renderHook(() => useCheckValidated(form as unknown as AnyFormApi, schemas))
expect(result.current.checkValidated()).toBe(false)
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Secret is required',
})
})
})

View File

@@ -0,0 +1,74 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { renderHook } from '@testing-library/react'
import { FormTypeEnum } from '../types'
import { useGetFormValues } from './use-get-form-values'
const mockCheckValidated = vi.fn()
const mockTransform = vi.fn()
vi.mock('./use-check-validated', () => ({
useCheckValidated: () => ({
checkValidated: mockCheckValidated,
}),
}))
vi.mock('../utils/secret-input', () => ({
getTransformedValuesWhenSecretInputPristine: (...args: unknown[]) => mockTransform(...args),
}))
describe('useGetFormValues', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return raw values when validation check is disabled', () => {
const form = {
store: { state: { values: { name: 'Alice' } } },
}
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, []))
expect(result.current.getFormValues({ needCheckValidatedValues: false })).toEqual({
values: { name: 'Alice' },
isCheckValidated: true,
})
})
it('should return transformed values when validation passes and transform is requested', () => {
const form = {
store: { state: { values: { password: 'abc123' } } },
}
const schemas = [{
name: 'password',
label: 'Password',
required: true,
type: FormTypeEnum.secretInput,
}]
mockCheckValidated.mockReturnValue(true)
mockTransform.mockReturnValue({ password: '[__HIDDEN__]' })
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, schemas))
expect(result.current.getFormValues({
needCheckValidatedValues: true,
needTransformWhenSecretFieldIsPristine: true,
})).toEqual({
values: { password: '[__HIDDEN__]' },
isCheckValidated: true,
})
})
it('should return empty values when validation fails', () => {
const form = {
store: { state: { values: { name: '' } } },
}
mockCheckValidated.mockReturnValue(false)
const { result } = renderHook(() => useGetFormValues(form as unknown as AnyFormApi, []))
expect(result.current.getFormValues({ needCheckValidatedValues: true })).toEqual({
values: {},
isCheckValidated: false,
})
})
})

View File

@@ -0,0 +1,78 @@
import { renderHook } from '@testing-library/react'
import { createElement } from 'react'
import { FormTypeEnum } from '../types'
import { useGetValidators } from './use-get-validators'
vi.mock('@/hooks/use-i18n', () => ({
useRenderI18nObject: () => (obj: Record<string, string>) => obj.en_US,
}))
describe('useGetValidators', () => {
it('should create required validators when field is required without custom validators', () => {
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'username',
label: 'Username',
required: true,
type: FormTypeEnum.textInput,
})
const mountMessage = validators?.onMount?.({ value: '' })
const blurMessage = validators?.onBlur?.({ value: '' })
expect(mountMessage).toContain('common.errorMsg.fieldRequired')
expect(mountMessage).toContain('"field":"Username"')
expect(blurMessage).toContain('common.errorMsg.fieldRequired')
})
it('should keep existing validators when custom validators are provided', () => {
const customValidators = {
onChange: vi.fn(() => 'custom error'),
}
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'username',
label: 'Username',
required: true,
type: FormTypeEnum.textInput,
validators: customValidators,
})
expect(validators).toBe(customValidators)
})
it('should fallback to field name when label is a react element', () => {
const { result } = renderHook(() => useGetValidators())
const validators = result.current.getValidators({
name: 'apiKey',
label: createElement('span', undefined, 'API Key'),
required: true,
type: FormTypeEnum.textInput,
})
const mountMessage = validators?.onMount?.({ value: '' })
expect(mountMessage).toContain('"field":"apiKey"')
})
it('should translate object labels and skip validators for non-required fields', () => {
const { result } = renderHook(() => useGetValidators())
const requiredValidators = result.current.getValidators({
name: 'workspace',
label: { en_US: 'Workspace', zh_Hans: '工作区' },
required: true,
type: FormTypeEnum.textInput,
})
const nonRequiredValidators = result.current.getValidators({
name: 'optionalField',
label: 'Optional',
required: false,
type: FormTypeEnum.textInput,
})
const changeMessage = requiredValidators?.onChange?.({ value: '' })
expect(changeMessage).toContain('"field":"Workspace"')
expect(nonRequiredValidators).toBeUndefined()
})
})

View File

@@ -0,0 +1,64 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useAppForm, withForm } from './index'
const FormHarness = ({ onSubmit }: { onSubmit: (value: Record<string, unknown>) => void }) => {
const form = useAppForm({
defaultValues: { title: 'Initial title' },
onSubmit: ({ value }) => onSubmit(value),
})
return (
<form>
<form.AppField
name="title"
children={field => <field.TextField label="Title" />}
/>
<form.AppForm>
<button type="button" onClick={() => form.handleSubmit()}>
Submit
</button>
</form.AppForm>
</form>
)
}
const InlinePreview = withForm({
defaultValues: { title: '' },
render: ({ form }) => {
return (
<form.AppField
name="title"
children={field => <field.TextField label="Preview Title" />}
/>
)
},
})
const WithFormHarness = () => {
const form = useAppForm({
defaultValues: { title: 'Preview value' },
onSubmit: () => {},
})
return <InlinePreview form={form} />
}
describe('form index exports', () => {
it('should submit values through the generated app form', async () => {
const onSubmit = vi.fn()
render(<FormHarness onSubmit={onSubmit} />)
fireEvent.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ title: 'Initial title' })
})
})
it('should render components created with withForm', () => {
render(<WithFormHarness />)
expect(screen.getByRole('textbox')).toHaveValue('Preview value')
expect(screen.getByText('Preview Title')).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,18 @@
import { FormItemValidateStatusEnum, FormTypeEnum } from './types'
describe('form types', () => {
it('should expose expected form type values', () => {
expect(Object.values(FormTypeEnum)).toContain('text-input')
expect(Object.values(FormTypeEnum)).toContain('dynamic-select')
expect(Object.values(FormTypeEnum)).toContain('boolean')
})
it('should expose expected validation status values', () => {
expect(Object.values(FormItemValidateStatusEnum)).toEqual([
'success',
'warning',
'error',
'validating',
])
})
})

View File

@@ -0,0 +1,54 @@
import type { AnyFormApi } from '@tanstack/react-form'
import { FormTypeEnum } from '../../types'
import { getTransformedValuesWhenSecretInputPristine, transformFormSchemasSecretInput } from './index'
describe('secret input utilities', () => {
it('should mask only selected truthy values in transformFormSchemasSecretInput', () => {
expect(transformFormSchemasSecretInput(['apiKey'], {
apiKey: 'secret',
token: 'token-value',
emptyValue: '',
})).toEqual({
apiKey: '[__HIDDEN__]',
token: 'token-value',
emptyValue: '',
})
})
it('should mask pristine secret input fields from form state', () => {
const formSchemas = [
{ name: 'apiKey', type: FormTypeEnum.secretInput, label: 'API Key', required: true },
{ name: 'name', type: FormTypeEnum.textInput, label: 'Name', required: true },
]
const form = {
store: {
state: {
values: {
apiKey: 'secret',
name: 'Alice',
},
},
},
getFieldMeta: (name: string) => ({ isPristine: name === 'apiKey' }),
}
expect(getTransformedValuesWhenSecretInputPristine(formSchemas, form as unknown as AnyFormApi)).toEqual({
apiKey: '[__HIDDEN__]',
name: 'Alice',
})
})
it('should keep value unchanged when secret input is not pristine', () => {
const formSchemas = [
{ name: 'apiKey', type: FormTypeEnum.secretInput, label: 'API Key', required: true },
]
const form = {
store: { state: { values: { apiKey: 'secret' } } },
getFieldMeta: () => ({ isPristine: false }),
}
expect(getTransformedValuesWhenSecretInputPristine(formSchemas, form as unknown as AnyFormApi)).toEqual({
apiKey: 'secret',
})
})
})

View File

@@ -0,0 +1,39 @@
import * as z from 'zod'
import { zodSubmitValidator } from './zod-submit-validator'
describe('zodSubmitValidator', () => {
it('should return undefined for valid values', () => {
const validator = zodSubmitValidator(z.object({
name: z.string().min(2),
}))
expect(validator({ value: { name: 'Alice' } })).toBeUndefined()
})
it('should return first error message per field for invalid values', () => {
const validator = zodSubmitValidator(z.object({
name: z.string().min(3, 'Name too short'),
age: z.number().min(18, 'Must be adult'),
}))
expect(validator({ value: { name: 'Al', age: 15 } })).toEqual({
fields: {
name: 'Name too short',
age: 'Must be adult',
},
})
})
it('should ignore root-level issues without a field path', () => {
const schema = z.object({ value: z.number() }).superRefine((_value, ctx) => {
ctx.addIssue({
code: z.ZodIssueCode.custom,
message: 'Root error',
path: [],
})
})
const validator = zodSubmitValidator(schema)
expect(validator({ value: { value: 1 } })).toEqual({ fields: {} })
})
})

View File

@@ -4,7 +4,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useProviderContext } from '@/context/provider-context'
import { DataSourceType } from '@/models/datasets'
import { useDocumentList } from '@/service/knowledge/use-document'
import { useDocumentsPageState } from '../hooks/use-documents-page-state'
import useDocumentsPageState from '../hooks/use-documents-page-state'
import Documents from '../index'
// Type for mock selector function - use `as MockState` to bypass strict type checking in tests
@@ -117,10 +117,13 @@ const mockHandleStatusFilterClear = vi.fn()
const mockHandleSortChange = vi.fn()
const mockHandlePageChange = vi.fn()
const mockHandleLimitChange = vi.fn()
const mockUpdatePollingState = vi.fn()
const mockAdjustPageForTotal = vi.fn()
vi.mock('../hooks/use-documents-page-state', () => ({
useDocumentsPageState: vi.fn(() => ({
default: vi.fn(() => ({
inputValue: '',
searchValue: '',
debouncedSearchValue: '',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'all',
@@ -135,6 +138,9 @@ vi.mock('../hooks/use-documents-page-state', () => ({
handleLimitChange: mockHandleLimitChange,
selectedIds: [] as string[],
setSelectedIds: mockSetSelectedIds,
timerCanRun: false,
updatePollingState: mockUpdatePollingState,
adjustPageForTotal: mockAdjustPageForTotal,
})),
}))
@@ -313,33 +319,6 @@ describe('Documents', () => {
expect(screen.queryByTestId('documents-list')).not.toBeInTheDocument()
})
it('should keep rendering list when loading with existing data', () => {
vi.mocked(useDocumentList).mockReturnValueOnce({
data: {
data: [
{
id: 'doc-1',
name: 'Document 1',
indexing_status: 'completed',
data_source_type: 'upload_file',
position: 1,
enabled: true,
},
],
total: 1,
page: 1,
limit: 10,
has_more: false,
} as DocumentListResponse,
isLoading: true,
refetch: vi.fn(),
} as unknown as ReturnType<typeof useDocumentList>)
render(<Documents {...defaultProps} />)
expect(screen.getByTestId('documents-list')).toBeInTheDocument()
expect(screen.getByTestId('list-documents-count')).toHaveTextContent('1')
})
it('should render empty element when no documents exist', () => {
vi.mocked(useDocumentList).mockReturnValueOnce({
data: { data: [], total: 0, page: 1, limit: 10, has_more: false },
@@ -505,75 +484,17 @@ describe('Documents', () => {
})
})
describe('Query Options', () => {
it('should pass function refetchInterval to useDocumentList', () => {
describe('Side Effects and Cleanup', () => {
it('should call updatePollingState when documents response changes', () => {
render(<Documents {...defaultProps} />)
const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0]
expect(payload).toBeDefined()
expect(typeof payload?.refetchInterval).toBe('function')
expect(mockUpdatePollingState).toHaveBeenCalled()
})
it('should stop polling when all documents are in terminal statuses', () => {
it('should call adjustPageForTotal when documents response changes', () => {
render(<Documents {...defaultProps} />)
const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0]
const refetchInterval = payload?.refetchInterval
expect(typeof refetchInterval).toBe('function')
if (typeof refetchInterval !== 'function')
throw new Error('Expected function refetchInterval')
const interval = refetchInterval({
state: {
data: {
data: [
{ indexing_status: 'completed' },
{ indexing_status: 'paused' },
{ indexing_status: 'error' },
],
},
},
} as unknown as Parameters<typeof refetchInterval>[0])
expect(interval).toBe(false)
})
it('should keep polling for transient status filters', () => {
vi.mocked(useDocumentsPageState).mockReturnValueOnce({
inputValue: '',
debouncedSearchValue: '',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'indexing',
sortValue: '-created_at' as const,
normalizedStatusFilterValue: 'indexing',
handleStatusFilterChange: mockHandleStatusFilterChange,
handleStatusFilterClear: mockHandleStatusFilterClear,
handleSortChange: mockHandleSortChange,
currPage: 0,
limit: 10,
handlePageChange: mockHandlePageChange,
handleLimitChange: mockHandleLimitChange,
selectedIds: [] as string[],
setSelectedIds: mockSetSelectedIds,
})
render(<Documents {...defaultProps} />)
const payload = vi.mocked(useDocumentList).mock.calls.at(-1)?.[0]
const refetchInterval = payload?.refetchInterval
expect(typeof refetchInterval).toBe('function')
if (typeof refetchInterval !== 'function')
throw new Error('Expected function refetchInterval')
const interval = refetchInterval({
state: {
data: {
data: [{ indexing_status: 'completed' }],
},
},
} as unknown as Parameters<typeof refetchInterval>[0])
expect(interval).toBe(2500)
expect(mockAdjustPageForTotal).toHaveBeenCalled()
})
})
@@ -670,6 +591,36 @@ describe('Documents', () => {
})
})
describe('Polling State', () => {
it('should enable polling when documents are indexing', () => {
vi.mocked(useDocumentsPageState).mockReturnValueOnce({
inputValue: '',
searchValue: '',
debouncedSearchValue: '',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'all',
sortValue: '-created_at' as const,
normalizedStatusFilterValue: 'all',
handleStatusFilterChange: mockHandleStatusFilterChange,
handleStatusFilterClear: mockHandleStatusFilterClear,
handleSortChange: mockHandleSortChange,
currPage: 0,
limit: 10,
handlePageChange: mockHandlePageChange,
handleLimitChange: mockHandleLimitChange,
selectedIds: [] as string[],
setSelectedIds: mockSetSelectedIds,
timerCanRun: true,
updatePollingState: mockUpdatePollingState,
adjustPageForTotal: mockAdjustPageForTotal,
})
render(<Documents {...defaultProps} />)
expect(screen.getByTestId('documents-list')).toBeInTheDocument()
})
})
describe('Pagination', () => {
it('should display correct total in list', () => {
render(<Documents {...defaultProps} />)
@@ -684,6 +635,7 @@ describe('Documents', () => {
it('should handle page changes', () => {
vi.mocked(useDocumentsPageState).mockReturnValueOnce({
inputValue: '',
searchValue: '',
debouncedSearchValue: '',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'all',
@@ -698,6 +650,9 @@ describe('Documents', () => {
handleLimitChange: mockHandleLimitChange,
selectedIds: [] as string[],
setSelectedIds: mockSetSelectedIds,
timerCanRun: false,
updatePollingState: mockUpdatePollingState,
adjustPageForTotal: mockAdjustPageForTotal,
})
render(<Documents {...defaultProps} />)
@@ -709,6 +664,7 @@ describe('Documents', () => {
it('should display selected count', () => {
vi.mocked(useDocumentsPageState).mockReturnValueOnce({
inputValue: '',
searchValue: '',
debouncedSearchValue: '',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'all',
@@ -723,6 +679,9 @@ describe('Documents', () => {
handleLimitChange: mockHandleLimitChange,
selectedIds: ['doc-1', 'doc-2'],
setSelectedIds: mockSetSelectedIds,
timerCanRun: false,
updatePollingState: mockUpdatePollingState,
adjustPageForTotal: mockAdjustPageForTotal,
})
render(<Documents {...defaultProps} />)
@@ -734,6 +693,7 @@ describe('Documents', () => {
it('should pass filter value to list', () => {
vi.mocked(useDocumentsPageState).mockReturnValueOnce({
inputValue: 'test search',
searchValue: 'test search',
debouncedSearchValue: 'test search',
handleInputChange: mockHandleInputChange,
statusFilterValue: 'completed',
@@ -748,6 +708,9 @@ describe('Documents', () => {
handleLimitChange: mockHandleLimitChange,
selectedIds: [] as string[],
setSelectedIds: mockSetSelectedIds,
timerCanRun: false,
updatePollingState: mockUpdatePollingState,
adjustPageForTotal: mockAdjustPageForTotal,
})
render(<Documents {...defaultProps} />)

View File

@@ -20,8 +20,9 @@ const mockHandleSave = vi.fn()
vi.mock('../document-list/hooks', () => ({
useDocumentSort: vi.fn(() => ({
sortField: null,
sortOrder: 'desc',
sortOrder: null,
handleSort: mockHandleSort,
sortedDocuments: [],
})),
useDocumentSelection: vi.fn(() => ({
isAllSelected: false,
@@ -124,8 +125,8 @@ const defaultProps = {
pagination: { total: 0, current: 1, limit: 10, onChange: vi.fn() },
onUpdate: vi.fn(),
onManageMetadata: vi.fn(),
remoteSortValue: '-created_at',
onSortChange: vi.fn(),
statusFilterValue: 'all',
remoteSortValue: '',
}
describe('DocumentList', () => {
@@ -139,6 +140,8 @@ describe('DocumentList', () => {
render(<DocumentList {...defaultProps} />)
expect(screen.getByText('#')).toBeInTheDocument()
expect(screen.getByTestId('sort-name')).toBeInTheDocument()
expect(screen.getByTestId('sort-word_count')).toBeInTheDocument()
expect(screen.getByTestId('sort-hit_count')).toBeInTheDocument()
expect(screen.getByTestId('sort-created_at')).toBeInTheDocument()
})
@@ -161,9 +164,10 @@ describe('DocumentList', () => {
it('should render document rows from sortedDocuments', () => {
const docs = [createDoc({ id: 'a', name: 'Doc A' }), createDoc({ id: 'b', name: 'Doc B' })]
vi.mocked(useDocumentSort).mockReturnValue({
sortField: 'created_at',
sortField: null,
sortOrder: 'desc',
handleSort: mockHandleSort,
sortedDocuments: docs,
} as unknown as ReturnType<typeof useDocumentSort>)
render(<DocumentList {...defaultProps} documents={docs} />)
@@ -178,9 +182,9 @@ describe('DocumentList', () => {
it('should call handleSort when sort header is clicked', () => {
render(<DocumentList {...defaultProps} />)
fireEvent.click(screen.getByTestId('sort-created_at'))
fireEvent.click(screen.getByTestId('sort-name'))
expect(mockHandleSort).toHaveBeenCalledWith('created_at')
expect(mockHandleSort).toHaveBeenCalledWith('name')
})
})
@@ -225,6 +229,7 @@ describe('DocumentList', () => {
sortField: null,
sortOrder: 'desc',
handleSort: mockHandleSort,
sortedDocuments: [],
} as unknown as ReturnType<typeof useDocumentSort>)
render(<DocumentList {...defaultProps} documents={[]} />)

View File

@@ -2,7 +2,7 @@ import type { ReactNode } from 'react'
import type { Props as PaginationProps } from '@/app/components/base/pagination'
import type { SimpleDocumentDetail } from '@/models/datasets'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ChunkingMode, DataSourceType } from '@/models/datasets'
import DocumentList from '../../list'
@@ -13,7 +13,6 @@ vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
}),
useSearchParams: () => new URLSearchParams(),
}))
vi.mock('@/context/dataset-detail', () => ({
@@ -91,8 +90,8 @@ describe('DocumentList', () => {
pagination: defaultPagination,
onUpdate: vi.fn(),
onManageMetadata: vi.fn(),
remoteSortValue: '-created_at',
onSortChange: vi.fn(),
statusFilterValue: '',
remoteSortValue: '',
}
beforeEach(() => {
@@ -221,15 +220,16 @@ describe('DocumentList', () => {
expect(sortIcons.length).toBeGreaterThan(0)
})
it('should call onSortChange when sortable header is clicked', () => {
const onSortChange = vi.fn()
const { container } = render(<DocumentList {...defaultProps} onSortChange={onSortChange} />, { wrapper: createWrapper() })
it('should update sort order when sort header is clicked', () => {
render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
const sortableHeaders = container.querySelectorAll('thead button')
if (sortableHeaders.length > 0)
// Find and click a sort header by its parent div containing the label text
const sortableHeaders = document.querySelectorAll('[class*="cursor-pointer"]')
if (sortableHeaders.length > 0) {
fireEvent.click(sortableHeaders[0])
}
expect(onSortChange).toHaveBeenCalled()
expect(screen.getByRole('table')).toBeInTheDocument()
})
})
@@ -360,15 +360,13 @@ describe('DocumentList', () => {
expect(modal).not.toBeInTheDocument()
})
it('should show rename modal when rename button is clicked', async () => {
it('should show rename modal when rename button is clicked', () => {
const { container } = render(<DocumentList {...defaultProps} />, { wrapper: createWrapper() })
// Find and click the rename button in the first row
const renameButtons = container.querySelectorAll('.cursor-pointer.rounded-md')
if (renameButtons.length > 0) {
await act(async () => {
fireEvent.click(renameButtons[0])
})
fireEvent.click(renameButtons[0])
}
// After clicking rename, the modal should potentially be visible
@@ -386,7 +384,7 @@ describe('DocumentList', () => {
})
describe('Edit Metadata Modal', () => {
it('should handle edit metadata action', async () => {
it('should handle edit metadata action', () => {
const props = {
...defaultProps,
selectedIds: ['doc-1'],
@@ -395,9 +393,7 @@ describe('DocumentList', () => {
const editButton = screen.queryByRole('button', { name: /metadata/i })
if (editButton) {
await act(async () => {
fireEvent.click(editButton)
})
fireEvent.click(editButton)
}
expect(screen.getByRole('table')).toBeInTheDocument()
@@ -458,6 +454,16 @@ describe('DocumentList', () => {
expect(screen.getByRole('table')).toBeInTheDocument()
})
it('should handle status filter value', () => {
const props = {
...defaultProps,
statusFilterValue: 'completed',
}
render(<DocumentList {...props} />, { wrapper: createWrapper() })
expect(screen.getByRole('table')).toBeInTheDocument()
})
it('should handle remote sort value', () => {
const props = {
...defaultProps,

View File

@@ -7,13 +7,11 @@ import { DataSourceType } from '@/models/datasets'
import DocumentTableRow from '../document-table-row'
const mockPush = vi.fn()
let mockSearchParams = ''
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
}),
useSearchParams: () => new URLSearchParams(mockSearchParams),
}))
const createTestQueryClient = () => new QueryClient({
@@ -97,7 +95,6 @@ describe('DocumentTableRow', () => {
beforeEach(() => {
vi.clearAllMocks()
mockSearchParams = ''
})
describe('Rendering', () => {
@@ -189,15 +186,6 @@ describe('DocumentTableRow', () => {
expect(mockPush).toHaveBeenCalledWith('/datasets/custom-dataset/documents/custom-doc')
})
it('should preserve search params when navigating to detail', () => {
mockSearchParams = 'page=2&status=error'
render(<DocumentTableRow {...defaultProps} />, { wrapper: createWrapper() })
fireEvent.click(screen.getByRole('row'))
expect(mockPush).toHaveBeenCalledWith('/datasets/dataset-1/documents/doc-1?page=2&status=error')
})
})
describe('Word Count Display', () => {

View File

@@ -4,8 +4,8 @@ import SortHeader from '../sort-header'
describe('SortHeader', () => {
const defaultProps = {
field: 'created_at' as const,
label: 'Upload Time',
field: 'name' as const,
label: 'File Name',
currentSortField: null,
sortOrder: 'desc' as const,
onSort: vi.fn(),
@@ -14,12 +14,12 @@ describe('SortHeader', () => {
describe('rendering', () => {
it('should render the label', () => {
render(<SortHeader {...defaultProps} />)
expect(screen.getByText('Upload Time')).toBeInTheDocument()
expect(screen.getByText('File Name')).toBeInTheDocument()
})
it('should render the sort icon', () => {
const { container } = render(<SortHeader {...defaultProps} />)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).toBeInTheDocument()
})
})
@@ -27,13 +27,13 @@ describe('SortHeader', () => {
describe('inactive state', () => {
it('should have disabled text color when not active', () => {
const { container } = render(<SortHeader {...defaultProps} />)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-disabled')
})
it('should not be rotated when not active', () => {
const { container } = render(<SortHeader {...defaultProps} />)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).not.toHaveClass('rotate-180')
})
})
@@ -41,25 +41,25 @@ describe('SortHeader', () => {
describe('active state', () => {
it('should have tertiary text color when active', () => {
const { container } = render(
<SortHeader {...defaultProps} currentSortField="created_at" />,
<SortHeader {...defaultProps} currentSortField="name" />,
)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-tertiary')
})
it('should not be rotated when active and desc', () => {
const { container } = render(
<SortHeader {...defaultProps} currentSortField="created_at" sortOrder="desc" />,
<SortHeader {...defaultProps} currentSortField="name" sortOrder="desc" />,
)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).not.toHaveClass('rotate-180')
})
it('should be rotated when active and asc', () => {
const { container } = render(
<SortHeader {...defaultProps} currentSortField="created_at" sortOrder="asc" />,
<SortHeader {...defaultProps} currentSortField="name" sortOrder="asc" />,
)
const icon = container.querySelector('button span')
const icon = container.querySelector('svg')
expect(icon).toHaveClass('rotate-180')
})
})
@@ -69,22 +69,34 @@ describe('SortHeader', () => {
const onSort = vi.fn()
render(<SortHeader {...defaultProps} onSort={onSort} />)
fireEvent.click(screen.getByText('Upload Time'))
fireEvent.click(screen.getByText('File Name'))
expect(onSort).toHaveBeenCalledWith('created_at')
expect(onSort).toHaveBeenCalledWith('name')
})
it('should call onSort with correct field', () => {
const onSort = vi.fn()
render(<SortHeader {...defaultProps} field="hit_count" onSort={onSort} />)
render(<SortHeader {...defaultProps} field="word_count" onSort={onSort} />)
fireEvent.click(screen.getByText('Upload Time'))
fireEvent.click(screen.getByText('File Name'))
expect(onSort).toHaveBeenCalledWith('hit_count')
expect(onSort).toHaveBeenCalledWith('word_count')
})
})
describe('different fields', () => {
it('should work with word_count field', () => {
render(
<SortHeader
{...defaultProps}
field="word_count"
label="Words"
currentSortField="word_count"
/>,
)
expect(screen.getByText('Words')).toBeInTheDocument()
})
it('should work with hit_count field', () => {
render(
<SortHeader

View File

@@ -1,7 +1,8 @@
import type { FC } from 'react'
import type { SimpleDocumentDetail } from '@/models/datasets'
import { RiEditLine } from '@remixicon/react'
import { pick } from 'es-toolkit/object'
import { useRouter, useSearchParams } from 'next/navigation'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
@@ -61,15 +62,13 @@ const DocumentTableRow: FC<DocumentTableRowProps> = React.memo(({
const { t } = useTranslation()
const { formatTime } = useTimestamp()
const router = useRouter()
const searchParams = useSearchParams()
const isFile = doc.data_source_type === DataSourceType.FILE
const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : ''
const queryString = searchParams.toString()
const handleRowClick = useCallback(() => {
router.push(`/datasets/${datasetId}/documents/${doc.id}${queryString ? `?${queryString}` : ''}`)
}, [router, datasetId, doc.id, queryString])
router.push(`/datasets/${datasetId}/documents/${doc.id}`)
}, [router, datasetId, doc.id])
const handleCheckboxClick = useCallback((e: React.MouseEvent) => {
e.stopPropagation()
@@ -101,7 +100,7 @@ const DocumentTableRow: FC<DocumentTableRowProps> = React.memo(({
<DocumentSourceIcon doc={doc} fileType={fileType} />
</div>
<Tooltip popupContent={doc.name}>
<span className="grow truncate text-sm">{doc.name}</span>
<span className="grow-1 truncate text-sm">{doc.name}</span>
</Tooltip>
{doc.summary_index_status && (
<div className="ml-1 hidden shrink-0 group-hover:flex">
@@ -114,7 +113,7 @@ const DocumentTableRow: FC<DocumentTableRowProps> = React.memo(({
className="cursor-pointer rounded-md p-1 hover:bg-state-base-hover"
onClick={handleRenameClick}
>
<span className="i-ri-edit-line h-4 w-4 text-text-tertiary" />
<RiEditLine className="h-4 w-4 text-text-tertiary" />
</div>
</Tooltip>
</div>

View File

@@ -1,5 +1,6 @@
import type { FC } from 'react'
import type { SortField, SortOrder } from '../hooks'
import { RiArrowDownLine } from '@remixicon/react'
import * as React from 'react'
import { cn } from '@/utils/classnames'
@@ -22,20 +23,19 @@ const SortHeader: FC<SortHeaderProps> = React.memo(({
const isDesc = isActive && sortOrder === 'desc'
return (
<button
type="button"
className="flex items-center bg-transparent p-0 text-left hover:text-text-secondary"
<div
className="flex cursor-pointer items-center hover:text-text-secondary"
onClick={() => onSort(field)}
>
{label}
<span
<RiArrowDownLine
className={cn(
'i-ri-arrow-down-line ml-0.5 h-3 w-3 transition-all',
'ml-0.5 h-3 w-3 transition-all',
isActive ? 'text-text-tertiary' : 'text-text-disabled',
isActive && !isDesc ? 'rotate-180' : '',
)}
/>
</button>
</div>
)
})

View File

@@ -1,98 +1,340 @@
import type { SimpleDocumentDetail } from '@/models/datasets'
import { act, renderHook } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { describe, expect, it } from 'vitest'
import { useDocumentSort } from '../use-document-sort'
describe('useDocumentSort', () => {
describe('remote state parsing', () => {
it('should parse descending created_at sort', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-created_at',
onRemoteSortChange,
}))
type LocalDoc = SimpleDocumentDetail & { percent?: number }
expect(result.current.sortField).toBe('created_at')
const createMockDocument = (overrides: Partial<LocalDoc> = {}): LocalDoc => ({
id: 'doc1',
name: 'Test Document',
data_source_type: 'upload_file',
data_source_info: {},
data_source_detail_dict: {},
word_count: 100,
hit_count: 10,
created_at: 1000000,
position: 1,
doc_form: 'text_model',
enabled: true,
archived: false,
display_status: 'available',
created_from: 'api',
...overrides,
} as LocalDoc)
describe('useDocumentSort', () => {
describe('initial state', () => {
it('should return null sortField initially', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
expect(result.current.sortField).toBeNull()
expect(result.current.sortOrder).toBe('desc')
})
it('should parse ascending hit_count sort', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: 'hit_count',
onRemoteSortChange,
}))
it('should return documents unchanged when no sort is applied', () => {
const docs = [
createMockDocument({ id: 'doc1', name: 'B' }),
createMockDocument({ id: 'doc2', name: 'A' }),
]
expect(result.current.sortField).toBe('hit_count')
expect(result.current.sortOrder).toBe('asc')
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
expect(result.current.sortedDocuments).toEqual(docs)
})
})
describe('handleSort', () => {
it('should set sort field when called', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortField).toBe('name')
expect(result.current.sortOrder).toBe('desc')
})
it('should fallback to inactive field for unsupported sort key', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-name',
onRemoteSortChange,
}))
it('should toggle sort order when same field is clicked twice', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortOrder).toBe('desc')
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortOrder).toBe('asc')
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortOrder).toBe('desc')
})
it('should reset to desc when different field is selected', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('name')
})
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortOrder).toBe('asc')
act(() => {
result.current.handleSort('word_count')
})
expect(result.current.sortField).toBe('word_count')
expect(result.current.sortOrder).toBe('desc')
})
it('should not change state when null is passed', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort(null)
})
expect(result.current.sortField).toBeNull()
})
})
describe('sorting documents', () => {
const docs = [
createMockDocument({ id: 'doc1', name: 'Banana', word_count: 200, hit_count: 5, created_at: 3000 }),
createMockDocument({ id: 'doc2', name: 'Apple', word_count: 100, hit_count: 10, created_at: 1000 }),
createMockDocument({ id: 'doc3', name: 'Cherry', word_count: 300, hit_count: 1, created_at: 2000 }),
]
it('should sort by name descending', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('name')
})
const names = result.current.sortedDocuments.map(d => d.name)
expect(names).toEqual(['Cherry', 'Banana', 'Apple'])
})
it('should sort by name ascending', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('name')
})
act(() => {
result.current.handleSort('name')
})
const names = result.current.sortedDocuments.map(d => d.name)
expect(names).toEqual(['Apple', 'Banana', 'Cherry'])
})
it('should sort by word_count descending', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('word_count')
})
const counts = result.current.sortedDocuments.map(d => d.word_count)
expect(counts).toEqual([300, 200, 100])
})
it('should sort by hit_count ascending', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('hit_count')
})
act(() => {
result.current.handleSort('hit_count')
})
const counts = result.current.sortedDocuments.map(d => d.hit_count)
expect(counts).toEqual([1, 5, 10])
})
it('should sort by created_at descending', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('created_at')
})
const times = result.current.sortedDocuments.map(d => d.created_at)
expect(times).toEqual([3000, 2000, 1000])
})
})
describe('status filtering', () => {
const docs = [
createMockDocument({ id: 'doc1', display_status: 'available' }),
createMockDocument({ id: 'doc2', display_status: 'error' }),
createMockDocument({ id: 'doc3', display_status: 'available' }),
]
it('should not filter when statusFilterValue is empty', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
expect(result.current.sortedDocuments.length).toBe(3)
})
it('should not filter when statusFilterValue is all', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: 'all',
remoteSortValue: '',
}),
)
expect(result.current.sortedDocuments.length).toBe(3)
})
})
describe('remoteSortValue reset', () => {
it('should reset sort state when remoteSortValue changes', () => {
const { result, rerender } = renderHook(
({ remoteSortValue }) =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue,
}),
{ initialProps: { remoteSortValue: 'initial' } },
)
act(() => {
result.current.handleSort('name')
})
act(() => {
result.current.handleSort('name')
})
expect(result.current.sortField).toBe('name')
expect(result.current.sortOrder).toBe('asc')
rerender({ remoteSortValue: 'changed' })
expect(result.current.sortField).toBeNull()
expect(result.current.sortOrder).toBe('desc')
})
})
describe('handleSort', () => {
it('should switch to desc when selecting a different field', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-created_at',
onRemoteSortChange,
}))
describe('edge cases', () => {
it('should handle documents with missing values', () => {
const docs = [
createMockDocument({ id: 'doc1', name: undefined as unknown as string, word_count: undefined }),
createMockDocument({ id: 'doc2', name: 'Test', word_count: 100 }),
]
const { result } = renderHook(() =>
useDocumentSort({
documents: docs,
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('hit_count')
result.current.handleSort('name')
})
expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count')
expect(result.current.sortedDocuments.length).toBe(2)
})
it('should toggle desc -> asc when clicking active field', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-hit_count',
onRemoteSortChange,
}))
it('should handle empty documents array', () => {
const { result } = renderHook(() =>
useDocumentSort({
documents: [],
statusFilterValue: '',
remoteSortValue: '',
}),
)
act(() => {
result.current.handleSort('hit_count')
result.current.handleSort('name')
})
expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count')
})
it('should toggle asc -> desc when clicking active field', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: 'created_at',
onRemoteSortChange,
}))
act(() => {
result.current.handleSort('created_at')
})
expect(onRemoteSortChange).toHaveBeenCalledWith('-created_at')
})
it('should ignore null field', () => {
const onRemoteSortChange = vi.fn()
const { result } = renderHook(() => useDocumentSort({
remoteSortValue: '-created_at',
onRemoteSortChange,
}))
act(() => {
result.current.handleSort(null)
})
expect(onRemoteSortChange).not.toHaveBeenCalled()
expect(result.current.sortedDocuments).toEqual([])
})
})
})

View File

@@ -1,42 +1,102 @@
import { useCallback, useMemo } from 'react'
import type { SimpleDocumentDetail } from '@/models/datasets'
import { useCallback, useMemo, useRef, useState } from 'react'
import { normalizeStatusForQuery } from '@/app/components/datasets/documents/status-filter'
type RemoteSortField = 'hit_count' | 'created_at'
const REMOTE_SORT_FIELDS = new Set<RemoteSortField>(['hit_count', 'created_at'])
export type SortField = RemoteSortField | null
export type SortField = 'name' | 'word_count' | 'hit_count' | 'created_at' | null
export type SortOrder = 'asc' | 'desc'
type LocalDoc = SimpleDocumentDetail & { percent?: number }
type UseDocumentSortOptions = {
documents: LocalDoc[]
statusFilterValue: string
remoteSortValue: string
onRemoteSortChange: (nextSortValue: string) => void
}
export const useDocumentSort = ({
documents,
statusFilterValue,
remoteSortValue,
onRemoteSortChange,
}: UseDocumentSortOptions) => {
const sortOrder: SortOrder = remoteSortValue.startsWith('-') ? 'desc' : 'asc'
const sortKey = remoteSortValue.startsWith('-') ? remoteSortValue.slice(1) : remoteSortValue
const [sortField, setSortField] = useState<SortField>(null)
const [sortOrder, setSortOrder] = useState<SortOrder>('desc')
const prevRemoteSortValueRef = useRef(remoteSortValue)
const sortField = useMemo<SortField>(() => {
return REMOTE_SORT_FIELDS.has(sortKey as RemoteSortField) ? sortKey as RemoteSortField : null
}, [sortKey])
// Reset sort when remote sort changes
if (prevRemoteSortValueRef.current !== remoteSortValue) {
prevRemoteSortValueRef.current = remoteSortValue
setSortField(null)
setSortOrder('desc')
}
const handleSort = useCallback((field: SortField) => {
if (!field)
if (field === null)
return
if (sortField === field) {
const nextSortOrder = sortOrder === 'desc' ? 'asc' : 'desc'
onRemoteSortChange(nextSortOrder === 'desc' ? `-${field}` : field)
return
setSortOrder(prev => prev === 'asc' ? 'desc' : 'asc')
}
onRemoteSortChange(`-${field}`)
}, [onRemoteSortChange, sortField, sortOrder])
else {
setSortField(field)
setSortOrder('desc')
}
}, [sortField])
const sortedDocuments = useMemo(() => {
let filteredDocs = documents
if (statusFilterValue && statusFilterValue !== 'all') {
filteredDocs = filteredDocs.filter(doc =>
typeof doc.display_status === 'string'
&& normalizeStatusForQuery(doc.display_status) === statusFilterValue,
)
}
if (!sortField)
return filteredDocs
const sortedDocs = [...filteredDocs].sort((a, b) => {
let aValue: string | number
let bValue: string | number
switch (sortField) {
case 'name':
aValue = a.name?.toLowerCase() || ''
bValue = b.name?.toLowerCase() || ''
break
case 'word_count':
aValue = a.word_count || 0
bValue = b.word_count || 0
break
case 'hit_count':
aValue = a.hit_count || 0
bValue = b.hit_count || 0
break
case 'created_at':
aValue = a.created_at
bValue = b.created_at
break
default:
return 0
}
if (sortField === 'name') {
const result = (aValue as string).localeCompare(bValue as string)
return sortOrder === 'asc' ? result : -result
}
else {
const result = (aValue as number) - (bValue as number)
return sortOrder === 'asc' ? result : -result
}
})
return sortedDocs
}, [documents, sortField, sortOrder, statusFilterValue])
return {
sortField,
sortOrder,
handleSort,
sortedDocuments,
}
}

View File

@@ -14,7 +14,7 @@ import { useDatasetDetailContextWithSelector as useDatasetDetailContext } from '
import { ChunkingMode, DocumentActionType } from '@/models/datasets'
import BatchAction from '../detail/completed/common/batch-action'
import s from '../style.module.css'
import { DocumentTableRow, SortHeader } from './document-list/components'
import { DocumentTableRow, renderTdValue, SortHeader } from './document-list/components'
import { useDocumentActions, useDocumentSelection, useDocumentSort } from './document-list/hooks'
import RenameModal from './rename-modal'
@@ -29,8 +29,8 @@ type DocumentListProps = {
pagination: PaginationProps
onUpdate: () => void
onManageMetadata: () => void
statusFilterValue: string
remoteSortValue: string
onSortChange: (value: string) => void
}
/**
@@ -45,8 +45,8 @@ const DocumentList: FC<DocumentListProps> = ({
pagination,
onUpdate,
onManageMetadata,
statusFilterValue,
remoteSortValue,
onSortChange,
}) => {
const { t } = useTranslation()
const datasetConfig = useDatasetDetailContext(s => s.dataset)
@@ -55,9 +55,10 @@ const DocumentList: FC<DocumentListProps> = ({
const isQAMode = chunkingMode === ChunkingMode.qa
// Sorting
const { sortField, sortOrder, handleSort } = useDocumentSort({
const { sortField, sortOrder, handleSort, sortedDocuments } = useDocumentSort({
documents,
statusFilterValue,
remoteSortValue,
onRemoteSortChange: onSortChange,
})
// Selection
@@ -70,7 +71,7 @@ const DocumentList: FC<DocumentListProps> = ({
downloadableSelectedIds,
clearSelection,
} = useDocumentSelection({
documents,
documents: sortedDocuments,
selectedIds,
onSelectedIdChange,
})
@@ -134,10 +135,24 @@ const DocumentList: FC<DocumentListProps> = ({
</div>
</td>
<td>
{t('list.table.header.fileName', { ns: 'datasetDocuments' })}
<SortHeader
field="name"
label={t('list.table.header.fileName', { ns: 'datasetDocuments' })}
currentSortField={sortField}
sortOrder={sortOrder}
onSort={handleSort}
/>
</td>
<td className="w-[130px]">{t('list.table.header.chunkingMode', { ns: 'datasetDocuments' })}</td>
<td className="w-24">{t('list.table.header.words', { ns: 'datasetDocuments' })}</td>
<td className="w-24">
<SortHeader
field="word_count"
label={t('list.table.header.words', { ns: 'datasetDocuments' })}
currentSortField={sortField}
sortOrder={sortOrder}
onSort={handleSort}
/>
</td>
<td className="w-44">
<SortHeader
field="hit_count"
@@ -161,7 +176,7 @@ const DocumentList: FC<DocumentListProps> = ({
</tr>
</thead>
<tbody className="text-text-secondary">
{documents.map((doc, index) => (
{sortedDocuments.map((doc, index) => (
<DocumentTableRow
key={doc.id}
doc={doc}
@@ -233,3 +248,5 @@ const DocumentList: FC<DocumentListProps> = ({
}
export default DocumentList
export { renderTdValue }

View File

@@ -9,7 +9,6 @@ const mocks = vi.hoisted(() => {
documentError: null as Error | null,
documentMetadata: null as Record<string, unknown> | null,
media: 'desktop' as string,
searchParams: '' as string,
}
return {
state,
@@ -27,7 +26,6 @@ const mocks = vi.hoisted(() => {
// --- External mocks ---
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: mocks.push }),
useSearchParams: () => new URLSearchParams(mocks.state.searchParams),
}))
vi.mock('@/hooks/use-breakpoints', () => ({
@@ -195,7 +193,6 @@ describe('DocumentDetail', () => {
mocks.state.documentError = null
mocks.state.documentMetadata = null
mocks.state.media = 'desktop'
mocks.state.searchParams = ''
})
afterEach(() => {
@@ -289,23 +286,15 @@ describe('DocumentDetail', () => {
})
it('should toggle metadata panel when button clicked', () => {
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
const { container } = render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
expect(screen.getByTestId('metadata')).toBeInTheDocument()
fireEvent.click(screen.getByTestId('document-detail-metadata-toggle'))
const svgs = container.querySelectorAll('svg')
const toggleBtn = svgs[svgs.length - 1].closest('button')!
fireEvent.click(toggleBtn)
expect(screen.queryByTestId('metadata')).not.toBeInTheDocument()
})
it('should expose aria semantics for metadata toggle button', () => {
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
const toggle = screen.getByTestId('document-detail-metadata-toggle')
expect(toggle).toHaveAttribute('aria-label')
expect(toggle).toHaveAttribute('aria-pressed', 'true')
fireEvent.click(toggle)
expect(toggle).toHaveAttribute('aria-pressed', 'false')
})
it('should pass correct props to Metadata', () => {
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
const metadata = screen.getByTestId('metadata')
@@ -316,21 +305,20 @@ describe('DocumentDetail', () => {
describe('Navigation', () => {
it('should navigate back when back button clicked', () => {
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
fireEvent.click(screen.getByTestId('document-detail-back-button'))
const { container } = render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
const backBtn = container.querySelector('svg')!.parentElement!
fireEvent.click(backBtn)
expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents')
})
it('should expose aria label for back button', () => {
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
expect(screen.getByTestId('document-detail-back-button')).toHaveAttribute('aria-label')
})
it('should preserve query params when navigating back', () => {
mocks.state.searchParams = 'page=2&status=active'
render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
fireEvent.click(screen.getByTestId('document-detail-back-button'))
const origLocation = window.location
window.history.pushState({}, '', '?page=2&status=active')
const { container } = render(<DocumentDetail datasetId="ds-1" documentId="doc-1" />)
const backBtn = container.querySelector('svg')!.parentElement!
fireEvent.click(backBtn)
expect(mocks.push).toHaveBeenCalledWith('/datasets/ds-1/documents?page=2&status=active')
window.history.pushState({}, '', origLocation.href)
})
})

View File

@@ -1,7 +1,8 @@
'use client'
import type { FC } from 'react'
import type { DataSourceInfo, FileItem, FullDocumentDetail, LegacyDataSourceInfo } from '@/models/datasets'
import { useRouter, useSearchParams } from 'next/navigation'
import { RiArrowLeftLine, RiLayoutLeft2Line, RiLayoutRight2Line } from '@remixicon/react'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -34,7 +35,6 @@ type DocumentDetailProps = {
const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
const router = useRouter()
const searchParams = useSearchParams()
const { t } = useTranslation()
const media = useBreakpoints()
@@ -98,8 +98,11 @@ const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
})
const backToPrev = () => {
// Preserve pagination and filter states when navigating back
const searchParams = new URLSearchParams(window.location.search)
const queryString = searchParams.toString()
const backPath = `/datasets/${datasetId}/documents${queryString ? `?${queryString}` : ''}`
const separator = queryString ? '?' : ''
const backPath = `/datasets/${datasetId}/documents${separator}${queryString}`
router.push(backPath)
}
@@ -149,11 +152,6 @@ const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
return chunkMode === ChunkingMode.parentChild && parentMode === 'full-doc'
}, [documentDetail?.doc_form, parentMode])
const backButtonLabel = t('operation.back', { ns: 'common' })
const metadataToggleLabel = `${showMetadata
? t('operation.close', { ns: 'common' })
: t('operation.view', { ns: 'common' })} ${t('metadata.title', { ns: 'datasetDocuments' })}`
return (
<DocumentContext.Provider value={{
datasetId,
@@ -164,19 +162,9 @@ const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
>
<div className="flex h-full flex-col bg-background-default">
<div className="flex min-h-16 flex-wrap items-center justify-between border-b border-b-divider-subtle py-2.5 pl-3 pr-4">
<button
type="button"
data-testid="document-detail-back-button"
aria-label={backButtonLabel}
title={backButtonLabel}
onClick={backToPrev}
className="flex h-8 w-8 shrink-0 cursor-pointer items-center justify-center rounded-full hover:bg-components-button-tertiary-bg"
>
<span
aria-hidden="true"
className="i-ri-arrow-left-line h-4 w-4 text-components-button-ghost-text hover:text-text-tertiary"
/>
</button>
<div onClick={backToPrev} className="flex h-8 w-8 shrink-0 cursor-pointer items-center justify-center rounded-full hover:bg-components-button-tertiary-bg">
<RiArrowLeftLine className="h-4 w-4 text-components-button-ghost-text hover:text-text-tertiary" />
</div>
<DocumentTitle
datasetId={datasetId}
extension={documentUploadFile?.extension}
@@ -228,17 +216,13 @@ const DocumentDetail: FC<DocumentDetailProps> = ({ datasetId, documentId }) => {
/>
<button
type="button"
data-testid="document-detail-metadata-toggle"
aria-label={metadataToggleLabel}
aria-pressed={showMetadata}
title={metadataToggleLabel}
className={style.layoutRightIcon}
onClick={() => setShowMetadata(!showMetadata)}
>
{
showMetadata
? <span aria-hidden="true" className="i-ri-layout-left-2-line h-4 w-4 text-components-button-secondary-text" />
: <span aria-hidden="true" className="i-ri-layout-right-2-line h-4 w-4 text-components-button-secondary-text" />
? <RiLayoutLeft2Line className="h-4 w-4 text-components-button-secondary-text" />
: <RiLayoutRight2Line className="h-4 w-4 text-components-button-secondary-text" />
}
</button>
</div>

View File

@@ -0,0 +1,439 @@
import type { DocumentListQuery } from '../use-document-list-query-state'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useDocumentListQueryState from '../use-document-list-query-state'
const mockPush = vi.fn()
const mockSearchParams = new URLSearchParams()
vi.mock('@/models/datasets', () => ({
DisplayStatusList: [
'queuing',
'indexing',
'paused',
'error',
'available',
'enabled',
'disabled',
'archived',
],
}))
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: mockPush }),
usePathname: () => '/datasets/test-id/documents',
useSearchParams: () => mockSearchParams,
}))
describe('useDocumentListQueryState', () => {
beforeEach(() => {
vi.clearAllMocks()
// Reset mock search params to empty
for (const key of [...mockSearchParams.keys()])
mockSearchParams.delete(key)
})
// Tests for parseParams (exposed via the query property)
describe('parseParams (via query)', () => {
it('should return default query when no search params present', () => {
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query).toEqual({
page: 1,
limit: 10,
keyword: '',
status: 'all',
sort: '-created_at',
})
})
it('should parse page from search params', () => {
mockSearchParams.set('page', '3')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.page).toBe(3)
})
it('should default page to 1 when page is zero', () => {
mockSearchParams.set('page', '0')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.page).toBe(1)
})
it('should default page to 1 when page is negative', () => {
mockSearchParams.set('page', '-5')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.page).toBe(1)
})
it('should default page to 1 when page is NaN', () => {
mockSearchParams.set('page', 'abc')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.page).toBe(1)
})
it('should parse limit from search params', () => {
mockSearchParams.set('limit', '50')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(50)
})
it('should default limit to 10 when limit is zero', () => {
mockSearchParams.set('limit', '0')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(10)
})
it('should default limit to 10 when limit exceeds 100', () => {
mockSearchParams.set('limit', '101')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(10)
})
it('should default limit to 10 when limit is negative', () => {
mockSearchParams.set('limit', '-1')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(10)
})
it('should accept limit at boundary 100', () => {
mockSearchParams.set('limit', '100')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(100)
})
it('should accept limit at boundary 1', () => {
mockSearchParams.set('limit', '1')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.limit).toBe(1)
})
it('should parse and decode keyword from search params', () => {
mockSearchParams.set('keyword', encodeURIComponent('hello world'))
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.keyword).toBe('hello world')
})
it('should return empty keyword when not present', () => {
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.keyword).toBe('')
})
it('should sanitize status from search params', () => {
mockSearchParams.set('status', 'available')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.status).toBe('available')
})
it('should fallback status to all for unknown status', () => {
mockSearchParams.set('status', 'badvalue')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.status).toBe('all')
})
it('should resolve active status alias to available', () => {
mockSearchParams.set('status', 'active')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.status).toBe('available')
})
it('should parse valid sort value from search params', () => {
mockSearchParams.set('sort', 'hit_count')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.sort).toBe('hit_count')
})
it('should default sort to -created_at for invalid sort value', () => {
mockSearchParams.set('sort', 'invalid_sort')
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.sort).toBe('-created_at')
})
it('should default sort to -created_at when not present', () => {
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.sort).toBe('-created_at')
})
it.each([
'-created_at',
'created_at',
'-hit_count',
'hit_count',
] as const)('should accept valid sort value %s', (sortValue) => {
mockSearchParams.set('sort', sortValue)
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current.query.sort).toBe(sortValue)
})
})
// Tests for updateQuery
describe('updateQuery', () => {
it('should call router.push with updated params when page is changed', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ page: 3 })
})
expect(mockPush).toHaveBeenCalledTimes(1)
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('page=3')
})
it('should call router.push with scroll false', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ page: 2 })
})
expect(mockPush).toHaveBeenCalledWith(
expect.any(String),
{ scroll: false },
)
})
it('should set status in URL when status is not all', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ status: 'error' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('status=error')
})
it('should not set status in URL when status is all', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ status: 'all' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('status=')
})
it('should set sort in URL when sort is not the default -created_at', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ sort: 'hit_count' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('sort=hit_count')
})
it('should not set sort in URL when sort is default -created_at', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ sort: '-created_at' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('sort=')
})
it('should encode keyword in URL when keyword is provided', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ keyword: 'test query' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
// Source code applies encodeURIComponent before setting in URLSearchParams
expect(pushedUrl).toContain('keyword=')
const params = new URLSearchParams(pushedUrl.split('?')[1])
// params.get decodes one layer, but the value was pre-encoded with encodeURIComponent
expect(decodeURIComponent(params.get('keyword')!)).toBe('test query')
})
it('should remove keyword from URL when keyword is empty', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ keyword: '' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('keyword=')
})
it('should sanitize invalid status to all and not include in URL', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ status: 'invalidstatus' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('status=')
})
it('should sanitize invalid sort to -created_at and not include in URL', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('sort=')
})
it('should omit page and limit when they are default and no keyword', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ page: 1, limit: 10 })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).not.toContain('page=')
expect(pushedUrl).not.toContain('limit=')
})
it('should include page and limit when page is greater than 1', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ page: 2 })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('page=2')
expect(pushedUrl).toContain('limit=10')
})
it('should include page and limit when limit is non-default', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ limit: 25 })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('page=1')
expect(pushedUrl).toContain('limit=25')
})
it('should include page and limit when keyword is provided', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ keyword: 'search' })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toContain('page=1')
expect(pushedUrl).toContain('limit=10')
})
it('should use pathname prefix in pushed URL', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({ page: 2 })
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toMatch(/^\/datasets\/test-id\/documents/)
})
it('should push path without query string when all values are defaults', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.updateQuery({})
})
const pushedUrl = mockPush.mock.calls[0][0] as string
expect(pushedUrl).toBe('/datasets/test-id/documents')
})
})
// Tests for resetQuery
describe('resetQuery', () => {
it('should push URL with default query params when called', () => {
mockSearchParams.set('page', '5')
mockSearchParams.set('status', 'error')
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.resetQuery()
})
expect(mockPush).toHaveBeenCalledTimes(1)
const pushedUrl = mockPush.mock.calls[0][0] as string
// Default query has all defaults, so no params should be in the URL
expect(pushedUrl).toBe('/datasets/test-id/documents')
})
it('should call router.push with scroll false when resetting', () => {
const { result } = renderHook(() => useDocumentListQueryState())
act(() => {
result.current.resetQuery()
})
expect(mockPush).toHaveBeenCalledWith(
expect.any(String),
{ scroll: false },
)
})
})
// Tests for return value stability
describe('return value', () => {
it('should return query, updateQuery, and resetQuery', () => {
const { result } = renderHook(() => useDocumentListQueryState())
expect(result.current).toHaveProperty('query')
expect(result.current).toHaveProperty('updateQuery')
expect(result.current).toHaveProperty('resetQuery')
expect(typeof result.current.updateQuery).toBe('function')
expect(typeof result.current.resetQuery).toBe('function')
})
})
})

View File

@@ -1,426 +0,0 @@
import type { DocumentListQuery } from '../use-document-list-query-state'
import { act, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { renderHookWithNuqs } from '@/test/nuqs-testing'
import { useDocumentListQueryState } from '../use-document-list-query-state'
vi.mock('@/models/datasets', () => ({
DisplayStatusList: [
'queuing',
'indexing',
'paused',
'error',
'available',
'enabled',
'disabled',
'archived',
],
}))
const renderWithAdapter = (searchParams = '') => {
return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams })
}
describe('useDocumentListQueryState', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('query parsing', () => {
it('should return default query when no search params present', () => {
const { result } = renderWithAdapter()
expect(result.current.query).toEqual({
page: 1,
limit: 10,
keyword: '',
status: 'all',
sort: '-created_at',
})
})
it('should parse page from search params', () => {
const { result } = renderWithAdapter('?page=3')
expect(result.current.query.page).toBe(3)
})
it('should default page to 1 when page is zero', () => {
const { result } = renderWithAdapter('?page=0')
expect(result.current.query.page).toBe(1)
})
it('should default page to 1 when page is negative', () => {
const { result } = renderWithAdapter('?page=-5')
expect(result.current.query.page).toBe(1)
})
it('should default page to 1 when page is NaN', () => {
const { result } = renderWithAdapter('?page=abc')
expect(result.current.query.page).toBe(1)
})
it('should parse limit from search params', () => {
const { result } = renderWithAdapter('?limit=50')
expect(result.current.query.limit).toBe(50)
})
it('should default limit to 10 when limit is zero', () => {
const { result } = renderWithAdapter('?limit=0')
expect(result.current.query.limit).toBe(10)
})
it('should default limit to 10 when limit exceeds 100', () => {
const { result } = renderWithAdapter('?limit=101')
expect(result.current.query.limit).toBe(10)
})
it('should default limit to 10 when limit is negative', () => {
const { result } = renderWithAdapter('?limit=-1')
expect(result.current.query.limit).toBe(10)
})
it('should accept limit at boundary 100', () => {
const { result } = renderWithAdapter('?limit=100')
expect(result.current.query.limit).toBe(100)
})
it('should accept limit at boundary 1', () => {
const { result } = renderWithAdapter('?limit=1')
expect(result.current.query.limit).toBe(1)
})
it('should parse keyword from search params', () => {
const { result } = renderWithAdapter('?keyword=hello+world')
expect(result.current.query.keyword).toBe('hello world')
})
it('should preserve legacy double-encoded keyword text after URL decoding', () => {
const { result } = renderWithAdapter('?keyword=test%2520query')
expect(result.current.query.keyword).toBe('test%20query')
})
it('should return empty keyword when not present', () => {
const { result } = renderWithAdapter()
expect(result.current.query.keyword).toBe('')
})
it('should sanitize status from search params', () => {
const { result } = renderWithAdapter('?status=available')
expect(result.current.query.status).toBe('available')
})
it('should fallback status to all for unknown status', () => {
const { result } = renderWithAdapter('?status=badvalue')
expect(result.current.query.status).toBe('all')
})
it('should resolve active status alias to available', () => {
const { result } = renderWithAdapter('?status=active')
expect(result.current.query.status).toBe('available')
})
it('should parse valid sort value from search params', () => {
const { result } = renderWithAdapter('?sort=hit_count')
expect(result.current.query.sort).toBe('hit_count')
})
it('should default sort to -created_at for invalid sort value', () => {
const { result } = renderWithAdapter('?sort=invalid_sort')
expect(result.current.query.sort).toBe('-created_at')
})
it('should default sort to -created_at when not present', () => {
const { result } = renderWithAdapter()
expect(result.current.query.sort).toBe('-created_at')
})
it.each([
'-created_at',
'created_at',
'-hit_count',
'hit_count',
] as const)('should accept valid sort value %s', (sortValue) => {
const { result } = renderWithAdapter(`?sort=${sortValue}`)
expect(result.current.query.sort).toBe(sortValue)
})
})
describe('updateQuery', () => {
it('should update page in state when page is changed', () => {
const { result } = renderWithAdapter()
act(() => {
result.current.updateQuery({ page: 3 })
})
expect(result.current.query.page).toBe(3)
})
it('should sync page to URL with push history', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('page')).toBe('2')
expect(update.options.history).toBe('push')
})
it('should set status in URL when status is not all', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ status: 'error' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('status')).toBe('error')
})
it('should not set status in URL when status is all', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ status: 'all' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('status')).toBe(false)
})
it('should set sort in URL when sort is not the default -created_at', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ sort: 'hit_count' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('sort')).toBe('hit_count')
})
it('should not set sort in URL when sort is default -created_at', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ sort: '-created_at' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('sort')).toBe(false)
})
it('should set keyword in URL when keyword is provided', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ keyword: 'test query' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('keyword')).toBe('test query')
expect(update.options.history).toBe('replace')
})
it('should use replace history when keyword update also resets page', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?page=3')
act(() => {
result.current.updateQuery({ keyword: 'hello', page: 1 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('keyword')).toBe('hello')
expect(update.searchParams.has('page')).toBe(false)
expect(update.options.history).toBe('replace')
})
it('should remove keyword from URL when keyword is empty', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing')
act(() => {
result.current.updateQuery({ keyword: '' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('keyword')).toBe(false)
expect(update.options.history).toBe('replace')
})
it('should remove keyword from URL when keyword contains only whitespace', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?keyword=existing')
act(() => {
result.current.updateQuery({ keyword: ' ' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('keyword')).toBe(false)
expect(result.current.query.keyword).toBe('')
})
it('should preserve literal percent-encoded-like keyword values', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ keyword: '%2F' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('keyword')).toBe('%2F')
expect(result.current.query.keyword).toBe('%2F')
})
it('should keep keyword text unchanged when updating query from legacy URL', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?keyword=test%2520query')
act(() => {
result.current.updateQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
expect(result.current.query.keyword).toBe('test%20query')
})
it('should sanitize invalid status to all and not include in URL', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ status: 'invalidstatus' })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('status')).toBe(false)
})
it('should sanitize invalid sort to -created_at and not include in URL', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ sort: 'invalidsort' as DocumentListQuery['sort'] })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('sort')).toBe(false)
})
it('should not include page in URL when page is default', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ page: 1 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('page')).toBe(false)
})
it('should include page in URL when page is greater than 1', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('page')).toBe('2')
})
it('should include limit in URL when limit is non-default', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ limit: 25 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.get('limit')).toBe('25')
})
it('should sanitize invalid page to default and omit page from URL', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ page: -1 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('page')).toBe(false)
expect(result.current.query.page).toBe(1)
})
it('should sanitize invalid limit to default and omit limit from URL', async () => {
const { result, onUrlUpdate } = renderWithAdapter()
act(() => {
result.current.updateQuery({ limit: 999 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('limit')).toBe(false)
expect(result.current.query.limit).toBe(10)
})
})
describe('resetQuery', () => {
it('should reset all values to defaults', () => {
const { result } = renderWithAdapter('?page=5&status=error&sort=hit_count')
act(() => {
result.current.resetQuery()
})
expect(result.current.query).toEqual({
page: 1,
limit: 10,
keyword: '',
status: 'all',
sort: '-created_at',
})
})
it('should clear all params from URL when called', async () => {
const { result, onUrlUpdate } = renderWithAdapter('?page=5&status=error')
act(() => {
result.current.resetQuery()
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0]
expect(update.searchParams.has('page')).toBe(false)
expect(update.searchParams.has('status')).toBe(false)
})
})
describe('return value', () => {
it('should return query, updateQuery, and resetQuery', () => {
const { result } = renderWithAdapter()
expect(result.current).toHaveProperty('query')
expect(result.current).toHaveProperty('updateQuery')
expect(result.current).toHaveProperty('resetQuery')
expect(typeof result.current.updateQuery).toBe('function')
expect(typeof result.current.resetQuery).toBe('function')
})
})
})

View File

@@ -1,10 +1,12 @@
import type { DocumentListQuery } from '../use-document-list-query-state'
import type { DocumentListResponse } from '@/models/datasets'
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useDocumentsPageState } from '../use-documents-page-state'
const mockUpdateQuery = vi.fn()
const mockResetQuery = vi.fn()
let mockQuery: DocumentListQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' }
vi.mock('@/models/datasets', () => ({
@@ -20,70 +22,151 @@ vi.mock('@/models/datasets', () => ({
],
}))
vi.mock('ahooks', () => ({
useDebounce: (value: unknown, _options?: { wait?: number }) => value,
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/datasets/test-id/documents',
useSearchParams: () => new URLSearchParams(),
}))
vi.mock('../use-document-list-query-state', async () => {
const React = await import('react')
// Mock ahooks debounce utilities: required because tests capture the debounce
// callback reference to invoke it synchronously, bypassing real timer delays.
let capturedDebounceFnCallback: (() => void) | null = null
vi.mock('ahooks', () => ({
useDebounce: (value: unknown, _options?: { wait?: number }) => value,
useDebounceFn: (fn: () => void, _options?: { wait?: number }) => {
capturedDebounceFnCallback = fn
return { run: fn, cancel: vi.fn(), flush: vi.fn() }
},
}))
// Mock the dependent hook
vi.mock('../use-document-list-query-state', () => ({
default: () => ({
query: mockQuery,
updateQuery: mockUpdateQuery,
resetQuery: mockResetQuery,
}),
}))
// Factory for creating DocumentListResponse test data
function createDocumentListResponse(overrides: Partial<DocumentListResponse> = {}): DocumentListResponse {
return {
useDocumentListQueryState: () => {
const [query, setQuery] = React.useState<DocumentListQuery>(mockQuery)
return {
query,
updateQuery: (updates: Partial<DocumentListQuery>) => {
mockUpdateQuery(updates)
setQuery(prev => ({ ...prev, ...updates }))
},
}
},
data: [],
has_more: false,
total: 0,
page: 1,
limit: 10,
...overrides,
}
})
}
// Factory for creating a minimal document item
function createDocumentItem(overrides: Record<string, unknown> = {}) {
return {
id: `doc-${Math.random().toString(36).slice(2, 8)}`,
name: 'test-doc.txt',
indexing_status: 'completed' as string,
display_status: 'available' as string,
enabled: true,
archived: false,
word_count: 100,
created_at: Date.now(),
updated_at: Date.now(),
created_from: 'web' as const,
created_by: 'user-1',
dataset_process_rule_id: 'rule-1',
doc_form: 'text_model' as const,
doc_language: 'en',
position: 1,
data_source_type: 'upload_file',
...overrides,
}
}
describe('useDocumentsPageState', () => {
beforeEach(() => {
vi.clearAllMocks()
capturedDebounceFnCallback = null
mockQuery = { page: 1, limit: 10, keyword: '', status: 'all', sort: '-created_at' }
})
// Initial state verification
describe('initial state', () => {
it('should return correct initial query-derived state', () => {
it('should return correct initial search state', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.inputValue).toBe('')
expect(result.current.searchValue).toBe('')
expect(result.current.debouncedSearchValue).toBe('')
})
it('should return correct initial filter and sort state', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.statusFilterValue).toBe('all')
expect(result.current.sortValue).toBe('-created_at')
expect(result.current.normalizedStatusFilterValue).toBe('all')
})
it('should return correct initial pagination state', () => {
const { result } = renderHook(() => useDocumentsPageState())
// page is query.page - 1 = 0
expect(result.current.currPage).toBe(0)
expect(result.current.limit).toBe(10)
})
it('should return correct initial selection state', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.selectedIds).toEqual([])
})
it('should initialize from non-default query values', () => {
mockQuery = {
page: 3,
limit: 25,
keyword: 'initial',
status: 'enabled',
sort: 'hit_count',
}
it('should return correct initial polling state', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.timerCanRun).toBe(true)
})
it('should initialize from query when query has keyword', () => {
mockQuery = { ...mockQuery, keyword: 'initial search' }
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.inputValue).toBe('initial')
expect(result.current.currPage).toBe(2)
expect(result.current.inputValue).toBe('initial search')
expect(result.current.searchValue).toBe('initial search')
})
it('should initialize pagination from query with non-default page', () => {
mockQuery = { ...mockQuery, page: 3, limit: 25 }
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.currPage).toBe(2) // page - 1
expect(result.current.limit).toBe(25)
expect(result.current.statusFilterValue).toBe('enabled')
expect(result.current.normalizedStatusFilterValue).toBe('available')
})
it('should initialize status filter from query', () => {
mockQuery = { ...mockQuery, status: 'error' }
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.statusFilterValue).toBe('error')
})
it('should initialize sort from query', () => {
mockQuery = { ...mockQuery, sort: 'hit_count' }
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.sortValue).toBe('hit_count')
})
})
// Handler behaviors
describe('handleInputChange', () => {
it('should update keyword and reset page', () => {
it('should update input value when called', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -91,59 +174,30 @@ describe('useDocumentsPageState', () => {
})
expect(result.current.inputValue).toBe('new value')
expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'new value', page: 1 })
})
it('should clear selected ids when keyword changes', () => {
it('should trigger debounced search callback when called', () => {
const { result } = renderHook(() => useDocumentsPageState())
// First call sets inputValue and triggers the debounced fn
act(() => {
result.current.setSelectedIds(['doc-1'])
})
expect(result.current.selectedIds).toEqual(['doc-1'])
act(() => {
result.current.handleInputChange('keyword')
result.current.handleInputChange('search term')
})
expect(result.current.selectedIds).toEqual([])
})
it('should keep selected ids when keyword is unchanged', () => {
mockQuery = { ...mockQuery, keyword: 'same' }
const { result } = renderHook(() => useDocumentsPageState())
// The debounced fn captures inputValue from its render closure.
// After re-render with new inputValue, calling the captured callback again
// should reflect the updated state.
act(() => {
result.current.setSelectedIds(['doc-1'])
if (capturedDebounceFnCallback)
capturedDebounceFnCallback()
})
act(() => {
result.current.handleInputChange('same')
})
expect(result.current.selectedIds).toEqual(['doc-1'])
expect(mockUpdateQuery).toHaveBeenCalledWith({ keyword: 'same', page: 1 })
expect(result.current.searchValue).toBe('search term')
})
})
describe('handleStatusFilterChange', () => {
it('should sanitize status, reset page, and clear selection', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.setSelectedIds(['doc-1'])
})
act(() => {
result.current.handleStatusFilterChange('invalid')
})
expect(result.current.statusFilterValue).toBe('all')
expect(result.current.selectedIds).toEqual([])
expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 })
})
it('should update to valid status value', () => {
it('should update status filter value when called with valid status', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -151,23 +205,61 @@ describe('useDocumentsPageState', () => {
})
expect(result.current.statusFilterValue).toBe('error')
})
it('should reset page to 0 when status filter changes', () => {
mockQuery = { ...mockQuery, page: 3 }
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('error')
})
expect(result.current.currPage).toBe(0)
})
it('should call updateQuery with sanitized status and page 1', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('error')
})
expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'error', page: 1 })
})
it('should sanitize invalid status to all', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('invalid')
})
expect(result.current.statusFilterValue).toBe('all')
expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 })
})
})
describe('handleStatusFilterClear', () => {
it('should reset status to all when status is not all', () => {
mockQuery = { ...mockQuery, status: 'error' }
it('should set status to all and reset page when status is not all', () => {
const { result } = renderHook(() => useDocumentsPageState())
// First set a non-all status
act(() => {
result.current.handleStatusFilterChange('error')
})
vi.clearAllMocks()
// Then clear
act(() => {
result.current.handleStatusFilterClear()
})
expect(result.current.statusFilterValue).toBe('all')
expect(mockUpdateQuery).toHaveBeenCalledWith({ status: 'all', page: 1 })
})
it('should do nothing when status is already all', () => {
it('should not call updateQuery when status is already all', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -179,7 +271,7 @@ describe('useDocumentsPageState', () => {
})
describe('handleSortChange', () => {
it('should update sort and reset page when sort changes', () => {
it('should update sort value and call updateQuery when value changes', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -190,7 +282,18 @@ describe('useDocumentsPageState', () => {
expect(mockUpdateQuery).toHaveBeenCalledWith({ sort: 'hit_count', page: 1 })
})
it('should ignore sort update when value is unchanged', () => {
it('should reset page to 0 when sort changes', () => {
mockQuery = { ...mockQuery, page: 5 }
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleSortChange('hit_count')
})
expect(result.current.currPage).toBe(0)
})
it('should not call updateQuery when sort value is same as current', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -201,8 +304,8 @@ describe('useDocumentsPageState', () => {
})
})
describe('pagination handlers', () => {
it('should update page with one-based value', () => {
describe('handlePageChange', () => {
it('should update current page and call updateQuery', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -210,10 +313,23 @@ describe('useDocumentsPageState', () => {
})
expect(result.current.currPage).toBe(2)
expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 })
expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // newPage + 1
})
it('should update limit and reset page', () => {
it('should handle page 0 (first page)', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handlePageChange(0)
})
expect(result.current.currPage).toBe(0)
expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 })
})
})
describe('handleLimitChange', () => {
it('should update limit, reset page to 0, and call updateQuery', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
@@ -226,29 +342,359 @@ describe('useDocumentsPageState', () => {
})
})
// Selection state
describe('selection state', () => {
it('should update selectedIds via setSelectedIds', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.setSelectedIds(['doc-1', 'doc-2'])
})
expect(result.current.selectedIds).toEqual(['doc-1', 'doc-2'])
})
})
// Polling state management
describe('updatePollingState', () => {
it('should not update timer when documentsRes is undefined', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState(undefined)
})
// timerCanRun remains true (initial value)
expect(result.current.timerCanRun).toBe(true)
})
it('should not update timer when documentsRes.data is undefined', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState({ data: undefined } as unknown as DocumentListResponse)
})
expect(result.current.timerCanRun).toBe(true)
})
it('should set timerCanRun to false when all documents are completed and status filter is all', () => {
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'completed' }),
createDocumentItem({ indexing_status: 'completed' }),
] as DocumentListResponse['data'],
total: 2,
})
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState(response)
})
expect(result.current.timerCanRun).toBe(false)
})
it('should set timerCanRun to true when some documents are not completed', () => {
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'completed' }),
createDocumentItem({ indexing_status: 'indexing' }),
] as DocumentListResponse['data'],
total: 2,
})
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState(response)
})
expect(result.current.timerCanRun).toBe(true)
})
it('should count paused documents as completed for polling purposes', () => {
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'paused' }),
createDocumentItem({ indexing_status: 'completed' }),
] as DocumentListResponse['data'],
total: 2,
})
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState(response)
})
// All docs are "embedded" (completed, paused, error), so hasIncomplete = false
// statusFilter is 'all', so shouldForcePolling = false
expect(result.current.timerCanRun).toBe(false)
})
it('should count error documents as completed for polling purposes', () => {
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'error' }),
createDocumentItem({ indexing_status: 'completed' }),
] as DocumentListResponse['data'],
total: 2,
})
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.updatePollingState(response)
})
expect(result.current.timerCanRun).toBe(false)
})
it('should force polling when status filter is a transient status (queuing)', () => {
const { result } = renderHook(() => useDocumentsPageState())
// Set status filter to queuing
act(() => {
result.current.handleStatusFilterChange('queuing')
})
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'completed' }),
] as DocumentListResponse['data'],
total: 1,
})
act(() => {
result.current.updatePollingState(response)
})
// shouldForcePolling = true (queuing is transient), hasIncomplete = false
// timerCanRun = true || false = true
expect(result.current.timerCanRun).toBe(true)
})
it('should force polling when status filter is indexing', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('indexing')
})
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'completed' }),
] as DocumentListResponse['data'],
total: 1,
})
act(() => {
result.current.updatePollingState(response)
})
expect(result.current.timerCanRun).toBe(true)
})
it('should force polling when status filter is paused', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('paused')
})
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'paused' }),
] as DocumentListResponse['data'],
total: 1,
})
act(() => {
result.current.updatePollingState(response)
})
expect(result.current.timerCanRun).toBe(true)
})
it('should not force polling when status filter is a non-transient status (error)', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('error')
})
const response = createDocumentListResponse({
data: [
createDocumentItem({ indexing_status: 'error' }),
] as DocumentListResponse['data'],
total: 1,
})
act(() => {
result.current.updatePollingState(response)
})
// shouldForcePolling = false (error is not transient), hasIncomplete = false (error is embedded)
expect(result.current.timerCanRun).toBe(false)
})
it('should set timerCanRun to true when data is empty and filter is transient', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('indexing')
})
const response = createDocumentListResponse({ data: [] as DocumentListResponse['data'], total: 0 })
act(() => {
result.current.updatePollingState(response)
})
// shouldForcePolling = true (indexing is transient), hasIncomplete = false (0 !== 0 is false)
expect(result.current.timerCanRun).toBe(true)
})
})
// Page adjustment
describe('adjustPageForTotal', () => {
it('should not adjust page when documentsRes is undefined', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.adjustPageForTotal(undefined)
})
expect(result.current.currPage).toBe(0)
})
it('should not adjust page when currPage is within total pages', () => {
const { result } = renderHook(() => useDocumentsPageState())
const response = createDocumentListResponse({ total: 20 })
act(() => {
result.current.adjustPageForTotal(response)
})
// currPage is 0, totalPages is 2, so no adjustment needed
expect(result.current.currPage).toBe(0)
})
it('should adjust page to last page when currPage exceeds total pages', () => {
mockQuery = { ...mockQuery, page: 6 }
const { result } = renderHook(() => useDocumentsPageState())
// currPage should be 5 (page - 1)
expect(result.current.currPage).toBe(5)
const response = createDocumentListResponse({ total: 30 }) // 30/10 = 3 pages
act(() => {
result.current.adjustPageForTotal(response)
})
// currPage (5) + 1 > totalPages (3), so adjust to totalPages - 1 = 2
expect(result.current.currPage).toBe(2)
expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 3 }) // handlePageChange passes newPage + 1
})
it('should adjust page to 0 when total is 0 and currPage > 0', () => {
mockQuery = { ...mockQuery, page: 3 }
const { result } = renderHook(() => useDocumentsPageState())
const response = createDocumentListResponse({ total: 0 })
act(() => {
result.current.adjustPageForTotal(response)
})
// totalPages = 0, so adjust to max(0 - 1, 0) = 0
expect(result.current.currPage).toBe(0)
expect(mockUpdateQuery).toHaveBeenCalledWith({ page: 1 })
})
it('should not adjust page when currPage is 0 even if total is 0', () => {
const { result } = renderHook(() => useDocumentsPageState())
const response = createDocumentListResponse({ total: 0 })
act(() => {
result.current.adjustPageForTotal(response)
})
// currPage is 0, condition is currPage > 0 so no adjustment
expect(mockUpdateQuery).not.toHaveBeenCalled()
})
})
// Normalized status filter value
describe('normalizedStatusFilterValue', () => {
it('should return all for default status', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(result.current.normalizedStatusFilterValue).toBe('all')
})
it('should normalize enabled to available', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('enabled')
})
expect(result.current.normalizedStatusFilterValue).toBe('available')
})
it('should return non-aliased status as-is', () => {
const { result } = renderHook(() => useDocumentsPageState())
act(() => {
result.current.handleStatusFilterChange('error')
})
expect(result.current.normalizedStatusFilterValue).toBe('error')
})
})
// Return value shape
describe('return value', () => {
it('should return all expected properties', () => {
const { result } = renderHook(() => useDocumentsPageState())
// Search state
expect(result.current).toHaveProperty('inputValue')
expect(result.current).toHaveProperty('searchValue')
expect(result.current).toHaveProperty('debouncedSearchValue')
expect(result.current).toHaveProperty('handleInputChange')
// Filter & sort state
expect(result.current).toHaveProperty('statusFilterValue')
expect(result.current).toHaveProperty('sortValue')
expect(result.current).toHaveProperty('normalizedStatusFilterValue')
expect(result.current).toHaveProperty('handleStatusFilterChange')
expect(result.current).toHaveProperty('handleStatusFilterClear')
expect(result.current).toHaveProperty('handleSortChange')
// Pagination state
expect(result.current).toHaveProperty('currPage')
expect(result.current).toHaveProperty('limit')
expect(result.current).toHaveProperty('handlePageChange')
expect(result.current).toHaveProperty('handleLimitChange')
// Selection state
expect(result.current).toHaveProperty('selectedIds')
expect(result.current).toHaveProperty('setSelectedIds')
// Polling state
expect(result.current).toHaveProperty('timerCanRun')
expect(result.current).toHaveProperty('updatePollingState')
expect(result.current).toHaveProperty('adjustPageForTotal')
})
it('should expose function handlers', () => {
it('should have function types for all handlers', () => {
const { result } = renderHook(() => useDocumentsPageState())
expect(typeof result.current.handleInputChange).toBe('function')
@@ -258,6 +704,8 @@ describe('useDocumentsPageState', () => {
expect(typeof result.current.handlePageChange).toBe('function')
expect(typeof result.current.handleLimitChange).toBe('function')
expect(typeof result.current.setSelectedIds).toBe('function')
expect(typeof result.current.updatePollingState).toBe('function')
expect(typeof result.current.adjustPageForTotal).toBe('function')
})
})
})

View File

@@ -1,6 +1,6 @@
import type { inferParserType } from 'nuqs'
import type { ReadonlyURLSearchParams } from 'next/navigation'
import type { SortType } from '@/service/datasets'
import { createParser, parseAsString, throttle, useQueryStates } from 'nuqs'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import { useCallback, useMemo } from 'react'
import { sanitizeStatusValue } from '../status-filter'
@@ -13,87 +13,99 @@ const sanitizeSortValue = (value?: string | null): SortType => {
return (ALLOWED_SORT_VALUES.includes(value as SortType) ? value : '-created_at') as SortType
}
const sanitizePageValue = (value: number): number => {
return Number.isInteger(value) && value > 0 ? value : 1
export type DocumentListQuery = {
page: number
limit: number
keyword: string
status: string
sort: SortType
}
const sanitizeLimitValue = (value: number): number => {
return Number.isInteger(value) && value > 0 && value <= 100 ? value : 10
const DEFAULT_QUERY: DocumentListQuery = {
page: 1,
limit: 10,
keyword: '',
status: 'all',
sort: '-created_at',
}
const parseAsPage = createParser<number>({
parse: (value) => {
const n = Number.parseInt(value, 10)
return Number.isNaN(n) || n <= 0 ? null : n
},
serialize: value => value.toString(),
}).withDefault(1)
// Parse the query parameters from the URL search string.
function parseParams(params: ReadonlyURLSearchParams): DocumentListQuery {
const page = Number.parseInt(params.get('page') || '1', 10)
const limit = Number.parseInt(params.get('limit') || '10', 10)
const keyword = params.get('keyword') || ''
const status = sanitizeStatusValue(params.get('status'))
const sort = sanitizeSortValue(params.get('sort'))
const parseAsLimit = createParser<number>({
parse: (value) => {
const n = Number.parseInt(value, 10)
return Number.isNaN(n) || n <= 0 || n > 100 ? null : n
},
serialize: value => value.toString(),
}).withDefault(10)
const parseAsDocStatus = createParser<string>({
parse: value => sanitizeStatusValue(value),
serialize: value => value,
}).withDefault('all')
const parseAsDocSort = createParser<SortType>({
parse: value => sanitizeSortValue(value),
serialize: value => value,
}).withDefault('-created_at' as SortType)
const parseAsKeyword = parseAsString.withDefault('')
export const documentListParsers = {
page: parseAsPage,
limit: parseAsLimit,
keyword: parseAsKeyword,
status: parseAsDocStatus,
sort: parseAsDocSort,
return {
page: page > 0 ? page : 1,
limit: (limit > 0 && limit <= 100) ? limit : 10,
keyword: keyword ? decodeURIComponent(keyword) : '',
status,
sort,
}
}
export type DocumentListQuery = inferParserType<typeof documentListParsers>
// Update the URL search string with the given query parameters.
function updateSearchParams(query: DocumentListQuery, searchParams: URLSearchParams) {
const { page, limit, keyword, status, sort } = query || {}
// Search input updates can be frequent; throttle URL writes to reduce history/api churn.
const KEYWORD_URL_UPDATE_THROTTLE = throttle(300)
const hasNonDefaultParams = (page && page > 1) || (limit && limit !== 10) || (keyword && keyword.trim())
export function useDocumentListQueryState() {
const [query, setQuery] = useQueryStates(documentListParsers)
if (hasNonDefaultParams) {
searchParams.set('page', (page || 1).toString())
searchParams.set('limit', (limit || 10).toString())
}
else {
searchParams.delete('page')
searchParams.delete('limit')
}
if (keyword && keyword.trim())
searchParams.set('keyword', encodeURIComponent(keyword))
else
searchParams.delete('keyword')
const sanitizedStatus = sanitizeStatusValue(status)
if (sanitizedStatus && sanitizedStatus !== 'all')
searchParams.set('status', sanitizedStatus)
else
searchParams.delete('status')
const sanitizedSort = sanitizeSortValue(sort)
if (sanitizedSort !== '-created_at')
searchParams.set('sort', sanitizedSort)
else
searchParams.delete('sort')
}
function useDocumentListQueryState() {
const searchParams = useSearchParams()
const query = useMemo(() => parseParams(searchParams), [searchParams])
const router = useRouter()
const pathname = usePathname()
// Helper function to update specific query parameters
const updateQuery = useCallback((updates: Partial<DocumentListQuery>) => {
const patch = { ...updates }
if ('page' in patch && patch.page !== undefined)
patch.page = sanitizePageValue(patch.page)
if ('limit' in patch && patch.limit !== undefined)
patch.limit = sanitizeLimitValue(patch.limit)
if ('status' in patch)
patch.status = sanitizeStatusValue(patch.status)
if ('sort' in patch)
patch.sort = sanitizeSortValue(patch.sort)
if ('keyword' in patch && typeof patch.keyword === 'string' && patch.keyword.trim() === '')
patch.keyword = ''
// If keyword is part of this patch (even with page reset), treat it as a search update:
// use replace to avoid creating a history entry per input-driven change.
if ('keyword' in patch) {
setQuery(patch, {
history: 'replace',
limitUrlUpdates: patch.keyword === '' ? undefined : KEYWORD_URL_UPDATE_THROTTLE,
})
return
}
setQuery(patch, { history: 'push' })
}, [setQuery])
const newQuery = { ...query, ...updates }
newQuery.status = sanitizeStatusValue(newQuery.status)
newQuery.sort = sanitizeSortValue(newQuery.sort)
const params = new URLSearchParams()
updateSearchParams(newQuery, params)
const search = params.toString()
const queryString = search ? `?${search}` : ''
router.push(`${pathname}${queryString}`, { scroll: false })
}, [query, router, pathname])
// Helper function to reset query to defaults
const resetQuery = useCallback(() => {
setQuery(null, { history: 'replace' })
}, [setQuery])
const params = new URLSearchParams()
updateSearchParams(DEFAULT_QUERY, params)
const search = params.toString()
const queryString = search ? `?${search}` : ''
router.push(`${pathname}${queryString}`, { scroll: false })
}, [router, pathname])
return useMemo(() => ({
query,
@@ -101,3 +113,5 @@ export function useDocumentListQueryState() {
resetQuery,
}), [query, updateQuery, resetQuery])
}
export default useDocumentListQueryState

View File

@@ -1,63 +1,175 @@
import type { DocumentListResponse } from '@/models/datasets'
import type { SortType } from '@/service/datasets'
import { useDebounce } from 'ahooks'
import { useCallback, useState } from 'react'
import { useDebounce, useDebounceFn } from 'ahooks'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { normalizeStatusForQuery, sanitizeStatusValue } from '../status-filter'
import { useDocumentListQueryState } from './use-document-list-query-state'
import useDocumentListQueryState from './use-document-list-query-state'
/**
* Custom hook to manage documents page state including:
* - Search state (input value, debounced search value)
* - Filter state (status filter, sort value)
* - Pagination state (current page, limit)
* - Selection state (selected document ids)
* - Polling state (timer control for auto-refresh)
*/
export function useDocumentsPageState() {
const { query, updateQuery } = useDocumentListQueryState()
const inputValue = query.keyword
const debouncedSearchValue = useDebounce(query.keyword, { wait: 500 })
// Search state
const [inputValue, setInputValue] = useState<string>('')
const [searchValue, setSearchValue] = useState<string>('')
const debouncedSearchValue = useDebounce(searchValue, { wait: 500 })
const statusFilterValue = sanitizeStatusValue(query.status)
const sortValue = query.sort
const normalizedStatusFilterValue = normalizeStatusForQuery(statusFilterValue)
// Filter & sort state
const [statusFilterValue, setStatusFilterValue] = useState<string>(() => sanitizeStatusValue(query.status))
const [sortValue, setSortValue] = useState<SortType>(query.sort)
const normalizedStatusFilterValue = useMemo(
() => normalizeStatusForQuery(statusFilterValue),
[statusFilterValue],
)
const currPage = query.page - 1
const limit = query.limit
// Pagination state
const [currPage, setCurrPage] = useState<number>(query.page - 1)
const [limit, setLimit] = useState<number>(query.limit)
// Selection state
const [selectedIds, setSelectedIds] = useState<string[]>([])
// Polling state
const [timerCanRun, setTimerCanRun] = useState(true)
// Initialize search value from URL on mount
useEffect(() => {
if (query.keyword) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
}, []) // Only run on mount
// Sync local state with URL query changes
useEffect(() => {
setCurrPage(query.page - 1)
setLimit(query.limit)
if (query.keyword !== searchValue) {
setInputValue(query.keyword)
setSearchValue(query.keyword)
}
setStatusFilterValue((prev) => {
const nextValue = sanitizeStatusValue(query.status)
return prev === nextValue ? prev : nextValue
})
setSortValue(query.sort)
}, [query])
// Update URL when search changes
useEffect(() => {
if (debouncedSearchValue !== query.keyword) {
setCurrPage(0)
updateQuery({ keyword: debouncedSearchValue, page: 1 })
}
}, [debouncedSearchValue, query.keyword, updateQuery])
// Clear selection when search changes
useEffect(() => {
if (searchValue !== query.keyword)
setSelectedIds([])
}, [searchValue, query.keyword])
// Clear selection when status filter changes
useEffect(() => {
setSelectedIds([])
}, [normalizedStatusFilterValue])
// Page change handler
const handlePageChange = useCallback((newPage: number) => {
setCurrPage(newPage)
updateQuery({ page: newPage + 1 })
}, [updateQuery])
// Limit change handler
const handleLimitChange = useCallback((newLimit: number) => {
setLimit(newLimit)
setCurrPage(0)
updateQuery({ limit: newLimit, page: 1 })
}, [updateQuery])
const handleInputChange = useCallback((value: string) => {
if (value !== query.keyword)
setSelectedIds([])
updateQuery({ keyword: value, page: 1 })
}, [query.keyword, updateQuery])
// Debounced search handler
const { run: handleSearch } = useDebounceFn(() => {
setSearchValue(inputValue)
}, { wait: 500 })
// Input change handler
const handleInputChange = useCallback((value: string) => {
setInputValue(value)
handleSearch()
}, [handleSearch])
// Status filter change handler
const handleStatusFilterChange = useCallback((value: string) => {
const selectedValue = sanitizeStatusValue(value)
setSelectedIds([])
setStatusFilterValue(selectedValue)
setCurrPage(0)
updateQuery({ status: selectedValue, page: 1 })
}, [updateQuery])
// Status filter clear handler
const handleStatusFilterClear = useCallback(() => {
if (statusFilterValue === 'all')
return
setSelectedIds([])
setStatusFilterValue('all')
setCurrPage(0)
updateQuery({ status: 'all', page: 1 })
}, [statusFilterValue, updateQuery])
// Sort change handler
const handleSortChange = useCallback((value: string) => {
const next = value as SortType
if (next === sortValue)
return
setSortValue(next)
setCurrPage(0)
updateQuery({ sort: next, page: 1 })
}, [sortValue, updateQuery])
// Update polling state based on documents response
const updatePollingState = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes?.data)
return
let completedNum = 0
documentsRes.data.forEach((documentItem) => {
const { indexing_status } = documentItem
const isEmbedded = indexing_status === 'completed' || indexing_status === 'paused' || indexing_status === 'error'
if (isEmbedded)
completedNum++
})
const hasIncompleteDocuments = completedNum !== documentsRes.data.length
const transientStatuses = ['queuing', 'indexing', 'paused']
const shouldForcePolling = normalizedStatusFilterValue === 'all'
? false
: transientStatuses.includes(normalizedStatusFilterValue)
setTimerCanRun(shouldForcePolling || hasIncompleteDocuments)
}, [normalizedStatusFilterValue])
// Adjust page when total pages change
const adjustPageForTotal = useCallback((documentsRes: DocumentListResponse | undefined) => {
if (!documentsRes)
return
const totalPages = Math.ceil(documentsRes.total / limit)
if (currPage > 0 && currPage + 1 > totalPages)
handlePageChange(totalPages > 0 ? totalPages - 1 : 0)
}, [limit, currPage, handlePageChange])
return {
// Search state
inputValue,
searchValue,
debouncedSearchValue,
handleInputChange,
// Filter & sort state
statusFilterValue,
sortValue,
normalizedStatusFilterValue,
@@ -65,12 +177,21 @@ export function useDocumentsPageState() {
handleStatusFilterClear,
handleSortChange,
// Pagination state
currPage,
limit,
handlePageChange,
handleLimitChange,
// Selection state
selectedIds,
setSelectedIds,
// Polling state
timerCanRun,
updatePollingState,
adjustPageForTotal,
}
}
export default useDocumentsPageState

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import { useRouter } from 'next/navigation'
import { useCallback } from 'react'
import { useCallback, useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useProviderContext } from '@/context/provider-context'
@@ -13,16 +13,12 @@ import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata
import DocumentsHeader from './components/documents-header'
import EmptyElement from './components/empty-element'
import List from './components/list'
import { useDocumentsPageState } from './hooks/use-documents-page-state'
import useDocumentsPageState from './hooks/use-documents-page-state'
type IDocumentsProps = {
datasetId: string
}
const POLLING_INTERVAL = 2500
const TERMINAL_INDEXING_STATUSES = new Set(['completed', 'paused', 'error'])
const FORCED_POLLING_STATUSES = new Set(['queuing', 'indexing', 'paused'])
const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
const router = useRouter()
const { plan } = useProviderContext()
@@ -48,6 +44,9 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
handleLimitChange,
selectedIds,
setSelectedIds,
timerCanRun,
updatePollingState,
adjustPageForTotal,
} = useDocumentsPageState()
// Fetch document list
@@ -60,17 +59,19 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
status: normalizedStatusFilterValue,
sort: sortValue,
},
refetchInterval: (query) => {
const shouldForcePolling = normalizedStatusFilterValue !== 'all'
&& FORCED_POLLING_STATUSES.has(normalizedStatusFilterValue)
const documents = query.state.data?.data
if (!documents)
return POLLING_INTERVAL
const hasIncompleteDocuments = documents.some(({ indexing_status }) => !TERMINAL_INDEXING_STATUSES.has(indexing_status))
return shouldForcePolling || hasIncompleteDocuments ? POLLING_INTERVAL : false
},
refetchInterval: timerCanRun ? 2500 : 0,
})
// Update polling state when documents change
useEffect(() => {
updatePollingState(documentsRes)
}, [documentsRes, updatePollingState])
// Adjust page when total changes
useEffect(() => {
adjustPageForTotal(documentsRes)
}, [documentsRes, adjustPageForTotal])
// Invalidation hooks
const invalidDocumentList = useInvalidDocumentList(datasetId)
const invalidDocumentDetail = useInvalidDocumentDetail()
@@ -118,7 +119,7 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
// Render content based on loading and data state
const renderContent = () => {
if (isListLoading && !documentsRes)
if (isListLoading)
return <Loading type="app" />
if (total > 0) {
@@ -130,8 +131,8 @@ const Documents: FC<IDocumentsProps> = ({ datasetId }) => {
onUpdate={handleUpdate}
selectedIds={selectedIds}
onSelectedIdChange={setSelectedIds}
statusFilterValue={normalizedStatusFilterValue}
remoteSortValue={sortValue}
onSortChange={handleSortChange}
pagination={{
total,
limit,

View File

@@ -1,12 +1,12 @@
import type { Mock } from 'vitest'
import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal'
import type { App } from '@/models/explore'
import { act, fireEvent, screen, waitFor } from '@testing-library/react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { fetchAppDetail } from '@/service/explore'
import { useMembers } from '@/service/use-common'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { AppModeEnum } from '@/types/app'
import AppList from '../index'
@@ -132,9 +132,10 @@ const mockMemberRole = (hasEditPermission: boolean) => {
const renderAppList = (hasEditPermission = false, onSuccess?: () => void, searchParams?: Record<string, string>) => {
mockMemberRole(hasEditPermission)
return renderWithNuqs(
<AppList onSuccess={onSuccess} />,
{ searchParams },
return render(
<NuqsTestingAdapter searchParams={searchParams}>
<AppList onSuccess={onSuccess} />
</NuqsTestingAdapter>,
)
}

View File

@@ -1,10 +1,9 @@
import type { AppRouterInstance } from 'next/dist/shared/lib/app-router-context.shared-runtime'
import type { AppContextValue } from '@/context/app-context'
import type { ModalContextState } from '@/context/modal-context'
import type { ProviderContextState } from '@/context/provider-context'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { AppRouterContext } from 'next/dist/shared/lib/app-router-context.shared-runtime'
import { useRouter } from 'next/navigation'
import { Plan } from '@/app/components/billing/type'
import { useAppContext } from '@/context/app-context'
import { useGlobalPublicStore } from '@/context/global-public-context'
@@ -50,6 +49,14 @@ vi.mock('@/service/use-common', () => ({
useLogout: vi.fn(),
}))
vi.mock('next/navigation', async (importOriginal) => {
const actual = await importOriginal<typeof import('next/navigation')>()
return {
...actual,
useRouter: vi.fn(),
}
})
vi.mock('@/context/i18n', () => ({
useDocLink: () => (path: string) => `https://docs.dify.ai${path}`,
}))
@@ -119,15 +126,6 @@ describe('AccountDropdown', () => {
const mockSetShowAccountSettingModal = vi.fn()
const renderWithRouter = (ui: React.ReactElement) => {
const mockRouter = {
push: mockPush,
replace: vi.fn(),
prefetch: vi.fn(),
back: vi.fn(),
forward: vi.fn(),
refresh: vi.fn(),
} as unknown as AppRouterInstance
const queryClient = new QueryClient({
defaultOptions: {
queries: {
@@ -138,9 +136,7 @@ describe('AccountDropdown', () => {
return render(
<QueryClientProvider client={queryClient}>
<AppRouterContext.Provider value={mockRouter}>
{ui}
</AppRouterContext.Provider>
{ui}
</QueryClientProvider>,
)
}
@@ -166,6 +162,14 @@ describe('AccountDropdown', () => {
vi.mocked(useLogout).mockReturnValue({
mutateAsync: mockLogout,
} as unknown as ReturnType<typeof useLogout>)
vi.mocked(useRouter).mockReturnValue({
push: mockPush,
replace: vi.fn(),
prefetch: vi.fn(),
back: vi.fn(),
forward: vi.fn(),
refresh: vi.fn(),
})
})
afterEach(() => {

View File

@@ -1,20 +1,21 @@
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import type { ReactNode } from 'react'
import { act, renderHook } from '@testing-library/react'
import { Provider as JotaiProvider } from 'jotai'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
import { DEFAULT_SORT } from '../constants'
const createWrapper = (searchParams = '') => {
const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams })
const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>()
const wrapper = ({ children }: { children: ReactNode }) => (
<JotaiProvider>
<NuqsWrapper>
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={onUrlUpdate}>
{children}
</NuqsWrapper>
</NuqsTestingAdapter>
</JotaiProvider>
)
return { wrapper }
return { wrapper, onUrlUpdate }
}
describe('Marketplace sort atoms', () => {

View File

@@ -1,8 +1,9 @@
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import type { ReactNode } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import { Provider as JotaiProvider } from 'jotai'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
import PluginTypeSwitch from '../plugin-type-switch'
vi.mock('#i18n', () => ({
@@ -24,15 +25,15 @@ vi.mock('#i18n', () => ({
}))
const createWrapper = (searchParams = '') => {
const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams })
const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>()
const Wrapper = ({ children }: { children: ReactNode }) => (
<JotaiProvider>
<NuqsWrapper>
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={onUrlUpdate}>
{children}
</NuqsWrapper>
</NuqsTestingAdapter>
</JotaiProvider>
)
return { Wrapper }
return { Wrapper, onUrlUpdate }
}
describe('PluginTypeSwitch', () => {

View File

@@ -2,8 +2,8 @@ import type { ReactNode } from 'react'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { renderHook, waitFor } from '@testing-library/react'
import { Provider as JotaiProvider } from 'jotai'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
vi.mock('@/config', () => ({
API_PREFIX: '/api',
@@ -37,7 +37,6 @@ vi.mock('@/service/client', () => ({
}))
const createWrapper = (searchParams = '') => {
const { wrapper: NuqsWrapper } = createNuqsTestWrapper({ searchParams })
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false, gcTime: 0 },
@@ -46,9 +45,9 @@ const createWrapper = (searchParams = '') => {
const Wrapper = ({ children }: { children: ReactNode }) => (
<JotaiProvider>
<QueryClientProvider client={queryClient}>
<NuqsWrapper>
<NuqsTestingAdapter searchParams={searchParams}>
{children}
</NuqsWrapper>
</NuqsTestingAdapter>
</QueryClientProvider>
</JotaiProvider>
)

View File

@@ -1,8 +1,8 @@
import type { ReactNode } from 'react'
import { render } from '@testing-library/react'
import { Provider as JotaiProvider } from 'jotai'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
import StickySearchAndSwitchWrapper from '../sticky-search-and-switch-wrapper'
vi.mock('#i18n', () => ({
@@ -20,12 +20,11 @@ vi.mock('../search-box/search-box-wrapper', () => ({
default: () => <div data-testid="search-box-wrapper">SearchBoxWrapper</div>,
}))
const { wrapper: NuqsWrapper } = createNuqsTestWrapper()
const Wrapper = ({ children }: { children: ReactNode }) => (
<JotaiProvider>
<NuqsWrapper>
<NuqsTestingAdapter>
{children}
</NuqsWrapper>
</NuqsTestingAdapter>
</JotaiProvider>
)

View File

@@ -1,5 +1,4 @@
import type { SearchParams } from 'nuqs/server'
import type { MarketplaceSearchParams } from './search-params'
import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
@@ -15,7 +14,7 @@ async function getDehydratedState(searchParams?: Promise<SearchParams>) {
return
}
const loadSearchParams = createLoader(marketplaceSearchParamsParsers)
const params: MarketplaceSearchParams = await loadSearchParams(searchParams)
const params = await loadSearchParams(searchParams)
if (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(params.category)) {
return

View File

@@ -1,4 +1,3 @@
import type { inferParserType } from 'nuqs/server'
import type { ActivePluginType } from './constants'
import { parseAsArrayOf, parseAsString, parseAsStringEnum } from 'nuqs/server'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
@@ -8,5 +7,3 @@ export const marketplaceSearchParamsParsers = {
q: parseAsString.withDefault('').withOptions({ history: 'replace' }),
tags: parseAsArrayOf(parseAsString).withDefault([]).withOptions({ history: 'replace' }),
}
export type MarketplaceSearchParams = inferParserType<typeof marketplaceSearchParamsParsers>

View File

@@ -7,9 +7,6 @@ import { PluginPageContext, PluginPageContextProvider, usePluginPageContext } fr
// Mock dependencies
vi.mock('nuqs', () => ({
parseAsStringEnum: vi.fn(() => ({
withDefault: vi.fn(() => ({})),
})),
useQueryState: vi.fn(() => ['plugins', vi.fn()]),
}))

View File

@@ -80,9 +80,6 @@ vi.mock('@/service/use-plugins', () => ({
}))
vi.mock('nuqs', () => ({
parseAsStringEnum: vi.fn(() => ({
withDefault: vi.fn(() => ({})),
})),
useQueryState: vi.fn(() => ['plugins', vi.fn()]),
}))

View File

@@ -3,7 +3,7 @@
import type { ReactNode, RefObject } from 'react'
import type { FilterState } from './filter-management'
import { noop } from 'es-toolkit/function'
import { parseAsStringEnum, useQueryState } from 'nuqs'
import { useQueryState } from 'nuqs'
import {
useMemo,
useRef,
@@ -15,19 +15,6 @@ import {
} from 'use-context-selector'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { PLUGIN_PAGE_TABS_MAP, usePluginPageTabs } from '../hooks'
import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants'
export type PluginPageTab = typeof PLUGIN_PAGE_TABS_MAP[keyof typeof PLUGIN_PAGE_TABS_MAP]
| (typeof PLUGIN_TYPE_SEARCH_MAP)[keyof typeof PLUGIN_TYPE_SEARCH_MAP]
const PLUGIN_PAGE_TAB_VALUES: PluginPageTab[] = [
PLUGIN_PAGE_TABS_MAP.plugins,
PLUGIN_PAGE_TABS_MAP.marketplace,
...Object.values(PLUGIN_TYPE_SEARCH_MAP),
]
const parseAsPluginPageTab = parseAsStringEnum<PluginPageTab>(PLUGIN_PAGE_TAB_VALUES)
.withDefault(PLUGIN_PAGE_TABS_MAP.plugins)
export type PluginPageContextValue = {
containerRef: RefObject<HTMLDivElement | null>
@@ -35,8 +22,8 @@ export type PluginPageContextValue = {
setCurrentPluginID: (pluginID?: string) => void
filters: FilterState
setFilters: (filter: FilterState) => void
activeTab: PluginPageTab
setActiveTab: (tab: PluginPageTab) => void
activeTab: string
setActiveTab: (tab: string) => void
options: Array<{ value: string, text: string }>
}
@@ -52,7 +39,7 @@ export const PluginPageContext = createContext<PluginPageContextValue>({
searchQuery: '',
},
setFilters: noop,
activeTab: PLUGIN_PAGE_TABS_MAP.plugins,
activeTab: '',
setActiveTab: noop,
options: [],
})
@@ -81,7 +68,9 @@ export const PluginPageContextProvider = ({
const options = useMemo(() => {
return enable_marketplace ? tabs : tabs.filter(tab => tab.value !== PLUGIN_PAGE_TABS_MAP.marketplace)
}, [tabs, enable_marketplace])
const [activeTab, setActiveTab] = useQueryState('tab', parseAsPluginPageTab)
const [activeTab, setActiveTab] = useQueryState('tab', {
defaultValue: options[0].value,
})
return (
<PluginPageContext.Provider

View File

@@ -1,7 +1,6 @@
'use client'
import type { Dependency, PluginDeclaration, PluginManifestInMarket } from '../types'
import type { PluginPageTab } from './context'
import {
RiBookOpenLine,
RiDragDropLine,
@@ -38,16 +37,6 @@ import PluginTasks from './plugin-tasks'
import useReferenceSetting from './use-reference-setting'
import { useUploader } from './use-uploader'
const pluginPageTabSet = new Set<string>([
PLUGIN_PAGE_TABS_MAP.plugins,
PLUGIN_PAGE_TABS_MAP.marketplace,
...Object.values(PLUGIN_TYPE_SEARCH_MAP),
])
const isPluginPageTab = (value: string): value is PluginPageTab => {
return pluginPageTabSet.has(value)
}
export type PluginPageProps = {
plugins: React.ReactNode
marketplace: React.ReactNode
@@ -165,10 +154,7 @@ const PluginPage = ({
<div className="flex-1">
<TabSlider
value={isPluginsTab ? PLUGIN_PAGE_TABS_MAP.plugins : PLUGIN_PAGE_TABS_MAP.marketplace}
onChange={(nextTab) => {
if (isPluginPageTab(nextTab))
setActiveTab(nextTab)
}}
onChange={setActiveTab}
options={options}
/>
</div>

View File

@@ -1,6 +1,6 @@
import { cleanup, fireEvent, screen } from '@testing-library/react'
import { cleanup, fireEvent, render, screen } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { renderWithNuqs } from '@/test/nuqs-testing'
import { ToolTypeEnum } from '../../workflow/block-selector/types'
import ProviderList from '../provider-list'
import { getToolType } from '../utils'
@@ -206,9 +206,10 @@ describe('getToolType', () => {
})
const renderProviderList = (searchParams?: Record<string, string>) => {
return renderWithNuqs(
<ProviderList />,
{ searchParams },
return render(
<NuqsTestingAdapter searchParams={searchParams}>
<ProviderList />
</NuqsTestingAdapter>,
)
}

View File

@@ -1,6 +1,6 @@
'use client'
import type { Collection } from './types'
import { parseAsStringLiteral, useQueryState } from 'nuqs'
import { useQueryState } from 'nuqs'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
@@ -23,17 +23,6 @@ import { useMarketplace } from './marketplace/hooks'
import MCPList from './mcp'
import { getToolType } from './utils'
const TOOL_PROVIDER_CATEGORY_VALUES = ['builtin', 'api', 'workflow', 'mcp'] as const
type ToolProviderCategory = typeof TOOL_PROVIDER_CATEGORY_VALUES[number]
const toolProviderCategorySet = new Set<string>(TOOL_PROVIDER_CATEGORY_VALUES)
const isToolProviderCategory = (value: string): value is ToolProviderCategory => {
return toolProviderCategorySet.has(value)
}
const parseAsToolProviderCategory = parseAsStringLiteral(TOOL_PROVIDER_CATEGORY_VALUES)
.withDefault('builtin')
const ProviderList = () => {
// const searchParams = useSearchParams()
// searchParams.get('category') === 'workflow'
@@ -42,7 +31,9 @@ const ProviderList = () => {
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
const containerRef = useRef<HTMLDivElement>(null)
const [activeTab, setActiveTab] = useQueryState('category', parseAsToolProviderCategory)
const [activeTab, setActiveTab] = useQueryState('category', {
defaultValue: 'builtin',
})
const options = [
{ value: 'builtin', text: t('type.builtIn', { ns: 'tools' }) },
{ value: 'api', text: t('type.custom', { ns: 'tools' }) },
@@ -133,8 +124,6 @@ const ProviderList = () => {
<TabSliderNew
value={activeTab}
onChange={(state) => {
if (!isToolProviderCategory(state))
return
setActiveTab(state)
if (state !== activeTab)
setCurrentProviderId(undefined)

View File

@@ -1,9 +1,9 @@
import { act, screen, waitFor } from '@testing-library/react'
import { act, render, screen, waitFor } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import * as React from 'react'
import { defaultPlan } from '@/app/components/billing/config'
import { Plan } from '@/app/components/billing/type'
import { ModalContextProvider } from '@/context/modal-context'
import { renderWithNuqs } from '@/test/nuqs-testing'
vi.mock('@/config', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/config')>()
@@ -71,10 +71,12 @@ const createPlan = (overrides: PlanOverrides = {}): PlanShape => ({
},
})
const renderProvider = () => renderWithNuqs(
<ModalContextProvider>
<div data-testid="modal-context-test-child" />
</ModalContextProvider>,
const renderProvider = () => render(
<NuqsTestingAdapter>
<ModalContextProvider>
<div data-testid="modal-context-test-child" />
</ModalContextProvider>
</NuqsTestingAdapter>,
)
describe('ModalContextProvider trigger events limit modal', () => {

View File

@@ -158,7 +158,7 @@ export const ModalContextProvider = ({
}: ModalContextProviderProps) => {
// Use nuqs hooks for URL-based modal state management
const [showPricingModal, setPricingModalOpen] = usePricingModal()
const [urlAccountModalState, setUrlAccountModalState] = useAccountSettingModal()
const [urlAccountModalState, setUrlAccountModalState] = useAccountSettingModal<AccountSettingTab>()
const accountSettingCallbacksRef = useRef<Omit<ModalState<AccountSettingTab>, 'payload'> | null>(null)
const accountSettingTab = urlAccountModalState.isOpen

View File

@@ -225,38 +225,6 @@ Simulate the interactions that matter to users—primary clicks, change events,
Mock the specific Next.js navigation hooks your component consumes (`useRouter`, `usePathname`, `useSearchParams`) and drive realistic routing flows—query parameters, redirects, guarded routes, URL updates—while asserting the rendered outcome or navigation side effects.
#### 7.1 `nuqs` Query State Testing
When testing code that uses `useQueryState` or `useQueryStates`, treat `nuqs` as the source of truth for URL synchronization.
- ✅ In runtime, keep `NuqsAdapter` in app layout (already wired in `app/layout.tsx`).
- ✅ In tests, wrap with `NuqsTestingAdapter` (prefer helper utilities from `@/test/nuqs-testing`).
- ✅ Assert URL behavior via `onUrlUpdate` events (`searchParams`, `options.history`) instead of only asserting router mocks.
- ✅ For custom parsers created with `createParser`, keep `parse` and `serialize` bijective (round-trip safe). Add edge-case coverage for values like `%2F`, `%25`, spaces, and legacy encoded URLs.
- ✅ Assert default-clearing behavior explicitly (`clearOnDefault` semantics remove params when value equals default).
- ⚠️ Only mock `nuqs` directly when URL behavior is intentionally out of scope for the test. For ESM-safe partial mocks, use async `vi.mock` with `importOriginal`.
Example:
```tsx
import { renderHookWithNuqs } from '@/test/nuqs-testing'
it('should update query with push history', async () => {
const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), {
searchParams: '?page=1',
})
act(() => {
result.current.setQuery({ page: 2 })
})
await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled())
const update = onUrlUpdate.mock.calls.at(-1)![0]
expect(update.options.history).toBe('push')
expect(update.searchParams.get('page')).toBe('2')
})
```
### 8. Edge Cases (REQUIRED - All Components)
**Must Test**:

View File

@@ -3113,6 +3113,11 @@
"count": 1
}
},
"app/components/datasets/documents/components/list.tsx": {
"react-refresh/only-export-components": {
"count": 1
}
},
"app/components/datasets/documents/create-from-pipeline/actions/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2
@@ -3482,6 +3487,16 @@
"count": 3
}
},
"app/components/datasets/documents/hooks/use-documents-page-state.ts": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 12
}
},
"app/components/datasets/documents/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 2
}
},
"app/components/datasets/external-api/external-api-modal/Form.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2

View File

@@ -1,6 +1,8 @@
import { act, waitFor } from '@testing-library/react'
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import type { ReactNode } from 'react'
import { act, renderHook, waitFor } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { ACCOUNT_SETTING_MODAL_ACTION } from '@/app/components/header/account-setting/constants'
import { renderHookWithNuqs } from '@/test/nuqs-testing'
import {
clearQueryParams,
PRICING_MODAL_QUERY_PARAM,
@@ -18,7 +20,14 @@ vi.mock('@/utils/client', () => ({
}))
const renderWithAdapter = <T,>(hook: () => T, searchParams = '') => {
return renderHookWithNuqs(hook, { searchParams })
const onUrlUpdate = vi.fn<(event: UrlUpdateEvent) => void>()
const wrapper = ({ children }: { children: ReactNode }) => (
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={onUrlUpdate}>
{children}
</NuqsTestingAdapter>
)
const { result } = renderHook(hook, { wrapper })
return { result, onUrlUpdate }
}
// Query param hooks: defaults, parsing, and URL sync behavior.

View File

@@ -13,19 +13,14 @@
* - Use shallow routing to avoid unnecessary re-renders
*/
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import {
createParser,
parseAsStringEnum,
parseAsStringLiteral,
parseAsString,
useQueryState,
useQueryStates,
} from 'nuqs'
import { useCallback } from 'react'
import {
ACCOUNT_SETTING_MODAL_ACTION,
ACCOUNT_SETTING_TAB,
} from '@/app/components/header/account-setting/constants'
import { ACCOUNT_SETTING_MODAL_ACTION } from '@/app/components/header/account-setting/constants'
import { isServer } from '@/utils/client'
/**
@@ -57,10 +52,6 @@ export function usePricingModal() {
)
}
const accountSettingTabValues = Object.values(ACCOUNT_SETTING_TAB) as AccountSettingTab[]
const parseAsAccountSettingAction = parseAsStringLiteral([ACCOUNT_SETTING_MODAL_ACTION] as const)
const parseAsAccountSettingTab = parseAsStringEnum<AccountSettingTab>(accountSettingTabValues)
/**
* Hook to manage account setting modal state via URL
* @returns [state, setState] - Object with isOpen + payload (tab) and setter
@@ -70,11 +61,11 @@ const parseAsAccountSettingTab = parseAsStringEnum<AccountSettingTab>(accountSet
* setAccountModalState({ payload: 'billing' }) // Sets ?action=showSettings&tab=billing
* setAccountModalState(null) // Removes both params
*/
export function useAccountSettingModal() {
export function useAccountSettingModal<T extends string = string>() {
const [accountState, setAccountState] = useQueryStates(
{
action: parseAsAccountSettingAction,
tab: parseAsAccountSettingTab,
action: parseAsString,
tab: parseAsString,
},
{
history: 'replace',
@@ -82,7 +73,7 @@ export function useAccountSettingModal() {
)
const setState = useCallback(
(state: { payload: AccountSettingTab } | null) => {
(state: { payload: T } | null) => {
if (!state) {
setAccountState({ action: null, tab: null }, { history: 'replace' })
return
@@ -97,7 +88,7 @@ export function useAccountSettingModal() {
)
const isOpen = accountState.action === ACCOUNT_SETTING_MODAL_ACTION
const currentTab = isOpen ? accountState.tab : null
const currentTab = (isOpen ? accountState.tab : null) as T | null
return [{ isOpen, payload: currentTab }, setState] as const
}

View File

@@ -28,9 +28,12 @@
"scripts": {
"dev": "next dev",
"dev:inspect": "next dev --inspect",
"dev:vinext": "vinext dev",
"build": "next build",
"build:docker": "next build && node scripts/optimize-standalone.js",
"build:vinext": "vinext build",
"start": "node ./scripts/copy-and-start.mjs",
"start:vinext": "vinext start",
"lint": "eslint --cache --concurrency=auto",
"lint:ci": "eslint --cache --concurrency 2",
"lint:fix": "pnpm lint --fix",
@@ -173,6 +176,7 @@
"@iconify-json/ri": "1.2.9",
"@mdx-js/loader": "3.1.1",
"@mdx-js/react": "3.1.1",
"@mdx-js/rollup": "3.1.1",
"@next/eslint-plugin-next": "16.1.6",
"@next/mdx": "16.1.5",
"@rgrove/parse-xml": "4.2.0",
@@ -211,6 +215,7 @@
"@typescript-eslint/parser": "8.54.0",
"@typescript/native-preview": "7.0.0-dev.20251209.1",
"@vitejs/plugin-react": "5.1.2",
"@vitejs/plugin-rsc": "0.5.20",
"@vitest/coverage-v8": "4.0.17",
"autoprefixer": "10.4.21",
"code-inspector-plugin": "1.3.6",
@@ -233,6 +238,7 @@
"postcss": "8.5.6",
"postcss-js": "5.0.3",
"react-scan": "0.4.3",
"react-server-dom-webpack": "19.2.4",
"sass": "1.93.2",
"serwist": "9.5.4",
"storybook": "10.2.0",
@@ -240,6 +246,7 @@
"tsx": "4.21.0",
"typescript": "5.9.3",
"uglify-js": "3.19.3",
"vinext": "https://pkg.pr.new/hyoban/vinext@f07e125",
"vite": "7.3.1",
"vite-tsconfig-paths": "6.0.4",
"vitest": "4.0.17",

462
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,7 @@
import type { UseQueryOptions } from '@tanstack/react-query'
import type { DocumentDownloadResponse, DocumentDownloadZipRequest, MetadataType, SortType } from '../datasets'
import type { CommonResponse } from '@/models/common'
import type { DocumentDetailResponse, DocumentListResponse, UpdateDocumentBatchParams } from '@/models/datasets'
import {
keepPreviousData,
useMutation,
useQuery,
} from '@tanstack/react-query'
@@ -16,8 +14,6 @@ import { useInvalid } from '../use-base'
const NAME_SPACE = 'knowledge/document'
export const useDocumentListKey = [NAME_SPACE, 'documentList']
type DocumentListRefetchInterval = UseQueryOptions<DocumentListResponse>['refetchInterval']
export const useDocumentList = (payload: {
datasetId: string
query: {
@@ -27,7 +23,7 @@ export const useDocumentList = (payload: {
sort?: SortType
status?: string
}
refetchInterval?: DocumentListRefetchInterval
refetchInterval?: number | false
}) => {
const { query, datasetId, refetchInterval } = payload
const { keyword, page, limit, sort, status } = query
@@ -46,7 +42,6 @@ export const useDocumentList = (payload: {
queryFn: () => get<DocumentListResponse>(`/datasets/${datasetId}/documents`, {
params,
}),
placeholderData: keepPreviousData,
refetchInterval,
})
}

View File

@@ -1,60 +0,0 @@
import type { UrlUpdateEvent } from 'nuqs/adapters/testing'
import type { ComponentProps, ReactElement, ReactNode } from 'react'
import type { Mock } from 'vitest'
import { render, renderHook } from '@testing-library/react'
import { NuqsTestingAdapter } from 'nuqs/adapters/testing'
import { vi } from 'vitest'
type NuqsSearchParams = ComponentProps<typeof NuqsTestingAdapter>['searchParams']
type NuqsOnUrlUpdate = (event: UrlUpdateEvent) => void
type NuqsOnUrlUpdateSpy = Mock<NuqsOnUrlUpdate>
type NuqsTestOptions = {
searchParams?: NuqsSearchParams
onUrlUpdate?: NuqsOnUrlUpdateSpy
}
type NuqsHookTestOptions<Props> = NuqsTestOptions & {
initialProps?: Props
}
type NuqsWrapperProps = {
children: ReactNode
}
export const createNuqsTestWrapper = (options: NuqsTestOptions = {}) => {
const { searchParams = '', onUrlUpdate } = options
const urlUpdateSpy = onUrlUpdate ?? vi.fn<NuqsOnUrlUpdate>()
const wrapper = ({ children }: NuqsWrapperProps) => (
<NuqsTestingAdapter searchParams={searchParams} onUrlUpdate={urlUpdateSpy}>
{children}
</NuqsTestingAdapter>
)
return {
wrapper,
onUrlUpdate: urlUpdateSpy,
}
}
export const renderWithNuqs = (ui: ReactElement, options: NuqsTestOptions = {}) => {
const { wrapper, onUrlUpdate } = createNuqsTestWrapper(options)
const rendered = render(ui, { wrapper })
return {
...rendered,
onUrlUpdate,
}
}
export const renderHookWithNuqs = <Result, Props = void>(
callback: (props: Props) => Result,
options: NuqsHookTestOptions<Props> = {},
) => {
const { initialProps, ...nuqsOptions } = options
const { wrapper, onUrlUpdate } = createNuqsTestWrapper(nuqsOptions)
const rendered = renderHook(callback, { wrapper, initialProps })
return {
...rendered,
onUrlUpdate,
}
}

View File

@@ -1,16 +1,58 @@
import path from 'node:path'
import { fileURLToPath } from 'node:url'
import type { Plugin } from 'vite'
import mdx from '@mdx-js/rollup'
import react from '@vitejs/plugin-react'
import { defineConfig } from 'vite'
import vinext from 'vinext'
import { defineConfig, loadEnv } from 'vite'
import tsconfigPaths from 'vite-tsconfig-paths'
const __dirname = path.dirname(fileURLToPath(import.meta.url))
const isCI = !!process.env.CI
export default defineConfig({
plugins: [tsconfigPaths(), react()],
resolve: {
alias: {
'~@': __dirname,
export default defineConfig(({ mode }) => {
const env = loadEnv(mode, process.cwd(), '')
return {
plugins: [
...(mode === 'test'
? [
react(),
{
// Stub .mdx files so components importing them can be unit-tested
name: 'mdx-stub',
enforce: 'pre',
transform(_, id) {
if (id.endsWith('.mdx'))
return { code: 'export default () => null', map: null }
},
} as Plugin,
]
: [
mdx(),
vinext(),
]),
tsconfigPaths(),
],
resolve: {
alias: {
'~@': __dirname,
},
},
},
optimizeDeps: {
exclude: ['nuqs'],
},
server: {
port: 3000,
},
envPrefix: 'NEXT_PUBLIC_',
test: {
environment: 'jsdom',
globals: true,
setupFiles: ['./vitest.setup.ts'],
coverage: {
provider: 'v8',
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
},
},
define: {
'process.env': env,
},
}
})

View File

@@ -1,27 +0,0 @@
import { defineConfig, mergeConfig } from 'vitest/config'
import viteConfig from './vite.config'
const isCI = !!process.env.CI
export default mergeConfig(viteConfig, defineConfig({
plugins: [
{
// Stub .mdx files so components importing them can be unit-tested
name: 'mdx-stub',
enforce: 'pre',
transform(_, id) {
if (id.endsWith('.mdx'))
return { code: 'export default () => null', map: null }
},
},
],
test: {
environment: 'jsdom',
globals: true,
setupFiles: ['./vitest.setup.ts'],
coverage: {
provider: 'v8',
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
},
},
}))