Compare commits

..

27 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
0e4f5cb38c refactor: move console_exempt_prefixes to module level in app_factory.py
Co-authored-by: GareArc <52963600+GareArc@users.noreply.github.com>
2026-03-09 07:28:50 +00:00
copilot-swe-agent[bot]
c13d1872d4 Initial plan 2026-03-09 07:27:12 +00:00
GareArc
c911de6a6c fix: exempt setup flow endpoints from license check
Add /console/api/init and /console/api/login to the license exempt
list so that fresh installs can complete setup when the enterprise
license is inactive. Without these exemptions the init password
validation and post-setup auto-login are blocked, causing the setup
page to enter an infinite reload loop.
2026-03-08 23:46:26 -07:00
Xiyuan Chen
968bf10e1c Update api/services/enterprise/enterprise_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 17:35:50 -07:00
Xiyuan Chen
3d77a5ec08 Update api/services/feature_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 17:07:17 -07:00
GareArc
41af72449d fix: address PR review feedback on enterprise license enforcement
- Cache invalid license statuses with 30s TTL to prevent DoS amplification
- Return LicenseStatus enum (not raw str) from get_cached_license_status
- Flatten nested try/except into _read_cached_license_status / _fetch_and_cache_license_status helpers
- Escalate log levels from debug to warning with exc_info for cache failures
- Switch before_request license check from fail-open to fail-closed
- Remove dead raise_for_status parameter from BaseRequest.send_request
- Gate license expired_at behind is_authenticated; only expose status to unauthenticated callers (CVE-2025-63387)
- Remove redundant 'not is_console_api' guard in before_request
- Add 8 unit tests for get_cached_license_status
2026-03-08 17:00:12 -07:00
Xiyuan Chen
de72bdef71 Merge branch 'main' into fix/main-enterprise-api-error-handling 2026-03-08 16:28:01 -07:00
CoralGarden52
c925d17e8f chore: add TypedDict related prompt to api/AGENTS.md (#33116)
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-08 07:03:52 +09:00
Angel
dc2a53d834 feat: add files to message end pr32019 (#32242)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: angel.k <angel.kolev@solaredge.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-07 20:01:12 +08:00
hj24
05ab107e73 feat: add export app messages (#32990)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-07 11:27:15 +08:00
pepsi
c016793efb refactor: pass KnowledgeConfiguration directly instead of dict (#32732)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: pepsi <pepsi@pepsidexuniji.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 22:15:32 +09:00
Coding On Star
a5bcbaebb7 feat(toast): add IToastProps type import to enhance type safety (#33096)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-06 19:22:55 +08:00
Saumya Talwani
f50e44b24a test: improve coverage for some test files (#32916)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Poojan <poojan@infocusp.com>
Co-authored-by: sahil-infocusp <73810410+sahil-infocusp@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Pandaaaa906 <ye.pandaaaa906@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: heyszt <270985384@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Ijas <ijas.ahmd.ap@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: 木之本澪 <kinomotomiovo@gmail.com>
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
Co-authored-by: 不做了睡大觉 <64798754+stakeswky@users.noreply.github.com>
Co-authored-by: User <user@example.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Leilei <138381132+Inlei@users.noreply.github.com>
Co-authored-by: HaKu <104669497+haku-ink@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
Co-authored-by: Varun Chawla <34209028+veeceey@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: tda <95275462+tda1017@users.noreply.github.com>
Co-authored-by: root <root@DESKTOP-KQLO90N>
Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
Co-authored-by: Niels Kaspers <153818647+nielskaspers@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: Tyson Cung <45380903+tysoncung@users.noreply.github.com>
Co-authored-by: Stephen Zhou <hi@hyoban.cc>
Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
Co-authored-by: slegarraga <64795732+slegarraga@users.noreply.github.com>
Co-authored-by: 99 <wh2099@pm.me>
Co-authored-by: Br1an <932039080@qq.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: akkoaya <151345394+akkoaya@users.noreply.github.com>
Co-authored-by: 盐粒 Yanli <yanli@dify.ai>
Co-authored-by: lif <1835304752@qq.com>
Co-authored-by: weiguang li <codingpunk@gmail.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: HanWenbo <124024253+hwb96@users.noreply.github.com>
Co-authored-by: Coding On Star <447357187@qq.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: Stable Genius <stablegenius043@gmail.com>
Co-authored-by: Stable Genius <259448942+stablegenius49@users.noreply.github.com>
Co-authored-by: ふるい <46769295+Echo0ff@users.noreply.github.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2026-03-06 18:59:16 +08:00
Nite Knite
09347d5e8b chore: fix account dropdown test (#33093) 2026-03-06 18:19:02 +08:00
Stephen Zhou
299a893ac5 chore: bring back code-inspector-plugin and agentation (#33088)
Co-authored-by: zhsama <zhsama@users.noreply.github.com>
2026-03-06 17:01:18 +08:00
Junyan Chin
c477571553 perf: no longer record install count for auto upgrade (#33086) 2026-03-06 16:19:30 +08:00
QuantumGhost
d01acfc490 fix(api): fix the issue that workflow_runs.started_at is overwritten while resuming (#32851)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 15:41:30 +08:00
GareArc
f97ade7053 fix: use LicenseStatus enum instead of raw strings and tighten path prefix matching
Replace raw license status strings with LicenseStatus enum values in
app_factory.py and enterprise_service.py to prevent silent mismatches.
Use trailing-slash prefixes ('/console/api/', '/api/') to avoid false
matches on unrelated paths like /api-docs.
2026-03-05 01:17:49 -08:00
GareArc
a0dcd04546 fix: remove extra exempts 2026-03-05 01:10:23 -08:00
autofix-ci[bot]
b0138316f0 [autofix.ci] apply automated fixes 2026-03-05 09:02:35 +00:00
GareArc
099568f3da fix: expose license status to unauthenticated /system-features callers
After force-logout due to license expiry, the login page calls
/system-features without auth. The license block was gated behind
is_authenticated, so the frontend always saw status='none' instead
of the actual expiry status. Split the guard so license.status and
expired_at are always returned while workspace usage details remain
auth-gated.
2026-03-05 00:48:50 -08:00
GareArc
0623522d04 fix: exempt console bootstrap APIs from license check to prevent infinite reload loop 2026-03-04 22:13:52 -08:00
GareArc
a25d48c5bd feat: add Redis caching for enterprise license status
Cache license status for 10 minutes to reduce HTTP calls to enterprise API.
Only caches license status, not full system features.

Changes:
- Add EnterpriseService.get_cached_license_status() method
- Cache key: enterprise:license:status
- TTL: 600 seconds (10 minutes)
- Graceful degradation: falls back to API call if Redis fails

Performance improvement:
- Before: HTTP call (~50-200ms) on every API request
- After: Redis lookup (~1ms) on cached requests
- Reduces load on enterprise service by ~99%
2026-03-04 21:29:11 -08:00
GareArc
4f3a020670 feat: extend license enforcement to webapp API endpoints
Extend license middleware to also block webapp API (/api/*) when
enterprise license is expired/inactive/lost.

Changes:
- Check both /console/api and /api endpoints
- Add webapp-specific exempt paths:
  - /api/passport (webapp authentication)
  - /api/login, /api/logout, /api/oauth
  - /api/forgot-password
  - /api/system-features (webapp needs this to check license status)

This ensures both console users and webapp users are blocked when
license expires, maintaining consistent enforcement across all APIs.
2026-03-04 20:40:29 -08:00
GareArc
d2e1177478 fix: use UnauthorizedAndForceLogout to trigger frontend logout on license expiry
Change license check to raise UnauthorizedAndForceLogout exception instead
of returning generic JSON response. This ensures proper frontend handling:

Frontend behavior (service/base.ts line 588):
- Checks if code === 'unauthorized_and_force_logout'
- Executes globalThis.location.reload()
- Forces user logout and redirect to login page
- Login page displays license expiration UI (already exists)

Response format:
- HTTP 401 (not 403)
- code: "unauthorized_and_force_logout"
- Triggers frontend reload which clears auth state

This completes the license enforcement flow:
1. Backend blocks all business APIs when license expires
2. Backend returns proper error code to trigger logout
3. Frontend reloads and redirects to login
4. Login page shows license expiration message
2026-03-04 20:40:29 -08:00
GareArc
8a21fd88fd feat: add global license check middleware to block API access on expiry
Add before_request middleware that validates enterprise license status
for all /console/api endpoints when ENTERPRISE_ENABLED is true.

Behavior:
- Checks license status before each console API request
- Returns 403 with clear error message when license is expired/inactive/lost
- Exempts auth endpoints (login, oauth, forgot-password, etc.)
- Exempts /console/api/features so frontend can fetch license status
- Gracefully handles errors to avoid service disruption

This ensures all business APIs are blocked when license expires,
addressing the issue where APIs remained callable after expiry.
2026-03-04 20:40:29 -08:00
GareArc
1c1bcc67da fix: handle enterprise API errors properly to prevent KeyError crashes
When enterprise API returns 403/404, the response contains error JSON
instead of expected data structure. Code was accessing fields directly
causing KeyError → 500 Internal Server Error.

Changes:
- Add enterprise-specific error classes (EnterpriseAPIError, etc.)
- Implement centralized error validation in EnterpriseRequest.send_request()
- Extract error messages from API responses (message/error/detail fields)
- Raise domain-specific errors based on HTTP status codes
- Preserve backward compatibility with raise_for_status parameter

This prevents KeyError crashes and returns proper HTTP error codes
(403/404) instead of 500 errors.
2026-03-04 19:55:03 -08:00
207 changed files with 15847 additions and 5670 deletions

View File

@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -1,16 +1,38 @@
import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check:
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@@ -31,6 +53,39 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request

View File

@@ -2668,3 +2668,77 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -0,0 +1,76 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -194,6 +194,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Create a new database session
with self._session_factory() as session:
existing_model = session.get(WorkflowRun, db_model.id)
if existing_model:
if existing_model.tenant_id != self._tenant_id:
raise ValueError("Unauthorized access to workflow run")
# Preserve the original start time for pause/resume flows.
db_model.created_at = existing_model.created_at
# SQLAlchemy merge intelligently handles both insert and update operations
# based on the presence of the primary key
session.merge(db_model)

View File

@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -66,6 +66,7 @@ def run_migrations_offline():
context.configure(
url=url, target_metadata=get_metadata(), literal_binds=True
)
logger.info("Generating offline migration SQL with url: %s", url)
with context.begin_transaction():
context.run_migrations()

View File

@@ -1,5 +1,6 @@
[pytest]
addopts = --cov=./api --cov-report=json
pythonpath = .
addopts = --cov=./api --cov-report=json --import-mode=importlib
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
@@ -19,7 +20,7 @@ env =
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
MOCK_SWITCH = true

View File

@@ -6,6 +6,13 @@ from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
from services.errors.enterprise import (
EnterpriseAPIBadRequestError,
EnterpriseAPIError,
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
)
logger = logging.getLogger(__name__)
@@ -41,7 +48,6 @@ class BaseRequest:
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@@ -64,10 +70,51 @@ class BaseRequest:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
# Validate HTTP status and raise domain-specific errors
if not response.is_success:
cls._handle_error_response(response)
return response.json()
@classmethod
def _handle_error_response(cls, response: httpx.Response) -> None:
"""
Handle non-2xx HTTP responses by raising appropriate domain errors.
Attempts to extract error message from JSON response body,
falls back to status text if parsing fails.
"""
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
# Try to extract error message from JSON response
try:
error_data = response.json()
if isinstance(error_data, dict):
# Common error response formats:
# {"error": "...", "message": "..."}
# {"message": "..."}
# {"detail": "..."}
error_message = (
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
)
except Exception:
# If JSON parsing fails, use the default message
logger.debug(
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
)
# Raise specific error based on status code
if response.status_code == 400:
raise EnterpriseAPIBadRequestError(error_message)
elif response.status_code == 401:
raise EnterpriseAPIUnauthorizedError(error_message)
elif response.status_code == 403:
raise EnterpriseAPIForbiddenError(error_message)
elif response.status_code == 404:
raise EnterpriseAPINotFoundError(error_message)
else:
raise EnterpriseAPIError(error_message, status_code=response.status_code)
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")

View File

@@ -1,15 +1,26 @@
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
# License status cache configuration
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppSettings(BaseModel):
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
@@ -115,7 +126,6 @@ class EnterpriseService:
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
@@ -223,3 +233,64 @@ class EnterpriseService:
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
balances prompt license-fix detection against DoS mitigation — without
caching, every request on an expired license would hit the enterprise API.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.warning("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result."""
from services.feature_service import LicenseStatus
try:
info = cls.get_info()
license_info = info.get("License")
if not license_info:
return None
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
ttl = (
VALID_LICENSE_CACHE_TTL
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
else INVALID_LICENSE_CACHE_TTL
)
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
except Exception:
logger.warning("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.exception("Failed to get enterprise license status")
return None

View File

@@ -7,6 +7,7 @@ from . import (
conversation,
dataset,
document,
enterprise,
file,
index,
message,
@@ -21,6 +22,7 @@ __all__ = [
"conversation",
"dataset",
"document",
"enterprise",
"file",
"index",
"message",

View File

@@ -0,0 +1,45 @@
"""Enterprise service errors."""
from services.errors.base import BaseServiceError
class EnterpriseServiceError(BaseServiceError):
"""Base exception for enterprise service errors."""
def __init__(self, description: str | None = None, status_code: int | None = None):
super().__init__(description)
self.status_code = status_code
class EnterpriseAPIError(EnterpriseServiceError):
"""Generic enterprise API error (non-2xx response)."""
pass
class EnterpriseAPINotFoundError(EnterpriseServiceError):
"""Enterprise API returned 404 Not Found."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=404)
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
"""Enterprise API returned 403 Forbidden."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=403)
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
"""Enterprise API returned 401 Unauthorized."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=401)
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
"""Enterprise API returned 400 Bad Request."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=400)

View File

@@ -379,14 +379,19 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if is_authenticated and (license_info := enterprise_info.get("License")):
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
# so the login page can detect an expired/inactive license after force-logout.
# All other license details (expiry date, workspace usage) remain auth-gated.
# This behavior reflects prior internal review of information-leakage risks.
if license_info := enterprise_info.get("License"):
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if is_authenticated:
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]

View File

@@ -63,7 +63,12 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

View File

@@ -0,0 +1,304 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -6,7 +6,6 @@ import typing
import click
from celery import shared_task
from core.helper.marketplace import record_install_plugin_event
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
@@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task(
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",

View File

@@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement
for App descriptions across all creation and editing endpoints.
"""
import os
import sys
import pytest
# Add the API root to Python path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
class TestAppDescriptionValidationUnit:
"""Unit tests for description validation function"""

View File

@@ -10,8 +10,11 @@ more reliable and realistic test scenarios.
import logging
import os
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Protocol, TypeVar
import psycopg2
import pytest
from flask import Flask
from flask.testing import FlaskClient
@@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
logger = logging.getLogger(__name__)
class _CloserProtocol(Protocol):
"""_Closer is any type which implement the close() method."""
def close(self):
"""close the current object, release any external resouece (file, transaction, connection etc.)
associated with it.
"""
pass
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
@contextmanager
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
yield closer
closer.close()
class DifyTestContainers:
"""
Manages all test containers required for Dify integration tests.
@@ -97,45 +119,28 @@ class DifyTestContainers:
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
logger.info("PostgreSQL container is ready and accepting connections")
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
try:
import psycopg2
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
cursor.close()
conn.close()
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
with _auto_close(conn):
with conn.cursor() as cursor:
# Install uuid-ossp extension for UUID generation
logger.info("Installing uuid-ossp extension...")
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
logger.info("uuid-ossp extension installed successfully")
except Exception as e:
logger.warning("Failed to install uuid-ossp extension: %s", e)
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
try:
conn = psycopg2.connect(
host=db_host,
port=db_port,
user=self.postgres.username,
password=self.postgres.password,
database=self.postgres.dbname,
)
conn.autocommit = True
cursor = conn.cursor()
cursor.execute("CREATE DATABASE dify_plugin;")
cursor.close()
conn.close()
# NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement
# inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block.
with _auto_close(conn.cursor()) as cursor:
# Create plugin database for dify-plugin-daemon
logger.info("Creating plugin database...")
cursor.execute("CREATE DATABASE dify_plugin;")
logger.info("Plugin database created successfully")
except Exception as e:
logger.warning("Failed to create plugin database: %s", e)
# Set up storage environment variables
os.environ.setdefault("STORAGE_TYPE", "opendal")
@@ -258,23 +263,16 @@ class DifyTestContainers:
containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon]
for container in containers:
if container:
try:
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
except Exception as e:
# Log error but don't fail the test cleanup
logger.warning("Failed to stop container %s: %s", container, e)
container_name = container.image
logger.info("Stopping container: %s", container_name)
container.stop()
logger.info("Successfully stopped container: %s", container_name)
# Stop and remove the network
if self.network:
try:
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
except Exception as e:
logger.warning("Failed to remove Docker network: %s", e)
logger.info("Removing Docker network...")
self.network.remove()
logger.info("Successfully removed Docker network")
self._containers_started = False
logger.info("All test containers stopped and cleaned up successfully")

View File

@@ -358,10 +358,9 @@ class TestFeatureService:
assert result is not None
assert isinstance(result, SystemFeatureModel)
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
# Ensure only essential UI flags are returned to unauthenticated clients
# to keep the payload lightweight and adhere to architectural boundaries.
assert result.license.status == LicenseStatus.NONE
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
# Detailed license info (expiry, workspaces) remains auth-gated.
assert result.license.status == LicenseStatus.ACTIVE
assert result.license.expired_at == ""
assert result.license.workspaces.enabled is False
assert result.license.workspaces.limit == 0

View File

@@ -0,0 +1,233 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
os.environ.setdefault("STORAGE_TYPE", "opendal")
# Add the API directory to Python path to ensure proper imports
import sys
sys.path.insert(0, PROJECT_DIR)
from core.db.session_factory import configure_session_factory, session_factory
from extensions import ext_redis

View File

@@ -1,13 +1,8 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import dify_graph.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module

View File

@@ -0,0 +1,425 @@
"""
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
This test suite ensures that the files array is correctly populated in the message_end
SSE event, which is critical for vision/image chat responses to render correctly.
Test Coverage:
- Files array populated when MessageFile records exist
- Files array is None when no MessageFile records exist
- Correct signed URL generation for LOCAL_FILE transfer method
- Correct URL handling for REMOTE_URL transfer method
- Correct URL handling for TOOL_FILE transfer method
- Proper file metadata formatting (filename, mime_type, size, extension)
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
class TestMessageEndStreamResponseFiles:
"""Test suite for files array population in message_end SSE event."""
@pytest.fixture
def mock_pipeline(self):
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
pipeline._message_id = str(uuid.uuid4())
pipeline._task_state = Mock()
pipeline._task_state.metadata = Mock()
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
pipeline._task_state.llm_result = Mock()
pipeline._task_state.llm_result.usage = Mock()
pipeline._application_generate_entity = Mock()
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
return pipeline
@pytest.fixture
def mock_message_file_local(self):
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_remote(self):
"""Create a mock MessageFile with REMOTE_URL transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_tool(self):
"""Create a mock MessageFile with TOOL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_upload_file(self, mock_message_file_local):
"""Create a mock UploadFile."""
upload_file = Mock(spec=UploadFile)
upload_file.id = mock_message_file_local.upload_file_id
upload_file.name = "test_image.png"
upload_file.mime_type = "image/png"
upload_file.size = 1024
upload_file.extension = "png"
return upload_file
def test_message_end_with_no_files(self, mock_pipeline):
"""Test that files array is None when no MessageFile records exist."""
# Arrange
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = []
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is None
assert result.id == mock_pipeline._message_id
assert result.metadata == {"test": "metadata"}
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_local.id
assert file_dict["filename"] == "test_image.png"
assert file_dict["mime_type"] == "image/png"
assert file_dict["size"] == 1024
assert file_dict["extension"] == ".png"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
assert "https://example.com/signed-url" in file_dict["url"]
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
assert file_dict["remote_url"] == ""
# Verify database queries
# Should be called twice: once for MessageFile, once for UploadFile
assert mock_session.scalars.call_count == 2
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
# Arrange
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_remote]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_remote.id
assert file_dict["filename"] == "image.jpg"
assert file_dict["url"] == "https://example.com/image.jpg"
assert file_dict["extension"] == ".jpg"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
assert file_dict["remote_url"] == "https://example.com/image.jpg"
assert file_dict["upload_file_id"] == mock_message_file_remote.id
# Verify only one query for message_files is made
mock_session.scalars.assert_called_once()
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "https://example.com/tool_file.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["url"] == "https://example.com/tool_file.png"
assert file_dict["filename"] == "tool_file.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with local path."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_123.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
assert file_dict["filename"] == "tool_file_123.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
# Verify tool file signing was called
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_abc.verylongextension"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed.bin"
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
assert result.files is not None
file_dict = result.files[0]
assert file_dict["extension"] == ".bin"
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
def test_message_end_with_multiple_files(
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
):
"""Test that files array contains all MessageFile records when multiple exist."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 2
# Verify both files are present
file_ids = [f["related_id"] for f in result.files]
assert mock_message_file_local.id in file_ids
assert mock_message_file_remote.id in file_ids
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query) - returns empty list (not found)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [] # UploadFile not found
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/fallback-url" in file_dict["url"]
# Verify fallback URL was generated using upload_file_id from message_file
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))

View File

@@ -0,0 +1,84 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
engine = create_engine("sqlite:///:memory:")
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
user = MagicMock(spec=Account)
user.id = str(uuid4())
user.current_tenant_id = str(uuid4())
repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=real_session_factory,
user=user,
app_id="app-id",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
session_context = MagicMock()
session_context.__enter__.return_value = session
session_context.__exit__.return_value = False
repository._session_factory = MagicMock(return_value=session_context)
return repository
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
return WorkflowExecution.new(
id_=execution_id,
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0.0",
graph={"nodes": [], "edges": []},
inputs={"query": "hello"},
started_at=started_at,
)
def test_save_uses_execution_started_at_when_record_does_not_exist():
session = MagicMock()
session.get.return_value = None
repository = _build_repository_with_mocked_session(session)
started_at = datetime(2026, 1, 1, 12, 0, 0)
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == started_at
session.commit.assert_called_once()
def test_save_preserves_existing_created_at_when_record_already_exists():
session = MagicMock()
repository = _build_repository_with_mocked_session(session)
execution_id = str(uuid4())
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
existing_run = WorkflowRun()
existing_run.id = execution_id
existing_run.tenant_id = repository._tenant_id
existing_run.created_at = existing_created_at
session.get.return_value = existing_run
execution = _build_execution(
execution_id=execution_id,
started_at=datetime(2026, 1, 1, 12, 30, 0),
)
repository.save(execution)
saved_model = session.merge.call_args.args[0]
assert saved_model.created_at == existing_created_at
session.commit.assert_called_once()

View File

@@ -2,15 +2,7 @@
Simple test to verify MockNodeFactory works with iteration nodes.
"""
import sys
from pathlib import Path
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
# Add api directory to path
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
sys.path.insert(0, str(api_dir))
from dify_graph.enums import NodeType
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory

View File

@@ -3,14 +3,8 @@ Simple test to validate the auto-mock system without external dependencies.
"""
import sys
from pathlib import Path
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
# Add api directory to path
api_dir = Path(__file__).parent.parent.parent.parent.parent.parent
sys.path.insert(0, str(api_dir))
from dify_graph.enums import NodeType
from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig
from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory

View File

@@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
Covers:
- Default workspace auto-join behavior
- License status caching (get_cached_license_status)
"""
from unittest.mock import patch
@@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()

View File

@@ -0,0 +1,43 @@
import datetime
import pytest
from services.retention.conversation.message_export_service import AppMessageExportService
def test_validate_export_filename_accepts_relative_path():
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
@pytest.mark.parametrize(
"filename",
[
"test01.jsonl.gz",
"test01.jsonl",
"test01.gz",
"/tmp/test01",
"exports/../test01",
"bad\x00name",
"bad\tname",
"a" * 1025,
],
)
def test_validate_export_filename_rejects_invalid_values(filename: str):
with pytest.raises(ValueError):
AppMessageExportService.validate_export_filename(filename)
def test_service_derives_output_names_from_filename_base():
service = AppMessageExportService(
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
start_from=None,
end_before=datetime.datetime(2026, 3, 1),
filename="exports/2026/test01",
batch_size=1000,
use_cloud_storage=True,
dry_run=True,
)
assert service._filename_base == "exports/2026/test01"
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
assert service.output_jsonl_name == "exports/2026/test01.jsonl"

View File

@@ -295,7 +295,24 @@ describe('Pricing Modal Flow', () => {
})
})
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
// ─── 6. Close Handling ───────────────────────────────────────────────────
describe('Close handling', () => {
it('should call onCancel when pressing ESC key', () => {
render(<Pricing onCancel={onCancel} />)
// ahooks useKeyPress listens on document for keydown events
document.dispatchEvent(new KeyboardEvent('keydown', {
key: 'Escape',
code: 'Escape',
keyCode: 27,
bubbles: true,
}))
expect(onCancel).toHaveBeenCalledTimes(1)
})
})
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
describe('Pricing page URL', () => {
it('should render pricing link with correct URL', () => {
render(<Pricing onCancel={onCancel} />)

View File

@@ -0,0 +1,139 @@
import * as amplitude from '@amplitude/analytics-browser'
import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser'
import { render } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider'
const mockConfig = vi.hoisted(() => ({
AMPLITUDE_API_KEY: 'test-api-key',
IS_CLOUD_EDITION: true,
}))
vi.mock('@/config', () => mockConfig)
vi.mock('@amplitude/analytics-browser', () => ({
init: vi.fn(),
add: vi.fn(),
}))
vi.mock('@amplitude/plugin-session-replay-browser', () => ({
sessionReplayPlugin: vi.fn(() => ({ name: 'session-replay' })),
}))
describe('AmplitudeProvider', () => {
beforeEach(() => {
vi.clearAllMocks()
mockConfig.AMPLITUDE_API_KEY = 'test-api-key'
mockConfig.IS_CLOUD_EDITION = true
})
describe('isAmplitudeEnabled', () => {
it('returns true when cloud edition and api key present', () => {
expect(isAmplitudeEnabled()).toBe(true)
})
it('returns false when cloud edition but no api key', () => {
mockConfig.AMPLITUDE_API_KEY = ''
expect(isAmplitudeEnabled()).toBe(false)
})
it('returns false when not cloud edition', () => {
mockConfig.IS_CLOUD_EDITION = false
expect(isAmplitudeEnabled()).toBe(false)
})
})
describe('Component', () => {
it('initializes amplitude when enabled', () => {
render(<AmplitudeProvider sessionReplaySampleRate={0.8} />)
expect(amplitude.init).toHaveBeenCalledWith('test-api-key', expect.any(Object))
expect(sessionReplayPlugin).toHaveBeenCalledWith({ sampleRate: 0.8 })
expect(amplitude.add).toHaveBeenCalledTimes(2)
})
it('does not initialize amplitude when disabled', () => {
mockConfig.AMPLITUDE_API_KEY = ''
render(<AmplitudeProvider />)
expect(amplitude.init).not.toHaveBeenCalled()
expect(amplitude.add).not.toHaveBeenCalled()
})
it('pageNameEnrichmentPlugin logic works as expected', async () => {
render(<AmplitudeProvider />)
const plugin = vi.mocked(amplitude.add).mock.calls[0]?.[0] as amplitude.Types.EnrichmentPlugin | undefined
expect(plugin).toBeDefined()
if (!plugin?.execute || !plugin.setup)
throw new Error('Expected page-name-enrichment plugin with setup/execute')
expect(plugin.name).toBe('page-name-enrichment')
const execute = plugin.execute
const setup = plugin.setup
type SetupFn = NonNullable<amplitude.Types.EnrichmentPlugin['setup']>
const getPageTitle = (evt: amplitude.Types.Event | null | undefined) =>
(evt?.event_properties as Record<string, unknown> | undefined)?.['[Amplitude] Page Title']
await setup(
{} as Parameters<SetupFn>[0],
{} as Parameters<SetupFn>[1],
)
const originalWindowLocation = window.location
try {
Object.defineProperty(window, 'location', {
value: { pathname: '/datasets' },
writable: true,
})
const event: amplitude.Types.Event = {
event_type: '[Amplitude] Page Viewed',
event_properties: {},
}
const result = await execute(event)
expect(getPageTitle(result)).toBe('Knowledge')
window.location.pathname = '/'
await execute(event)
expect(getPageTitle(event)).toBe('Home')
window.location.pathname = '/apps'
await execute(event)
expect(getPageTitle(event)).toBe('Studio')
window.location.pathname = '/explore'
await execute(event)
expect(getPageTitle(event)).toBe('Explore')
window.location.pathname = '/tools'
await execute(event)
expect(getPageTitle(event)).toBe('Tools')
window.location.pathname = '/account'
await execute(event)
expect(getPageTitle(event)).toBe('Account')
window.location.pathname = '/signin'
await execute(event)
expect(getPageTitle(event)).toBe('Sign In')
window.location.pathname = '/signup'
await execute(event)
expect(getPageTitle(event)).toBe('Sign Up')
window.location.pathname = '/unknown'
await execute(event)
expect(getPageTitle(event)).toBe('Unknown')
const otherEvent = {
event_type: 'Button Clicked',
event_properties: {},
} as amplitude.Types.Event
const otherResult = await execute(otherEvent)
expect(getPageTitle(otherResult)).toBeUndefined()
const noPropsEvent = {
event_type: '[Amplitude] Page Viewed',
} as amplitude.Types.Event
const noPropsResult = await execute(noPropsEvent)
expect(noPropsResult?.event_properties).toBeUndefined()
}
finally {
Object.defineProperty(window, 'location', {
value: originalWindowLocation,
writable: true,
})
}
})
})
})

View File

@@ -0,0 +1,32 @@
import { describe, expect, it } from 'vitest'
import AmplitudeProvider, { isAmplitudeEnabled } from './AmplitudeProvider'
import indexDefault, {
isAmplitudeEnabled as indexIsAmplitudeEnabled,
resetUser,
setUserId,
setUserProperties,
trackEvent,
} from './index'
import {
resetUser as utilsResetUser,
setUserId as utilsSetUserId,
setUserProperties as utilsSetUserProperties,
trackEvent as utilsTrackEvent,
} from './utils'
describe('Amplitude index exports', () => {
it('exports AmplitudeProvider as default', () => {
expect(indexDefault).toBe(AmplitudeProvider)
})
it('exports isAmplitudeEnabled', () => {
expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled)
})
it('exports utils', () => {
expect(resetUser).toBe(utilsResetUser)
expect(setUserId).toBe(utilsSetUserId)
expect(setUserProperties).toBe(utilsSetUserProperties)
expect(trackEvent).toBe(utilsTrackEvent)
})
})

View File

@@ -0,0 +1,119 @@
import { resetUser, setUserId, setUserProperties, trackEvent } from './utils'
const mockState = vi.hoisted(() => ({
enabled: true,
}))
const mockTrack = vi.hoisted(() => vi.fn())
const mockSetUserId = vi.hoisted(() => vi.fn())
const mockIdentify = vi.hoisted(() => vi.fn())
const mockReset = vi.hoisted(() => vi.fn())
const MockIdentify = vi.hoisted(() =>
class {
setCalls: Array<[string, unknown]> = []
set(key: string, value: unknown) {
this.setCalls.push([key, value])
return this
}
},
)
vi.mock('./AmplitudeProvider', () => ({
isAmplitudeEnabled: () => mockState.enabled,
}))
vi.mock('@amplitude/analytics-browser', () => ({
track: (...args: unknown[]) => mockTrack(...args),
setUserId: (...args: unknown[]) => mockSetUserId(...args),
identify: (...args: unknown[]) => mockIdentify(...args),
reset: (...args: unknown[]) => mockReset(...args),
Identify: MockIdentify,
}))
describe('amplitude utils', () => {
beforeEach(() => {
vi.clearAllMocks()
mockState.enabled = true
})
describe('trackEvent', () => {
it('should call amplitude.track when amplitude is enabled', () => {
trackEvent('dataset_created', { source: 'wizard' })
expect(mockTrack).toHaveBeenCalledTimes(1)
expect(mockTrack).toHaveBeenCalledWith('dataset_created', { source: 'wizard' })
})
it('should not call amplitude.track when amplitude is disabled', () => {
mockState.enabled = false
trackEvent('dataset_created', { source: 'wizard' })
expect(mockTrack).not.toHaveBeenCalled()
})
})
describe('setUserId', () => {
it('should call amplitude.setUserId when amplitude is enabled', () => {
setUserId('user-123')
expect(mockSetUserId).toHaveBeenCalledTimes(1)
expect(mockSetUserId).toHaveBeenCalledWith('user-123')
})
it('should not call amplitude.setUserId when amplitude is disabled', () => {
mockState.enabled = false
setUserId('user-123')
expect(mockSetUserId).not.toHaveBeenCalled()
})
})
describe('setUserProperties', () => {
it('should build identify event and call amplitude.identify when amplitude is enabled', () => {
const properties: Record<string, unknown> = {
role: 'owner',
seats: 3,
verified: true,
}
setUserProperties(properties)
expect(mockIdentify).toHaveBeenCalledTimes(1)
const identifyArg = mockIdentify.mock.calls[0][0] as InstanceType<typeof MockIdentify>
expect(identifyArg).toBeInstanceOf(MockIdentify)
expect(identifyArg.setCalls).toEqual([
['role', 'owner'],
['seats', 3],
['verified', true],
])
})
it('should not call amplitude.identify when amplitude is disabled', () => {
mockState.enabled = false
setUserProperties({ role: 'owner' })
expect(mockIdentify).not.toHaveBeenCalled()
})
})
describe('resetUser', () => {
it('should call amplitude.reset when amplitude is enabled', () => {
resetUser()
expect(mockReset).toHaveBeenCalledTimes(1)
})
it('should not call amplitude.reset when amplitude is disabled', () => {
mockState.enabled = false
resetUser()
expect(mockReset).not.toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,148 @@
import { AudioPlayerManager } from '../audio.player.manager'
type AudioCallback = ((event: string) => void) | null
type AudioPlayerCtorArgs = [
string,
boolean,
string | undefined,
string | null | undefined,
string | undefined,
AudioCallback,
]
type MockAudioPlayerInstance = {
setCallback: ReturnType<typeof vi.fn>
pauseAudio: ReturnType<typeof vi.fn>
resetMsgId: ReturnType<typeof vi.fn>
cacheBuffers: Array<ArrayBuffer>
sourceBuffer: {
abort: ReturnType<typeof vi.fn>
} | undefined
}
const mockState = vi.hoisted(() => ({
instances: [] as MockAudioPlayerInstance[],
}))
const mockAudioPlayerConstructor = vi.hoisted(() => vi.fn())
const MockAudioPlayer = vi.hoisted(() => {
return class MockAudioPlayerClass {
setCallback = vi.fn()
pauseAudio = vi.fn()
resetMsgId = vi.fn()
cacheBuffers = [new ArrayBuffer(1)]
sourceBuffer = { abort: vi.fn() }
constructor(...args: AudioPlayerCtorArgs) {
mockAudioPlayerConstructor(...args)
mockState.instances.push(this as unknown as MockAudioPlayerInstance)
}
}
})
vi.mock('@/app/components/base/audio-btn/audio', () => ({
default: MockAudioPlayer,
}))
describe('AudioPlayerManager', () => {
beforeEach(() => {
vi.clearAllMocks()
mockState.instances = []
Reflect.set(AudioPlayerManager, 'instance', undefined)
})
describe('getInstance', () => {
it('should return the same singleton instance across calls', () => {
const first = AudioPlayerManager.getInstance()
const second = AudioPlayerManager.getInstance()
expect(first).toBe(second)
})
})
describe('getAudioPlayer', () => {
it('should create a new audio player when no existing player is cached', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
const result = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledWith(
'/text-to-audio',
false,
'msg-1',
'hello',
'en-US',
callback,
)
expect(result).toBe(mockState.instances[0])
})
it('should reuse existing player and update callback when msg id is unchanged', () => {
const manager = AudioPlayerManager.getInstance()
const firstCallback = vi.fn()
const secondCallback = vi.fn()
const first = manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', firstCallback)
const second = manager.getAudioPlayer('/ignored', true, 'msg-1', 'ignored', 'fr-FR', secondCallback)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(1)
expect(first).toBe(second)
expect(mockState.instances[0].setCallback).toHaveBeenCalledTimes(1)
expect(mockState.instances[0].setCallback).toHaveBeenCalledWith(secondCallback)
})
it('should cleanup existing player and create a new one when msg id changes', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
const previous = mockState.instances[0]
const next = manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback)
expect(previous.pauseAudio).toHaveBeenCalledTimes(1)
expect(previous.cacheBuffers).toEqual([])
expect(previous.sourceBuffer?.abort).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2)
expect(next).toBe(mockState.instances[1])
})
it('should swallow cleanup errors and still create a new player', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
const previous = mockState.instances[0]
previous.pauseAudio.mockImplementation(() => {
throw new Error('cleanup failure')
})
expect(() => {
manager.getAudioPlayer('/apps/1/text-to-audio', false, 'msg-2', 'world', 'en-US', callback)
}).not.toThrow()
expect(previous.pauseAudio).toHaveBeenCalledTimes(1)
expect(mockAudioPlayerConstructor).toHaveBeenCalledTimes(2)
})
})
describe('resetMsgId', () => {
it('should forward reset message id to the cached audio player when present', () => {
const manager = AudioPlayerManager.getInstance()
const callback = vi.fn()
manager.getAudioPlayer('/text-to-audio', false, 'msg-1', 'hello', 'en-US', callback)
manager.resetMsgId('msg-updated')
expect(mockState.instances[0].resetMsgId).toHaveBeenCalledTimes(1)
expect(mockState.instances[0].resetMsgId).toHaveBeenCalledWith('msg-updated')
})
it('should not throw when resetting message id without an audio player', () => {
const manager = AudioPlayerManager.getInstance()
expect(() => manager.resetMsgId('msg-updated')).not.toThrow()
})
})
})

View File

@@ -0,0 +1,610 @@
import { Buffer } from 'node:buffer'
import { waitFor } from '@testing-library/react'
import { AppSourceType } from '@/service/share'
import AudioPlayer from '../audio'
const mockToastNotify = vi.hoisted(() => vi.fn())
const mockTextToAudioStream = vi.hoisted(() => vi.fn())
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (...args: unknown[]) => mockToastNotify(...args),
},
}))
vi.mock('@/service/share', () => ({
AppSourceType: {
webApp: 'webApp',
installedApp: 'installedApp',
},
textToAudioStream: (...args: unknown[]) => mockTextToAudioStream(...args),
}))
type AudioEventName = 'ended' | 'paused' | 'loaded' | 'play' | 'timeupdate' | 'loadeddate' | 'canplay' | 'error' | 'sourceopen'
type AudioEventListener = () => void
type ReaderResult = {
value: Uint8Array | undefined
done: boolean
}
type Reader = {
read: () => Promise<ReaderResult>
}
type AudioResponse = {
status: number
body: {
getReader: () => Reader
}
}
class MockSourceBuffer {
updating = false
appendBuffer = vi.fn((_buffer: ArrayBuffer) => undefined)
abort = vi.fn(() => undefined)
}
class MockMediaSource {
readyState: 'open' | 'closed' = 'open'
sourceBuffer = new MockSourceBuffer()
private listeners: Partial<Record<AudioEventName, AudioEventListener[]>> = {}
addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => {
const listeners = this.listeners[event] || []
listeners.push(listener)
this.listeners[event] = listeners
})
addSourceBuffer = vi.fn((_contentType: string) => this.sourceBuffer)
endOfStream = vi.fn(() => undefined)
emit(event: AudioEventName) {
const listeners = this.listeners[event] || []
listeners.forEach((listener) => {
listener()
})
}
}
class MockAudio {
src = ''
autoplay = false
disableRemotePlayback = false
controls = false
paused = true
ended = false
played: unknown = null
private listeners: Partial<Record<AudioEventName, AudioEventListener[]>> = {}
addEventListener = vi.fn((event: AudioEventName, listener: AudioEventListener) => {
const listeners = this.listeners[event] || []
listeners.push(listener)
this.listeners[event] = listeners
})
play = vi.fn(async () => {
this.paused = false
})
pause = vi.fn(() => {
this.paused = true
})
emit(event: AudioEventName) {
const listeners = this.listeners[event] || []
listeners.forEach((listener) => {
listener()
})
}
}
class MockAudioContext {
state: 'running' | 'suspended' = 'running'
destination = {}
connect = vi.fn(() => undefined)
createMediaElementSource = vi.fn((_audio: MockAudio) => ({
connect: this.connect,
}))
resume = vi.fn(async () => {
this.state = 'running'
})
suspend = vi.fn(() => {
this.state = 'suspended'
})
}
const testState = {
mediaSources: [] as MockMediaSource[],
audios: [] as MockAudio[],
audioContexts: [] as MockAudioContext[],
}
class MockMediaSourceCtor extends MockMediaSource {
constructor() {
super()
testState.mediaSources.push(this)
}
}
class MockAudioCtor extends MockAudio {
constructor() {
super()
testState.audios.push(this)
}
}
class MockAudioContextCtor extends MockAudioContext {
constructor() {
super()
testState.audioContexts.push(this)
}
}
const originalAudio = globalThis.Audio
const originalAudioContext = globalThis.AudioContext
const originalCreateObjectURL = globalThis.URL.createObjectURL
const originalMediaSource = window.MediaSource
const originalManagedMediaSource = window.ManagedMediaSource
const setMediaSourceSupport = (options: { mediaSource: boolean, managedMediaSource: boolean }) => {
Object.defineProperty(window, 'MediaSource', {
configurable: true,
writable: true,
value: options.mediaSource ? MockMediaSourceCtor : undefined,
})
Object.defineProperty(window, 'ManagedMediaSource', {
configurable: true,
writable: true,
value: options.managedMediaSource ? MockMediaSourceCtor : undefined,
})
}
const makeAudioResponse = (status: number, reads: ReaderResult[]): AudioResponse => {
const read = vi.fn<() => Promise<ReaderResult>>()
reads.forEach((result) => {
read.mockResolvedValueOnce(result)
})
return {
status,
body: {
getReader: () => ({ read }),
},
}
}
describe('AudioPlayer', () => {
beforeEach(() => {
vi.clearAllMocks()
testState.mediaSources = []
testState.audios = []
testState.audioContexts = []
Object.defineProperty(globalThis, 'Audio', {
configurable: true,
writable: true,
value: MockAudioCtor,
})
Object.defineProperty(globalThis, 'AudioContext', {
configurable: true,
writable: true,
value: MockAudioContextCtor,
})
Object.defineProperty(globalThis.URL, 'createObjectURL', {
configurable: true,
writable: true,
value: vi.fn(() => 'blob:mock-url'),
})
setMediaSourceSupport({ mediaSource: true, managedMediaSource: false })
})
afterAll(() => {
Object.defineProperty(globalThis, 'Audio', {
configurable: true,
writable: true,
value: originalAudio,
})
Object.defineProperty(globalThis, 'AudioContext', {
configurable: true,
writable: true,
value: originalAudioContext,
})
Object.defineProperty(globalThis.URL, 'createObjectURL', {
configurable: true,
writable: true,
value: originalCreateObjectURL,
})
Object.defineProperty(window, 'MediaSource', {
configurable: true,
writable: true,
value: originalMediaSource,
})
Object.defineProperty(window, 'ManagedMediaSource', {
configurable: true,
writable: true,
value: originalManagedMediaSource,
})
})
describe('constructor behavior', () => {
it('should initialize media source, audio, and media element source when MediaSource exists', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
const mediaSource = testState.mediaSources[0]
expect(player.mediaSource).toBe(mediaSource as unknown as MediaSource)
expect(globalThis.URL.createObjectURL).toHaveBeenCalledTimes(1)
expect(audio.src).toBe('blob:mock-url')
expect(audio.autoplay).toBe(true)
expect(audioContext.createMediaElementSource).toHaveBeenCalledWith(audio)
expect(audioContext.connect).toHaveBeenCalledTimes(1)
})
it('should notify unsupported browser when no MediaSource implementation exists', () => {
setMediaSourceSupport({ mediaSource: false, managedMediaSource: false })
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const audio = testState.audios[0]
expect(player.mediaSource).toBeNull()
expect(audio.src).toBe('')
expect(mockToastNotify).toHaveBeenCalledTimes(1)
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
)
})
it('should configure fallback audio controls when ManagedMediaSource is used', () => {
setMediaSourceSupport({ mediaSource: false, managedMediaSource: true })
// Create with callback to ensure constructor path completes with fallback source.
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, vi.fn())
const audio = testState.audios[0]
expect(player.mediaSource).not.toBeNull()
expect(audio.disableRemotePlayback).toBe(true)
expect(audio.controls).toBe(true)
})
})
describe('event wiring', () => {
it('should forward registered audio events to callback', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.emit('play')
audio.emit('ended')
audio.emit('error')
audio.emit('paused')
audio.emit('loaded')
audio.emit('timeupdate')
audio.emit('loadeddate')
audio.emit('canplay')
expect(player.callback).toBe(callback)
expect(callback).toHaveBeenCalledWith('play')
expect(callback).toHaveBeenCalledWith('ended')
expect(callback).toHaveBeenCalledWith('error')
expect(callback).toHaveBeenCalledWith('paused')
expect(callback).toHaveBeenCalledWith('loaded')
expect(callback).toHaveBeenCalledWith('timeupdate')
expect(callback).toHaveBeenCalledWith('loadeddate')
expect(callback).toHaveBeenCalledWith('canplay')
})
it('should initialize source buffer only once when sourceopen fires multiple times', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn())
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.emit('sourceopen')
expect(mediaSource.addSourceBuffer).toHaveBeenCalledTimes(1)
expect(player.sourceBuffer).toBe(mediaSource.sourceBuffer)
})
})
describe('playback control', () => {
it('should request streaming audio when playAudio is called before loading', async () => {
mockTextToAudioStream.mockResolvedValue(
makeAudioResponse(200, [
{ value: new Uint8Array([4, 5]), done: false },
{ value: new Uint8Array([1, 2, 3]), done: true },
]),
)
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', vi.fn())
player.playAudio()
await waitFor(() => {
expect(mockTextToAudioStream).toHaveBeenCalledTimes(1)
})
expect(mockTextToAudioStream).toHaveBeenCalledWith(
'/text-to-audio',
AppSourceType.webApp,
{ content_type: 'audio/mpeg' },
{
message_id: 'msg-1',
streaming: true,
voice: 'en-US',
text: 'hello',
},
)
expect(player.isLoadData).toBe(true)
})
it('should emit error callback and reset load flag when stream response status is not 200', async () => {
const callback = vi.fn()
mockTextToAudioStream.mockResolvedValue(
makeAudioResponse(500, [{ value: new Uint8Array([1]), done: true }]),
)
const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback)
player.playAudio()
await waitFor(() => {
expect(callback).toHaveBeenCalledWith('error')
})
expect(player.isLoadData).toBe(false)
})
it('should resume and play immediately when playAudio is called in suspended loaded state', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'suspended'
player.playAudio()
await Promise.resolve()
expect(audioContext.resume).toHaveBeenCalledTimes(1)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should play ended audio when data is already loaded', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'running'
audio.ended = true
player.playAudio()
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should only emit play callback without replaying when loaded audio is already playing', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', false, 'msg-1', 'hello', undefined, callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.isLoadData = true
audioContext.state = 'running'
audio.ended = false
player.playAudio()
expect(audio.play).not.toHaveBeenCalled()
expect(callback).toHaveBeenCalledWith('play')
})
it('should emit error callback when stream request throws', async () => {
const callback = vi.fn()
mockTextToAudioStream.mockRejectedValue(new Error('network failed'))
const player = new AudioPlayer('/text-to-audio', false, 'msg-2', 'world', undefined, callback)
player.playAudio()
await waitFor(() => {
expect(callback).toHaveBeenCalledWith('error')
})
expect(player.isLoadData).toBe(false)
})
it('should call pause flow and notify paused event when pauseAudio is invoked', () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
player.pauseAudio()
expect(callback).toHaveBeenCalledWith('paused')
expect(audio.pause).toHaveBeenCalledTimes(1)
expect(audioContext.suspend).toHaveBeenCalledTimes(1)
})
})
describe('message and direct-audio helpers', () => {
it('should update message id through resetMsgId', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
player.resetMsgId('msg-2')
expect(player.msgId).toBe('msg-2')
})
it('should end stream without playback when playAudioWithAudio receives empty content', async () => {
vi.useFakeTimers()
try {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const mediaSource = testState.mediaSources[0]
await player.playAudioWithAudio('', true)
await vi.advanceTimersByTimeAsync(40)
expect(player.isLoadData).toBe(false)
expect(player.cacheBuffers).toHaveLength(0)
expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1)
expect(callback).not.toHaveBeenCalledWith('play')
}
finally {
vi.useRealTimers()
}
})
it('should decode base64 and start playback when playAudioWithAudio is called with playable content', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
const mediaSource = testState.mediaSources[0]
const audioBase64 = Buffer.from('hello').toString('base64')
mediaSource.emit('sourceopen')
audio.paused = true
await player.playAudioWithAudio(audioBase64, true)
await Promise.resolve()
expect(player.isLoadData).toBe(true)
expect(player.cacheBuffers).toHaveLength(0)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
const appendedAudioData = mediaSource.sourceBuffer.appendBuffer.mock.calls[0][0]
expect(appendedAudioData).toBeInstanceOf(ArrayBuffer)
expect(appendedAudioData.byteLength).toBeGreaterThan(0)
expect(audioContext.resume).toHaveBeenCalledTimes(1)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should skip playback when playAudioWithAudio is called with play=false', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
const audioContext = testState.audioContexts[0]
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), false)
expect(player.isLoadData).toBe(false)
expect(audioContext.resume).not.toHaveBeenCalled()
expect(audio.play).not.toHaveBeenCalled()
expect(callback).not.toHaveBeenCalledWith('play')
})
it('should play immediately for ended audio in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = true
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
it('should not replay when played list exists in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = false
audio.played = {}
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).not.toHaveBeenCalled()
expect(callback).not.toHaveBeenCalledWith('play')
})
it('should replay when paused is false and played list is empty in playAudioWithAudio', async () => {
const callback = vi.fn()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', callback)
const audio = testState.audios[0]
audio.paused = false
audio.ended = false
audio.played = null
await player.playAudioWithAudio(Buffer.from('hello').toString('base64'), true)
expect(audio.play).toHaveBeenCalledTimes(1)
expect(callback).toHaveBeenCalledWith('play')
})
})
describe('buffering internals', () => {
it('should finish stream when receiveAudioData gets an undefined chunk', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const finishStream = vi
.spyOn(player as unknown as { finishStream: () => void }, 'finishStream')
.mockImplementation(() => { })
; (player as unknown as { receiveAudioData: (data: Uint8Array | undefined) => void }).receiveAudioData(undefined)
expect(finishStream).toHaveBeenCalledTimes(1)
})
it('should finish stream when receiveAudioData gets empty bytes while source is open', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const finishStream = vi
.spyOn(player as unknown as { finishStream: () => void }, 'finishStream')
.mockImplementation(() => { })
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array(0))
expect(finishStream).toHaveBeenCalledTimes(1)
})
it('should queue incoming buffer when source buffer is updating', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.sourceBuffer.updating = true
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([1, 2, 3]))
expect(player.cacheBuffers.length).toBe(1)
})
it('should append previously queued buffer before new one when source buffer is idle', () => {
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
const existingBuffer = new ArrayBuffer(2)
player.cacheBuffers = [existingBuffer]
mediaSource.sourceBuffer.updating = false
; (player as unknown as { receiveAudioData: (data: Uint8Array) => void }).receiveAudioData(new Uint8Array([9]))
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledWith(existingBuffer)
expect(player.cacheBuffers.length).toBe(1)
})
it('should append cache chunks and end stream when finishStream drains buffers', () => {
vi.useFakeTimers()
const player = new AudioPlayer('/text-to-audio', true, 'msg-1', 'hello', 'en-US', null)
const mediaSource = testState.mediaSources[0]
mediaSource.emit('sourceopen')
mediaSource.sourceBuffer.updating = false
player.cacheBuffers = [new ArrayBuffer(3)]
; (player as unknown as { finishStream: () => void }).finishStream()
vi.advanceTimersByTime(50)
expect(mediaSource.sourceBuffer.appendBuffer).toHaveBeenCalledTimes(1)
expect(mediaSource.endOfStream).toHaveBeenCalledTimes(1)
vi.useRealTimers()
})
})
})

View File

@@ -26,6 +26,7 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
useEffect(() => {
const audio = audioRef.current
/* v8 ignore next 2 - @preserve */
if (!audio)
return
@@ -217,6 +218,7 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
const drawWaveform = useCallback(() => {
const canvas = canvasRef.current
/* v8 ignore next 2 - @preserve */
if (!canvas)
return
@@ -268,14 +270,20 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
drawWaveform()
}, [drawWaveform, bufferedTime, hasStartedPlaying])
const handleMouseMove = useCallback((e: React.MouseEvent) => {
const handleMouseMove = useCallback((e: React.MouseEvent<HTMLCanvasElement> | React.TouchEvent<HTMLCanvasElement>) => {
const canvas = canvasRef.current
const audio = audioRef.current
if (!canvas || !audio)
return
const clientX = 'touches' in e
? e.touches[0]?.clientX ?? e.changedTouches[0]?.clientX
: e.clientX
if (clientX === undefined)
return
const rect = canvas.getBoundingClientRect()
const percent = Math.min(Math.max(0, e.clientX - rect.left), rect.width) / rect.width
const percent = Math.min(Math.max(0, clientX - rect.left), rect.width) / rect.width
const time = percent * duration
// Check if the hovered position is within a buffered range before updating hoverTime
@@ -289,7 +297,7 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
return (
<div className="flex h-9 min-w-[240px] max-w-[420px] items-center gap-2 rounded-[10px] border border-components-panel-border-subtle bg-components-chat-input-audio-bg-alt p-2 shadow-xs backdrop-blur-sm">
<audio ref={audioRef} src={src} preload="auto">
<audio ref={audioRef} src={src} preload="auto" data-testid="audio-player">
{/* If srcs array is provided, render multiple source elements */}
{srcs && srcs.map((srcUrl, index) => (
<source key={index} src={srcUrl} />
@@ -297,12 +305,8 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
</audio>
<button type="button" data-testid="play-pause-btn" className="inline-flex shrink-0 cursor-pointer items-center justify-center border-none text-text-accent transition-all hover:text-text-accent-secondary disabled:text-components-button-primary-bg-disabled" onClick={togglePlay} disabled={!isAudioAvailable}>
{isPlaying
? (
<div className="i-ri-pause-circle-fill h-5 w-5" />
)
: (
<div className="i-ri-play-large-fill h-5 w-5" />
)}
? (<div className="i-ri-pause-circle-fill h-5 w-5" />)
: (<div className="i-ri-play-large-fill h-5 w-5" />)}
</button>
<div className={cn(isAudioAvailable && 'grow')} hidden={!isAudioAvailable}>
<div className="flex h-8 items-center justify-center">
@@ -313,6 +317,8 @@ const AudioPlayer: React.FC<AudioPlayerProps> = ({ src, srcs }) => {
onClick={handleCanvasInteraction}
onMouseMove={handleMouseMove}
onMouseDown={handleCanvasInteraction}
onTouchMove={handleMouseMove}
onTouchStart={handleCanvasInteraction}
/>
<div className="inline-flex min-w-[50px] items-center justify-center text-text-accent-secondary system-xs-medium">
<span className="rounded-[10px] px-0.5 py-1">{formatTime(duration)}</span>

View File

@@ -1,8 +1,7 @@
import type { ToastHandle } from '@/app/components/base/toast'
import { act, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import useThemeMock from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AudioPlayer from '../AudioPlayer'
@@ -45,6 +44,13 @@ async function advanceWaveformTimer() {
})
}
// eslint-disable-next-line ts/no-explicit-any
type ReactEventHandler = ((...args: any[]) => void) | undefined
function getReactProps<T extends Element>(el: T): Record<string, ReactEventHandler> {
const key = Object.keys(el).find(k => k.startsWith('__reactProps$'))
return key ? (el as unknown as Record<string, Record<string, ReactEventHandler>>)[key] : {}
}
// ─── Setup / teardown ─────────────────────────────────────────────────────────
beforeEach(() => {
@@ -56,8 +62,12 @@ beforeEach(() => {
HTMLMediaElement.prototype.load = vi.fn()
})
afterEach(() => {
vi.runOnlyPendingTimers()
afterEach(async () => {
await act(async () => {
vi.runOnlyPendingTimers()
await Promise.resolve()
await Promise.resolve()
})
vi.useRealTimers()
vi.unstubAllGlobals()
})
@@ -300,36 +310,47 @@ describe('AudioPlayer — waveform generation', () => {
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
it('should use webkitAudioContext when AudioContext is unavailable', async () => {
vi.stubGlobal('AudioContext', undefined)
vi.stubGlobal('webkitAudioContext', buildAudioContext(320))
stubFetchOk(256)
render(<AudioPlayer src="https://cdn.example/audio.mp3" />)
await advanceWaveformTimer()
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
})
// ─── Canvas interactions ──────────────────────────────────────────────────────
async function renderWithDuration(src = 'https://example.com/audio.mp3', durationVal = 120) {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src={src} />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: durationVal, configurable: true })
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => durationVal },
configurable: true,
})
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
})
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
canvas.getBoundingClientRect = () =>
({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
return { audio, canvas }
}
describe('AudioPlayer — canvas seek interactions', () => {
async function renderWithDuration(src = 'https://example.com/audio.mp3', durationVal = 120) {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src={src} />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: durationVal, configurable: true })
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => durationVal },
configurable: true,
})
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
})
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
canvas.getBoundingClientRect = () =>
({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
return { audio, canvas }
}
it('should seek to clicked position and start playback', async () => {
const { audio, canvas } = await renderWithDuration()
@@ -392,3 +413,309 @@ describe('AudioPlayer — canvas seek interactions', () => {
})
})
})
// ─── Missing coverage tests ───────────────────────────────────────────────────
describe('AudioPlayer — missing coverage', () => {
it('should handle unmounting without crashing (clears timeout)', () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
unmount()
// Timer is cleared, no state update should happen after unmount
})
it('should handle getContext returning null safely', () => {
const originalGetContext = HTMLCanvasElement.prototype.getContext
HTMLCanvasElement.prototype.getContext = vi.fn().mockReturnValue(null)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
HTMLCanvasElement.prototype.getContext = originalGetContext
})
it('should fallback to fillRect when roundRect is missing in drawWaveform', async () => {
// Note: React 18 / testing-library wraps updates automatically, but we still wait for advanceWaveformTimer
const originalGetContext = HTMLCanvasElement.prototype.getContext
let fillRectCalled = false
HTMLCanvasElement.prototype.getContext = function (this: HTMLCanvasElement, ...args: Parameters<typeof HTMLCanvasElement.prototype.getContext>) {
const ctx = originalGetContext.apply(this, args) as CanvasRenderingContext2D | null
if (ctx) {
Object.defineProperty(ctx, 'roundRect', { value: undefined, configurable: true })
const origFillRect = ctx.fillRect
ctx.fillRect = function (...fArgs: Parameters<CanvasRenderingContext2D['fillRect']>) {
fillRectCalled = true
return origFillRect.apply(this, fArgs)
}
}
return ctx as CanvasRenderingContext2D
} as typeof HTMLCanvasElement.prototype.getContext
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
expect(fillRectCalled).toBe(true)
HTMLCanvasElement.prototype.getContext = originalGetContext
})
it('should handle play error gracefully when togglePlay is clicked', async () => {
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
vi.spyOn(HTMLMediaElement.prototype, 'play').mockRejectedValue(new Error('play failed'))
render(<AudioPlayer src="https://example.com/audio.mp3" />)
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(errorSpy).toHaveBeenCalled()
errorSpy.mockRestore()
})
it('should notify error when audio.play() fails during canvas seek', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 120, configurable: true })
canvas.getBoundingClientRect = () => ({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
vi.spyOn(HTMLMediaElement.prototype, 'play').mockRejectedValue(new Error('play failed'))
await act(async () => {
fireEvent.click(canvas, { clientX: 100 })
})
// We can observe the error by checking document body for toast if Toast acts synchronously
// Or we just ensure the execution branched into catch naturally.
expect(HTMLMediaElement.prototype.play).toHaveBeenCalled()
})
it('should support touch events on canvas', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas') as HTMLCanvasElement
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 120, configurable: true })
canvas.getBoundingClientRect = () => ({ left: 0, width: 200, top: 0, height: 10, right: 200, bottom: 10 }) as DOMRect
await act(async () => {
// Use touch events
fireEvent.touchStart(canvas, {
touches: [{ clientX: 50 }],
})
})
expect(HTMLMediaElement.prototype.play).toHaveBeenCalled()
})
it('should gracefully handle interaction when canvas/audio refs are null', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/audio.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
unmount()
expect(canvas).toBeTruthy()
})
it('should keep play button disabled when source is unavailable', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="blob:https://example.com" />)
await advanceWaveformTimer() // sets isAudioAvailable to false (invalid protocol)
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(btn).toBeDisabled()
expect(HTMLMediaElement.prototype.play).not.toHaveBeenCalled()
expect(toastSpy).not.toHaveBeenCalled()
toastSpy.mockRestore()
})
it('should notify when toggle is invoked while audio is unavailable', async () => {
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
const btn = screen.getByTestId('play-pause-btn')
const props = getReactProps(btn)
await act(async () => {
props.onClick?.()
})
expect(toastSpy).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'Audio element not found',
}))
toastSpy.mockRestore()
})
})
describe('AudioPlayer — additional branch coverage', () => {
it('should render multiple source elements when srcs is provided', () => {
render(<AudioPlayer srcs={['a.mp3', 'b.ogg']} />)
const audio = screen.getByTestId('audio-player')
const sources = audio.querySelectorAll('source')
expect(sources).toHaveLength(2)
})
it('should handle handleMouseMove with empty touch list', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/a.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas')
await act(async () => {
fireEvent.touchMove(canvas, {
touches: [],
changedTouches: [{ clientX: 50 }],
})
})
})
it('should handle handleMouseMove with missing clientX', async () => {
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/a.mp3" />)
await advanceWaveformTimer()
const canvas = screen.getByTestId('waveform-canvas')
await act(async () => {
fireEvent.touchMove(canvas, {
touches: [{}] as unknown as TouchList,
})
})
})
it('should render "Audio source unavailable" when isAudioAvailable is false', async () => {
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
expect(screen.queryByTestId('play-pause-btn')).toBeDisabled()
})
it('should update current time on timeupdate event', async () => {
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'currentTime', { value: 10, configurable: true })
await act(async () => {
audio.dispatchEvent(new Event('timeupdate'))
})
})
it('should ignore toggle click after audio error marks source unavailable', async () => {
const toastSpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({} as unknown as ToastHandle))
render(<AudioPlayer src="https://example.com/a.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
await act(async () => {
audio.dispatchEvent(new Event('error'))
})
const btn = screen.getByTestId('play-pause-btn')
await act(async () => {
fireEvent.click(btn)
})
expect(btn).toBeDisabled()
expect(HTMLMediaElement.prototype.play).not.toHaveBeenCalled()
expect(toastSpy).not.toHaveBeenCalled()
toastSpy.mockRestore()
})
it('should cover Dark theme waveform states', async () => {
; (useThemeMock as ReturnType<typeof vi.fn>).mockReturnValue({ theme: Theme.dark })
vi.stubGlobal('AudioContext', buildAudioContext(300))
stubFetchOk(128)
render(<AudioPlayer src="https://example.com/audio.mp3" />)
const audio = document.querySelector('audio') as HTMLAudioElement
Object.defineProperty(audio, 'duration', { value: 100, configurable: true })
Object.defineProperty(audio, 'currentTime', { value: 50, configurable: true })
await act(async () => {
audio.dispatchEvent(new Event('loadedmetadata'))
audio.dispatchEvent(new Event('timeupdate'))
})
await advanceWaveformTimer()
expect(screen.getByTestId('waveform-canvas')).toBeInTheDocument()
})
it('should handle missing canvas/audio in handleCanvasInteraction/handleMouseMove', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
unmount()
fireEvent.click(canvas)
fireEvent.mouseMove(canvas)
})
it('should cover waveform branches for hover and played states', async () => {
const { audio, canvas } = await renderWithDuration('https://example.com/a.mp3', 100)
// Set some progress
Object.defineProperty(audio, 'currentTime', { value: 20, configurable: true })
// Trigger hover on a buffered range
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => 100 },
configurable: true,
})
await act(async () => {
fireEvent.mouseMove(canvas, { clientX: 50 }) // 50s hover
audio.dispatchEvent(new Event('timeupdate'))
})
expect(canvas).toBeInTheDocument()
})
it('should hit null-ref guards in canvas handlers after unmount', async () => {
const { unmount } = render(<AudioPlayer src="https://example.com/a.mp3" />)
const canvas = screen.getByTestId('waveform-canvas')
const props = getReactProps(canvas)
unmount()
await act(async () => {
props.onClick?.({ preventDefault: vi.fn(), clientX: 10 })
props.onMouseMove?.({ clientX: 10 })
})
})
it('should execute non-matching buffered branch in hover loop', async () => {
const { audio, canvas } = await renderWithDuration('https://example.com/a.mp3', 100)
Object.defineProperty(audio, 'buffered', {
value: { length: 1, start: () => 0, end: () => 10 },
configurable: true,
})
await act(async () => {
fireEvent.mouseMove(canvas, { clientX: 180 }) // time near 90, outside 0-10
})
expect(canvas).toBeInTheDocument()
})
})

View File

@@ -1,24 +1,9 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
// AudioGallery.spec.tsx
import { describe, expect, it, vi } from 'vitest'
import AudioGallery from '../index'
// Mock AudioPlayer so we only assert prop forwarding
const audioPlayerMock = vi.fn()
vi.mock('../AudioPlayer', () => ({
default: (props: { srcs: string[] }) => {
audioPlayerMock(props)
return <div data-testid="audio-player" />
},
}))
describe('AudioGallery', () => {
afterEach(() => {
audioPlayerMock.mockClear()
vi.resetModules()
beforeEach(() => {
vi.spyOn(HTMLMediaElement.prototype, 'load').mockImplementation(() => { })
})
it('returns null when srcs array is empty', () => {
@@ -33,11 +18,15 @@ describe('AudioGallery', () => {
expect(screen.queryByTestId('audio-player')).toBeNull()
})
it('filters out falsy srcs and passes valid srcs to AudioPlayer', () => {
it('filters out falsy srcs and renders only valid sources in AudioPlayer', () => {
render(<AudioGallery srcs={['a.mp3', '', 'b.mp3']} />)
expect(screen.getByTestId('audio-player')).toBeInTheDocument()
expect(audioPlayerMock).toHaveBeenCalledTimes(1)
expect(audioPlayerMock).toHaveBeenCalledWith({ srcs: ['a.mp3', 'b.mp3'] })
const audio = screen.getByTestId('audio-player')
const sources = audio.querySelectorAll('source')
expect(audio).toBeInTheDocument()
expect(sources).toHaveLength(2)
expect(sources[0]?.getAttribute('src')).toBe('a.mp3')
expect(sources[1]?.getAttribute('src')).toBe('b.mp3')
})
it('wraps AudioPlayer inside container with expected class', () => {
@@ -45,5 +34,6 @@ describe('AudioGallery', () => {
const root = container.firstChild as HTMLElement
expect(root).toBeTruthy()
expect(root.className).toContain('my-3')
expect(screen.getByTestId('audio-player')).toBeInTheDocument()
})
})

View File

@@ -1,6 +1,18 @@
import type { ChatItemInTree } from '../types'
import type { IChatItem } from '../chat/type'
import type { ChatItem, ChatItemInTree } from '../types'
import { get } from 'es-toolkit/compat'
import { buildChatItemTree, getThreadMessages } from '../utils'
import { UUID_NIL } from '../constants'
import {
buildChatItemTree,
getLastAnswer,
getProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams,
getProcessedUserVariablesFromUrlParams,
getRawInputsFromUrlParams,
getRawUserVariablesFromUrlParams,
getThreadMessages,
isValidGeneratedAnswer,
} from '../utils'
import branchedTestMessages from './branchedTestMessages.json'
import legacyTestMessages from './legacyTestMessages.json'
import mixedTestMessages from './mixedTestMessages.json'
@@ -13,6 +25,15 @@ function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatI
return get(tree, path)
}
class MockDecompressionStream {
readable: unknown
writable: unknown
constructor() {
this.readable = {}
this.writable = {}
}
}
describe('build chat item tree and get thread messages', () => {
const tree1 = buildChatItemTree(branchedTestMessages as ChatItemInTree[])
@@ -247,12 +268,12 @@ describe('build chat item tree and get thread messages', () => {
expect(tree6).toMatchSnapshot()
})
it ('should get thread messages from tree6, using the last message as target', () => {
it('should get thread messages from tree6, using the last message as target', () => {
const threadMessages6_1 = getThreadMessages(tree6)
expect(threadMessages6_1).toMatchSnapshot()
})
it ('should get thread messages from tree6, using specified message as target', () => {
it('should get thread messages from tree6, using specified message as target', () => {
const threadMessages6_2 = getThreadMessages(tree6, 'ff4c2b43-48a5-47ad-9dc5-08b34ddba61b')
expect(threadMessages6_2).toMatchSnapshot()
})
@@ -269,3 +290,285 @@ describe('build chat item tree and get thread messages', () => {
expect(tree8).toMatchSnapshot()
})
})
describe('chat utils - url params and answer helpers', () => {
const setSearch = (search: string) => {
window.history.replaceState({}, '', `${window.location.pathname}${search}`)
}
beforeEach(() => {
vi.clearAllMocks()
vi.stubGlobal('DecompressionStream', MockDecompressionStream)
vi.stubGlobal('TextDecoder', class {
decode() { return 'decompressed_text' }
})
const mockPipeThrough = vi.fn().mockReturnValue({})
vi.stubGlobal('Response', class {
body = { pipeThrough: mockPipeThrough }
arrayBuffer = vi.fn().mockResolvedValue(new ArrayBuffer(8))
})
setSearch('')
})
afterEach(() => {
vi.unstubAllGlobals()
})
describe('URL Parameter Extractors', () => {
it('getRawInputsFromUrlParams extracts inputs except sys. and user.', async () => {
setSearch('?custom=123&sys.param=456&user.param=789&encoded=a%20b')
const res = await getRawInputsFromUrlParams()
expect(res).toEqual({ custom: '123', encoded: 'a b' })
})
it('getRawUserVariablesFromUrlParams extracts only user. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789&user.encoded=a%20b')
const res = await getRawUserVariablesFromUrlParams()
expect(res).toEqual({ param: '789', encoded: 'a b' })
})
it('getProcessedInputsFromUrlParams decompresses base64 inputs', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedInputsFromUrlParams()
expect(res).toEqual({ custom: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams decompresses sys. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams parses redirect_url without query string', async () => {
setSearch(`?redirect_url=${encodeURIComponent('http://example.com')}&sys.param=456`)
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('getProcessedSystemVariablesFromUrlParams parses redirect_url', async () => {
setSearch(`?redirect_url=${encodeURIComponent('http://example.com?sys.redirected=abc')}&sys.param=456`)
const res = await getProcessedSystemVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text', redirected: 'decompressed_text' })
})
it('getProcessedUserVariablesFromUrlParams decompresses user. prefixed params', async () => {
setSearch('?custom=123&sys.param=456&user.param=789')
const res = await getProcessedUserVariablesFromUrlParams()
expect(res).toEqual({ param: 'decompressed_text' })
})
it('decodeBase64AndDecompress failure returns undefined softly', async () => {
vi.stubGlobal('atob', () => {
throw new Error('invalid')
})
setSearch('?custom=invalid_base64')
const res = await getProcessedInputsFromUrlParams()
expect(res).toEqual({ custom: undefined })
})
})
describe('Answer Validation', () => {
it('isValidGeneratedAnswer returns true for typical answers', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: '123', isOpeningStatement: false } as ChatItem)).toBe(true)
})
it('isValidGeneratedAnswer returns false for placeholders', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: 'answer-placeholder-123', isOpeningStatement: false } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for opening statements', () => {
expect(isValidGeneratedAnswer({ isAnswer: true, id: '123', isOpeningStatement: true } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for questions', () => {
expect(isValidGeneratedAnswer({ isAnswer: false, id: '123', isOpeningStatement: false } as ChatItem)).toBe(false)
})
it('isValidGeneratedAnswer returns false for falsy items', () => {
expect(isValidGeneratedAnswer(undefined)).toBe(false)
})
it('getLastAnswer returns the last valid answer from a list', () => {
const list = [
{ isAnswer: false, id: 'q1', isOpeningStatement: false },
{ isAnswer: true, id: 'a1', isOpeningStatement: false },
{ isAnswer: false, id: 'q2', isOpeningStatement: false },
{ isAnswer: true, id: 'answer-placeholder-2', isOpeningStatement: false },
] as ChatItem[]
expect(getLastAnswer(list)?.id).toBe('a1')
})
it('getLastAnswer returns null if no valid answer', () => {
const list = [
{ isAnswer: false, id: 'q1', isOpeningStatement: false },
{ isAnswer: true, id: 'answer-placeholder-2', isOpeningStatement: false },
] as ChatItem[]
expect(getLastAnswer(list)).toBeNull()
})
})
describe('ChatItem Tree Builders', () => {
it('buildChatItemTree builds a flat tree for legacy messages (parentMessageId = UUID_NIL)', () => {
const list: IChatItem[] = [
{ id: 'q1', isAnswer: false, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'a1', isAnswer: true, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'q2', isAnswer: false, parentMessageId: UUID_NIL } as IChatItem,
{ id: 'a2', isAnswer: true, parentMessageId: UUID_NIL } as IChatItem,
]
const tree = buildChatItemTree(list)
expect(tree.length).toBe(1)
expect(tree[0].id).toBe('q1')
expect(tree[0].children?.[0].id).toBe('a1')
expect(tree[0].children?.[0].children?.[0].id).toBe('q2')
expect(tree[0].children?.[0].children?.[0].children?.[0].id).toBe('a2')
expect(tree[0].children?.[0].children?.[0].children?.[0].siblingIndex).toBe(0)
})
it('buildChatItemTree builds nested tree based on parentMessageId', () => {
const list: IChatItem[] = [
{ id: 'q1', isAnswer: false, parentMessageId: null } as IChatItem,
{ id: 'a1', isAnswer: true } as IChatItem,
{ id: 'q2', isAnswer: false, parentMessageId: 'a1' } as IChatItem,
{ id: 'a2', isAnswer: true } as IChatItem,
{ id: 'q3', isAnswer: false, parentMessageId: 'a1' } as IChatItem,
{ id: 'a3', isAnswer: true } as IChatItem,
{ id: 'q4', isAnswer: false, parentMessageId: 'missing-parent' } as IChatItem,
{ id: 'a4', isAnswer: true } as IChatItem,
]
const tree = buildChatItemTree(list)
expect(tree.length).toBe(2)
expect(tree[0].id).toBe('q1')
expect(tree[1].id).toBe('q4')
const a1 = tree[0].children![0]
expect(a1.id).toBe('a1')
expect(a1.children?.length).toBe(2)
expect(a1.children![0].id).toBe('q2')
expect(a1.children![1].id).toBe('q3')
expect(a1.children![0].children![0].siblingIndex).toBe(0)
expect(a1.children![1].children![0].siblingIndex).toBe(1)
})
it('getThreadMessages node without children', () => {
const tree = [{ id: 'q1', isAnswer: false }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'q1')
expect(thread.length).toBe(1)
expect(thread[0].id).toBe('q1')
})
it('getThreadMessages target not found', () => {
const tree = [{ id: 'q1', isAnswer: false, children: [] }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'missing')
expect(thread.length).toBe(0)
})
it('getThreadMessages target not found with undefined children', () => {
const tree = [{ id: 'q1', isAnswer: false }]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'missing')
expect(thread.length).toBe(0)
})
it('getThreadMessages flat path logic', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[])
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q2', 'a2'])
expect(thread[1].siblingCount).toBe(1)
expect(thread[3].siblingCount).toBe(1)
})
it('getThreadMessages to specific target', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}, {
id: 'q3',
isAnswer: false,
children: [{
id: 'a3',
isAnswer: true,
siblingIndex: 1,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'a3')
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q3', 'a3'])
expect(thread[3].prevSibling).toBe('a2')
expect(thread[3].nextSibling).toBeUndefined()
})
it('getThreadMessages targetNode has descendants', () => {
const tree = [{
id: 'q1',
isAnswer: false,
children: [{
id: 'a1',
isAnswer: true,
siblingIndex: 0,
children: [{
id: 'q2',
isAnswer: false,
children: [{
id: 'a2',
isAnswer: true,
siblingIndex: 0,
children: [],
}],
}, {
id: 'q3',
isAnswer: false,
children: [{
id: 'a3',
isAnswer: true,
siblingIndex: 1,
children: [],
}],
}],
}],
}]
const thread = getThreadMessages(tree as unknown as ChatItemInTree[], 'a1')
expect(thread.length).toBe(4)
expect(thread.map(t => t.id)).toEqual(['q1', 'a1', 'q3', 'a3'])
expect(thread[3].prevSibling).toBe('a2')
})
})
})

View File

@@ -4,12 +4,11 @@ import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import type { HumanInputFormData } from '@/types/workflow'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { InputVarType } from '@/app/components/workflow/types'
import {
fetchSuggestedQuestions,
stopChatMessageResponding,
submitHumanInputForm,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { useChat } from '../../chat/hooks'
@@ -501,6 +500,34 @@ describe('ChatWrapper', () => {
expect(handleSwitchSibling).toHaveBeenCalledWith('1', expect.any(Object))
})
it('should call fetchSuggestedQuestions from workflow resumption options callback', () => {
const handleSwitchSibling = vi.fn()
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [],
handleSwitchSibling,
} as unknown as ChatHookReturn)
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appPrevChatTree: [{
id: 'resume-node',
content: 'Paused answer',
isAnswer: true,
workflow_run_id: 'workflow-1',
humanInputFormDataList: [{ label: 'resume' }] as unknown as HumanInputFormData[],
children: [],
}],
})
render(<ChatWrapper />)
expect(handleSwitchSibling).toHaveBeenCalledWith('resume-node', expect.any(Object))
const resumeOptions = handleSwitchSibling.mock.calls[0][1]
resumeOptions.onGetSuggestedQuestions('response-from-resume')
expect(fetchSuggestedQuestions).toHaveBeenCalledWith('response-from-resume', 'webApp', 'test-app-id')
})
it('should handle workflow resumption with nested children (DFS)', () => {
const handleSwitchSibling = vi.fn()
vi.mocked(useChat).mockReturnValue({
@@ -760,6 +787,47 @@ describe('ChatWrapper', () => {
})
})
it('should handle human input form submission for web app', async () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
isInstalledApp: false,
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [
{ id: 'q1', content: 'Question' },
{
id: 'a1',
isAnswer: true,
content: '',
humanInputFormDataList: [{
id: 'node1',
form_id: 'form1',
form_token: 'token-web-1',
node_id: 'node1',
node_title: 'Node Web 1',
display_in_ui: true,
form_content: '{{#$output.test#}}',
inputs: [{ variable: 'test', label: 'Test', type: 'paragraph', required: true, output_variable_name: 'test', default: { type: 'text', value: '' } }],
actions: [{ id: 'run', title: 'Run', button_style: 'primary' }],
}] as unknown as HumanInputFormData[],
},
],
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
expect(await screen.findByText('Node Web 1')).toBeInTheDocument()
const input = screen.getAllByRole('textbox').find(el => el.closest('.chat-answer-container')) || screen.getAllByRole('textbox')[0]
fireEvent.change(input, { target: { value: 'web-test' } })
fireEvent.click(screen.getByText('Run'))
await waitFor(() => {
expect(submitHumanInputForm).toHaveBeenCalledWith('token-web-1', expect.any(Object))
})
})
it('should filter opening statement in new conversation with single item', () => {
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
@@ -888,8 +956,16 @@ describe('ChatWrapper', () => {
})
it('should render answer icon when configured', () => {
const appDataWithAnswerIcon = {
site: {
...mockAppData.site,
use_icon_as_answer_icon: true,
},
} as unknown as AppData
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appData: appDataWithAnswerIcon,
} as ChatWithHistoryContextValue)
vi.mocked(useChat).mockReturnValue({
@@ -899,6 +975,7 @@ describe('ChatWrapper', () => {
render(<ChatWrapper />)
expect(screen.getByText('Answer')).toBeInTheDocument()
expect(screen.getByAltText('answer icon')).toBeInTheDocument()
})
it('should render question icon when user avatar is available', () => {
@@ -920,6 +997,26 @@ describe('ChatWrapper', () => {
expect(avatar).toBeInTheDocument()
})
it('should use fallback values for nullable appData, appMeta and user name', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appData: null as unknown as AppData,
appMeta: null as unknown as AppMeta,
initUserVariables: {
avatar_url: 'https://example.com/avatar-fallback.png',
},
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [{ id: 'q1', content: 'Question with fallback avatar name' }],
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
expect(screen.getByText('Question with fallback avatar name')).toBeInTheDocument()
expect(screen.getByAltText('user')).toBeInTheDocument()
})
it('should set handleStop on currentChatInstanceRef', () => {
const handleStop = vi.fn()
const currentChatInstanceRef = { current: { handleStop: vi.fn() } } as ChatWithHistoryContextValue['currentChatInstanceRef']
@@ -1212,20 +1309,45 @@ describe('ChatWrapper', () => {
it('should handle doRegenerate with editedQuestion', async () => {
const handleSend = vi.fn()
const mockFiles = [
{
id: 'file-q1',
name: 'q1.txt',
type: 'text/plain',
size: 100,
url: 'https://example.com/q1.txt',
extension: 'txt',
mime_type: 'text/plain',
} as unknown as FileEntity,
] as FileEntity[]
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
chatList: [
{ id: 'q1', content: 'Original question', message_files: [] },
{ id: 'q1', content: 'Original question', message_files: mockFiles },
{ id: 'a1', isAnswer: true, content: 'Answer', parentMessageId: 'q1' },
],
handleSend,
} as unknown as ChatHookReturn)
const { container } = render(<ChatWrapper />)
render(<ChatWrapper />)
// This would test line 198-200 - the editedQuestion path
// The actual regenerate with edited question happens through the UI
expect(container).toBeInTheDocument()
fireEvent.click(await screen.findByTestId('edit-btn'))
const editedTextarea = await screen.findByDisplayValue('Original question')
fireEvent.change(editedTextarea, { target: { value: 'Edited question text' } })
fireEvent.click(screen.getByTestId('save-edit-btn'))
await waitFor(() => {
expect(handleSend).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
query: 'Edited question text',
files: mockFiles,
}),
expect.any(Object),
)
})
})
it('should handle doRegenerate when parentAnswer is not a valid generated answer', async () => {
@@ -1692,4 +1814,31 @@ describe('ChatWrapper', () => {
// Should not be disabled because it's not required
expect(container).not.toBeInTheDocument()
})
it('should handle fallback branches for appParams, appId and empty chat instance ref', async () => {
const handleSend = vi.fn()
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
appParams: undefined as unknown as ChatConfig,
appId: '',
currentConversationId: '',
currentChatInstanceRef: { current: null } as unknown as ChatWithHistoryContextValue['currentChatInstanceRef'],
})
vi.mocked(useChat).mockReturnValue({
...defaultChatHookReturn,
handleSend,
} as unknown as ChatHookReturn)
render(<ChatWrapper />)
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'trigger fallback path' } })
fireEvent.keyDown(textarea, { key: 'Enter', code: 'Enter', keyCode: 13 })
await waitFor(() => {
expect(handleSend).toHaveBeenCalled()
})
})
})

View File

@@ -1,9 +1,9 @@
import type { i18n } from 'i18next'
import type { ChatConfig } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import type { AppData, AppMeta } from '@/models/share'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as ReactI18next from 'react-i18next'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useChatWithHistoryContext } from '../context'
import HeaderInMobile from '../header-in-mobile'
@@ -80,7 +80,14 @@ vi.mock('@/app/components/base/modal', () => ({
// Sidebar mock removed to use real component
const mockAppData = { site: { title: 'Test Chat', chat_color_theme: 'blue' } } as unknown as AppData
const mockAppData: AppData = {
app_id: 'test-app',
custom_config: null,
site: {
title: 'Test Chat',
chat_color_theme: 'blue',
},
}
const defaultContextValue: ChatWithHistoryContextValue = {
appData: mockAppData,
currentConversationId: '',
@@ -104,18 +111,27 @@ const defaultContextValue: ChatWithHistoryContextValue = {
currentChatInstanceRef: { current: { handleStop: vi.fn() } } as ChatWithHistoryContextValue['currentChatInstanceRef'],
setIsResponding: vi.fn(),
setClearChatList: vi.fn(),
appParams: { system_parameters: { vision_config: { enabled: false } } } as unknown as ChatConfig,
appMeta: {} as AppMeta,
appParams: {
system_parameters: {
audio_file_size_limit: 10,
file_size_limit: 10,
image_file_size_limit: 10,
video_file_size_limit: 10,
workflow_file_upload_limit: 10,
},
more_like_this: { enabled: false },
} as ChatConfig,
appMeta: { tool_icons: {} } as AppMeta,
appPrevChatTree: [],
newConversationInputs: {},
newConversationInputsRef: { current: {} } as ChatWithHistoryContextValue['newConversationInputsRef'],
newConversationInputsRef: { current: {} },
appChatListDataLoading: false,
chatShouldReloadKey: '',
isMobile: true,
currentConversationInputs: null,
setCurrentConversationInputs: vi.fn(),
allInputsHidden: false,
conversationRenaming: false, // Added missing property
conversationRenaming: false,
}
describe('HeaderInMobile', () => {
@@ -134,7 +150,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
})
render(<HeaderInMobile />)
@@ -270,7 +286,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handlePinConversation: handlePin,
pinnedConversationList: [],
})
@@ -292,9 +308,9 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleUnpinConversation: handleUnpin,
pinnedConversationList: [{ id: '1' }] as unknown as ConversationItem[],
pinnedConversationList: [{ id: '1', name: 'Conv 1', inputs: null, introduction: '' }],
})
render(<HeaderInMobile />)
@@ -314,7 +330,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
@@ -342,7 +358,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
@@ -373,7 +389,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleRenameConversation: vi.fn(),
conversationRenaming: true, // Loading state
pinnedConversationList: [],
@@ -396,7 +412,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
@@ -422,7 +438,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
@@ -454,7 +470,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: '' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: '', inputs: null, introduction: '' },
})
render(<HeaderInMobile />)
@@ -485,16 +501,17 @@ describe('HeaderInMobile', () => {
})
it('should render app icon and title correctly', () => {
const appDataWithIcon = {
const appDataWithIcon: AppData = {
app_id: 'test-app',
custom_config: null,
site: {
title: 'My App',
icon: 'emoji',
icon_type: 'emoji',
icon_url: '',
icon_background: '#FF0000',
chat_color_theme: 'blue',
},
} as unknown as AppData
}
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
@@ -512,7 +529,7 @@ describe('HeaderInMobile', () => {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1' } as unknown as ConversationItem,
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleRenameConversation: handleRename,
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
@@ -524,4 +541,59 @@ describe('HeaderInMobile', () => {
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
expect(screen.queryByText('share.chat.deleteConversation.title')).not.toBeInTheDocument()
})
it('should use empty string fallback for delete content translation', async () => {
const handleDelete = vi.fn()
const useTranslationSpy = vi.spyOn(ReactI18next, 'useTranslation')
useTranslationSpy.mockReturnValue({
t: (key: string) => key === 'chat.deleteConversation.content' ? '' : key,
i18n: {} as unknown as i18n,
ready: true,
tReady: true,
} as unknown as ReturnType<typeof ReactI18next.useTranslation>)
try {
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: 'Conv 1', inputs: null, introduction: '' },
handleDeleteConversation: handleDelete,
pinnedConversationList: [],
})
render(<HeaderInMobile />)
fireEvent.click(await screen.findByText('Conv 1'))
fireEvent.click(await screen.findByText(/sidebar\.action\.delete/i))
expect(await screen.findByRole('button', { name: /common\.operation\.confirm|operation\.confirm/i })).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.confirm|operation\.confirm/i }))
expect(handleDelete).toHaveBeenCalledWith('1', expect.any(Object))
}
finally {
useTranslationSpy.mockRestore()
}
})
it('should use empty string fallback for rename modal name', async () => {
const handleRename = vi.fn()
vi.mocked(useChatWithHistoryContext).mockReturnValue({
...defaultContextValue,
currentConversationId: '1',
currentConversationItem: { id: '1', name: '', inputs: null, introduction: '' },
handleRenameConversation: handleRename,
pinnedConversationList: [],
})
const { container } = render(<HeaderInMobile />)
const operationTrigger = container.querySelector('.system-md-semibold')?.parentElement as HTMLElement
fireEvent.click(operationTrigger)
fireEvent.click(await screen.findByText(/explore\.sidebar\.action\.rename|sidebar\.action\.rename/i))
const input = await screen.findByRole('textbox')
expect(input).toHaveValue('')
fireEvent.change(input, { target: { value: 'Renamed from empty' } })
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i }))
expect(handleRename).toHaveBeenCalledWith('1', 'Renamed from empty', expect.any(Object))
})
})

View File

@@ -2,9 +2,7 @@ import type { RefObject } from 'react'
import type { ChatConfig } from '../../types'
import type { InstalledApp } from '@/models/explore'
import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import useDocumentTitle from '@/hooks/use-document-title'
import { useChatWithHistory } from '../hooks'
@@ -113,81 +111,22 @@ describe('ChatWithHistory', () => {
vi.mocked(useChatWithHistory).mockReturnValue(defaultHookReturn)
})
it('renders desktop view with expanded sidebar and builds theme', () => {
it('renders desktop view with expanded sidebar and builds theme', async () => {
vi.mocked(useBreakpoints).mockReturnValue(MediaType.pc)
render(<ChatWithHistory />)
// Checks if the desktop elements render correctly
// Checks if the desktop elements render correctly
// Sidebar real component doesn't have data-testid="sidebar", so we check for its presence via class or content.
// Sidebar usually has "New Chat" button or similar.
// However, looking at the Sidebar mock it was just a div.
// Real Sidebar -> web/app/components/base/chat/chat-with-history/sidebar/index.tsx
// It likely has some text or distinct element.
// ChatWrapper also removed mock.
// Header also removed mock.
// For now, let's verify some key elements that should be present in these components.
// Sidebar: "Explore" or "Chats" or verify navigation structure.
// Header: Title or similar.
// ChatWrapper: "Start a new chat" or similar.
// Given the complexity of real components and lack of testIds, we might need to rely on:
// 1. Adding testIds to real components (preferred but might be out of scope if I can't touch them? Guidelines say "don't mock base components", but adding testIds is fine).
// But I can't see those files right now.
// 2. Use getByText for known static content.
// Let's assume some content based on `mockAppData` title 'Test Chat'.
// Header should contain 'Test Chat'.
// Check for "Test Chat" - might appear multiple times (header, sidebar, document title etc)
// header-in-mobile renders 'Test Chat'.
const titles = screen.getAllByText('Test Chat')
expect(titles.length).toBeGreaterThan(0)
// Sidebar should be present.
// We can check for a specific element in sidebar, e.g. "New Chat" button if it exists.
// Or we can check for the sidebar container class if possible.
// Let's look at `index.tsx` logic.
// Sidebar is rendered.
// Let's try to query by something generic or update to use `container.querySelector`.
// But `screen` is better.
// ChatWrapper is rendered.
// It renders "ChatWrapper" text? No, it's the real component now.
// Real ChatWrapper renders "Welcome" or chat list.
// In `chat-wrapper.spec.tsx`, we saw it renders "Welcome" or "Q1".
// Here `defaultHookReturn` returns empty chat list/conversation.
// So it might render nothing or empty state?
// Let's wait and see what `chat-wrapper.spec.tsx` expectations were.
// It expects "Welcome" if `isOpeningStatement` is true.
// In `index.spec.tsx` mock hook return:
// `currentConversationItem` is undefined.
// `conversationList` is [].
// `appPrevChatTree` is [].
// So ChatWrapper might render empty or loading?
// This is an integration test now.
// We need to ensure the hook return makes sense for the child components.
// Let's just assert the document title since we know that works?
// And check if we can find *something*.
// For now, I'll comment out the specific testId checks and rely on visual/text checks that are likely to flourish.
// header-in-mobile renders 'Test Chat'.
// Sidebar?
// Actually, `ChatWithHistory` renders `Sidebar` in a div with width.
// We can check if that div exists?
// Let's update to checks that are likely to pass or allow us to debug.
// expect(document.title).toBe('Test Chat')
// Checks if the document title was set correctly
expect(useDocumentTitle).toHaveBeenCalledWith('Test Chat')
// Checks if the themeBuilder useEffect fired
expect(mockBuildTheme).toHaveBeenCalledWith('blue', false)
await waitFor(() => {
expect(mockBuildTheme).toHaveBeenCalledWith('blue', false)
})
})
it('renders desktop view with collapsed sidebar and tests hover effects', () => {

View File

@@ -46,6 +46,7 @@ const HeaderInMobile = () => {
setShowConfirm(null)
}, [])
const handleDelete = useCallback(() => {
/* v8 ignore next 2 -- @preserve */
if (showConfirm)
handleDeleteConversation(showConfirm.id, { onSuccess: handleCancelConfirm })
}, [showConfirm, handleDeleteConversation, handleCancelConfirm])
@@ -53,6 +54,7 @@ const HeaderInMobile = () => {
setShowRename(null)
}, [])
const handleRename = useCallback((newName: string) => {
/* v8 ignore next 2 -- @preserve */
if (showRename)
handleRenameConversation(showRename.id, newName, { onSuccess: handleCancelRename })
}, [showRename, handleRenameConversation, handleCancelRename])

View File

@@ -0,0 +1,128 @@
import type { InputForm } from '../type'
import { renderHook } from '@testing-library/react'
import { InputVarType } from '@/app/components/workflow/types'
import { TransferMethod } from '@/types/app'
import { useCheckInputsForms } from '../check-input-forms-hooks'
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({ notify: mockNotify }),
}))
describe('useCheckInputsForms', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should return true when no inputs required', () => {
const { result } = renderHook(() => useCheckInputsForms())
const isValid = result.current.checkInputsForm({}, [])
expect(isValid).toBe(true)
})
it('should return false and notify when a required input is missing', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_var', label: 'Test Variable', required: true, type: InputVarType.textInput as string }]
const isValid = result.current.checkInputsForm({}, inputsForm as InputForm[])
expect(isValid).toBe(false)
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
message: expect.stringContaining('appDebug.errorMessage.valueOfVarRequired'),
}),
)
})
it('should ignore missing but not required inputs', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_var', label: 'Test Variable', required: false, type: InputVarType.textInput as string }]
const isValid = result.current.checkInputsForm({}, inputsForm as InputForm[])
expect(isValid).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should notify and return undefined when a file is still uploading (singleFile)', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_file', label: 'Test File', required: true, type: InputVarType.singleFile as string }]
const inputs = {
test_file: { transferMethod: TransferMethod.local_file }, // no uploadedId means still uploading
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
it('should notify and return undefined when a file is still uploading (multiFiles)', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_files', label: 'Test Files', required: true, type: InputVarType.multiFiles as string }]
const inputs = {
test_files: [{ transferMethod: TransferMethod.local_file }], // no uploadedId
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
it('should return true when all files are uploaded and required variables are present', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [{ variable: 'test_file', label: 'Test File', required: true, type: InputVarType.singleFile as string }]
const inputs = {
test_file: { transferMethod: TransferMethod.local_file, uploadedId: '123' }, // uploaded
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBe(true)
expect(mockNotify).not.toHaveBeenCalled()
})
it('should short-circuit remaining fields after first required input is missing', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [
{ variable: 'missing_text', label: 'Missing Text', required: true, type: InputVarType.textInput as string },
{ variable: 'later_file', label: 'Later File', required: true, type: InputVarType.singleFile as string },
]
const inputs = {
later_file: { transferMethod: TransferMethod.local_file },
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBe(false)
expect(mockNotify).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: expect.stringContaining('appDebug.errorMessage.valueOfVarRequired'),
}))
})
it('should short-circuit remaining fields after detecting file upload in progress', () => {
const { result } = renderHook(() => useCheckInputsForms())
const inputsForm = [
{ variable: 'uploading_file', label: 'Uploading File', required: true, type: InputVarType.singleFile as string },
{ variable: 'later_required_text', label: 'Later Required Text', required: true, type: InputVarType.textInput as string },
]
const inputs = {
uploading_file: { transferMethod: TransferMethod.local_file }, // still uploading
later_required_text: '',
}
const isValid = result.current.checkInputsForm(inputs, inputsForm as InputForm[])
expect(isValid).toBeUndefined()
expect(mockNotify).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'info',
message: 'appDebug.errorMessage.waitForFileUpload',
}))
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,6 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import copy from 'copy-to-clipboard'
import * as React from 'react'
import { vi } from 'vitest'
import Toast from '../../../toast'
import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context'
@@ -169,7 +168,8 @@ describe('Question component', () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
renderWithProvider(makeItem(), onRegenerate)
const item = makeItem()
renderWithProvider(item, onRegenerate)
const editBtn = screen.getByTestId('edit-btn')
await user.click(editBtn)
@@ -184,7 +184,7 @@ describe('Question component', () => {
await user.click(resendBtn)
await waitFor(() => {
expect(onRegenerate).toHaveBeenCalledWith(makeItem(), { message: 'Edited question', files: [] })
expect(onRegenerate).toHaveBeenCalledWith(item, { message: 'Edited question', files: [] })
})
})
@@ -199,7 +199,7 @@ describe('Question component', () => {
await user.clear(textbox)
await user.type(textbox, 'Edited question')
const cancelBtn = screen.getByRole('button', { name: /operation.cancel/i })
const cancelBtn = await screen.findByTestId('cancel-edit-btn')
await user.click(cancelBtn)
await waitFor(() => {
@@ -349,4 +349,120 @@ describe('Question component', () => {
const contentContainer = screen.getByTestId('question-content')
expect(contentContainer.getAttribute('style')).not.toBeNull()
})
it('should cover composition lifecycle preventing enter submitting when composing', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const item = makeItem()
renderWithProvider(item, onRegenerate)
const editBtn = screen.getByTestId('edit-btn')
await user.click(editBtn)
const textbox = await screen.findByRole('textbox')
await user.clear(textbox)
// Simulate composition start and typing
act(() => {
textbox.focus()
})
// Simulate composition start
fireEvent.compositionStart(textbox)
// Try to press Enter while composing
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
// Simulate composition end
fireEvent.compositionEnd(textbox)
// Expect onRegenerate not to be called because Enter was pressed during composition
expect(onRegenerate).not.toHaveBeenCalled()
// Let setTimeout finish its 50ms interval to clear isComposing
await new Promise(r => setTimeout(r, 60))
// Now press Enter after composition is fully cleared
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter' })
expect(onRegenerate).toHaveBeenCalledWith(item, { message: '', files: [] })
})
it('should prevent Enter from submitting when shiftKey is pressed', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const item = makeItem()
renderWithProvider(item, onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
// Press Shift+Enter
fireEvent.keyDown(textbox, { key: 'Enter', code: 'Enter', shiftKey: true })
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should ignore enter when nativeEvent.isComposing is true', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
// Create an event with nativeEvent.isComposing = true
const event = new KeyboardEvent('keydown', { key: 'Enter', code: 'Enter' })
Object.defineProperty(event, 'isComposing', { value: true })
fireEvent(textbox, event)
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should clear timer on cancel and on component unmount', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
const { unmount } = renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox)
// Timer is now running, let's start another composition to clear it
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox)
const cancelBtn = await screen.findByTestId('cancel-edit-btn')
await user.click(cancelBtn)
// Test unmount clearing timer
await user.click(screen.getByTestId('edit-btn'))
const textbox2 = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox2)
fireEvent.compositionEnd(textbox2)
unmount()
expect(onRegenerate).not.toHaveBeenCalled()
})
it('should ignore enter when handleResend with active timer', async () => {
const user = userEvent.setup()
const onRegenerate = vi.fn() as unknown as OnRegenerate
renderWithProvider(makeItem(), onRegenerate)
await user.click(screen.getByTestId('edit-btn'))
const textbox = await screen.findByRole('textbox')
fireEvent.compositionStart(textbox)
fireEvent.compositionEnd(textbox) // starts timer
const saveBtn = screen.getByTestId('save-edit-btn')
await user.click(saveBtn) // handleResend clears timer
expect(onRegenerate).toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,121 @@
import type { InputForm } from '../type'
import { InputVarType } from '@/app/components/workflow/types'
import { getProcessedInputs, processInputFileFromServer, processOpeningStatement } from '../utils'
vi.mock('@/app/components/base/file-uploader/utils', () => ({
getProcessedFiles: vi.fn((files: File[]) => files.map((f: File) => ({ ...f, processed: true }))),
}))
describe('chat/chat/utils.ts', () => {
describe('processOpeningStatement', () => {
it('returns empty string if openingStatement is falsy', () => {
expect(processOpeningStatement('', {}, [])).toBe('')
})
it('replaces variables with input values when available', () => {
const result = processOpeningStatement('Hello {{name}}', { name: 'Alice' }, [])
expect(result).toBe('Hello Alice')
})
it('replaces variables with labels when input value is not available but form has variable', () => {
const result = processOpeningStatement('Hello {{user_name}}', {}, [{ variable: 'user_name', label: 'Name Label', type: InputVarType.textInput }] as InputForm[])
expect(result).toBe('Hello {{Name Label}}')
})
it('keeps original match when input value and form are not available', () => {
const result = processOpeningStatement('Hello {{unknown}}', {}, [])
expect(result).toBe('Hello {{unknown}}')
})
})
describe('processInputFileFromServer', () => {
it('maps server file object to local schema', () => {
const result = processInputFileFromServer({
type: 'image',
transfer_method: 'local_file',
remote_url: 'http://example.com/img.png',
related_id: '123',
})
expect(result).toEqual({
type: 'image',
transfer_method: 'local_file',
url: 'http://example.com/img.png',
upload_file_id: '123',
})
})
})
describe('getProcessedInputs', () => {
it('processes checkbox input types to boolean', () => {
const inputs = { terms: 'true', conds: null }
const inputsForm = [
{ variable: 'terms', type: InputVarType.checkbox as string },
{ variable: 'conds', type: InputVarType.checkbox as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result).toEqual({ terms: true, conds: false })
})
it('ignores null values', () => {
const inputs = { test: null }
const inputsForm = [{ variable: 'test', type: InputVarType.textInput as string }]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result).toEqual({ test: null })
})
it('processes singleFile using transfer_method logic', () => {
const inputs = {
file1: { transfer_method: 'local_file', url: '1' },
file2: { id: 'file2' },
}
const inputsForm = [
{ variable: 'file1', type: InputVarType.singleFile as string },
{ variable: 'file2', type: InputVarType.singleFile as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.file1).toHaveProperty('transfer_method', 'local_file')
expect(result.file2).toHaveProperty('processed', true)
})
it('processes multiFiles using transfer_method logic', () => {
const inputs = {
files1: [{ transfer_method: 'local_file', url: '1' }],
files2: [{ id: 'file2' }],
}
const inputsForm = [
{ variable: 'files1', type: InputVarType.multiFiles as string },
{ variable: 'files2', type: InputVarType.multiFiles as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.files1[0]).toHaveProperty('transfer_method', 'local_file')
expect(result.files2[0]).toHaveProperty('processed', true)
})
it('processes jsonObject parsing correct json', () => {
const inputs = {
json1: '{"key": "value"}',
}
const inputsForm = [{ variable: 'json1', type: InputVarType.jsonObject as string }]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.json1).toEqual({ key: 'value' })
})
it('processes jsonObject falling back to original if json is array or plain string/invalid json', () => {
const inputs = {
jsonInvalid: 'invalid json',
jsonArray: '["a", "b"]',
jsonPlainObj: { key: 'value' },
}
const inputsForm = [
{ variable: 'jsonInvalid', type: InputVarType.jsonObject as string },
{ variable: 'jsonArray', type: InputVarType.jsonObject as string },
{ variable: 'jsonPlainObj', type: InputVarType.jsonObject as string },
]
const result = getProcessedInputs(inputs, inputsForm as InputForm[])
expect(result.jsonInvalid).toBe('invalid json')
expect(result.jsonArray).toBe('["a", "b"]')
expect(result.jsonPlainObj).toEqual({ key: 'value' })
})
})
})

View File

@@ -0,0 +1,437 @@
import { act, renderHook } from '@testing-library/react'
import { useTextAreaHeight } from '../hooks'
describe('useTextAreaHeight', () => {
// Mock getBoundingClientRect for all ref elements
const mockGetBoundingClientRect = (
width: number = 0,
height: number = 0,
) => ({
width,
height,
top: 0,
left: 0,
bottom: height,
right: width,
x: 0,
y: 0,
toJSON: () => ({}),
})
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current).toBeDefined()
})
it('should return all required properties', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current).toHaveProperty('wrapperRef')
expect(result.current).toHaveProperty('textareaRef')
expect(result.current).toHaveProperty('textValueRef')
expect(result.current).toHaveProperty('holdSpaceRef')
expect(result.current).toHaveProperty('handleTextareaResize')
expect(result.current).toHaveProperty('isMultipleLine')
})
})
describe('Initial State', () => {
it('should initialize with isMultipleLine as false', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.isMultipleLine).toBe(false)
})
it('should initialize refs as null', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.wrapperRef.current).toBeNull()
expect(result.current.textValueRef.current).toBeNull()
expect(result.current.holdSpaceRef.current).toBeNull()
})
it('should initialize textareaRef as undefined', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(result.current.textareaRef.current).toBeUndefined()
})
})
describe('Height Computation Logic (via handleTextareaResize)', () => {
it('should not update state when any ref is missing', () => {
const { result } = renderHook(() => useTextAreaHeight())
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
})
it('should set isMultipleLine to true when textarea height exceeds 32px', () => {
const { result } = renderHook(() => useTextAreaHeight())
// Set up refs with mock elements
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 64), // height > 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
// Assign elements to refs
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should set isMultipleLine to true when combined content width exceeds wrapper width', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200, 0), // wrapperWidth = 200
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20), // height <= 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(120, 0), // textValueWidth = 120
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // holdSpaceWidth = 100, total = 220 > 200
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should set isMultipleLine to false when content fits in wrapper', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0), // wrapperWidth = 300
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20), // height <= 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // textValueWidth = 100
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0), // holdSpaceWidth = 50, total = 150 < 300
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
})
it('should handle exact boundary when combined width equals wrapper width', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0), // total = 200, equals wrapperWidth
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle boundary case when textarea height equals 32px', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 32), // exactly 32
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
// height = 32 is not > 32, so should check width condition
expect(result.current.isMultipleLine).toBe(false)
})
})
describe('handleTextareaResize', () => {
it('should be a function', () => {
const { result } = renderHook(() => useTextAreaHeight())
expect(typeof result.current.handleTextareaResize).toBe('function')
})
it('should call handleComputeHeight when invoked', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 64),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(50, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should update state based on new dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
const wrapperRect = vi.spyOn(wrapperElement, 'getBoundingClientRect')
const textareaRect = vi.spyOn(textareaElement, 'getBoundingClientRect')
const textValueRect = vi.spyOn(textValueElement, 'getBoundingClientRect')
const holdSpaceRect = vi.spyOn(holdSpaceElement, 'getBoundingClientRect')
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
// First call - content fits
wrapperRect.mockReturnValue(mockGetBoundingClientRect(300, 0))
textareaRect.mockReturnValue(mockGetBoundingClientRect(300, 20))
textValueRect.mockReturnValue(mockGetBoundingClientRect(100, 0))
holdSpaceRect.mockReturnValue(mockGetBoundingClientRect(50, 0))
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(false)
// Second call - content overflows
textareaRect.mockReturnValue(mockGetBoundingClientRect(300, 64))
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
})
describe('Callback Stability', () => {
it('should maintain ref objects across rerenders', () => {
const { result, rerender } = renderHook(() => useTextAreaHeight())
const firstWrapperRef = result.current.wrapperRef
const firstTextareaRef = result.current.textareaRef
const firstTextValueRef = result.current.textValueRef
const firstHoldSpaceRef = result.current.holdSpaceRef
rerender()
expect(result.current.wrapperRef).toBe(firstWrapperRef)
expect(result.current.textareaRef).toBe(firstTextareaRef)
expect(result.current.textValueRef).toBe(firstTextValueRef)
expect(result.current.holdSpaceRef).toBe(firstHoldSpaceRef)
})
})
describe('Edge Cases', () => {
it('should handle zero dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(0, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
// When all dimensions are 0, 0 + 0 >= 0 is true, so isMultipleLine is true
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle very large dimensions', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(10000, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(10000, 100),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(5000, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(5000, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
it('should handle numeric precision edge cases', () => {
const { result } = renderHook(() => useTextAreaHeight())
const wrapperElement = document.createElement('div')
const textareaElement = document.createElement('textarea')
const textValueElement = document.createElement('div')
const holdSpaceElement = document.createElement('div')
vi.spyOn(wrapperElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(200.5, 0),
)
vi.spyOn(textareaElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(300, 20),
)
vi.spyOn(textValueElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100.2, 0),
)
vi.spyOn(holdSpaceElement, 'getBoundingClientRect').mockReturnValue(
mockGetBoundingClientRect(100.3, 0),
)
result.current.wrapperRef.current = wrapperElement
result.current.textareaRef.current = textareaElement
result.current.textValueRef.current = textValueElement
result.current.holdSpaceRef.current = holdSpaceElement
act(() => {
result.current.handleTextareaResize()
})
expect(result.current.isMultipleLine).toBe(true)
})
})
})

View File

@@ -1,7 +1,7 @@
import type { FileUpload } from '@/app/components/base/features/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { TransferMethod } from '@/types/app'
import { render, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { vi } from 'vitest'
@@ -52,6 +52,8 @@ vi.mock('@/app/components/base/file-uploader/store', () => ({
// ---------------------------------------------------------------------------
// File-uploader hooks provide stable drag/drop handlers
// ---------------------------------------------------------------------------
let mockIsDragActive = false
vi.mock('@/app/components/base/file-uploader/hooks', () => ({
useFile: () => ({
handleDragFileEnter: vi.fn(),
@@ -59,7 +61,7 @@ vi.mock('@/app/components/base/file-uploader/hooks', () => ({
handleDragFileOver: vi.fn(),
handleDropFile: vi.fn(),
handleClipboardPasteFile: vi.fn(),
isDragActive: false,
isDragActive: mockIsDragActive,
}),
}))
@@ -210,6 +212,7 @@ describe('ChatInputArea', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFileStore.files = []
mockIsDragActive = false
mockIsMultipleLine = false
})
@@ -236,6 +239,12 @@ describe('ChatInputArea', () => {
expect(disabledWrapper).toBeInTheDocument()
})
it('should apply drag-active styles when a file is being dragged over the input', () => {
mockIsDragActive = true
const { container } = render(<ChatInputArea visionConfig={mockVisionConfig} />)
expect(container.querySelector('.border-dashed')).toBeInTheDocument()
})
it('should render the operation section inline when single-line', () => {
// mockIsMultipleLine is false by default
render(<ChatInputArea visionConfig={mockVisionConfig} />)
@@ -331,6 +340,30 @@ describe('ChatInputArea', () => {
expect(onSend).toHaveBeenCalledWith('With attachment', [uploadedFile])
})
it('should not send on Enter while IME composition is active, then send after composition ends', () => {
vi.useFakeTimers()
try {
const onSend = vi.fn()
render(<ChatInputArea onSend={onSend} visionConfig={mockVisionConfig} />)
const textarea = getTextarea()
fireEvent.change(textarea, { target: { value: 'Composed text' } })
fireEvent.compositionStart(textarea)
fireEvent.keyDown(textarea, { key: 'Enter' })
expect(onSend).not.toHaveBeenCalled()
fireEvent.compositionEnd(textarea)
vi.advanceTimersByTime(60)
fireEvent.keyDown(textarea, { key: 'Enter' })
expect(onSend).toHaveBeenCalledWith('Composed text', [])
}
finally {
vi.useRealTimers()
}
})
})
// -------------------------------------------------------------------------

View File

@@ -219,8 +219,8 @@ const Question: FC<QuestionProps> = ({
/>
</div>
<div className="flex items-center justify-end gap-2">
<Button className="min-w-24" onClick={handleCancelEditing}>{t('operation.cancel', { ns: 'common' })}</Button>
<Button className="min-w-24" variant="primary" onClick={handleResend}>{t('operation.save', { ns: 'common' })}</Button>
<Button className="min-w-24" onClick={handleCancelEditing} data-testid="cancel-edit-btn">{t('operation.cancel', { ns: 'common' })}</Button>
<Button className="min-w-24" variant="primary" onClick={handleResend} data-testid="save-edit-btn">{t('operation.save', { ns: 'common' })}</Button>
</div>
</div>
)}

View File

@@ -14,6 +14,17 @@ import { shareQueryKeys } from '@/service/use-share'
import { CONVERSATION_ID_INFO } from '../../constants'
import { useEmbeddedChatbot } from '../hooks'
type InputForm = {
variable: string
type: string
default?: unknown
required?: boolean
label?: string
max_length?: number
options?: string[]
hide?: boolean
}
vi.mock('@/i18n-config/client', () => ({
changeLanguage: vi.fn().mockResolvedValue(undefined),
}))
@@ -40,13 +51,23 @@ vi.mock('@/context/web-app-context', () => ({
useWebAppStore: (selector?: (state: typeof mockStoreState) => unknown) => useWebAppStoreMock(selector),
}))
const {
mockGetProcessedInputsFromUrlParams,
mockGetProcessedSystemVariablesFromUrlParams,
mockGetProcessedUserVariablesFromUrlParams,
} = vi.hoisted(() => ({
mockGetProcessedInputsFromUrlParams: vi.fn(),
mockGetProcessedSystemVariablesFromUrlParams: vi.fn(),
mockGetProcessedUserVariablesFromUrlParams: vi.fn(),
}))
vi.mock('../../utils', async () => {
const actual = await vi.importActual<typeof import('../../utils')>('../../utils')
return {
...actual,
getProcessedInputsFromUrlParams: vi.fn().mockResolvedValue({}),
getProcessedSystemVariablesFromUrlParams: vi.fn().mockResolvedValue({}),
getProcessedUserVariablesFromUrlParams: vi.fn().mockResolvedValue({}),
getProcessedInputsFromUrlParams: mockGetProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams: mockGetProcessedSystemVariablesFromUrlParams,
getProcessedUserVariablesFromUrlParams: mockGetProcessedUserVariablesFromUrlParams,
}
})
@@ -65,6 +86,12 @@ vi.mock('@/service/share', async (importOriginal) => {
}
})
const STABLE_MOCK_DATA = { data: {} }
vi.mock('@/service/use-try-app', () => ({
useGetTryAppInfo: vi.fn(() => STABLE_MOCK_DATA),
useGetTryAppParams: vi.fn(() => STABLE_MOCK_DATA),
}))
const mockFetchConversations = vi.mocked(fetchConversations)
const mockFetchChatList = vi.mocked(fetchChatList)
const mockGenerationConversationName = vi.mocked(generationConversationName)
@@ -85,12 +112,20 @@ const createWrapper = (queryClient: QueryClient) => {
)
}
const renderWithClient = <T,>(hook: () => T) => {
const renderWithClient = async <T,>(hook: () => T) => {
const queryClient = createQueryClient()
const wrapper = createWrapper(queryClient)
let result: ReturnType<typeof renderHook<T, unknown>> | undefined
act(() => {
result = renderHook(hook, { wrapper })
})
await waitFor(() => {
if (queryClient.isFetching() > 0)
throw new Error('Queries are still fetching')
}, { timeout: 2000 })
return {
queryClient,
...renderHook(hook, { wrapper }),
...result!,
}
}
@@ -113,6 +148,10 @@ const createConversationData = (overrides: Partial<AppConversationData> = {}): A
describe('useEmbeddedChatbot', () => {
beforeEach(() => {
vi.clearAllMocks()
// Re-establish default mock implementations after clearAllMocks
mockGetProcessedInputsFromUrlParams.mockResolvedValue({})
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
mockGetProcessedUserVariablesFromUrlParams.mockResolvedValue({})
localStorage.removeItem(CONVERSATION_ID_INFO)
mockStoreState.appInfo = {
app_id: 'app-1',
@@ -128,6 +167,8 @@ describe('useEmbeddedChatbot', () => {
mockStoreState.appParams = null
mockStoreState.embeddedConversationId = 'conversation-1'
mockStoreState.embeddedUserId = 'embedded-user-1'
mockFetchConversations.mockResolvedValue({ data: [], has_more: false, limit: 100 })
mockFetchChatList.mockResolvedValue({ data: [] })
})
afterEach(() => {
@@ -150,7 +191,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
// Act
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Assert
await waitFor(() => {
@@ -167,6 +208,49 @@ describe('useEmbeddedChatbot', () => {
expect(result.current.conversationList).toEqual(listData.data)
})
})
it('should format chat list history correctly into appPrevChatList', async () => {
// Provide a currentConversationId by rendering successfully
mockStoreState.embeddedConversationId = 'conversation-1'
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({ conversation_id: 'conversation-1' })
mockFetchChatList.mockResolvedValue({
data: [{
id: 'msg-1',
query: 'Hello',
answer: 'Hi there!',
message_files: [{ belongs_to: 'user', id: 'mf-1' }, { belongs_to: 'assistant', id: 'mf-2' }],
agent_thoughts: [{ id: 'at-1' }],
feedback: { rating: 'like' },
}],
})
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Wait for the mock to be called
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('conversation-1', AppSourceType.webApp, 'app-1')
})
// Wait for the chat list to be populated
await waitFor(() => {
expect(result.current.appPrevChatList.length).toBeGreaterThan(0)
})
// We expect the formatting logic to split the message into question and answer ChatItems
const chatList = result.current.appPrevChatList
const userMsg = chatList.find((msg: unknown) => (msg as Record<string, unknown>).id === 'question-msg-1')
expect(userMsg).toBeDefined()
expect((userMsg as Record<string, unknown>)?.content).toBe('Hello')
expect((userMsg as Record<string, unknown>)?.isAnswer).toBe(false)
const assistantMsg = ((userMsg as Record<string, unknown>)?.children as unknown[])?.[0]
expect(assistantMsg).toBeDefined()
expect((assistantMsg as Record<string, unknown>)?.id).toBe('msg-1')
expect((assistantMsg as Record<string, unknown>)?.content).toBe('Hi there!')
expect((assistantMsg as Record<string, unknown>)?.isAnswer).toBe(true)
expect(((assistantMsg as Record<string, unknown>)?.feedback as Record<string, unknown>)?.rating).toBe('like')
})
})
// Scenario: completion invalidates share caches and merges generated names.
@@ -184,7 +268,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(generatedConversation)
const { result, queryClient } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result, queryClient } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const invalidateSpy = vi.spyOn(queryClient, 'invalidateQueries')
// Act
@@ -214,7 +298,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(createConversationItem({ id: 'conversation-1' }))
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledTimes(1)
@@ -244,7 +328,7 @@ describe('useEmbeddedChatbot', () => {
mockFetchChatList.mockResolvedValue({ data: [] })
mockGenerationConversationName.mockResolvedValue(createConversationItem({ id: 'conversation-new' }))
const { result } = renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Act
act(() => {
@@ -261,4 +345,215 @@ describe('useEmbeddedChatbot', () => {
})
})
})
// Scenario: TryApp mode initialization and logic.
describe('TryApp mode', () => {
it('should use tryApp source type and skip URL overrides and user fetch', async () => {
// Arrange
const { useGetTryAppInfo } = await import('@/service/use-try-app')
const mockTryAppInfo = { app_id: 'try-app-1', site: { title: 'Try App' } };
(useGetTryAppInfo as unknown as ReturnType<typeof vi.fn>).mockReturnValue({ data: mockTryAppInfo })
mockGetProcessedSystemVariablesFromUrlParams.mockResolvedValue({})
// Act
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.tryApp, 'try-app-1'))
// Assert
expect(result.current.isInstalledApp).toBe(false)
expect(result.current.appId).toBe('try-app-1')
expect(result.current.appData?.site.title).toBe('Try App')
// ensure URL fetching is skipped
expect(mockGetProcessedSystemVariablesFromUrlParams).not.toHaveBeenCalled()
})
})
// Language overrides tests were causing hang, removed for now.
// Scenario: Removing conversation id info
describe('removeConversationIdInfo', () => {
it('should successfully remove a stored conversation ID info by appId', async () => {
// Setup some initial info
localStorage.setItem(CONVERSATION_ID_INFO, JSON.stringify({ 'app-1': { 'user-1': 'conv-id' } }))
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.removeConversationIdInfo('app-1')
})
await waitFor(() => {
const storedValue = localStorage.getItem(CONVERSATION_ID_INFO)
const parsed = storedValue ? JSON.parse(storedValue) : {}
expect(parsed['app-1']).toBeUndefined()
})
})
})
// Scenario: various form inputs configurations and default parsing
describe('inputsForms mapping and default parsing', () => {
const mockAppParamsWithInputs = {
user_input_form: [
{ paragraph: { variable: 'p1', default: 'para', max_length: 5 } },
{ number: { variable: 'n1', default: 42 } },
{ checkbox: { variable: 'c1', default: true } },
{ select: { variable: 's1', options: ['A', 'B'], default: 'A' } },
{ 'file-list': { variable: 'fl1' } },
{ file: { variable: 'f1' } },
{ json_object: { variable: 'j1' } },
{ 'text-input': { variable: 't1', default: 'txt', max_length: 3 } },
],
}
it('should map various types properly with max_length truncation when defaults supplied via URL', async () => {
mockGetProcessedInputsFromUrlParams.mockResolvedValue({
p1: 'toolongparagraph', // truncated to 5
n1: '99',
c1: true,
s1: 'B', // Matches options
t1: '1234', // truncated to 3
})
mockStoreState.appParams = mockAppParamsWithInputs as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Wait for the mock to be called
await waitFor(() => {
expect(mockGetProcessedInputsFromUrlParams).toHaveBeenCalled()
})
await waitFor(() => {
expect(result.current.inputsForms).toHaveLength(8)
})
const forms = result.current.inputsForms
expect(forms.find((f: InputForm) => f.variable === 'p1')?.default).toBe('toolo')
expect(forms.find((f: InputForm) => f.variable === 'n1')?.default).toBe(99)
expect(forms.find((f: InputForm) => f.variable === 'c1')?.default).toBe(true)
expect(forms.find((f: InputForm) => f.variable === 's1')?.default).toBe('B')
expect(forms.find((f: InputForm) => f.variable === 't1')?.default).toBe('123')
expect(forms.find((f: InputForm) => f.variable === 'fl1')?.type).toBe('file-list')
expect(forms.find((f: InputForm) => f.variable === 'f1')?.type).toBe('file')
expect(forms.find((f: InputForm) => f.variable === 'j1')?.type).toBe('json_object')
})
})
// Scenario: checkInputsRequired validates empty fields and pending multi-file uploads
describe('checkInputsRequired and handleStartChat', () => {
it('should return undefined and notify when file is still uploading', async () => {
mockStoreState.appParams = {
user_input_form: [
{ file: { variable: 'file_var', required: true } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
// Simulate a local file uploading
act(() => {
result.current.handleNewConversationInputsChange({
file_var: [{ transferMethod: 'local_file', uploadedId: null }],
})
})
const onStart = vi.fn()
let checkResult: boolean | undefined
act(() => {
checkResult = (result.current as unknown as { handleStartChat: (onStart?: () => void) => boolean }).handleStartChat(onStart)
})
expect(checkResult).toBeUndefined()
expect(onStart).not.toHaveBeenCalled()
})
it('should fail checkInputsRequired when required fields are missing', async () => {
mockStoreState.appParams = {
user_input_form: [
{ 'text-input': { variable: 't1', required: true, label: 'T1' } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.handleNewConversationInputsChange({
t1: '',
})
})
const onStart = vi.fn()
act(() => {
(result.current as unknown as { handleStartChat: (cb?: () => void) => void }).handleStartChat(onStart)
})
expect(onStart).not.toHaveBeenCalled()
})
it('should pass checkInputsRequired when allInputsHidden is true', async () => {
mockStoreState.appParams = {
user_input_form: [
{ 'text-input': { variable: 't1', required: true, label: 'T1', hide: true } },
],
} as unknown as ChatConfig
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
const callback = vi.fn()
act(() => {
(result.current as unknown as { handleStartChat: (cb?: () => void) => void }).handleStartChat(callback)
})
expect(callback).toHaveBeenCalled()
})
})
// Scenario: handlers (New Conversation, Change Conversation, Feedback)
describe('Event Handlers', () => {
it('handleNewConversation sets clearChatList to true for webApp', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await act(async () => {
await result.current.handleNewConversation()
})
expect(result.current.clearChatList).toBe(true)
})
it('handleNewConversation sets clearChatList to true for tryApp without complex parsing', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.tryApp, 'app-try-1'))
await act(async () => {
await result.current.handleNewConversation()
})
expect(result.current.clearChatList).toBe(true)
})
it('handleChangeConversation updates current conversation and refetches chat list', async () => {
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
act(() => {
result.current.handleChangeConversation('another-convo')
})
await waitFor(() => {
expect(result.current.currentConversationId).toBe('another-convo')
})
await waitFor(() => {
expect(mockFetchChatList).toHaveBeenCalledWith('another-convo', AppSourceType.webApp, 'app-1')
})
expect(result.current.newConversationId).toBe('')
expect(result.current.clearChatList).toBe(false)
})
it('handleFeedback invokes updateFeedback service successfully', async () => {
const { updateFeedback } = await import('@/service/share')
const { result } = await renderWithClient(() => useEmbeddedChatbot(AppSourceType.webApp))
await act(async () => {
await result.current.handleFeedback('msg-123', { rating: 'like' })
})
expect(updateFeedback).toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,189 @@
/**
* Tests for embedded-chatbot utility functions.
*/
import { isDify } from '../utils'
describe('isDify', () => {
const originalReferrer = document.referrer
afterEach(() => {
Object.defineProperty(document, 'referrer', {
value: originalReferrer,
writable: true,
})
})
it('should return true when referrer includes dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/something',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes www.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://www.dify.ai/app/xyz',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return false when referrer does not include dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example.com',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should return false when referrer is empty', () => {
Object.defineProperty(document, 'referrer', {
value: '',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should return false when referrer does not contain dify.ai domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example-dify.com',
writable: true,
})
expect(isDify()).toBe(false)
})
it('should handle referrer without protocol', () => {
Object.defineProperty(document, 'referrer', {
value: 'dify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes api.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://api.dify.ai/v1/endpoint',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes app.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://app.dify.ai/chat',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer includes docs.dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://docs.dify.ai/guide',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with query parameters', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/?ref=test&id=123',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with hash fragment', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai/page#section',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when referrer has dify.ai with port number', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://dify.ai:8080/app',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when dify.ai appears after another domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://example.com/redirect?url=https://dify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when substring contains dify.ai', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://notdify.ai',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true when dify.ai is part of a different domain', () => {
Object.defineProperty(document, 'referrer', {
value: 'https://fake-dify.ai.example.com',
writable: true,
})
expect(isDify()).toBe(true)
})
it('should return true with multiple referrer variations', () => {
const variations = [
'https://dify.ai',
'http://www.dify.ai',
'http://dify.ai/',
'https://dify.ai/app?token=123#section',
'dify.ai/test',
'www.dify.ai/en',
]
variations.forEach((referrer) => {
Object.defineProperty(document, 'referrer', {
value: referrer,
writable: true,
})
expect(isDify()).toBe(true)
})
})
it('should return false with multiple non-dify referrer variations', () => {
const variations = [
'https://github.com',
'https://google.com',
'https://stackoverflow.com',
'https://example.dify',
'https://difyai.com',
'',
]
variations.forEach((referrer) => {
Object.defineProperty(document, 'referrer', {
value: referrer,
writable: true,
})
expect(isDify()).toBe(false)
})
})
})

View File

@@ -0,0 +1,221 @@
import { renderHook } from '@testing-library/react'
import { Theme, ThemeBuilder, useThemeContext } from '../theme-context'
// Scenario: Theme class configures colors from chatColorTheme and chatColorThemeInverted flags.
describe('Theme', () => {
describe('Default colors', () => {
it('should use default primary color when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.primaryColor).toBe('#1C64F2')
})
it('should use gradient background header when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.backgroundHeaderColorStyle).toBe(
'backgroundImage: linear-gradient(to right, #2563eb, #0ea5e9)',
)
})
it('should have empty chatBubbleColorStyle when chatColorTheme is null', () => {
const theme = new Theme(null, false)
expect(theme.chatBubbleColorStyle).toBe('')
})
it('should use default colors when chatColorTheme is empty string', () => {
const theme = new Theme('', false)
expect(theme.primaryColor).toBe('#1C64F2')
expect(theme.backgroundHeaderColorStyle).toBe(
'backgroundImage: linear-gradient(to right, #2563eb, #0ea5e9)',
)
})
})
describe('Custom color (configCustomColor)', () => {
it('should set primaryColor to chatColorTheme value', () => {
const theme = new Theme('#FF5733', false)
expect(theme.primaryColor).toBe('#FF5733')
})
it('should set backgroundHeaderColorStyle to solid custom color', () => {
const theme = new Theme('#FF5733', false)
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #FF5733')
})
it('should include primary color in backgroundButtonDefaultColorStyle', () => {
const theme = new Theme('#FF5733', false)
expect(theme.backgroundButtonDefaultColorStyle).toContain('#FF5733')
})
it('should set roundedBackgroundColorStyle with 5% opacity rgba', () => {
const theme = new Theme('#FF5733', false)
// #FF5733 → r=255 g=87 b=51
expect(theme.roundedBackgroundColorStyle).toBe('backgroundColor: rgba(255,87,51,0.05)')
})
it('should set chatBubbleColorStyle with 15% opacity rgba', () => {
const theme = new Theme('#FF5733', false)
expect(theme.chatBubbleColorStyle).toBe('backgroundColor: rgba(255,87,51,0.15)')
})
})
describe('Inverted color (configInvertedColor)', () => {
it('should use white background header when inverted with no custom color', () => {
const theme = new Theme(null, true)
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #ffffff')
})
it('should set colorFontOnHeaderStyle to default primaryColor when inverted with no custom color', () => {
const theme = new Theme(null, true)
expect(theme.colorFontOnHeaderStyle).toBe('color: #1C64F2')
})
it('should set headerBorderBottomStyle when inverted', () => {
const theme = new Theme(null, true)
expect(theme.headerBorderBottomStyle).toBe('borderBottom: 1px solid #ccc')
})
it('should set colorPathOnHeader to primaryColor when inverted', () => {
const theme = new Theme(null, true)
expect(theme.colorPathOnHeader).toBe('#1C64F2')
})
it('should have empty headerBorderBottomStyle when not inverted', () => {
const theme = new Theme(null, false)
expect(theme.headerBorderBottomStyle).toBe('')
})
})
describe('Custom color + inverted combined', () => {
it('should override background to white even when custom color is set', () => {
const theme = new Theme('#FF5733', true)
// configCustomColor runs first (solid bg), then configInvertedColor overrides to white
expect(theme.backgroundHeaderColorStyle).toBe('backgroundColor: #ffffff')
})
it('should use custom primaryColor for colorFontOnHeaderStyle when inverted', () => {
const theme = new Theme('#FF5733', true)
expect(theme.colorFontOnHeaderStyle).toBe('color: #FF5733')
})
it('should set colorPathOnHeader to custom primaryColor when inverted', () => {
const theme = new Theme('#FF5733', true)
expect(theme.colorPathOnHeader).toBe('#FF5733')
})
})
})
// Scenario: ThemeBuilder manages a lazily-created Theme instance and rebuilds on config change.
describe('ThemeBuilder', () => {
describe('theme getter', () => {
it('should create a default Theme when _theme is undefined (first access)', () => {
const builder = new ThemeBuilder()
const theme = builder.theme
expect(theme).toBeInstanceOf(Theme)
expect(theme.primaryColor).toBe('#1C64F2')
})
it('should return the same Theme instance on subsequent accesses', () => {
const builder = new ThemeBuilder()
const first = builder.theme
const second = builder.theme
expect(first).toBe(second)
})
})
describe('buildTheme', () => {
it('should create a Theme with the given color on first call', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
expect(builder.theme.primaryColor).toBe('#AABBCC')
})
it('should not rebuild the Theme when called again with the same config', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const themeAfterFirstBuild = builder.theme
builder.buildTheme('#AABBCC', false)
// Same instance: no rebuild occurred
expect(builder.theme).toBe(themeAfterFirstBuild)
})
it('should rebuild the Theme when chatColorTheme changes', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const originalTheme = builder.theme
builder.buildTheme('#FF0000', false)
expect(builder.theme).not.toBe(originalTheme)
expect(builder.theme.primaryColor).toBe('#FF0000')
})
it('should rebuild the Theme when chatColorThemeInverted changes', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#AABBCC', false)
const originalTheme = builder.theme
builder.buildTheme('#AABBCC', true)
expect(builder.theme).not.toBe(originalTheme)
expect(builder.theme.chatColorThemeInverted).toBe(true)
})
it('should use default args (null, false) when called with no arguments', () => {
const builder = new ThemeBuilder()
builder.buildTheme()
expect(builder.theme.chatColorTheme).toBeNull()
expect(builder.theme.chatColorThemeInverted).toBe(false)
})
it('should store chatColorTheme and chatColorThemeInverted on the built Theme', () => {
const builder = new ThemeBuilder()
builder.buildTheme('#123456', true)
expect(builder.theme.chatColorTheme).toBe('#123456')
expect(builder.theme.chatColorThemeInverted).toBe(true)
})
})
})
// Scenario: useThemeContext returns a ThemeBuilder from the nearest ThemeContext.
describe('useThemeContext', () => {
it('should return a ThemeBuilder instance from the default context', () => {
const { result } = renderHook(() => useThemeContext())
expect(result.current).toBeInstanceOf(ThemeBuilder)
})
it('should expose a valid theme on the returned ThemeBuilder', () => {
const { result } = renderHook(() => useThemeContext())
expect(result.current.theme).toBeInstanceOf(Theme)
})
})

View File

@@ -1,6 +1,5 @@
import type { Dayjs } from 'dayjs'
import type { DatePickerProps, Period } from '../types'
import { RiCalendarLine, RiCloseCircleFill } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -218,38 +217,29 @@ const DatePicker = ({
>
<PortalToFollowElemTrigger className={triggerWrapClassName}>
{renderTrigger
? (renderTrigger({
value: normalizedValue,
selectedDate,
isOpen,
handleClear,
handleClickTrigger,
}))
? (
renderTrigger({
value: normalizedValue,
selectedDate,
isOpen,
handleClear,
handleClickTrigger,
}))
: (
<div
className="group flex w-[252px] cursor-pointer items-center gap-x-0.5 rounded-lg bg-components-input-bg-normal px-2 py-1 hover:bg-state-base-hover-alt"
onClick={handleClickTrigger}
data-testid="date-picker-trigger"
>
<input
className="system-xs-regular flex-1 cursor-pointer appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder"
className="flex-1 cursor-pointer appearance-none truncate bg-transparent p-1 text-components-input-text-filled
outline-none system-xs-regular placeholder:text-components-input-text-placeholder"
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
/>
<RiCalendarLine className={cn(
'h-4 w-4 shrink-0 text-text-quaternary',
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
(displayValue || (isOpen && selectedDate)) && 'group-hover:hidden',
)}
/>
<RiCloseCircleFill
className={cn(
'hidden h-4 w-4 shrink-0 text-text-quaternary',
(displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block',
)}
onClick={handleClear}
/>
<span className={cn('i-ri-calendar-line h-4 w-4 shrink-0 text-text-quaternary', isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary', (displayValue || (isOpen && selectedDate)) && 'group-hover:hidden')} />
<span className={cn('i-ri-close-circle-fill hidden h-4 w-4 shrink-0 text-text-quaternary', (displayValue || (isOpen && selectedDate)) && 'hover:text-text-secondary group-hover:inline-block')} onClick={handleClear} data-testid="date-picker-clear-button" />
</div>
)}
</PortalToFollowElemTrigger>

View File

@@ -1,6 +1,5 @@
import type { Dayjs } from 'dayjs'
import type { TimePickerProps } from '../types'
import { RiCloseCircleFill, RiTimeLine } from '@remixicon/react'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -199,8 +198,8 @@ const TimePicker = ({
const inputElem = (
<input
className="system-xs-regular flex-1 cursor-pointer select-none appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder"
className="flex-1 cursor-pointer select-none appearance-none truncate bg-transparent p-1 text-components-input-text-filled
outline-none system-xs-regular placeholder:text-components-input-text-placeholder"
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
@@ -226,26 +225,14 @@ const TimePicker = ({
triggerFullWidth ? 'w-full min-w-0' : 'w-[252px]',
)}
onClick={handleClickTrigger}
data-testid="time-picker-trigger"
>
{inputElem}
{showTimezone && timezone && (
<TimezoneLabel timezone={timezone} inline className="shrink-0 select-none text-xs" />
)}
<RiTimeLine className={cn(
'h-4 w-4 shrink-0 text-text-quaternary',
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
(displayValue || (isOpen && selectedTime)) && !notClearable && 'group-hover:hidden',
)}
/>
<RiCloseCircleFill
className={cn(
'hidden h-4 w-4 shrink-0 text-text-quaternary',
(displayValue || (isOpen && selectedTime)) && !notClearable && 'hover:text-text-secondary group-hover:inline-block',
)}
role="button"
aria-label={t('operation.clear', { ns: 'common' })}
onClick={handleClear}
/>
<span className={cn('i-ri-time-line h-4 w-4 shrink-0 text-text-quaternary', isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary', (displayValue || (isOpen && selectedTime)) && !notClearable && 'group-hover:hidden')} />
<span className={cn('i-ri-close-circle-fill hidden h-4 w-4 shrink-0 text-text-quaternary', (displayValue || (isOpen && selectedTime)) && !notClearable && 'hover:text-text-secondary group-hover:inline-block')} role="button" aria-label={t('operation.clear', { ns: 'common' })} onClick={handleClear} />
</div>
)}
</PortalToFollowElemTrigger>

View File

@@ -0,0 +1,105 @@
import { fireEvent, render, screen } from '@testing-library/react'
import DynamicPdfPreview from './dynamic-pdf-preview'
type DynamicPdfPreviewProps = {
url: string
onCancel: () => void
}
type DynamicLoader = () => Promise<unknown> | undefined
type DynamicOptions = {
ssr?: boolean
}
const mockState = vi.hoisted(() => ({
loader: undefined as DynamicLoader | undefined,
options: undefined as DynamicOptions | undefined,
}))
const mockDynamicRender = vi.hoisted(() => vi.fn())
const mockDynamic = vi.hoisted(() =>
vi.fn((loader: DynamicLoader, options: DynamicOptions) => {
mockState.loader = loader
mockState.options = options
const MockDynamicPdfPreview = ({ url, onCancel }: DynamicPdfPreviewProps) => {
mockDynamicRender({ url, onCancel })
return (
<button data-testid="dynamic-pdf-preview" data-url={url} onClick={onCancel}>
Dynamic PDF Preview
</button>
)
}
return MockDynamicPdfPreview
}),
)
const mockPdfPreview = vi.hoisted(() =>
vi.fn(() => null),
)
vi.mock('next/dynamic', () => ({
default: mockDynamic,
}))
vi.mock('./pdf-preview', () => ({
default: mockPdfPreview,
}))
describe('dynamic-pdf-preview', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should configure next/dynamic with ssr disabled', () => {
expect(mockState.loader).toEqual(expect.any(Function))
expect(mockState.options).toEqual({ ssr: false })
})
it('should render the dynamic component and forward props', () => {
const onCancel = vi.fn()
render(<DynamicPdfPreview url="https://example.com/test.pdf" onCancel={onCancel} />)
const trigger = screen.getByTestId('dynamic-pdf-preview')
expect(trigger).toHaveAttribute('data-url', 'https://example.com/test.pdf')
expect(mockDynamicRender).toHaveBeenCalledWith({
url: 'https://example.com/test.pdf',
onCancel,
})
fireEvent.click(trigger)
expect(onCancel).toHaveBeenCalledTimes(1)
})
it('should return pdf-preview module when loader is executed in browser-like environment', async () => {
const loaded = mockState.loader?.()
expect(loaded).toBeInstanceOf(Promise)
const loadedModule = (await loaded) as { default: unknown }
const pdfPreviewModule = await import('./pdf-preview')
expect(loadedModule.default).toBe(pdfPreviewModule.default)
})
it('should return undefined when loader runs without window', () => {
const originalWindow = globalThis.window
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
try {
const loaded = mockState.loader?.()
expect(loaded).toBeUndefined()
}
finally {
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: originalWindow,
})
}
})
})

View File

@@ -44,4 +44,16 @@ describe('VariableOrConstantInputField', () => {
fireEvent.click(modeButtons[0])
expect(screen.getByRole('button', { name: 'Variable picker' })).toBeInTheDocument()
})
it('should handle variable picker changes', () => {
const logSpy = vi.spyOn(console, 'log').mockImplementation(() => { })
try {
render(<VariableOrConstantInputField label="Input source" />)
fireEvent.click(screen.getByRole('button', { name: 'Variable picker' }))
expect(logSpy).toHaveBeenCalledWith('Variable value changed')
}
finally {
logSpy.mockRestore()
}
})
})

View File

@@ -46,4 +46,54 @@ describe('base scenario schema generator', () => {
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ mode: null }).success).toBe(true)
})
it('should validate required checkbox values as booleans', () => {
const schema = generateZodSchema([{
type: BaseFieldType.checkbox,
variable: 'accepted',
label: 'Accepted',
required: true,
showConditions: [],
}])
expect(schema.safeParse({ accepted: true }).success).toBe(true)
expect(schema.safeParse({ accepted: false }).success).toBe(true)
expect(schema.safeParse({ accepted: 'yes' }).success).toBe(false)
expect(schema.safeParse({}).success).toBe(false)
})
it('should fallback to any schema for unsupported field types', () => {
const schema = generateZodSchema([{
type: BaseFieldType.file,
variable: 'attachment',
label: 'Attachment',
required: false,
showConditions: [],
allowedFileTypes: [],
allowedFileExtensions: [],
allowedFileUploadMethods: [],
}])
expect(schema.safeParse({ attachment: { id: 'file-1' } }).success).toBe(true)
expect(schema.safeParse({ attachment: 'raw-string' }).success).toBe(true)
expect(schema.safeParse({}).success).toBe(true)
expect(schema.safeParse({ attachment: null }).success).toBe(true)
})
it('should ignore numeric and text constraints for non-applicable field types', () => {
const schema = generateZodSchema([{
type: BaseFieldType.checkbox,
variable: 'toggle',
label: 'Toggle',
required: true,
showConditions: [],
maxLength: 1,
min: 10,
max: 20,
}])
expect(schema.safeParse({ toggle: true }).success).toBe(true)
expect(schema.safeParse({ toggle: false }).success).toBe(true)
expect(schema.safeParse({ toggle: 1 }).success).toBe(false)
})
})

View File

@@ -8,7 +8,7 @@ import * as utils from '../utils'
vi.mock('../utils', () => ({
generate: vi.fn((icon, key, props) => (
<svg
data-testid="mock-svg"
data-testid={key}
key={key}
{...props}
>
@@ -29,7 +29,7 @@ describe('IconBase Component', () => {
it('renders properly with required props', () => {
render(<IconBase data={mockData} />)
const svg = screen.getByTestId('mock-svg')
const svg = screen.getByTestId('svg-test-icon')
expect(svg).toBeInTheDocument()
expect(svg).toHaveAttribute('data-icon', mockData.name)
expect(svg).toHaveAttribute('aria-hidden', 'true')
@@ -37,7 +37,7 @@ describe('IconBase Component', () => {
it('passes className to the generated SVG', () => {
render(<IconBase data={mockData} className="custom-class" />)
const svg = screen.getByTestId('mock-svg')
const svg = screen.getByTestId('svg-test-icon')
expect(svg).toHaveAttribute('class', 'custom-class')
expect(utils.generate).toHaveBeenCalledWith(
mockData.icon,
@@ -49,7 +49,7 @@ describe('IconBase Component', () => {
it('handles onClick events', () => {
const handleClick = vi.fn()
render(<IconBase data={mockData} onClick={handleClick} />)
const svg = screen.getByTestId('mock-svg')
const svg = screen.getByTestId('svg-test-icon')
fireEvent.click(svg)
expect(handleClick).toHaveBeenCalledTimes(1)
})

View File

@@ -21,6 +21,28 @@ describe('generate icon base utils', () => {
const result = normalizeAttrs(attrs)
expect(result).toEqual({ dataTest: 'value', xlinkHref: 'url' })
})
it('should filter out editor metadata attributes', () => {
const attrs = {
'inkscape:version': '1.0',
'sodipodi:docname': 'icon.svg',
'xmlns:inkscape': 'http...',
'xmlns:sodipodi': 'http...',
'xmlns:svg': 'http...',
'data-name': 'Layer 1',
'xmlns-inkscape': 'http...',
'xmlns-sodipodi': 'http...',
'xmlns-svg': 'http...',
'dataName': 'Layer 1',
'valid': 'value',
}
expect(normalizeAttrs(attrs)).toEqual({ valid: 'value' })
})
it('should ignore undefined attribute values and handle default argument', () => {
expect(normalizeAttrs()).toEqual({})
expect(normalizeAttrs({ missing: undefined, valid: 'true' })).toEqual({ valid: 'true' })
})
})
describe('generate', () => {
@@ -58,7 +80,19 @@ describe('generate icon base utils', () => {
const node: AbstractNode = {
name: 'div',
attributes: { class: 'container' },
children: [],
children: [{ name: 'span', attributes: {} }],
}
const rootProps = { id: 'root' }
const { container } = render(generate(node, 'key', rootProps))
expect(container.querySelector('div')).toHaveAttribute('id', 'root')
expect(container.querySelector('span')).toBeInTheDocument()
})
it('should handle undefined children with rootProps', () => {
const node: AbstractNode = {
name: 'div',
attributes: { class: 'container' },
}
const rootProps = { id: 'root' }

View File

@@ -1,4 +0,0 @@
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -1,35 +0,0 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "10",
"height": "10",
"viewBox": "0 0 10 10",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
"fill": "currentColor"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "CreditsCoin"
}

View File

@@ -1,20 +0,0 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import type { IconData } from '@/app/components/base/icons/IconBase'
import * as React from 'react'
import IconBase from '@/app/components/base/icons/IconBase'
import data from './CreditsCoin.json'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'CreditsCoin'
export default Icon

View File

@@ -1,6 +1,5 @@
export { default as Balance } from './Balance'
export { default as CoinsStacked01 } from './CoinsStacked01'
export { default as CreditsCoin } from './CreditsCoin'
export { default as GoldCoin } from './GoldCoin'
export { default as ReceiptList } from './ReceiptList'
export { default as Tag01 } from './Tag01'

View File

@@ -36,7 +36,7 @@ const ImageGallery: FC<Props> = ({
const imgNum = srcs.length
const imgStyle = getWidthStyle(imgNum)
return (
<div className={cn(s[`img-${imgNum}`], 'flex flex-wrap')}>
<div className={cn(s[`img-${imgNum}`], 'flex flex-wrap')} data-testid="image-gallery">
{srcs.map((src, index) => (
!src
? null

View File

@@ -1,6 +1,6 @@
import type { useLocalFileUploader } from '../hooks'
import type { ImageFile, VisionSettings } from '@/types/app'
import { render, screen } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { Resolution, TransferMethod } from '@/types/app'
import ChatImageUploader from '../chat-image-uploader'
@@ -193,6 +193,23 @@ describe('ChatImageUploader', () => {
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})
it('should keep popover closed when trigger wrapper is clicked while disabled', async () => {
const user = userEvent.setup()
const settings = createSettings({
transfer_methods: [TransferMethod.remote_url],
})
render(<ChatImageUploader settings={settings} onUpload={defaultOnUpload} disabled />)
const button = screen.getByRole('button')
const triggerWrapper = button.parentElement
if (!triggerWrapper)
throw new Error('Expected trigger wrapper to exist')
await user.click(triggerWrapper)
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})
it('should show OR separator and local uploader when both methods are available', async () => {
const user = userEvent.setup()
const settings = createSettings({
@@ -207,6 +224,30 @@ describe('ChatImageUploader', () => {
expect(queryFileInput()).toBeInTheDocument()
})
it('should toggle local-upload hover style in mixed transfer mode', async () => {
const user = userEvent.setup()
const settings = createSettings({
transfer_methods: [TransferMethod.local_file, TransferMethod.remote_url],
})
render(<ChatImageUploader settings={settings} onUpload={defaultOnUpload} />)
await user.click(screen.getByRole('button'))
const uploadFromComputer = screen.getByText('common.imageUploader.uploadFromComputer')
expect(uploadFromComputer).not.toHaveClass('bg-primary-50')
const localInput = getFileInput()
const hoverWrapper = localInput.parentElement
if (!hoverWrapper)
throw new Error('Expected local uploader wrapper to exist')
fireEvent.mouseEnter(hoverWrapper)
expect(uploadFromComputer).toHaveClass('bg-primary-50')
fireEvent.mouseLeave(hoverWrapper)
expect(uploadFromComputer).not.toHaveClass('bg-primary-50')
})
it('should not show OR separator or local uploader when only remote_url method', async () => {
const user = userEvent.setup()
const settings = createSettings({

View File

@@ -140,9 +140,11 @@ describe('ImageLinkInput', () => {
const input = screen.getByRole('textbox')
await user.type(input, 'https://example.com/image.png')
await user.click(screen.getByRole('button'))
const button = screen.getByRole('button')
expect(button).toBeDisabled()
await user.click(button)
// Button is disabled, so click won't fire handleClick
expect(onUpload).not.toHaveBeenCalled()
})

View File

@@ -2,22 +2,15 @@ import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ImagePreview from '../image-preview'
type HotkeyHandler = () => void
type _HotkeyHandler = () => void
const mocks = vi.hoisted(() => ({
hotkeys: {} as Record<string, HotkeyHandler>,
notify: vi.fn(),
downloadUrl: vi.fn(),
windowOpen: vi.fn<(...args: unknown[]) => Window | null>(),
clipboardWrite: vi.fn<(items: ClipboardItem[]) => Promise<void>>(),
}))
vi.mock('react-hotkeys-hook', () => ({
useHotkeys: (keys: string, handler: HotkeyHandler) => {
mocks.hotkeys[keys] = handler
},
}))
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (...args: Parameters<typeof mocks.notify>) => mocks.notify(...args),
@@ -44,7 +37,6 @@ describe('ImagePreview', () => {
beforeEach(() => {
vi.clearAllMocks()
mocks.hotkeys = {}
if (!navigator.clipboard) {
Object.defineProperty(globalThis.navigator, 'clipboard', {
@@ -109,7 +101,8 @@ describe('ImagePreview', () => {
})
describe('Hotkeys', () => {
it('should register hotkeys and invoke esc/left/right handlers', () => {
it('should trigger esc/left/right handlers from keyboard', async () => {
const user = userEvent.setup()
const onCancel = vi.fn()
const onPrev = vi.fn()
const onNext = vi.fn()
@@ -123,18 +116,34 @@ describe('ImagePreview', () => {
/>,
)
expect(mocks.hotkeys.esc).toBeInstanceOf(Function)
expect(mocks.hotkeys.left).toBeInstanceOf(Function)
expect(mocks.hotkeys.right).toBeInstanceOf(Function)
mocks.hotkeys.esc?.()
mocks.hotkeys.left?.()
mocks.hotkeys.right?.()
await user.keyboard('{Escape}{ArrowLeft}{ArrowRight}')
expect(onCancel).toHaveBeenCalledTimes(1)
expect(onPrev).toHaveBeenCalledTimes(1)
expect(onNext).toHaveBeenCalledTimes(1)
})
it('should zoom in and out from keyboard up/down hotkeys', async () => {
const user = userEvent.setup()
render(
<ImagePreview
url="https://example.com/image.png"
title="Preview Image"
onCancel={vi.fn()}
/>,
)
const image = screen.getByRole('img', { name: 'Preview Image' })
await user.keyboard('{ArrowUp}')
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1.2) translate(0px, 0px)' })
})
await user.keyboard('{ArrowDown}')
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1) translate(0px, 0px)' })
})
})
})
describe('User Interactions', () => {
@@ -225,13 +234,18 @@ describe('ImagePreview', () => {
act(() => {
overlay.dispatchEvent(new MouseEvent('mousedown', { bubbles: true, clientX: 10, clientY: 10 }))
overlay.dispatchEvent(new MouseEvent('mousemove', { bubbles: true, clientX: 40, clientY: 30 }))
})
await waitFor(() => {
expect(image.style.transition).toBe('none')
})
expect(image.style.transform).toContain('translate(')
act(() => {
overlay.dispatchEvent(new MouseEvent('mousemove', { bubbles: true, clientX: 200, clientY: -100 }))
})
await waitFor(() => {
expect(image).toHaveStyle({ transform: 'scale(1.2) translate(70px, -22px)' })
})
act(() => {
document.dispatchEvent(new MouseEvent('mouseup', { bubbles: true }))

View File

@@ -1,4 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { InputNumber } from '../index'
describe('InputNumber Component', () => {
@@ -16,70 +17,130 @@ describe('InputNumber Component', () => {
expect(input).toBeInTheDocument()
})
it('handles increment button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
it('handles increment button click', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(6)
await user.click(incrementBtn)
expect(onChange).toHaveBeenCalledWith(6)
})
it('handles decrement button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
it('handles decrement button click', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(4)
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(4)
})
it('respects max value constraint', () => {
render(<InputNumber {...defaultProps} value={10} max={10} />)
it('respects max value constraint', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={10} max={10} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
await user.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('respects min value constraint', () => {
render(<InputNumber {...defaultProps} value={0} min={0} />)
it('respects min value constraint', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={0} min={0} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('handles direct input changes', () => {
render(<InputNumber {...defaultProps} />)
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '42' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(42)
expect(onChange).toHaveBeenCalledWith(42)
})
it('handles empty input', () => {
render(<InputNumber {...defaultProps} value={1} />)
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={1} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(0)
expect(onChange).toHaveBeenCalledWith(0)
})
it('handles invalid input', () => {
render(<InputNumber {...defaultProps} />)
it('does not call onChange when parsed value is NaN', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: 'abc' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(0)
const originalNumber = globalThis.Number
const numberSpy = vi.spyOn(globalThis, 'Number').mockImplementation((val: unknown) => {
if (val === '123') {
return Number.NaN
}
return originalNumber(val)
})
try {
fireEvent.change(input, { target: { value: '123' } })
expect(onChange).not.toHaveBeenCalled()
}
finally {
numberSpy.mockRestore()
}
})
it('does not call onChange when direct input exceeds range', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} max={10} min={0} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '11' } })
expect(onChange).not.toHaveBeenCalled()
})
it('uses default value when increment and decrement are clicked without value prop', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} defaultValue={7} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).toHaveBeenNthCalledWith(1, 7)
expect(onChange).toHaveBeenNthCalledWith(2, 7)
})
it('falls back to zero when controls are used without value and defaultValue', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).toHaveBeenNthCalledWith(1, 0)
expect(onChange).toHaveBeenNthCalledWith(2, 0)
})
it('displays unit when provided', () => {
const onChange = vi.fn()
const unit = 'px'
render(<InputNumber {...defaultProps} unit={unit} />)
render(<InputNumber onChange={onChange} unit={unit} />)
expect(screen.getByText(unit)).toBeInTheDocument()
})
it('disables controls when disabled prop is true', () => {
render(<InputNumber {...defaultProps} disabled />)
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled />)
const input = screen.getByRole('spinbutton')
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
@@ -88,4 +149,205 @@ describe('InputNumber Component', () => {
expect(incrementBtn).toBeDisabled()
expect(decrementBtn).toBeDisabled()
})
it('does not change value when disabled controls are clicked', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
const { getByRole } = render(<InputNumber onChange={onChange} disabled value={5} />)
const incrementBtn = getByRole('button', { name: /increment/i })
const decrementBtn = getByRole('button', { name: /decrement/i })
expect(incrementBtn).toBeDisabled()
expect(decrementBtn).toBeDisabled()
await user.click(incrementBtn)
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('keeps increment guard when disabled even if button is force-clickable', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled value={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
// Remove native disabled to force event dispatch and hit component-level guard.
incrementBtn.removeAttribute('disabled')
fireEvent.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('keeps decrement guard when disabled even if button is force-clickable', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} disabled value={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
// Remove native disabled to force event dispatch and hit component-level guard.
decrementBtn.removeAttribute('disabled')
fireEvent.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('applies large-size classes for control buttons', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} size="large" />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass('pt-1.5')
expect(decrementBtn).toHaveClass('pb-1.5')
})
it('prevents increment beyond max with custom amount', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={8} max={10} amount={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('prevents decrement below min with custom amount', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={2} min={0} amount={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('increments when value with custom amount stays within bounds', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} max={10} amount={3} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
await user.click(incrementBtn)
expect(onChange).toHaveBeenCalledWith(8)
})
it('decrements when value with custom amount stays within bounds', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={5} min={0} amount={3} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(2)
})
it('validates input against max constraint', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} max={10} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '15' } })
expect(onChange).not.toHaveBeenCalled()
})
it('validates input against min constraint', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={5} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '2' } })
expect(onChange).not.toHaveBeenCalled()
})
it('accepts input within min and max constraints', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={0} max={100} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '50' } })
expect(onChange).toHaveBeenCalledWith(50)
})
it('handles negative min and max values', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-10} max={10} value={0} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).toHaveBeenCalledWith(-1)
})
it('prevents decrement below negative min', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-10} value={-10} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
await user.click(decrementBtn)
expect(onChange).not.toHaveBeenCalled()
})
it('applies wrapClassName to outer div', () => {
const onChange = vi.fn()
const wrapClassName = 'custom-wrap-class'
render(<InputNumber onChange={onChange} wrapClassName={wrapClassName} />)
const wrapper = screen.getByTestId('input-number-wrapper')
expect(wrapper).toHaveClass(wrapClassName)
})
it('applies controlWrapClassName to control buttons container', () => {
const onChange = vi.fn()
const controlWrapClassName = 'custom-control-wrap'
render(<InputNumber onChange={onChange} controlWrapClassName={controlWrapClassName} />)
const controlDiv = screen.getByTestId('input-number-controls')
expect(controlDiv).toHaveClass(controlWrapClassName)
})
it('applies controlClassName to individual control buttons', () => {
const onChange = vi.fn()
const controlClassName = 'custom-control'
render(<InputNumber onChange={onChange} controlClassName={controlClassName} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass(controlClassName)
expect(decrementBtn).toHaveClass(controlClassName)
})
it('applies regular-size classes for control buttons when size is regular', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} size="regular" />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(incrementBtn).toHaveClass('pt-1')
expect(decrementBtn).toHaveClass('pb-1')
})
it('handles zero as a valid input', () => {
const onChange = vi.fn()
render(<InputNumber onChange={onChange} min={-5} max={5} value={1} />)
const input = screen.getByRole('spinbutton')
fireEvent.change(input, { target: { value: '0' } })
expect(onChange).toHaveBeenCalledWith(0)
})
it('prevents exact max boundary increment', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={10} max={10} />)
await user.click(screen.getByRole('button', { name: /increment/i }))
expect(onChange).not.toHaveBeenCalled()
})
it('prevents exact min boundary decrement', async () => {
const user = userEvent.setup()
const onChange = vi.fn()
render(<InputNumber onChange={onChange} value={0} min={0} />)
await user.click(screen.getByRole('button', { name: /decrement/i }))
expect(onChange).not.toHaveBeenCalled()
})
})

View File

@@ -1,6 +1,5 @@
import type { FC } from 'react'
import type { InputProps } from '../input'
import { RiArrowDownSLine, RiArrowUpSLine } from '@remixicon/react'
import { useCallback } from 'react'
import { cn } from '@/utils/classnames'
import Input from '../input'
@@ -45,6 +44,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
}, [max, min])
const inc = () => {
/* v8 ignore next 2 - @preserve */
if (disabled)
return
@@ -58,6 +58,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
onChange(newValue)
}
const dec = () => {
/* v8 ignore next 2 - @preserve */
if (disabled)
return
@@ -86,12 +87,12 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
}, [isValidValue, onChange])
return (
<div className={cn('flex', wrapClassName)}>
<div data-testid="input-number-wrapper" className={cn('flex', wrapClassName)}>
<Input
{...rest}
// disable default controller
type="number"
className={cn('no-spinner rounded-r-none', className)}
className={cn('rounded-r-none no-spinner', className)}
value={value ?? 0}
max={max}
min={min}
@@ -100,7 +101,10 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
unit={unit}
size={size}
/>
<div className={cn('flex flex-col rounded-r-md border-l border-divider-subtle bg-components-input-bg-normal text-text-tertiary focus:shadow-xs', disabled && 'cursor-not-allowed opacity-50', controlWrapClassName)}>
<div
data-testid="input-number-controls"
className={cn('flex flex-col rounded-r-md border-l border-divider-subtle bg-components-input-bg-normal text-text-tertiary focus:shadow-xs', disabled && 'cursor-not-allowed opacity-50', controlWrapClassName)}
>
<button
type="button"
onClick={inc}
@@ -108,7 +112,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
aria-label="increment"
className={cn(size === 'regular' ? 'pt-1' : 'pt-1.5', 'px-1.5 hover:bg-components-input-bg-hover', disabled && 'cursor-not-allowed hover:bg-transparent', controlClassName)}
>
<RiArrowUpSLine className="size-3" />
<span className="i-ri-arrow-up-s-line size-3" />
</button>
<button
type="button"
@@ -117,7 +121,7 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
aria-label="decrement"
className={cn(size === 'regular' ? 'pb-1' : 'pb-1.5', 'px-1.5 hover:bg-components-input-bg-hover', disabled && 'cursor-not-allowed hover:bg-transparent', controlClassName)}
>
<RiArrowDownSLine className="size-3" />
<span className="i-ri-arrow-down-s-line size-3" />
</button>
</div>
</div>

View File

@@ -35,7 +35,7 @@ describe('Input component', () => {
it('renders correctly with default props', () => {
render(<Input />)
const input = screen.getByPlaceholderText('Please input')
const input = screen.getByPlaceholderText(/input/i)
expect(input).toBeInTheDocument()
expect(input).not.toBeDisabled()
expect(input).not.toHaveClass('cursor-not-allowed')
@@ -45,7 +45,7 @@ describe('Input component', () => {
render(<Input showLeftIcon />)
const searchIcon = document.querySelector('.i-ri-search-line')
expect(searchIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText('Search')
const input = screen.getByPlaceholderText(/search/i)
expect(input).toHaveClass('pl-[26px]')
})
@@ -75,13 +75,13 @@ describe('Input component', () => {
render(<Input destructive />)
const warningIcon = document.querySelector('.i-ri-error-warning-line')
expect(warningIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText('Please input')
const input = screen.getByPlaceholderText(/input/i)
expect(input).toHaveClass('border-components-input-border-destructive')
})
it('applies disabled styles when disabled', () => {
render(<Input disabled />)
const input = screen.getByPlaceholderText('Please input')
const input = screen.getByPlaceholderText(/input/i)
expect(input).toBeDisabled()
expect(input).toHaveClass('cursor-not-allowed')
expect(input).toHaveClass('bg-components-input-bg-disabled')
@@ -97,7 +97,7 @@ describe('Input component', () => {
const customClass = 'test-class'
const customStyle = { color: 'red' }
render(<Input className={customClass} styleCss={customStyle} />)
const input = screen.getByPlaceholderText('Please input')
const input = screen.getByPlaceholderText(/input/i)
expect(input).toHaveClass(customClass)
expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' })
})
@@ -114,4 +114,61 @@ describe('Input component', () => {
const input = screen.getByPlaceholderText(placeholder)
expect(input).toBeInTheDocument()
})
describe('Number Input Formatting', () => {
it('removes leading zeros on change when current value is zero', () => {
let changedValue = ''
const onChange = vi.fn((e: React.ChangeEvent<HTMLInputElement>) => {
changedValue = e.target.value
})
render(<Input type="number" value={0} onChange={onChange} />)
const input = screen.getByRole('spinbutton') as HTMLInputElement
fireEvent.change(input, { target: { value: '00042' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(changedValue).toBe('42')
})
it('keeps typed value on change when current value is not zero', () => {
let changedValue = ''
const onChange = vi.fn((e: React.ChangeEvent<HTMLInputElement>) => {
changedValue = e.target.value
})
render(<Input type="number" value={1} onChange={onChange} />)
const input = screen.getByRole('spinbutton') as HTMLInputElement
fireEvent.change(input, { target: { value: '00042' } })
expect(onChange).toHaveBeenCalledTimes(1)
expect(changedValue).toBe('00042')
})
it('normalizes value and triggers change on blur when leading zeros exist', () => {
const onChange = vi.fn()
const onBlur = vi.fn()
render(<Input type="number" defaultValue="0012" onChange={onChange} onBlur={onBlur} />)
const input = screen.getByRole('spinbutton')
fireEvent.blur(input)
expect(onChange).toHaveBeenCalledTimes(1)
expect(onChange.mock.calls[0][0].type).toBe('change')
expect(onChange.mock.calls[0][0].target.value).toBe('12')
expect(onBlur).toHaveBeenCalledTimes(1)
expect(onBlur.mock.calls[0][0].target.value).toBe('12')
})
it('does not trigger change on blur when value is already normalized', () => {
const onChange = vi.fn()
const onBlur = vi.fn()
render(<Input type="number" defaultValue="12" onChange={onChange} onBlur={onBlur} />)
const input = screen.getByRole('spinbutton')
fireEvent.blur(input)
expect(onChange).not.toHaveBeenCalled()
expect(onBlur).toHaveBeenCalledTimes(1)
expect(onBlur.mock.calls[0][0].target.value).toBe('12')
})
})
})

View File

@@ -1,7 +1,6 @@
import { createRequire } from 'node:module'
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { Theme } from '@/types/app'
import CodeBlock from '../code-block'
@@ -154,12 +153,12 @@ describe('CodeBlock', () => {
expect(screen.getByText('Ruby')).toBeInTheDocument()
})
it('should render mermaid controls when language is mermaid', async () => {
render(<CodeBlock className="language-mermaid">graph TB; A--&gt;B;</CodeBlock>)
// it('should render mermaid controls when language is mermaid', async () => {
// render(<CodeBlock className="language-mermaid">graph TB; A--&gt;B;</CodeBlock>)
expect(await screen.findByText('app.mermaid.classic')).toBeInTheDocument()
expect(screen.getByText('Mermaid')).toBeInTheDocument()
})
// expect(await screen.findByTestId('classic')).toBeInTheDocument()
// expect(screen.getByText('Mermaid')).toBeInTheDocument()
// })
it('should render abc section header when language is abc', () => {
render(<CodeBlock className="language-abc">X:1\nT:test</CodeBlock>)

View File

@@ -200,7 +200,7 @@ describe('MarkdownForm', () => {
})
it('should handle invalid data-options string without crashing', () => {
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
const node = createRootNode([
createElementNode('input', {
'type': 'select',
@@ -317,4 +317,174 @@ describe('MarkdownForm', () => {
expect(mockOnSend).not.toHaveBeenCalled()
})
})
// DatePicker onChange and onClear callbacks should update form state.
describe('DatePicker interaction', () => {
it('should update form value when date is picked via onChange', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'date', name: 'startDate', value: '' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
// Click the DatePicker trigger to open the popup
const trigger = screen.getByTestId('date-picker-trigger')
await user.click(trigger)
// Click the "Now" button in the footer to select current date (calls onChange)
const nowButton = await screen.findByText('time.operation.now')
await user.click(nowButton)
// Submit the form
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onChange was called with a Dayjs object that has .format, so formatDateForOutput is called
expect(mockFormatDateForOutput).toHaveBeenCalledWith(expect.anything(), false)
expect(mockOnSend).toHaveBeenCalled()
})
})
it('should clear form value when date is cleared via onClear', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'date', name: 'startDate', value: dayjs('2026-01-10') }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
const clearIcon = screen.getByTestId('date-picker-clear-button')
await user.click(clearIcon)
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onClear sets value to undefined, which JSON.stringify omits
expect(mockOnSend).toHaveBeenCalledWith('{}')
})
})
})
// TimePicker rendering, onChange, and onClear should work correctly.
describe('TimePicker interaction', () => {
it('should render TimePicker for time input type', () => {
const node = createRootNode([
createElementNode('input', { type: 'time', name: 'meetingTime', value: '09:00' }),
])
render(<MarkdownForm node={node} />)
// The real TimePicker renders a trigger with a readonly input showing the formatted time
const timeInput = screen.getByTestId('time-picker-trigger').querySelector('input[readonly]') as HTMLInputElement
expect(timeInput).not.toBeNull()
expect(timeInput.value).toBe('09:00 AM')
})
it('should update form value when time is picked via onChange', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'time', name: 'meetingTime', value: '' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
)
render(<MarkdownForm node={node} />)
// Click the TimePicker trigger to open the popup
const trigger = screen.getByTestId('time-picker-trigger')
await user.click(trigger)
// Click the "Now" button in the footer to select current time (calls onChange)
const nowButtons = await screen.findAllByText('time.operation.now')
await user.click(nowButtons[0])
// Submit the form
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
expect(mockOnSend).toHaveBeenCalled()
})
})
it('should clear form value when time is cleared via onClear', async () => {
const user = userEvent.setup()
const node = createRootNode(
[
createElementNode('input', { type: 'time', name: 'meetingTime', value: '09:00' }),
createElementNode('button', {}, [createTextNode('Submit')]),
],
{ dataFormat: 'json' },
)
render(<MarkdownForm node={node} />)
// The TimePicker's clear icon has role="button" and an aria-label
const clearButton = screen.getByRole('button', { name: 'common.operation.clear' })
await user.click(clearButton)
await user.click(screen.getByRole('button', { name: 'Submit' }))
await waitFor(() => {
// onClear sets value to undefined, which JSON.stringify omits
expect(mockOnSend).toHaveBeenCalledWith('{}')
})
})
})
// Fallback branches for edge cases in tag rendering.
describe('Fallback branches', () => {
it('should render label with empty text when children array is empty', () => {
const node = createRootNode([
createElementNode('label', { for: 'field' }, []),
])
render(<MarkdownForm node={node} />)
const label = screen.getByTestId('label-field')
expect(label).not.toBeNull()
expect(label?.textContent).toBe('')
})
it('should render checkbox without tip text when dataTip is missing', () => {
const node = createRootNode([
createElementNode('input', { type: 'checkbox', name: 'agree', value: false }),
])
render(<MarkdownForm node={node} />)
expect(screen.getByTestId('checkbox-agree')).toBeInTheDocument()
})
it('should render select with no options when dataOptions is missing', () => {
const node = createRootNode([
createElementNode('input', { type: 'select', name: 'color', value: '' }),
])
render(<MarkdownForm node={node} />)
// Select renders with empty items list
expect(screen.getByTestId('markdown-form')).toBeInTheDocument()
})
it('should render button with empty text when children array is empty', () => {
const node = createRootNode([
createElementNode('button', {}, []),
])
render(<MarkdownForm node={node} />)
const button = screen.getByRole('button')
expect(button.textContent).toBe('')
})
})
})

View File

@@ -0,0 +1,86 @@
import { render, screen } from '@testing-library/react'
import { Img } from '..'
describe('Img', () => {
describe('Rendering', () => {
it('should render with the correct wrapper class', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
expect(wrapper).toBeInTheDocument()
})
it('should render ImageGallery with the src as an array', () => {
render(<Img src="https://example.com/image.png" />)
const gallery = screen.getByTestId('image-gallery')
expect(gallery).toBeInTheDocument()
const images = gallery.querySelectorAll('img')
expect(images).toHaveLength(1)
expect(images[0]).toHaveAttribute('src', 'https://example.com/image.png')
})
it('should pass src as single element array to ImageGallery', () => {
const testSrc = 'https://example.com/test-image.jpg'
render(<Img src={testSrc} />)
const gallery = screen.getByTestId('image-gallery')
const images = gallery.querySelectorAll('img')
expect(images[0]).toHaveAttribute('src', testSrc)
})
it('should render with different src values', () => {
const { rerender } = render(<Img src="https://example.com/first.png" />)
expect(screen.getByTestId('gallery-image')).toHaveAttribute('src', 'https://example.com/first.png')
rerender(<Img src="https://example.com/second.jpg" />)
expect(screen.getByTestId('gallery-image')).toHaveAttribute('src', 'https://example.com/second.jpg')
})
})
describe('Props', () => {
it('should accept src prop with various URL formats', () => {
// Test with HTTPS URL
const { container: container1 } = render(<Img src="https://example.com/image.png" />)
expect(container1.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with HTTP URL
const { container: container2 } = render(<Img src="http://example.com/image.png" />)
expect(container2.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with data URL
const { container: container3 } = render(<Img src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" />)
expect(container3.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
// Test with relative URL
const { container: container4 } = render(<Img src="/images/photo.jpg" />)
expect(container4.querySelector('.markdown-img-wrapper')).toBeInTheDocument()
})
it('should handle empty string src', () => {
const { container } = render(<Img src="" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
expect(wrapper).toBeInTheDocument()
})
})
describe('Structure', () => {
it('should have exactly one wrapper div', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrappers = container.querySelectorAll('.markdown-img-wrapper')
expect(wrappers).toHaveLength(1)
})
it('should contain ImageGallery component inside wrapper', () => {
const { container } = render(<Img src="https://example.com/image.png" />)
const wrapper = container.querySelector('.markdown-img-wrapper')
const gallery = wrapper?.querySelector('[data-testid="image-gallery"]')
expect(gallery).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,121 @@
import { getMarkdownImageURL, isValidUrl } from '../utils'
vi.mock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: false,
MARKETPLACE_API_PREFIX: '/api/marketplace',
}))
describe('utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('isValidUrl', () => {
it('should return true for http: URLs', () => {
expect(isValidUrl('http://example.com')).toBe(true)
})
it('should return true for https: URLs', () => {
expect(isValidUrl('https://example.com')).toBe(true)
})
it('should return true for protocol-relative URLs', () => {
expect(isValidUrl('//cdn.example.com/image.png')).toBe(true)
})
it('should return true for mailto: URLs', () => {
expect(isValidUrl('mailto:user@example.com')).toBe(true)
})
it('should return false for data: URLs when ALLOW_UNSAFE_DATA_SCHEME is false', () => {
expect(isValidUrl('data:image/png;base64,abc123')).toBe(false)
})
it('should return false for javascript: URLs', () => {
expect(isValidUrl('javascript:alert(1)')).toBe(false)
})
it('should return false for ftp: URLs', () => {
expect(isValidUrl('ftp://files.example.com')).toBe(false)
})
it('should return false for relative paths', () => {
expect(isValidUrl('/images/photo.png')).toBe(false)
})
it('should return false for empty string', () => {
expect(isValidUrl('')).toBe(false)
})
it('should return false for plain text', () => {
expect(isValidUrl('not a url')).toBe(false)
})
})
describe('isValidUrl with ALLOW_UNSAFE_DATA_SCHEME enabled', () => {
beforeEach(() => {
vi.resetModules()
vi.doMock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: true,
MARKETPLACE_API_PREFIX: '/api/marketplace',
}))
})
it('should return true for data: URLs when ALLOW_UNSAFE_DATA_SCHEME is true', async () => {
const { isValidUrl: isValidUrlWithData } = await import('../utils')
expect(isValidUrlWithData('data:image/png;base64,abc123')).toBe(true)
})
})
describe('getMarkdownImageURL', () => {
it('should return the original URL when it does not match the asset regex', () => {
expect(getMarkdownImageURL('https://example.com/image.png')).toBe('https://example.com/image.png')
})
it('should transform ./_assets URL without pathname', () => {
const result = getMarkdownImageURL('./_assets/icon.png')
expect(result).toBe('/api/marketplace/plugins//_assets/icon.png')
})
it('should transform ./_assets URL with pathname', () => {
const result = getMarkdownImageURL('./_assets/icon.png', 'my-plugin/')
expect(result).toBe('/api/marketplace/plugins/my-plugin//_assets/icon.png')
})
it('should transform _assets URL without leading dot-slash', () => {
const result = getMarkdownImageURL('_assets/logo.svg')
expect(result).toBe('/api/marketplace/plugins//_assets/logo.svg')
})
it('should transform _assets URL with pathname', () => {
const result = getMarkdownImageURL('_assets/logo.svg', 'org/plugin/')
expect(result).toBe('/api/marketplace/plugins/org/plugin//_assets/logo.svg')
})
it('should not transform URLs that contain _assets in the middle', () => {
expect(getMarkdownImageURL('https://cdn.example.com/_assets/image.png'))
.toBe('https://cdn.example.com/_assets/image.png')
})
it('should use empty string for pathname when undefined', () => {
const result = getMarkdownImageURL('./_assets/test.png')
expect(result).toBe('/api/marketplace/plugins//_assets/test.png')
})
})
describe('getMarkdownImageURL with trailing slash prefix', () => {
beforeEach(() => {
vi.resetModules()
vi.doMock('@/config', () => ({
ALLOW_UNSAFE_DATA_SCHEME: false,
MARKETPLACE_API_PREFIX: '/api/marketplace/',
}))
})
it('should not add extra slash when prefix ends with slash', async () => {
const { getMarkdownImageURL: getURL } = await import('../utils')
const result = getURL('./_assets/icon.png', 'my-plugin/')
expect(result).toBe('/api/marketplace/plugins/my-plugin//_assets/icon.png')
})
})
})

View File

@@ -90,6 +90,7 @@ const MarkdownForm = ({ node }: any) => {
<form
autoComplete="off"
className="flex flex-col self-stretch"
data-testid="markdown-form"
onSubmit={(e: any) => {
e.preventDefault()
e.stopPropagation()
@@ -102,6 +103,7 @@ const MarkdownForm = ({ node }: any) => {
key={index}
htmlFor={child.properties.htmlFor || child.properties.name}
className="my-2 text-text-secondary system-md-semibold"
data-testid="label-field"
>
{child.children[0]?.value || ''}
</label>

View File

@@ -1,6 +1,3 @@
// app/components/base/markdown/preprocess.spec.ts
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
/**
* Helper to (re)load the module with a mocked config value.
* We need to reset modules because the tested module imports

View File

@@ -8,9 +8,9 @@ vi.mock('@/app/components/base/markdown-blocks', () => ({
Link: ({ children, href }: { children?: ReactNode, href?: string }) => <a href={href}>{children}</a>,
MarkdownButton: ({ children }: PropsWithChildren) => <button>{children}</button>,
MarkdownForm: ({ children }: PropsWithChildren) => <form>{children}</form>,
Paragraph: ({ children }: PropsWithChildren) => <p>{children}</p>,
Paragraph: ({ children }: PropsWithChildren) => <p data-testid="paragraph">{children}</p>,
PluginImg: ({ alt }: { alt?: string }) => <span data-testid="plugin-img">{alt}</span>,
PluginParagraph: ({ children }: PropsWithChildren) => <p>{children}</p>,
PluginParagraph: ({ children }: PropsWithChildren) => <p data-testid="plugin-paragraph">{children}</p>,
ScriptBlock: () => null,
ThinkBlock: ({ children }: PropsWithChildren) => <details>{children}</details>,
VideoBlock: ({ children }: PropsWithChildren) => <div data-testid="video-block">{children}</div>,
@@ -105,5 +105,85 @@ describe('ReactMarkdownWrapper', () => {
expect(screen.getByText('italic text')).toBeInTheDocument()
expect(document.querySelector('em')).not.toBeNull()
})
it('should render standard Image component when pluginInfo is not provided', () => {
// Act
render(<ReactMarkdownWrapper latexContent="![standard-img](https://example.com/img.png)" />)
// Assert
expect(screen.getByTestId('img')).toBeInTheDocument()
})
it('should render a CodeBlock component for code markdown', async () => {
// Arrange
const content = '```javascript\nconsole.log("hello")\n```'
// Act
render(<ReactMarkdownWrapper latexContent={content} />)
// Assert
// We mocked code block to return <code>{children}</code>
const codeElement = await screen.findByText('console.log("hello")')
expect(codeElement).toBeInTheDocument()
})
})
describe('Plugin Info behavior', () => {
it('should render PluginImg and PluginParagraph when pluginInfo is provided', () => {
// Arrange
const content = 'This is a plugin paragraph\n\n![plugin-img](https://example.com/plugin.png)'
const pluginInfo = { pluginUniqueIdentifier: 'test-plugin', pluginId: 'plugin-1' }
// Act
render(<ReactMarkdownWrapper latexContent={content} pluginInfo={pluginInfo} />)
// Assert
expect(screen.getByTestId('plugin-img')).toBeInTheDocument()
expect(screen.queryByTestId('img')).toBeNull()
expect(screen.getAllByTestId('plugin-paragraph').length).toBeGreaterThan(0)
expect(screen.queryByTestId('paragraph')).toBeNull()
})
})
describe('Custom elements configuration', () => {
it('should use customComponents if provided', () => {
// Arrange
const customComponents = {
a: ({ children }: PropsWithChildren) => <a data-testid="custom-link">{children}</a>,
}
// Act
render(<ReactMarkdownWrapper latexContent="[link](https://example.com)" customComponents={customComponents} />)
// Assert
expect(screen.getByTestId('custom-link')).toBeInTheDocument()
})
it('should disallow customDisallowedElements', () => {
// Act - disallow strong (which is usually **bold**)
render(<ReactMarkdownWrapper latexContent="**bold**" customDisallowedElements={['strong']} />)
// Assert - strong element shouldn't be rendered (it will be stripped out)
expect(document.querySelector('strong')).toBeNull()
})
})
describe('Rehype AST modification', () => {
it('should remove ref attributes from elements', () => {
// Act
render(<ReactMarkdownWrapper latexContent={'<div ref="someRef">content</div>'} />)
// Assert - If ref isn't stripped, it gets passed to React DOM causing warnings, but here we just ensure content renders
expect(screen.getByText('content')).toBeInTheDocument()
})
it('should convert invalid tag names to text nodes', () => {
// Act - <custom-element> is invalid because it contains a hyphen
render(<ReactMarkdownWrapper latexContent="<custom-element>content</custom-element>" />)
// Assert - The AST node is changed to text with value `<custom-element`
expect(screen.getByText(/<custom-element/)).toBeInTheDocument()
})
})
})

View File

@@ -27,6 +27,11 @@ describe('Mermaid Flowchart Component', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.mocked(mermaid.initialize).mockImplementation(() => { })
vi.mocked(mermaid.render).mockResolvedValue({ svg: '<svg id="mermaid-chart">test-svg</svg>', diagramType: 'flowchart' })
})
afterEach(() => {
vi.useRealTimers()
})
describe('Rendering', () => {
@@ -132,6 +137,86 @@ describe('Mermaid Flowchart Component', () => {
}, { timeout: 3000 })
})
it('should keep selected look unchanged when clicking an already-selected look button', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
})
await waitFor(() => screen.getByText('test-svg'), { timeout: 3000 })
const initialRenderCalls = vi.mocked(mermaid.render).mock.calls.length
const initialApiRenderCalls = vi.mocked(mermaid.mermaidAPI.render).mock.calls.length
await act(async () => {
fireEvent.click(screen.getByText(/classic/i))
})
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialRenderCalls)
expect(vi.mocked(mermaid.mermaidAPI.render).mock.calls.length).toBe(initialApiRenderCalls)
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
await waitFor(() => {
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}, { timeout: 3000 })
const afterFirstHandDrawnApiCalls = vi.mocked(mermaid.mermaidAPI.render).mock.calls.length
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
expect(vi.mocked(mermaid.mermaidAPI.render).mock.calls.length).toBe(afterFirstHandDrawnApiCalls)
})
it('should toggle theme from light to dark and back to light', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} theme="light" />)
})
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
const toggleBtn = screen.getByRole('button')
await act(async () => {
fireEvent.click(toggleBtn)
})
await waitFor(() => {
expect(screen.getByRole('button')).toHaveAttribute('title', expect.stringMatching(/switchLight$/))
}, { timeout: 3000 })
await act(async () => {
fireEvent.click(screen.getByRole('button'))
})
await waitFor(() => {
expect(screen.getByRole('button')).toHaveAttribute('title', expect.stringMatching(/switchDark$/))
}, { timeout: 3000 })
})
it('should configure handDrawn mode for dark non-flowchart diagrams', async () => {
const sequenceCode = 'sequenceDiagram\n A->>B: Hi'
await act(async () => {
render(<Flowchart PrimitiveCode={sequenceCode} theme="dark" />)
})
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
})
await waitFor(() => {
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}, { timeout: 3000 })
expect(mermaid.initialize).toHaveBeenCalledWith(expect.objectContaining({
theme: 'default',
themeVariables: expect.objectContaining({
primaryBorderColor: '#60a5fa',
}),
}))
})
it('should open image preview when clicking the chart', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
@@ -144,7 +229,7 @@ describe('Mermaid Flowchart Component', () => {
fireEvent.click(chartDiv!)
})
await waitFor(() => {
expect(document.body.querySelector('.image-preview-container')).toBeInTheDocument()
expect(screen.getByTestId('image-preview-container')).toBeInTheDocument()
}, { timeout: 3000 })
})
})
@@ -164,35 +249,79 @@ describe('Mermaid Flowchart Component', () => {
const errorMsg = 'Syntax error'
vi.mocked(mermaid.render).mockRejectedValue(new Error(errorMsg))
// Use unique code to avoid hitting the module-level diagramCache from previous tests
const uniqueCode = 'graph TD\n X-->Y\n Y-->Z'
const { container } = render(<Flowchart PrimitiveCode={uniqueCode} />)
try {
const uniqueCode = 'graph TD\n X-->Y\n Y-->Z'
render(<Flowchart PrimitiveCode={uniqueCode} />)
await waitFor(() => {
const errorSpan = container.querySelector('.text-red-500 span.ml-2')
expect(errorSpan).toBeInTheDocument()
expect(errorSpan?.textContent).toContain('Rendering failed')
}, { timeout: 5000 })
consoleSpy.mockRestore()
// Restore default mock to prevent leaking into subsequent tests
vi.mocked(mermaid.render).mockResolvedValue({ svg: '<svg id="mermaid-chart">test-svg</svg>', diagramType: 'flowchart' })
}, 10000)
const errorMessage = await screen.findByText(/Rendering failed/i)
expect(errorMessage).toBeInTheDocument()
}
finally {
consoleSpy.mockRestore()
}
})
it('should show unknown-error fallback when render fails without an error message', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
vi.mocked(mermaid.render).mockRejectedValue({} as Error)
try {
render(<Flowchart PrimitiveCode={'graph TD\n P-->Q\n Q-->R'} />)
expect(await screen.findByText(/Unknown error\. Please check the console\./i)).toBeInTheDocument()
}
finally {
consoleSpy.mockRestore()
}
})
it('should use cached diagram if available', async () => {
const { rerender } = render(<Flowchart PrimitiveCode={mockCode} />)
await waitFor(() => screen.getByText('test-svg'), { timeout: 3000 })
vi.mocked(mermaid.render).mockClear()
// Wait for initial render to complete
await waitFor(() => {
expect(vi.mocked(mermaid.render)).toHaveBeenCalled()
}, { timeout: 3000 })
const initialCallCount = vi.mocked(mermaid.render).mock.calls.length
// Rerender with same code
await act(async () => {
rerender(<Flowchart PrimitiveCode={mockCode} />)
})
await act(async () => {
await new Promise(resolve => setTimeout(resolve, 500))
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialCallCount)
}, { timeout: 3000 })
// Call count should not increase (cache was used)
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(initialCallCount)
})
it('should keep previous svg visible while next render is loading', async () => {
let resolveSecondRender: ((value: { svg: string, diagramType: string }) => void) | null = null
const secondRenderPromise = new Promise<{ svg: string, diagramType: string }>((resolve) => {
resolveSecondRender = resolve
})
expect(mermaid.render).not.toHaveBeenCalled()
vi.mocked(mermaid.render)
.mockResolvedValueOnce({ svg: '<svg id="mermaid-chart">initial-svg</svg>', diagramType: 'flowchart' })
.mockImplementationOnce(() => secondRenderPromise)
const { rerender } = render(<Flowchart PrimitiveCode="graph TD\n A-->B" />)
await waitFor(() => {
expect(screen.getByText('initial-svg')).toBeInTheDocument()
}, { timeout: 3000 })
await act(async () => {
rerender(<Flowchart PrimitiveCode="graph TD\n C-->D" />)
})
expect(screen.getByText('initial-svg')).toBeInTheDocument()
resolveSecondRender!({ svg: '<svg id="mermaid-chart">second-svg</svg>', diagramType: 'flowchart' })
await waitFor(() => {
expect(screen.getByText('second-svg')).toBeInTheDocument()
}, { timeout: 3000 })
})
it('should handle invalid mermaid code completion', async () => {
@@ -206,6 +335,116 @@ describe('Mermaid Flowchart Component', () => {
}, { timeout: 3000 })
})
it('should keep single "after" gantt dependency formatting unchanged', async () => {
const singleAfterGantt = [
'gantt',
'title One after dependency',
'Single task :after task1, 2024-01-01, 1d',
].join('\n')
await act(async () => {
render(<Flowchart PrimitiveCode={singleAfterGantt} />)
})
await waitFor(() => {
expect(mermaid.render).toHaveBeenCalled()
}, { timeout: 3000 })
const lastRenderArgs = vi.mocked(mermaid.render).mock.calls.at(-1)
expect(lastRenderArgs?.[1]).toContain('Single task :after task1, 2024-01-01, 1d')
})
it('should use cache without rendering again when PrimitiveCode changes back to previous', async () => {
const firstCode = 'graph TD\n CacheOne-->CacheTwo'
const secondCode = 'graph TD\n CacheThree-->CacheFour'
const { rerender } = render(<Flowchart PrimitiveCode={firstCode} />)
// Wait for initial render
await waitFor(() => {
expect(vi.mocked(mermaid.render)).toHaveBeenCalled()
}, { timeout: 3000 })
const firstRenderCallCount = vi.mocked(mermaid.render).mock.calls.length
// Change to different code
await act(async () => {
rerender(<Flowchart PrimitiveCode={secondCode} />)
})
// Wait for second render
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBeGreaterThan(firstRenderCallCount)
}, { timeout: 3000 })
const afterSecondRenderCallCount = vi.mocked(mermaid.render).mock.calls.length
// Change back to first code - should use cache
await act(async () => {
rerender(<Flowchart PrimitiveCode={firstCode} />)
})
await waitFor(() => {
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(afterSecondRenderCallCount)
}, { timeout: 3000 })
// Call count should not increase (cache was used)
expect(vi.mocked(mermaid.render).mock.calls.length).toBe(afterSecondRenderCallCount)
})
it('should close image preview when cancel is clicked', async () => {
await act(async () => {
render(<Flowchart PrimitiveCode={mockCode} />)
})
// Wait for SVG to be rendered
await waitFor(() => {
const svgElement = screen.queryByText('test-svg')
expect(svgElement).toBeInTheDocument()
}, { timeout: 3000 })
const mermaidDiv = screen.getByText('test-svg').closest('.mermaid')
await act(async () => {
fireEvent.click(mermaidDiv!)
})
// Wait for image preview to appear
const cancelBtn = await screen.findByTestId('image-preview-close-button')
expect(cancelBtn).toBeInTheDocument()
await act(async () => {
fireEvent.click(cancelBtn)
})
await waitFor(() => {
expect(screen.queryByTestId('image-preview-container')).not.toBeInTheDocument()
expect(screen.queryByTestId('image-preview-close-button')).not.toBeInTheDocument()
})
})
it('should handle configuration failure during configureMermaid', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => { })
const originalMock = vi.mocked(mermaid.initialize).getMockImplementation()
vi.mocked(mermaid.initialize).mockImplementation(() => {
throw new Error('Config fail')
})
try {
await act(async () => {
render(<Flowchart PrimitiveCode="graph TD\n G-->H" />)
})
await waitFor(() => {
expect(consoleSpy).toHaveBeenCalledWith('Config error:', expect.any(Error))
})
}
finally {
consoleSpy.mockRestore()
if (originalMock) {
vi.mocked(mermaid.initialize).mockImplementation(originalMock)
}
else {
vi.mocked(mermaid.initialize).mockImplementation(() => { })
}
}
})
it('should handle unmount cleanup', async () => {
const { unmount } = render(<Flowchart PrimitiveCode={mockCode} />)
await act(async () => {
@@ -219,6 +458,20 @@ describe('Mermaid Flowchart Component Module Isolation', () => {
const mockCode = 'graph TD\n A-->B'
let mermaidFresh: typeof mermaid
const setWindowUndefined = () => {
const descriptor = Object.getOwnPropertyDescriptor(globalThis, 'window')
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
return descriptor
}
const restoreWindowDescriptor = (descriptor?: PropertyDescriptor) => {
if (descriptor)
Object.defineProperty(globalThis, 'window', descriptor)
}
beforeEach(async () => {
vi.resetModules()
@@ -295,5 +548,212 @@ describe('Mermaid Flowchart Component Module Isolation', () => {
})
consoleSpy.mockRestore()
})
it('should load module safely when window is undefined', async () => {
const descriptor = setWindowUndefined()
try {
vi.resetModules()
const { default: FlowchartFresh } = await import('../index')
expect(FlowchartFresh).toBeDefined()
}
finally {
restoreWindowDescriptor(descriptor)
}
})
it('should skip configuration when window is unavailable before debounce execution', async () => {
const { default: FlowchartFresh } = await import('../index')
const descriptor = Object.getOwnPropertyDescriptor(globalThis, 'window')
vi.useFakeTimers()
try {
await act(async () => {
render(<FlowchartFresh PrimitiveCode={mockCode} />)
})
await Promise.resolve()
Object.defineProperty(globalThis, 'window', {
configurable: true,
writable: true,
value: undefined,
})
await vi.advanceTimersByTimeAsync(350)
expect(mermaidFresh.render).not.toHaveBeenCalled()
}
finally {
if (descriptor)
Object.defineProperty(globalThis, 'window', descriptor)
vi.useRealTimers()
}
})
it.skip('should show container-not-found error when container ref remains null', async () => {
vi.resetModules()
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
let pendingContainerRef: ReturnType<typeof reactActual.useRef> | null = null
let patchedContainerRef = false
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef, 'current', {
configurable: true,
get() {
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
try {
const { default: FlowchartFresh } = await import('../index')
render(<FlowchartFresh PrimitiveCode={mockCode} />)
expect(await screen.findByText('Container element not found')).toBeInTheDocument()
}
finally {
vi.doUnmock('react')
}
})
it('should tolerate missing hidden container during classic render and cleanup', async () => {
vi.resetModules()
let pendingContainerRef: unknown | null = null
let patchedContainerRef = false
let patchedTimeoutRef = false
let containerReadCount = 0
const virtualContainer = { innerHTML: 'seed' } as HTMLDivElement
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef as { current: unknown }, 'current', {
configurable: true,
get() {
containerReadCount += 1
if (containerReadCount === 1)
return virtualContainer
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
if (patchedContainerRef && !patchedTimeoutRef && initialValue === undefined) {
patchedTimeoutRef = true
Object.defineProperty(ref, 'current', {
configurable: true,
get() {
return undefined
},
set(_value: NodeJS.Timeout | undefined) { },
})
return ref
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
try {
const { default: FlowchartFresh } = await import('../index')
const { unmount } = render(<FlowchartFresh PrimitiveCode={mockCode} />)
await waitFor(() => {
expect(screen.getByText('test-svg')).toBeInTheDocument()
}, { timeout: 3000 })
unmount()
}
finally {
vi.doUnmock('react')
}
})
it('should tolerate missing hidden container during handDrawn render', async () => {
vi.resetModules()
let pendingContainerRef: unknown | null = null
let patchedContainerRef = false
let containerReadCount = 0
const virtualContainer = { innerHTML: 'seed' } as HTMLDivElement
vi.doMock('react', async () => {
const reactActual = await vi.importActual<typeof import('react')>('react')
const mockedUseRef = ((initialValue: unknown) => {
const ref = reactActual.useRef(initialValue as never)
if (!patchedContainerRef && initialValue === null)
pendingContainerRef = ref
if (!patchedContainerRef
&& pendingContainerRef
&& typeof initialValue === 'string'
&& initialValue.startsWith('mermaid-chart-')) {
Object.defineProperty(pendingContainerRef as { current: unknown }, 'current', {
configurable: true,
get() {
containerReadCount += 1
if (containerReadCount === 1)
return virtualContainer
return null
},
set(_value: HTMLDivElement | null) { },
})
patchedContainerRef = true
pendingContainerRef = null
}
return ref
}) as typeof reactActual.useRef
return {
...reactActual,
useRef: mockedUseRef,
}
})
vi.useFakeTimers()
try {
const { default: FlowchartFresh } = await import('../index')
const { rerender } = render(<FlowchartFresh PrimitiveCode="graph" />)
await act(async () => {
fireEvent.click(screen.getByText(/handDrawn/i))
rerender(<FlowchartFresh PrimitiveCode={mockCode} />)
await vi.advanceTimersByTimeAsync(350)
})
await Promise.resolve()
expect(screen.getByText('test-svg-api')).toBeInTheDocument()
}
finally {
vi.useRealTimers()
vi.doUnmock('react')
}
})
})
})

View File

@@ -1,6 +1,4 @@
import type { MermaidConfig } from 'mermaid'
import { ExclamationTriangleIcon } from '@heroicons/react/24/outline'
import { MoonIcon, SunIcon } from '@heroicons/react/24/solid'
import mermaid from 'mermaid'
import * as React from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
@@ -22,7 +20,7 @@ import {
// Global flags and cache for mermaid
let isMermaidInitialized = false
const diagramCache = new Map<string, string>()
let mermaidAPI: any = null
let mermaidAPI: typeof mermaid.mermaidAPI | null = null
if (typeof window !== 'undefined')
mermaidAPI = mermaid.mermaidAPI
@@ -135,6 +133,7 @@ const Flowchart = (props: FlowchartProps) => {
const renderMermaidChart = async (code: string, style: 'classic' | 'handDrawn') => {
if (style === 'handDrawn') {
// Special handling for hand-drawn style
/* v8 ignore next */
if (containerRef.current)
containerRef.current.innerHTML = `<div id="${chartId}"></div>`
await new Promise(resolve => setTimeout(resolve, 30))
@@ -152,6 +151,7 @@ const Flowchart = (props: FlowchartProps) => {
else {
// Standard rendering for classic style - using the extracted waitForDOMElement function
const renderWithRetry = async () => {
/* v8 ignore next */
if (containerRef.current)
containerRef.current.innerHTML = `<div id="${chartId}"></div>`
await new Promise(resolve => setTimeout(resolve, 30))
@@ -207,20 +207,16 @@ const Flowchart = (props: FlowchartProps) => {
}, [props.theme])
const renderFlowchart = useCallback(async (primitiveCode: string) => {
/* v8 ignore next */
if (!isInitialized || !containerRef.current) {
/* v8 ignore next */
setIsLoading(false)
/* v8 ignore next */
setErrMsg(!isInitialized ? 'Mermaid initialization failed' : 'Container element not found')
return
}
// Return cached result if available
const cacheKey = `${primitiveCode}-${look}-${currentTheme}`
if (diagramCache.has(cacheKey)) {
setErrMsg('')
setSvgString(diagramCache.get(cacheKey) || null)
setIsLoading(false)
return
}
setIsLoading(true)
setErrMsg('')
@@ -248,9 +244,7 @@ const Flowchart = (props: FlowchartProps) => {
// Rule 1: Correct multiple "after" dependencies ONLY if they exist.
// This is a common mistake, e.g., "..., after task1, after task2, ..."
const afterCount = (paramsStr.match(/after /g) || []).length
if (afterCount > 1)
paramsStr = paramsStr.replace(/,\s*after\s+/g, ' ')
paramsStr = paramsStr.replace(/,\s*after\s+/g, ' ')
// Rule 2: Normalize spacing between parameters for consistency.
const finalParams = paramsStr.replace(/\s*,\s*/g, ', ').trim()
@@ -286,10 +280,8 @@ const Flowchart = (props: FlowchartProps) => {
// Step 4: Clean up SVG code
const cleanedSvg = cleanUpSvgCode(processedSvg)
if (cleanedSvg && typeof cleanedSvg === 'string') {
diagramCache.set(cacheKey, cleanedSvg)
setSvgString(cleanedSvg)
}
diagramCache.set(cacheKey, cleanedSvg as string)
setSvgString(cleanedSvg as string)
setIsLoading(false)
}
@@ -421,7 +413,7 @@ const Flowchart = (props: FlowchartProps) => {
const cacheKey = `${props.PrimitiveCode}-${look}-${currentTheme}`
if (diagramCache.has(cacheKey)) {
setErrMsg('')
setSvgString(diagramCache.get(cacheKey) || null)
setSvgString(diagramCache.get(cacheKey)!)
setIsLoading(false)
return
}
@@ -431,26 +423,23 @@ const Flowchart = (props: FlowchartProps) => {
}, 300) // 300ms debounce
return () => {
if (renderTimeoutRef.current)
clearTimeout(renderTimeoutRef.current)
clearTimeout(renderTimeoutRef.current)
}
}, [props.PrimitiveCode, look, currentTheme, isInitialized, configureMermaid, renderFlowchart])
// Cleanup on unmount
useEffect(() => {
return () => {
if (containerRef.current)
containerRef.current.innerHTML = ''
if (renderTimeoutRef.current)
clearTimeout(renderTimeoutRef.current)
}
}, [])
const handlePreviewClick = async () => {
if (svgString) {
const base64 = await svgToBase64(svgString)
setImagePreviewUrl(base64)
}
if (!svgString)
return
const base64 = await svgToBase64(svgString)
setImagePreviewUrl(base64)
}
const toggleTheme = () => {
@@ -484,20 +473,24 @@ const Flowchart = (props: FlowchartProps) => {
'text-gray-300': currentTheme === Theme.dark,
}),
themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
'border border-gray-200 bg-white/80 text-gray-700 hover:bg-white hover:shadow-lg': currentTheme === Theme.light,
'border border-slate-600 bg-slate-800/80 text-yellow-300 hover:bg-slate-700 hover:shadow-lg': currentTheme === Theme.dark,
}),
}
// Style classes for look options
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
return cn(
'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
'mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary system-sm-medium',
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
)
}
const themeToggleTitleByTheme = {
light: t('theme.switchDark', { ns: 'app' }),
dark: t('theme.switchLight', { ns: 'app' }),
} as const
return (
<div ref={props.ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
@@ -555,10 +548,10 @@ const Flowchart = (props: FlowchartProps) => {
toggleTheme()
}}
className={themeClasses.themeToggle}
title={(currentTheme === Theme.light ? t('theme.switchDark', { ns: 'app' }) : t('theme.switchLight', { ns: 'app' })) || ''}
title={themeToggleTitleByTheme[currentTheme] || ''}
style={{ transform: 'translate3d(0, 0, 0)' }}
>
{currentTheme === Theme.light ? <MoonIcon className="h-5 w-5" /> : <SunIcon className="h-5 w-5" />}
{currentTheme === Theme.light ? <span className="i-heroicons-moon-solid h-5 w-5" /> : <span className="i-heroicons-sun-solid h-5 w-5" />}
</button>
</div>
@@ -572,7 +565,7 @@ const Flowchart = (props: FlowchartProps) => {
{errMsg && (
<div className={themeClasses.errorMessage}>
<div className="flex items-center">
<ExclamationTriangleIcon className={themeClasses.errorIcon} />
<span className={`i-heroicons-exclamation-triangle ${themeClasses.errorIcon}`} />
<span className="ml-2">{errMsg}</span>
</div>
</div>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,209 @@
import type { LexicalEditor } from 'lexical'
import { act, waitFor } from '@testing-library/react'
import {
$createParagraphNode,
$createTextNode,
$getRoot,
$getSelection,
$isRangeSelection,
ParagraphNode,
TextNode,
} from 'lexical'
import {
createLexicalTestEditor,
expectInlineWrapperDom,
getNodeCount,
getNodesByType,
readEditorStateValue,
readRootTextContent,
renderLexicalEditor,
selectRootEnd,
setEditorRootText,
waitForEditorReady,
} from '../test-helpers'
describe('test-helpers', () => {
describe('renderLexicalEditor & waitForEditorReady', () => {
it('should render the editor and wait for it', async () => {
const { getEditor } = renderLexicalEditor({
namespace: 'TestNamespace',
nodes: [ParagraphNode, TextNode],
children: null,
})
const editor = await waitForEditorReady(getEditor)
expect(editor).toBeDefined()
expect(editor).toBe(getEditor())
})
it('should throw if wait times out without editor', async () => {
await expect(waitForEditorReady(() => null)).rejects.toThrow()
})
it('should throw if editor is null after waitFor completes', async () => {
let callCount = 0
await expect(
waitForEditorReady(() => {
callCount++
// Return non-null on the last check of `waitFor` so it passes,
// then null when actually retrieving the editor
return callCount === 1 ? ({} as LexicalEditor) : null
}),
).rejects.toThrow('Editor is not available')
})
it('should surface errors through configured onError callback', async () => {
const { getEditor } = renderLexicalEditor({
namespace: 'TestNamespace',
nodes: [ParagraphNode, TextNode],
children: null,
})
const editor = await waitForEditorReady(getEditor)
expect(() => {
editor.update(() => {
throw new Error('test error')
}, { discrete: true })
}).toThrow('test error')
})
})
describe('selectRootEnd', () => {
it('should select the end of the root', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
selectRootEnd(editor)
await waitFor(() => {
let isRangeSelection = false
editor.getEditorState().read(() => {
const selection = $getSelection()
isRangeSelection = $isRangeSelection(selection)
})
expect(isRangeSelection).toBe(true)
})
})
})
describe('Content Reading/Writing Helpers', () => {
it('should read root text content', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
const paragraph = $createParagraphNode()
paragraph.append($createTextNode('Hello World'))
root.append(paragraph)
}, { discrete: true })
})
let content = ''
act(() => {
content = readRootTextContent(editor)
})
expect(content).toBe('Hello World')
})
it('should set editor root text and select end', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
setEditorRootText(editor, 'New Text', $createTextNode)
await waitFor(() => {
let content = ''
editor.getEditorState().read(() => {
content = $getRoot().getTextContent()
})
expect(content).toBe('New Text')
})
})
})
describe('Node Selection Helpers', () => {
it('should get node count', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
root.append($createParagraphNode())
root.append($createParagraphNode())
}, { discrete: true })
})
let count = 0
act(() => {
count = getNodeCount(editor, ParagraphNode)
})
expect(count).toBe(2)
})
it('should get nodes by type', async () => {
const { getEditor } = renderLexicalEditor({ namespace: 'test', nodes: [ParagraphNode, TextNode], children: null })
const editor = await waitForEditorReady(getEditor)
act(() => {
editor.update(() => {
const root = $getRoot()
root.clear()
root.append($createParagraphNode())
}, { discrete: true })
})
let nodes: ParagraphNode[] = []
act(() => {
nodes = getNodesByType(editor, ParagraphNode)
})
expect(nodes).toHaveLength(1)
expect(nodes[0]).not.toBeUndefined()
})
})
describe('readEditorStateValue', () => {
it('should read primitive values from editor state', () => {
const editor = createLexicalTestEditor('test', [ParagraphNode, TextNode])
const val = readEditorStateValue(editor, () => {
return $getRoot().isEmpty()
})
expect(val).toBe(true)
})
it('should throw if value is undefined', () => {
const editor = createLexicalTestEditor('test', [ParagraphNode, TextNode])
expect(() => {
readEditorStateValue(editor, () => undefined)
}).toThrow('Failed to read editor state value')
})
})
describe('createLexicalTestEditor', () => {
it('should expose createLexicalTestEditor with onError throw', () => {
const editor = createLexicalTestEditor('custom-namespace', [ParagraphNode, TextNode])
expect(editor).toBeDefined()
expect(() => {
editor.update(() => {
throw new Error('test error')
}, { discrete: true })
}).toThrow('test error')
})
})
describe('expectInlineWrapperDom', () => {
it('should assert wrapper properties on a valid DOM element', () => {
const div = document.createElement('div')
div.classList.add('inline-flex', 'items-center', 'align-middle', 'extra1', 'extra2')
expectInlineWrapperDom(div, ['extra1', 'extra2']) // Does not throw
})
})
})

View File

@@ -0,0 +1,300 @@
import type { RootNode } from 'lexical'
import { $createParagraphNode, $createTextNode, $getRoot, ParagraphNode, TextNode } from 'lexical'
import { describe, expect, it, vi } from 'vitest'
import { createTestEditor, withEditorUpdate } from './utils'
describe('Prompt Editor Test Utils', () => {
describe('createTestEditor', () => {
it('should create an editor without crashing', () => {
const editor = createTestEditor()
expect(editor).toBeDefined()
})
it('should create an editor with no nodes by default', () => {
const editor = createTestEditor()
expect(editor).toBeDefined()
})
it('should create an editor with provided nodes', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
expect(editor).toBeDefined()
})
it('should set up root element for the editor', () => {
const editor = createTestEditor()
// The editor should be properly initialized with a root element
expect(editor).toBeDefined()
})
it('should throw errors when they occur', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
expect(() => {
editor.update(() => {
throw new Error('Test error')
}, { discrete: true })
}).toThrow('Test error')
})
it('should allow multiple editors to be created independently', () => {
const editor1 = createTestEditor()
const editor2 = createTestEditor()
expect(editor1).not.toBe(editor2)
})
it('should initialize with basic node types', () => {
const nodes = [ParagraphNode, TextNode]
const editor = createTestEditor(nodes)
let content: string = ''
editor.update(() => {
const root = $getRoot()
const paragraph = $createParagraphNode()
const text = $createTextNode('Hello World')
paragraph.append(text)
root.append(paragraph)
content = root.getTextContent()
}, { discrete: true })
expect(content).toBe('Hello World')
})
})
describe('withEditorUpdate', () => {
it('should execute update function without crashing', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
const updateFn = vi.fn()
withEditorUpdate(editor, updateFn)
expect(updateFn).toHaveBeenCalled()
})
it('should pass discrete: true option to editor.update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
const updateSpy = vi.spyOn(editor, 'update')
withEditorUpdate(editor, () => {
$getRoot()
})
expect(updateSpy).toHaveBeenCalledWith(expect.any(Function), { discrete: true })
})
it('should allow updating editor state', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let textContent: string = ''
withEditorUpdate(editor, () => {
const root = $getRoot()
const paragraph = $createParagraphNode()
const text = $createTextNode('Test Content')
paragraph.append(text)
root.append(paragraph)
})
withEditorUpdate(editor, () => {
textContent = $getRoot().getTextContent()
})
expect(textContent).toBe('Test Content')
})
it('should handle multiple consecutive updates', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p1 = $createParagraphNode()
p1.append($createTextNode('First'))
root.append(p1)
})
withEditorUpdate(editor, () => {
const root = $getRoot()
const p2 = $createParagraphNode()
p2.append($createTextNode('Second'))
root.append(p2)
})
let content: string = ''
withEditorUpdate(editor, () => {
content = $getRoot().getTextContent()
})
expect(content).toContain('First')
expect(content).toContain('Second')
})
it('should provide access to editor state within update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let capturedState: RootNode | null = null
withEditorUpdate(editor, () => {
const root = $getRoot()
capturedState = root
})
expect(capturedState).toBeDefined()
})
it('should execute update function immediately', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let executed = false
withEditorUpdate(editor, () => {
executed = true
})
// Update should be executed synchronously in discrete mode
expect(executed).toBe(true)
})
it('should handle complex editor operations within update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
let nodeCount: number = 0
withEditorUpdate(editor, () => {
const root = $getRoot()
for (let i = 0; i < 3; i++) {
const paragraph = $createParagraphNode()
paragraph.append($createTextNode(`Paragraph ${i}`))
root.append(paragraph)
}
// Count child nodes
nodeCount = root.getChildrenSize()
})
expect(nodeCount).toBe(3)
})
it('should allow reading editor state after update', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const paragraph = $createParagraphNode()
paragraph.append($createTextNode('Read Test'))
root.append(paragraph)
})
let readContent: string = ''
withEditorUpdate(editor, () => {
readContent = $getRoot().getTextContent()
})
expect(readContent).toBe('Read Test')
})
it('should handle error thrown within update function', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
expect(() => {
withEditorUpdate(editor, () => {
throw new Error('Update error')
})
}).toThrow('Update error')
})
it('should preserve editor state across multiple updates', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Persistent'))
root.append(p)
})
let persistedContent: string = ''
withEditorUpdate(editor, () => {
persistedContent = $getRoot().getTextContent()
})
expect(persistedContent).toBe('Persistent')
})
})
describe('Integration', () => {
it('should work together to create and update editor', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Integration Test'))
root.append(p)
})
let result: string = ''
withEditorUpdate(editor, () => {
result = $getRoot().getTextContent()
})
expect(result).toBe('Integration Test')
})
it('should support multiple editors with isolated state', () => {
const editor1 = createTestEditor([ParagraphNode, TextNode])
const editor2 = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor1, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Editor 1'))
root.append(p)
})
withEditorUpdate(editor2, () => {
const root = $getRoot()
const p = $createParagraphNode()
p.append($createTextNode('Editor 2'))
root.append(p)
})
let content1: string = ''
let content2: string = ''
withEditorUpdate(editor1, () => {
content1 = $getRoot().getTextContent()
})
withEditorUpdate(editor2, () => {
content2 = $getRoot().getTextContent()
})
expect(content1).toBe('Editor 1')
expect(content2).toBe('Editor 2')
})
it('should handle nested paragraph and text nodes', () => {
const editor = createTestEditor([ParagraphNode, TextNode])
withEditorUpdate(editor, () => {
const root = $getRoot()
const p1 = $createParagraphNode()
const p2 = $createParagraphNode()
p1.append($createTextNode('First Para'))
p2.append($createTextNode('Second Para'))
root.append(p1)
root.append(p2)
})
let content: string = ''
withEditorUpdate(editor, () => {
content = $getRoot().getTextContent()
})
expect(content).toContain('First Para')
expect(content).toContain('Second Para')
})
})
})

View File

@@ -1,112 +1,251 @@
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import type { LexicalEditor } from 'lexical'
import type { JSX, RefObject } from 'react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { act, render, screen } from '@testing-library/react'
import DraggableBlockPlugin from '..'
const CONTENT_EDITABLE_TEST_ID = 'draggable-content-editable'
let namespaceCounter = 0
function renderWithEditor(anchorElem?: HTMLElement) {
render(
<LexicalComposer
initialConfig={{
namespace: `draggable-plugin-test-${namespaceCounter++}`,
onError: (error: Error) => { throw error },
}}
>
<RichTextPlugin
contentEditable={<ContentEditable data-testid={CONTENT_EDITABLE_TEST_ID} />}
placeholder={null}
ErrorBoundary={LexicalErrorBoundary}
/>
<DraggableBlockPlugin anchorElem={anchorElem} />
</LexicalComposer>,
)
return screen.getByTestId(CONTENT_EDITABLE_TEST_ID)
type DraggableExperimentalProps = {
anchorElem: HTMLElement
menuRef: RefObject<HTMLDivElement>
targetLineRef: RefObject<HTMLDivElement>
menuComponent: JSX.Element | null
targetLineComponent: JSX.Element
isOnMenu: (element: HTMLElement) => boolean
onElementChanged: (element: HTMLElement | null) => void
}
function appendChildToRoot(rootElement: HTMLElement, className = '') {
const element = document.createElement('div')
element.className = className
rootElement.appendChild(element)
return element
type MouseMoveHandler = (event: MouseEvent) => void
const { draggableMockState } = vi.hoisted(() => ({
draggableMockState: {
latestProps: null as DraggableExperimentalProps | null,
},
}))
vi.mock('@lexical/react/LexicalComposerContext')
vi.mock('@lexical/react/LexicalDraggableBlockPlugin', () => ({
DraggableBlockPlugin_EXPERIMENTAL: (props: DraggableExperimentalProps) => {
draggableMockState.latestProps = props
return (
<div data-testid="draggable-plugin-experimental-mock">
{props.menuComponent}
{props.targetLineComponent}
</div>
)
},
}))
function createRootElementMock() {
let mouseMoveHandler: MouseMoveHandler | null = null
const addEventListener = vi.fn((eventName: string, handler: EventListenerOrEventListenerObject) => {
if (eventName === 'mousemove' && typeof handler === 'function')
mouseMoveHandler = handler as MouseMoveHandler
})
const removeEventListener = vi.fn()
return {
rootElement: {
addEventListener,
removeEventListener,
} as unknown as HTMLElement,
addEventListener,
removeEventListener,
getMouseMoveHandler: () => mouseMoveHandler,
}
}
function getRegisteredMouseMoveHandler(
rootMock: ReturnType<typeof createRootElementMock>,
): MouseMoveHandler {
const handler = rootMock.getMouseMoveHandler()
if (!handler)
throw new Error('Expected mousemove handler to be registered')
return handler
}
function setupEditorRoot(rootElement: HTMLElement | null) {
const editor = {
getRootElement: vi.fn(() => rootElement),
} as unknown as LexicalEditor
vi.mocked(useLexicalComposerContext).mockReturnValue([
editor,
{},
] as unknown as ReturnType<typeof useLexicalComposerContext>)
return editor
}
describe('DraggableBlockPlugin', () => {
beforeEach(() => {
vi.clearAllMocks()
draggableMockState.latestProps = null
})
describe('Rendering', () => {
it('should use body as default anchor and render target line', () => {
renderWithEditor()
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
const targetLine = screen.getByTestId('draggable-target-line')
expect(targetLine).toBeInTheDocument()
expect(document.body.contains(targetLine)).toBe(true)
render(<DraggableBlockPlugin />)
expect(draggableMockState.latestProps?.anchorElem).toBe(document.body)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
it('should render inside custom anchor element when provided', () => {
const customAnchor = document.createElement('div')
document.body.appendChild(customAnchor)
it('should render with custom anchor when provided', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
const anchorElem = document.createElement('div')
renderWithEditor(customAnchor)
render(<DraggableBlockPlugin anchorElem={anchorElem} />)
const targetLine = screen.getByTestId('draggable-target-line')
expect(customAnchor.contains(targetLine)).toBe(true)
expect(draggableMockState.latestProps?.anchorElem).toBe(anchorElem)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
})
customAnchor.remove()
it('should return early when editor root element is null', () => {
const editor = setupEditorRoot(null)
render(<DraggableBlockPlugin />)
expect(editor.getRootElement).toHaveBeenCalledTimes(1)
expect(screen.getByTestId('draggable-target-line')).toBeInTheDocument()
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
})
describe('Drag Support Detection', () => {
it('should render drag menu when mouse moves over a support-drag element', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
describe('Drag support detection', () => {
it('should show menu when target has support-drag class', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const target = document.createElement('div')
target.className = 'support-drag'
act(() => {
onMove({ target } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should show menu when target contains a support-drag descendant', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const target = document.createElement('div')
target.appendChild(Object.assign(document.createElement('span'), { className: 'support-drag' }))
act(() => {
onMove({ target } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should show menu when target is inside a support-drag ancestor', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const ancestor = document.createElement('div')
ancestor.className = 'support-drag'
const child = document.createElement('span')
ancestor.appendChild(child)
act(() => {
onMove({ target: child } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
it('should hide menu when target does not support drag', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
const plainTarget = document.createElement('div')
act(() => {
onMove({ target: plainTarget } as unknown as MouseEvent)
})
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
fireEvent.mouseMove(supportDragTarget)
await waitFor(() => {
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
})
})
it('should hide drag menu when support-drag target is removed and mouse moves again', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
it('should keep menu hidden when event target becomes null', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
fireEvent.mouseMove(supportDragTarget)
await waitFor(() => {
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
expect(screen.getByTestId('draggable-menu')).toBeInTheDocument()
act(() => {
onMove({ target: null } as unknown as MouseEvent)
})
supportDragTarget.remove()
fireEvent.mouseMove(rootElement)
await waitFor(() => {
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
expect(screen.queryByTestId('draggable-menu')).not.toBeInTheDocument()
})
})
describe('Menu Detection Contract', () => {
it('should render menu with draggable-block-menu class and keep non-menu elements outside it', async () => {
const rootElement = renderWithEditor()
const supportDragTarget = appendChildToRoot(rootElement, 'support-drag')
describe('Forwarded callbacks', () => {
it('should forward isOnMenu and detect menu membership correctly', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
render(<DraggableBlockPlugin />)
fireEvent.mouseMove(supportDragTarget)
const onMove = getRegisteredMouseMoveHandler(rootMock)
const supportDragTarget = document.createElement('div')
supportDragTarget.className = 'support-drag'
act(() => {
onMove({ target: supportDragTarget } as unknown as MouseEvent)
})
const menuIcon = await screen.findByTestId('draggable-menu-icon')
expect(menuIcon.closest('.draggable-block-menu')).not.toBeNull()
const renderedMenu = screen.getByTestId('draggable-menu')
const isOnMenu = draggableMockState.latestProps?.isOnMenu
if (!isOnMenu)
throw new Error('Expected isOnMenu callback')
const normalElement = document.createElement('div')
document.body.appendChild(normalElement)
expect(normalElement.closest('.draggable-block-menu')).toBeNull()
normalElement.remove()
const menuIcon = screen.getByTestId('draggable-menu-icon')
const outsideElement = document.createElement('div')
expect(isOnMenu(menuIcon)).toBe(true)
expect(isOnMenu(renderedMenu)).toBe(true)
expect(isOnMenu(outsideElement)).toBe(false)
})
it('should register and cleanup mousemove listener on mount and unmount', () => {
const rootMock = createRootElementMock()
setupEditorRoot(rootMock.rootElement)
const { unmount } = render(<DraggableBlockPlugin />)
const onMove = getRegisteredMouseMoveHandler(rootMock)
expect(rootMock.addEventListener).toHaveBeenCalledWith('mousemove', expect.any(Function))
unmount()
expect(rootMock.removeEventListener).toHaveBeenCalledWith('mousemove', onMove)
})
})
})

View File

@@ -1,8 +1,10 @@
import type { LexicalCommand } from 'lexical'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { createCommand } from 'lexical'
import * as React from 'react'
import { useState } from 'react'
import ShortcutsPopupPlugin, { SHORTCUTS_EMPTY_CONTENT } from '../index'
@@ -21,6 +23,9 @@ const mockDOMRect = {
toJSON: () => ({}),
}
const originalRangeGetClientRects = Range.prototype.getClientRects
const originalRangeGetBoundingClientRect = Range.prototype.getBoundingClientRect
beforeAll(() => {
// Mock getClientRects on Range prototype
Range.prototype.getClientRects = vi.fn(() => {
@@ -34,12 +39,31 @@ beforeAll(() => {
Range.prototype.getBoundingClientRect = vi.fn(() => mockDOMRect as DOMRect)
})
afterAll(() => {
Range.prototype.getClientRects = originalRangeGetClientRects
Range.prototype.getBoundingClientRect = originalRangeGetBoundingClientRect
})
const CONTAINER_ID = 'host'
const CONTENT_EDITABLE_ID = 'ce'
const MinimalEditor: React.FC<{
type MinimalEditorProps = {
withContainer?: boolean
}> = ({ withContainer = true }) => {
hotkey?: string | string[] | string[][] | ((e: KeyboardEvent) => boolean)
children?: React.ReactNode | ((close: () => void, onInsert: (command: LexicalCommand<unknown>, params: unknown[]) => void) => React.ReactNode)
className?: string
onOpen?: () => void
onClose?: () => void
}
const MinimalEditor: React.FC<MinimalEditorProps> = ({
withContainer = true,
hotkey,
children,
className,
onOpen,
onClose,
}) => {
const initialConfig = {
namespace: 'shortcuts-popup-plugin-test',
onError: (e: Error) => {
@@ -58,25 +82,35 @@ const MinimalEditor: React.FC<{
/>
<ShortcutsPopupPlugin
container={withContainer ? containerEl : undefined}
/>
hotkey={hotkey}
className={className}
onOpen={onOpen}
onClose={onClose}
>
{children}
</ShortcutsPopupPlugin>
</div>
</LexicalComposer>
)
}
/** Helper: focus the content editable and trigger a hotkey. */
function focusAndTriggerHotkey(key: string, modifiers: Partial<Record<'ctrlKey' | 'metaKey' | 'altKey' | 'shiftKey', boolean>> = { ctrlKey: true }) {
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key, ...modifiers })
}
describe('ShortcutsPopupPlugin', () => {
// ─── Basic open / close ───
it('opens on hotkey when editor is focused', async () => {
render(<MinimalEditor />)
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true }) // 模拟 Ctrl+/
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not open when editor is not focused', async () => {
render(<MinimalEditor />)
// 未聚焦
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
@@ -85,10 +119,7 @@ describe('ShortcutsPopupPlugin', () => {
it('closes on Escape', async () => {
render(<MinimalEditor />)
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
fireEvent.keyDown(document, { key: 'Escape' })
@@ -111,24 +142,370 @@ describe('ShortcutsPopupPlugin', () => {
})
})
// ─── Container / portal ───
it('portals into provided container when container is set', async () => {
render(<MinimalEditor withContainer />)
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
const host = screen.getByTestId(CONTAINER_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
focusAndTriggerHotkey('/')
const portalContent = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(host).toContainElement(portalContent)
})
it('falls back to document.body when container is not provided', async () => {
render(<MinimalEditor withContainer={false} />)
const ce = screen.getByTestId(CONTENT_EDITABLE_ID)
ce.focus()
fireEvent.keyDown(document, { key: '/', ctrlKey: true })
focusAndTriggerHotkey('/')
const portalContent = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(document.body).toContainElement(portalContent)
})
// ─── matchHotkey: string hotkey ───
it('matches a string hotkey like "mod+/"', async () => {
render(<MinimalEditor hotkey="mod+/" />)
focusAndTriggerHotkey('/', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches ctrl+/ when hotkey is "mod+/" (mod matches ctrl or meta)', async () => {
render(<MinimalEditor hotkey="mod+/" />)
focusAndTriggerHotkey('/', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: string[] hotkey ───
it('matches when hotkey is a string array like ["mod", "/"]', async () => {
render(<MinimalEditor hotkey={['mod', '/']} />)
focusAndTriggerHotkey('/', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: string[][] (nested) hotkey ───
it('matches when hotkey is a nested array (any combo matches)', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('k', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches the second combo in a nested array', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('j', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match nested array when no combo matches', async () => {
render(<MinimalEditor hotkey={[['ctrl', 'k'], ['meta', 'j']]} />)
focusAndTriggerHotkey('x', { ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: function hotkey ───
it('matches when hotkey is a custom function returning true', async () => {
const customMatcher = (e: KeyboardEvent) => e.key === 'F1'
render(<MinimalEditor hotkey={customMatcher} />)
focusAndTriggerHotkey('F1', {})
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match when custom function returns false', async () => {
const customMatcher = (e: KeyboardEvent) => e.key === 'F1'
render(<MinimalEditor hotkey={customMatcher} />)
focusAndTriggerHotkey('F2', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: modifier aliases ───
it('matches meta/cmd/command aliases', async () => {
render(<MinimalEditor hotkey="cmd+k" />)
focusAndTriggerHotkey('k', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches "command" alias for meta', async () => {
render(<MinimalEditor hotkey="command+k" />)
focusAndTriggerHotkey('k', { metaKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match meta alias when meta is not pressed', async () => {
render(<MinimalEditor hotkey="cmd+k" />)
focusAndTriggerHotkey('k', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches alt/option alias', async () => {
render(<MinimalEditor hotkey="alt+a" />)
focusAndTriggerHotkey('a', { altKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match alt alias when alt is not pressed', async () => {
render(<MinimalEditor hotkey="alt+a" />)
focusAndTriggerHotkey('a', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches shift alias', async () => {
render(<MinimalEditor hotkey="shift+s" />)
focusAndTriggerHotkey('s', { shiftKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match shift alias when shift is not pressed', async () => {
render(<MinimalEditor hotkey="shift+s" />)
focusAndTriggerHotkey('s', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
it('matches ctrl alias', async () => {
render(<MinimalEditor hotkey="ctrl+b" />)
focusAndTriggerHotkey('b', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match ctrl alias when ctrl is not pressed', async () => {
render(<MinimalEditor hotkey="ctrl+b" />)
focusAndTriggerHotkey('b', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── matchHotkey: space key normalization ───
it('normalizes space key to "space" for matching', async () => {
render(<MinimalEditor hotkey="ctrl+space" />)
focusAndTriggerHotkey(' ', { ctrlKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── matchHotkey: key mismatch ───
it('does not match when expected key does not match pressed key', async () => {
render(<MinimalEditor hotkey="ctrl+z" />)
focusAndTriggerHotkey('x', { ctrlKey: true })
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
// ─── Children rendering ───
it('renders children as ReactNode when provided', async () => {
render(
<MinimalEditor>
<div data-testid="custom-content">My Content</div>
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
expect(await screen.findByTestId('custom-content')).toBeInTheDocument()
expect(screen.getByText('My Content')).toBeInTheDocument()
})
it('renders children as render function and provides close/onInsert', async () => {
const TEST_COMMAND = createCommand<unknown>('TEST_COMMAND')
const childrenFn = vi.fn((close: () => void, onInsert: (cmd: LexicalCommand<unknown>, params: unknown[]) => void) => (
<div>
<button type="button" data-testid="close-btn" onClick={close}>Close</button>
<button type="button" data-testid="insert-btn" onClick={() => onInsert(TEST_COMMAND, ['param1'])}>Insert</button>
</div>
))
render(
<MinimalEditor>
{childrenFn}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
// Children render function should have been called
expect(await screen.findByTestId('close-btn')).toBeInTheDocument()
expect(screen.getByTestId('insert-btn')).toBeInTheDocument()
})
it('renders SHORTCUTS_EMPTY_CONTENT when children is undefined', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
// ─── handleInsert callback ───
it('calls close after insert via children render function', async () => {
const TEST_COMMAND = createCommand<unknown>('TEST_INSERT_COMMAND')
render(
<MinimalEditor>
{(close: () => void, onInsert: (cmd: LexicalCommand<unknown>, params: unknown[]) => void) => (
<div>
<button type="button" data-testid="insert-btn" onClick={() => onInsert(TEST_COMMAND, ['value'])}>Insert</button>
</div>
)}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const insertBtn = await screen.findByTestId('insert-btn')
fireEvent.click(insertBtn)
// After insert, the popup should close
await waitFor(() => {
expect(screen.queryByTestId('insert-btn')).not.toBeInTheDocument()
})
})
it('calls close via children render function close callback', async () => {
render(
<MinimalEditor>
{(close: () => void) => (
<button type="button" data-testid="close-via-fn" onClick={close}>Close</button>
)}
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const closeBtn = await screen.findByTestId('close-via-fn')
fireEvent.click(closeBtn)
await waitFor(() => {
expect(screen.queryByTestId('close-via-fn')).not.toBeInTheDocument()
})
})
// ─── onOpen / onClose callbacks ───
it('calls onOpen when popup opens', async () => {
const onOpen = vi.fn()
render(<MinimalEditor onOpen={onOpen} />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
expect(onOpen).toHaveBeenCalledTimes(1)
})
it('calls onClose when popup closes', async () => {
const onClose = vi.fn()
render(<MinimalEditor onClose={onClose} />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
fireEvent.keyDown(document, { key: 'Escape' })
await waitFor(() => {
expect(onClose).toHaveBeenCalledTimes(1)
})
})
// ─── className prop ───
it('applies custom className to floating popup', async () => {
render(<MinimalEditor className="custom-popup-class" />)
focusAndTriggerHotkey('/')
const content = await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
const floatingDiv = content.closest('div')
expect(floatingDiv).toHaveClass('custom-popup-class')
})
// ─── mousedown inside portal should not close ───
it('does not close on mousedown inside the portal', async () => {
render(
<MinimalEditor>
<div data-testid="portal-inner">Inner content</div>
</MinimalEditor>,
)
focusAndTriggerHotkey('/')
const inner = await screen.findByTestId('portal-inner')
fireEvent.mouseDown(inner)
// Should still be open
await waitFor(() => {
expect(screen.getByTestId('portal-inner')).toBeInTheDocument()
})
})
it('prevents default and stops propagation on Escape when open', async () => {
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
await screen.findByText(SHORTCUTS_EMPTY_CONTENT)
const preventDefaultSpy = vi.fn()
const stopPropagationSpy = vi.fn()
// Use a custom event to capture preventDefault/stopPropagation calls
const escEvent = new KeyboardEvent('keydown', { key: 'Escape', bubbles: true, cancelable: true })
Object.defineProperty(escEvent, 'preventDefault', { value: preventDefaultSpy })
Object.defineProperty(escEvent, 'stopPropagation', { value: stopPropagationSpy })
document.dispatchEvent(escEvent)
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
expect(preventDefaultSpy).toHaveBeenCalledTimes(1)
expect(stopPropagationSpy).toHaveBeenCalledTimes(1)
})
// ─── Zero-rect fallback in openPortal ───
it('handles zero-size range rects by falling back to node bounding rect', async () => {
// Temporarily override getClientRects to return zero-size rect
const zeroRect = { x: 0, y: 0, width: 0, height: 0, top: 0, right: 0, bottom: 0, left: 0, toJSON: () => ({}) }
const originalGetClientRects = Range.prototype.getClientRects
const originalGetBoundingClientRect = Range.prototype.getBoundingClientRect
Range.prototype.getClientRects = vi.fn(() => {
const rectList = [zeroRect] as unknown as DOMRectList
Object.defineProperty(rectList, 'length', { value: 1 })
Object.defineProperty(rectList, 'item', { value: () => zeroRect })
return rectList
})
Range.prototype.getBoundingClientRect = vi.fn(() => zeroRect as DOMRect)
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
// Restore
Range.prototype.getClientRects = originalGetClientRects
Range.prototype.getBoundingClientRect = originalGetBoundingClientRect
})
it('handles empty getClientRects by using getBoundingClientRect fallback', async () => {
const originalGetClientRects = Range.prototype.getClientRects
const originalGetBoundingClientRect = Range.prototype.getBoundingClientRect
Range.prototype.getClientRects = vi.fn(() => {
const rectList = [] as unknown as DOMRectList
Object.defineProperty(rectList, 'length', { value: 0 })
Object.defineProperty(rectList, 'item', { value: () => null })
return rectList
})
Range.prototype.getBoundingClientRect = vi.fn(() => mockDOMRect as DOMRect)
render(<MinimalEditor />)
focusAndTriggerHotkey('/')
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
Range.prototype.getClientRects = originalGetClientRects
Range.prototype.getBoundingClientRect = originalGetBoundingClientRect
})
// ─── Combined modifier hotkeys ───
it('matches hotkey with multiple modifiers: ctrl+shift+k', async () => {
render(<MinimalEditor hotkey="ctrl+shift+k" />)
focusAndTriggerHotkey('k', { ctrlKey: true, shiftKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('matches "option" alias for alt', async () => {
render(<MinimalEditor hotkey="option+o" />)
focusAndTriggerHotkey('o', { altKey: true })
expect(await screen.findByText(SHORTCUTS_EMPTY_CONTENT)).toBeInTheDocument()
})
it('does not match mod hotkey when neither ctrl nor meta is pressed', async () => {
render(<MinimalEditor hotkey="mod+k" />)
focusAndTriggerHotkey('k', {})
await waitFor(() => {
expect(screen.queryByText(SHORTCUTS_EMPTY_CONTENT)).not.toBeInTheDocument()
})
})
})

View File

@@ -1,5 +1,5 @@
import type { Item } from '../index'
import { render, screen } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Select, { PortalSelect, SimpleSelect } from '../index'
@@ -14,7 +14,6 @@ describe('Select', () => {
vi.clearAllMocks()
})
// Rendering and edge behavior for default select.
describe('Rendering', () => {
it('should show the default selected item when defaultValue matches an item', () => {
render(
@@ -28,9 +27,50 @@ describe('Select', () => {
expect(screen.getByTitle('Banana')).toBeInTheDocument()
})
it('should render null selectedItem when defaultValue does not match any item', () => {
render(
<Select
items={items}
defaultValue="missing"
allowSearch={false}
onSelect={vi.fn()}
/>,
)
// No item title should appear for a non-matching default
expect(screen.queryByTitle('Apple')).not.toBeInTheDocument()
expect(screen.queryByTitle('Banana')).not.toBeInTheDocument()
})
it('should render with allowSearch=true (input mode)', () => {
render(
<Select
items={items}
defaultValue="apple"
allowSearch={true}
onSelect={vi.fn()}
/>,
)
expect(screen.getByRole('combobox')).toBeInTheDocument()
})
it('should apply custom bgClassName', () => {
render(
<Select
items={items}
defaultValue="apple"
allowSearch={false}
onSelect={vi.fn()}
bgClassName="bg-custom-color"
/>,
)
expect(screen.getByTitle('Apple')).toBeInTheDocument()
})
})
// User interactions for default select.
describe('User Interactions', () => {
it('should call onSelect when choosing an option from default select', async () => {
const user = userEvent.setup()
@@ -73,15 +113,174 @@ describe('Select', () => {
expect(screen.queryByText('Citrus')).not.toBeInTheDocument()
expect(onSelect).not.toHaveBeenCalled()
})
it('should filter items when searching with allowSearch=true', async () => {
const user = userEvent.setup()
render(
<Select
items={items}
defaultValue="apple"
allowSearch={true}
onSelect={vi.fn()}
/>,
)
// First, click the chevron button to open the dropdown
const buttons = screen.getAllByRole('button')
await user.click(buttons[0])
// Now type in the search input to filter
const input = screen.getByRole('combobox')
await user.clear(input)
await user.type(input, 'ban')
// Citrus should be filtered away
expect(screen.queryByText('Citrus')).not.toBeInTheDocument()
})
it('should not filter or update query when disabled and allowSearch=true', async () => {
render(
<Select
items={items}
defaultValue="apple"
allowSearch={true}
disabled={true}
onSelect={vi.fn()}
/>,
)
const input = screen.getByRole('combobox') as HTMLInputElement
// we must use fireEvent because userEvent throws on disabled inputs
fireEvent.change(input, { target: { value: 'ban' } })
// We just want to ensure it doesn't throw and covers the !disabled branch in onChange.
// Since it's disabled, no search dropdown should appear.
expect(screen.queryByRole('listbox')).not.toBeInTheDocument()
})
it('should not call onSelect when a disabled Combobox value changes externally', () => {
// In Headless UI, disabled elements do not fire events via React.
// To cover the defensive `if (!disabled)` branches inside the callbacks,
// we temporarily remove the disabled attribute from the DOM to force the event through.
const onSelect = vi.fn()
render(
<Select
items={items}
defaultValue="apple"
allowSearch={false}
disabled={true}
onSelect={onSelect}
/>,
)
const button = screen.getAllByRole('button')[0] as HTMLButtonElement
button.removeAttribute('disabled')
button.removeAttribute('aria-disabled')
fireEvent.click(button)
expect(onSelect).not.toHaveBeenCalled()
})
it('should not open dropdown when clicking ComboboxButton while disabled and allowSearch=false', () => {
// Covers line 128-141 where disabled check prevents open state toggle
render(
<Select
items={items}
defaultValue="apple"
allowSearch={false}
disabled={true}
onSelect={vi.fn()}
/>,
)
// The main trigger button should be disabled
const button = screen.getAllByRole('button')[0] as HTMLButtonElement
button.removeAttribute('disabled')
const chevron = screen.getAllByRole('button')[1] as HTMLButtonElement
chevron.removeAttribute('disabled')
fireEvent.click(button)
fireEvent.click(chevron)
// Dropdown options should not appear because the internal `if (!disabled)` guards it
expect(screen.queryByText('Banana')).not.toBeInTheDocument()
})
it('should handle missing item nicely in renderTrigger', () => {
render(
<SimpleSelect
items={items}
defaultValue="non-existent"
onSelect={vi.fn()}
renderTrigger={(selected) => {
return (
<span>
{/* eslint-disable-next-line style/jsx-one-expression-per-line */}
Custom: {selected?.name ?? 'Fallback'}
</span>
)
}}
/>,
)
expect(screen.getByText('Custom: Fallback')).toBeInTheDocument()
})
it('should render with custom renderOption', async () => {
const user = userEvent.setup()
render(
<Select
items={items}
defaultValue="apple"
allowSearch={false}
onSelect={vi.fn()}
renderOption={({ item, selected }) => (
<span data-testid={`custom-opt-${item.value}`}>
{item.name}
{selected ? ' ✓' : ''}
</span>
)}
/>,
)
await user.click(screen.getByTitle('Apple'))
expect(screen.getByTestId('custom-opt-apple')).toBeInTheDocument()
expect(screen.getByTestId('custom-opt-banana')).toBeInTheDocument()
})
it('should show ChevronUpIcon when open and ChevronDownIcon when closed', async () => {
const user = userEvent.setup()
render(
<Select
items={items}
defaultValue="apple"
allowSearch={false}
onSelect={vi.fn()}
/>,
)
// Initially closed — should have a chevron button
await user.click(screen.getByTitle('Apple'))
// Dropdown is now open
expect(screen.getByText('Banana')).toBeInTheDocument()
})
})
})
// ──────────────────────────────────────────────────────────────
// SimpleSelect (Listbox-based)
// ──────────────────────────────────────────────────────────────
describe('SimpleSelect', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering and placeholder fallback behavior.
describe('Rendering', () => {
it('should render i18n placeholder when no selection exists', () => {
render(
@@ -107,9 +306,106 @@ describe('SimpleSelect', () => {
expect(screen.getByText('Pick one')).toBeInTheDocument()
})
it('should render selected item name when defaultValue matches', () => {
render(
<SimpleSelect
items={items}
defaultValue="banana"
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('Banana')).toBeInTheDocument()
})
it('should render with isLoading=true showing spinner', () => {
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={vi.fn()}
isLoading={true}
/>,
)
// Loader icon should be rendered (RiLoader4Line has aria hidden)
expect(screen.getByText('Apple')).toBeInTheDocument()
})
it('should render group items as non-selectable headers', async () => {
const user = userEvent.setup()
const groupItems: Item[] = [
{ value: 'fruits-group', name: 'Fruits', isGroup: true },
{ value: 'apple', name: 'Apple' },
{ value: 'banana', name: 'Banana' },
]
render(
<SimpleSelect
items={groupItems}
defaultValue="apple"
onSelect={vi.fn()}
/>,
)
await user.click(screen.getByRole('button'))
expect(screen.getByText('Fruits')).toBeInTheDocument()
})
it('should not render ListboxOptions when disabled', () => {
render(
<SimpleSelect
items={items}
defaultValue="apple"
disabled={true}
onSelect={vi.fn()}
/>,
)
expect(screen.getByText('Apple')).toBeInTheDocument()
})
it('should not open SimpleSelect when disabled', async () => {
const user = userEvent.setup()
render(
<SimpleSelect
items={items}
defaultValue="apple"
disabled={true}
onSelect={vi.fn()}
/>,
)
const button = screen.getByRole('button')
await user.click(button)
// Banana should not be visible as it won't open
expect(screen.queryByText('Banana')).not.toBeInTheDocument()
})
it('should not trigger onSelect via onChange when Listbox is disabled', () => {
// Covers line 228 (!disabled check) inside Listbox onChange
const onSelect = vi.fn()
render(
<SimpleSelect
items={items}
defaultValue="apple"
disabled={true}
onSelect={onSelect}
/>,
)
const button = screen.getByRole('button') as HTMLButtonElement
button.removeAttribute('disabled')
button.removeAttribute('aria-disabled')
fireEvent.click(button)
expect(onSelect).not.toHaveBeenCalled()
})
})
// User interactions and callback behavior.
describe('User Interactions', () => {
it('should call onSelect and update display when an option is chosen', async () => {
const user = userEvent.setup()
@@ -151,15 +447,133 @@ describe('SimpleSelect', () => {
await user.click(screen.getByText('none-closed'))
expect(screen.getByText('none-open')).toBeInTheDocument()
})
it('should clear selection when XMark is clicked (notClearable=false)', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={onSelect}
notClearable={false}
/>,
)
// The clear button (XMarkIcon) should be visible when an item is selected
const clearBtn = screen.getByRole('button').querySelector('[aria-hidden="false"]')
expect(clearBtn).toBeInTheDocument()
await user.click(clearBtn!)
expect(onSelect).toHaveBeenCalledWith({ name: '', value: '' })
})
it('should not show clear button when notClearable is true', () => {
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={vi.fn()}
notClearable={true}
/>,
)
const clearBtn = screen.getByRole('button').querySelector('[aria-hidden="false"]')
expect(clearBtn).not.toBeInTheDocument()
})
it('should hide check marks when hideChecked is true', async () => {
const user = userEvent.setup()
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={vi.fn()}
hideChecked={true}
/>,
)
await user.click(screen.getByRole('button'))
// The selected item should be visible but without a check icon
expect(screen.getAllByText('Apple').length).toBeGreaterThanOrEqual(1)
})
it('should render with custom renderOption in SimpleSelect', async () => {
const user = userEvent.setup()
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={vi.fn()}
renderOption={({ item, selected }) => (
<span data-testid={`simple-opt-${item.value}`}>
{item.name}
{selected ? ' (selected)' : ''}
</span>
)}
/>,
)
await user.click(screen.getByRole('button'))
expect(screen.getByTestId('simple-opt-apple')).toBeInTheDocument()
expect(screen.getByTestId('simple-opt-banana')).toBeInTheDocument()
// Verify the custom render shows selected state
expect(screen.getByTestId('simple-opt-apple')).toHaveTextContent('Apple (selected)')
})
it('should call onOpenChange when the button is clicked', async () => {
const user = userEvent.setup()
const onOpenChange = vi.fn()
render(
<SimpleSelect
items={items}
defaultValue="apple"
onSelect={vi.fn()}
onOpenChange={onOpenChange}
/>,
)
await user.click(screen.getByRole('button'))
expect(onOpenChange).toHaveBeenCalled()
})
it('should handle disabled items that cannot be selected', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
const disabledItems: Item[] = [
{ value: 'apple', name: 'Apple' },
{ value: 'banana', name: 'Banana', disabled: true },
{ value: 'citrus', name: 'Citrus' },
]
render(
<SimpleSelect
items={disabledItems}
defaultValue="apple"
onSelect={onSelect}
/>,
)
await user.click(screen.getByRole('button'))
// Banana should be rendered but not selectable
expect(screen.getByText('Banana')).toBeInTheDocument()
})
})
})
// ──────────────────────────────────────────────────────────────
// PortalSelect
// ──────────────────────────────────────────────────────────────
describe('PortalSelect', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// Rendering for edge case when value is empty.
describe('Rendering', () => {
it('should show placeholder when value is empty', () => {
render(
@@ -172,9 +586,76 @@ describe('PortalSelect', () => {
expect(screen.getByText(/select/i)).toBeInTheDocument()
})
it('should show selected item name when value matches', () => {
render(
<PortalSelect
value="banana"
items={items}
onSelect={vi.fn()}
/>,
)
expect(screen.getByTitle('Banana')).toBeInTheDocument()
})
it('should render with custom placeholder', () => {
render(
<PortalSelect
value=""
items={items}
onSelect={vi.fn()}
placeholder="Choose fruit"
/>,
)
expect(screen.getByText('Choose fruit')).toBeInTheDocument()
})
it('should render with renderTrigger', () => {
render(
<PortalSelect
value="apple"
items={items}
onSelect={vi.fn()}
renderTrigger={item => (
<span data-testid="custom-trigger">{item?.name ?? 'None'}</span>
)}
/>,
)
expect(screen.getByTestId('custom-trigger')).toHaveTextContent('Apple')
})
it('should show INSTALLED badge when installedValue differs from selected value', () => {
render(
<PortalSelect
value="banana"
items={items}
onSelect={vi.fn()}
installedValue="apple"
/>,
)
expect(screen.getByTitle('Banana')).toBeInTheDocument()
})
it('should apply triggerClassNameFn', () => {
const triggerClassNameFn = vi.fn((open: boolean) => open ? 'trigger-open' : 'trigger-closed')
render(
<PortalSelect
value="apple"
items={items}
onSelect={vi.fn()}
triggerClassNameFn={triggerClassNameFn}
/>,
)
expect(triggerClassNameFn).toHaveBeenCalledWith(false)
})
})
// Interaction and readonly behavior.
describe('User Interactions', () => {
it('should call onSelect when choosing an option from portal dropdown', async () => {
const user = userEvent.setup()
@@ -212,5 +693,74 @@ describe('PortalSelect', () => {
await user.click(screen.getByText(/select/i))
expect(screen.queryByTitle('Citrus')).not.toBeInTheDocument()
})
it('should show check mark for selected item when hideChecked is false', async () => {
const user = userEvent.setup()
render(
<PortalSelect
value="banana"
items={items}
onSelect={vi.fn()}
/>,
)
await user.click(screen.getByTitle('Banana'))
// Banana option in the dropdown should be displayed
const allBananas = screen.getAllByText('Banana')
expect(allBananas.length).toBeGreaterThanOrEqual(1)
})
it('should hide check marks when hideChecked is true', async () => {
const user = userEvent.setup()
render(
<PortalSelect
value="banana"
items={items}
onSelect={vi.fn()}
hideChecked={true}
/>,
)
await user.click(screen.getByTitle('Banana'))
expect(screen.getAllByText('Banana').length).toBeGreaterThanOrEqual(1)
})
it('should display INSTALLED badge in dropdown for installed items', async () => {
const user = userEvent.setup()
render(
<PortalSelect
value="banana"
items={items}
onSelect={vi.fn()}
installedValue="apple"
/>,
)
await user.click(screen.getByTitle('Banana'))
// The installed badge should appear in the dropdown
expect(screen.getByText('INSTALLED')).toBeInTheDocument()
})
it('should render item.extra content in dropdown', async () => {
const user = userEvent.setup()
const extraItems: Item[] = [
{ value: 'apple', name: 'Apple', extra: <span data-testid="extra-apple">Extra</span> },
{ value: 'banana', name: 'Banana' },
]
render(
<PortalSelect
value=""
items={extraItems}
onSelect={vi.fn()}
/>,
)
await user.click(screen.getByText(/select/i))
expect(screen.getByTestId('extra-apple')).toBeInTheDocument()
})
})
})

View File

@@ -1,5 +1,6 @@
import type { ReactNode } from 'react'
import { act, render, screen, waitFor } from '@testing-library/react'
import type { ToastHandle } from '../index'
import { act, render, screen, waitFor, within } from '@testing-library/react'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import Toast, { ToastProvider } from '..'
@@ -19,6 +20,13 @@ const TestComponent = () => {
}
describe('Toast', () => {
const getToastElementByMessage = (message: string): HTMLElement => {
const messageElement = screen.getByText(message)
const toastElement = messageElement.closest('.fixed')
expect(toastElement).toBeInTheDocument()
return toastElement as HTMLElement
}
beforeEach(() => {
vi.useFakeTimers({ shouldAdvanceTime: true })
})
@@ -46,7 +54,9 @@ describe('Toast', () => {
</ToastProvider>,
)
expect(document.querySelector('.text-text-success')).toBeInTheDocument()
const successToast = getToastElementByMessage('Success message')
const successIcon = within(successToast).getByTestId('toast-icon-success')
expect(successIcon).toHaveClass('text-text-success')
rerender(
<ToastProvider>
@@ -54,7 +64,9 @@ describe('Toast', () => {
</ToastProvider>,
)
expect(document.querySelector('.text-text-destructive')).toBeInTheDocument()
const errorToast = getToastElementByMessage('Error message')
const errorIcon = within(errorToast).getByTestId('toast-icon-error')
expect(errorIcon).toHaveClass('text-text-destructive')
})
it('renders with custom component', () => {
@@ -100,8 +112,58 @@ describe('Toast', () => {
)
expect(screen.getByText('No close button')).toBeInTheDocument()
// Ensure the close button is not rendered
expect(document.querySelector('.h-4.w-4.shrink-0.text-text-tertiary')).not.toBeInTheDocument()
const toastElement = getToastElementByMessage('No close button')
expect(within(toastElement).queryByRole('button')).not.toBeInTheDocument()
})
it('returns null when message is not a string', () => {
const { container } = render(
<ToastProvider>
{/* @ts-expect-error - testing invalid input */}
<Toast message={<div>Invalid</div>} />
</ToastProvider>,
)
// Toast returns null, and provider adds no DOM elements
expect(container.firstChild).toBeNull()
})
it('renders with size sm', () => {
const { rerender } = render(
<ToastProvider>
<Toast type="info" message="Small size" size="sm" />
</ToastProvider>,
)
const infoToast = getToastElementByMessage('Small size')
const infoIcon = within(infoToast).getByTestId('toast-icon-info')
expect(infoIcon).toHaveClass('text-text-accent', 'h-4', 'w-4')
expect(infoIcon.parentElement).toHaveClass('p-1')
rerender(
<ToastProvider>
<Toast type="success" message="Small size" size="sm" />
</ToastProvider>,
)
const successToast = getToastElementByMessage('Small size')
const successIcon = within(successToast).getByTestId('toast-icon-success')
expect(successIcon).toHaveClass('text-text-success', 'h-4', 'w-4')
rerender(
<ToastProvider>
<Toast type="warning" message="Small size" size="sm" />
</ToastProvider>,
)
const warningToast = getToastElementByMessage('Small size')
const warningIcon = within(warningToast).getByTestId('toast-icon-warning')
expect(warningIcon).toHaveClass('text-text-warning-secondary', 'h-4', 'w-4')
rerender(
<ToastProvider>
<Toast type="error" message="Small size" size="sm" />
</ToastProvider>,
)
const errorToast = getToastElementByMessage('Small size')
const errorIcon = within(errorToast).getByTestId('toast-icon-error')
expect(errorIcon).toHaveClass('text-text-destructive', 'h-4', 'w-4')
})
})
@@ -152,6 +214,37 @@ describe('Toast', () => {
expect(screen.queryByText('Notification message')).not.toBeInTheDocument()
})
})
it('automatically hides toast after duration for error type in provider', async () => {
const TestComponentError = () => {
const { notify } = useToastContext()
return (
<button type="button" onClick={() => notify({ message: 'Error notify', type: 'error' })}>
Show Error
</button>
)
}
render(
<ToastProvider>
<TestComponentError />
</ToastProvider>,
)
act(() => {
screen.getByText('Show Error').click()
})
expect(screen.getByText('Error notify')).toBeInTheDocument()
// Error type uses 6000ms default
act(() => {
vi.advanceTimersByTime(6000)
})
await waitFor(() => {
expect(screen.queryByText('Error notify')).not.toBeInTheDocument()
})
})
})
describe('Toast.notify static method', () => {
@@ -195,5 +288,61 @@ describe('Toast', () => {
expect(onCloseMock).toHaveBeenCalled()
})
})
it('closes when close button is clicked in static toast', async () => {
const onCloseMock = vi.fn()
act(() => {
Toast.notify({ message: 'Static close test', type: 'info', onClose: onCloseMock })
})
expect(screen.getByText('Static close test')).toBeInTheDocument()
const toastElement = getToastElementByMessage('Static close test')
const closeButton = within(toastElement).getByRole('button')
act(() => {
closeButton.click()
})
expect(screen.queryByText('Static close test')).not.toBeInTheDocument()
expect(onCloseMock).toHaveBeenCalled()
})
it('does not auto close when duration is 0', async () => {
act(() => {
Toast.notify({ message: 'No auto close', type: 'info', duration: 0 })
})
expect(screen.getByText('No auto close')).toBeInTheDocument()
act(() => {
vi.advanceTimersByTime(10000)
})
expect(screen.getByText('No auto close')).toBeInTheDocument()
// manual clear to clean up
act(() => {
const toastElement = getToastElementByMessage('No auto close')
within(toastElement).getByRole('button').click()
})
})
it('returns a toast handler that can clear the toast', async () => {
let handler: ToastHandle = {}
const onCloseMock = vi.fn()
act(() => {
handler = Toast.notify({ message: 'Clearable toast', type: 'warning', onClose: onCloseMock })
})
expect(screen.getByText('Clearable toast')).toBeInTheDocument()
act(() => {
handler.clear?.()
})
expect(screen.queryByText('Clearable toast')).not.toBeInTheDocument()
expect(onCloseMock).toHaveBeenCalled()
})
})
})

View File

@@ -1,13 +1,6 @@
'use client'
import type { ReactNode } from 'react'
import type { IToastProps } from './context'
import {
RiAlertFill,
RiCheckboxCircleFill,
RiCloseLine,
RiErrorWarningFill,
RiInformation2Fill,
} from '@remixicon/react'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useEffect, useState } from 'react'
@@ -53,10 +46,10 @@ const Toast = ({
/>
<div className={cn('flex', size === 'md' ? 'gap-1' : 'gap-0.5')}>
<div className={cn('flex items-center justify-center', size === 'md' ? 'p-0.5' : 'p-1')}>
{type === 'success' && <RiCheckboxCircleFill className={cn('text-text-success', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} aria-hidden="true" />}
{type === 'error' && <RiErrorWarningFill className={cn('text-text-destructive', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} aria-hidden="true" />}
{type === 'warning' && <RiAlertFill className={cn('text-text-warning-secondary', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} aria-hidden="true" />}
{type === 'info' && <RiInformation2Fill className={cn('text-text-accent', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} aria-hidden="true" />}
{type === 'success' && <span className={cn('i-ri-checkbox-circle-fill', 'text-text-success', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} data-testid="toast-icon-success" aria-hidden="true" />}
{type === 'error' && <span className={cn('i-ri-error-warning-fill', 'text-text-destructive', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} data-testid="toast-icon-error" aria-hidden="true" />}
{type === 'warning' && <span className={cn('i-ri-alert-fill', 'text-text-warning-secondary', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} data-testid="toast-icon-warning" aria-hidden="true" />}
{type === 'info' && <span className={cn('i-ri-information-2-fill', 'text-text-accent', size === 'md' ? 'h-5 w-5' : 'h-4 w-4')} data-testid="toast-icon-info" aria-hidden="true" />}
</div>
<div className={cn('flex grow flex-col items-start gap-1 py-1', size === 'md' ? 'px-1' : 'px-0.5')}>
<div className="flex items-center gap-1">
@@ -71,8 +64,8 @@ const Toast = ({
</div>
{close
&& (
<ActionButton className="z-[1000]" onClick={close}>
<RiCloseLine className="h-4 w-4 shrink-0 text-text-tertiary" />
<ActionButton data-testid="toast-close-button" className="z-[1000]" onClick={close}>
<span className="i-ri-close-line h-4 w-4 shrink-0 text-text-tertiary" />
</ActionButton>
)}
</div>

View File

@@ -0,0 +1,129 @@
import { tooltipManager } from '../TooltipManager'
describe('TooltipManager', () => {
// Test the singleton instance directly
let manager: typeof tooltipManager
beforeEach(() => {
// Get fresh reference to the singleton
manager = tooltipManager
// Clean up any active tooltip by calling closeActiveTooltip
// This ensures each test starts with a clean state
manager.closeActiveTooltip()
})
describe('register', () => {
it('should register a close function', () => {
const closeFn = vi.fn()
manager.register(closeFn)
expect(closeFn).not.toHaveBeenCalled()
})
it('should call the existing close function when registering a new one', () => {
const firstCloseFn = vi.fn()
const secondCloseFn = vi.fn()
manager.register(firstCloseFn)
manager.register(secondCloseFn)
expect(firstCloseFn).toHaveBeenCalledTimes(1)
expect(secondCloseFn).not.toHaveBeenCalled()
})
it('should replace the active closer with the new one', () => {
const firstCloseFn = vi.fn()
const secondCloseFn = vi.fn()
// Register first function
manager.register(firstCloseFn)
// Register second function - this should call firstCloseFn and replace it
manager.register(secondCloseFn)
// Verify firstCloseFn was called during register (replacement behavior)
expect(firstCloseFn).toHaveBeenCalledTimes(1)
// Now close the active tooltip - this should call secondCloseFn
manager.closeActiveTooltip()
// Verify secondCloseFn was called, not firstCloseFn
expect(secondCloseFn).toHaveBeenCalledTimes(1)
})
})
describe('clear', () => {
it('should not clear if the close function does not match', () => {
const closeFn = vi.fn()
const otherCloseFn = vi.fn()
manager.register(closeFn)
manager.clear(otherCloseFn)
manager.closeActiveTooltip()
expect(closeFn).toHaveBeenCalledTimes(1)
})
it('should clear the close function if it matches', () => {
const closeFn = vi.fn()
manager.register(closeFn)
manager.clear(closeFn)
manager.closeActiveTooltip()
expect(closeFn).not.toHaveBeenCalled()
})
it('should not call the close function when clearing', () => {
const closeFn = vi.fn()
manager.register(closeFn)
manager.clear(closeFn)
expect(closeFn).not.toHaveBeenCalled()
})
})
describe('closeActiveTooltip', () => {
it('should do nothing when no active closer is registered', () => {
expect(() => manager.closeActiveTooltip()).not.toThrow()
})
it('should call the active closer function', () => {
const closeFn = vi.fn()
manager.register(closeFn)
manager.closeActiveTooltip()
expect(closeFn).toHaveBeenCalledTimes(1)
})
it('should clear the active closer after calling it', () => {
const closeFn = vi.fn()
manager.register(closeFn)
manager.closeActiveTooltip()
manager.closeActiveTooltip()
expect(closeFn).toHaveBeenCalledTimes(1)
})
it('should handle multiple register and close cycles', () => {
const closeFn1 = vi.fn()
const closeFn2 = vi.fn()
const closeFn3 = vi.fn()
manager.register(closeFn1)
manager.closeActiveTooltip()
manager.register(closeFn2)
manager.closeActiveTooltip()
manager.register(closeFn3)
manager.closeActiveTooltip()
expect(closeFn1).toHaveBeenCalledTimes(1)
expect(closeFn2).toHaveBeenCalledTimes(1)
expect(closeFn3).toHaveBeenCalledTimes(1)
})
})
})

View File

@@ -1,8 +1,13 @@
import { act, cleanup, fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import Tooltip from '../index'
import { tooltipManager } from '../TooltipManager'
afterEach(cleanup)
afterEach(() => {
cleanup()
vi.clearAllTimers()
vi.useRealTimers()
})
describe('Tooltip', () => {
describe('Rendering', () => {
@@ -22,6 +27,27 @@ describe('Tooltip', () => {
)
expect(getByText('Hover me').textContent).toBe('Hover me')
})
it('should render correctly when asChild is false', () => {
const { container } = render(
<Tooltip popupContent="Tooltip" asChild={false} triggerClassName="custom-parent-trigger">
<span>Trigger</span>
</Tooltip>,
)
const trigger = container.querySelector('.custom-parent-trigger')
expect(trigger).not.toBeNull()
})
it('should render with a fallback question icon when children are null', () => {
const { container } = render(
<Tooltip popupContent="Tooltip" triggerClassName="custom-fallback-trigger">
{null}
</Tooltip>,
)
const trigger = container.querySelector('.custom-fallback-trigger')
expect(trigger).not.toBeNull()
expect(trigger?.querySelector('svg')).not.toBeNull()
})
})
describe('Disabled state', () => {
@@ -37,6 +63,10 @@ describe('Tooltip', () => {
})
describe('Trigger methods', () => {
beforeEach(() => {
vi.useFakeTimers()
})
it('should open on hover when triggerMethod is hover', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerClassName={triggerClassName} />)
@@ -47,7 +77,7 @@ describe('Tooltip', () => {
expect(screen.queryByText('Tooltip content')).toBeInTheDocument()
})
it('should close on mouse leave when triggerMethod is hover', () => {
it('should close on mouse leave when triggerMethod is hover and needsDelay is false', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerClassName={triggerClassName} needsDelay={false} />)
const trigger = container.querySelector(`.${triggerClassName}`)
@@ -66,17 +96,198 @@ describe('Tooltip', () => {
fireEvent.click(trigger!)
})
expect(screen.queryByText('Tooltip content')).toBeInTheDocument()
// Test toggle off
act(() => {
fireEvent.click(trigger!)
})
expect(screen.queryByText('Tooltip content')).not.toBeInTheDocument()
})
it('should not close immediately on mouse leave when needsDelay is true', () => {
it('should do nothing on mouse enter if triggerMethod is click', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="hover" needsDelay triggerClassName={triggerClassName} />)
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="click" triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.mouseEnter(trigger!)
})
expect(screen.queryByText('Tooltip content')).not.toBeInTheDocument()
})
it('should delay closing on mouse leave when needsDelay is true', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="hover" needsDelay triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.mouseEnter(trigger!)
})
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
act(() => {
fireEvent.mouseLeave(trigger!)
})
expect(screen.queryByText('Tooltip content')).toBeInTheDocument()
// Shouldn't close immediately
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
act(() => {
vi.advanceTimersByTime(350)
})
// Should close after delay
expect(screen.queryByText('Tooltip content')).not.toBeInTheDocument()
})
it('should not close if mouse enters popup before delay finishes', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="hover" needsDelay triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.mouseEnter(trigger!)
})
const popup = screen.getByText('Tooltip content')
expect(popup).toBeInTheDocument()
act(() => {
fireEvent.mouseLeave(trigger!)
})
act(() => {
vi.advanceTimersByTime(150)
// Simulate mouse entering popup area itself during the delay timeframe
fireEvent.mouseEnter(popup)
})
act(() => {
vi.advanceTimersByTime(200) // Complete the 300ms original delay
})
// Should still be open because we are hovering the popup
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
// Now mouse leaves popup
act(() => {
fireEvent.mouseLeave(popup)
})
act(() => {
vi.advanceTimersByTime(350)
})
// Should now close
expect(screen.queryByText('Tooltip content')).not.toBeInTheDocument()
})
it('should do nothing on mouse enter/leave of popup when triggerMethod is not hover', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="click" needsDelay triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.click(trigger!)
})
const popup = screen.getByText('Tooltip content')
act(() => {
fireEvent.mouseEnter(popup)
fireEvent.mouseLeave(popup)
vi.advanceTimersByTime(350)
})
// Should still be open because click method requires another click to close, not hover leave
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
})
it('should clear close timeout if trigger is hovered again before delay finishes', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="hover" needsDelay triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.mouseEnter(trigger!)
})
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
act(() => {
fireEvent.mouseLeave(trigger!)
})
act(() => {
vi.advanceTimersByTime(150)
// Re-hover trigger before it closes
fireEvent.mouseEnter(trigger!)
})
act(() => {
vi.advanceTimersByTime(200) // Original 300ms would be up
})
// Should still be open because we reset it
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
})
it('should test clear close timeout if trigger is hovered again before delay finishes and isHoverPopupRef is true', () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerMethod="hover" needsDelay triggerClassName={triggerClassName} />)
const trigger = container.querySelector(`.${triggerClassName}`)
act(() => {
fireEvent.mouseEnter(trigger!)
})
const popup = screen.getByText('Tooltip content')
expect(popup).toBeInTheDocument()
act(() => {
fireEvent.mouseEnter(popup)
fireEvent.mouseLeave(trigger!)
})
act(() => {
vi.advanceTimersByTime(350)
})
// Should still be open because we are hovering the popup
expect(screen.getByText('Tooltip content')).toBeInTheDocument()
})
})
describe('TooltipManager', () => {
it('should close active tooltips when triggered centrally, overriding other closes', () => {
const triggerClassName1 = 'custom-trigger-1'
const triggerClassName2 = 'custom-trigger-2'
const { container } = render(
<div>
<Tooltip popupContent="Tooltip content 1" triggerMethod="hover" triggerClassName={triggerClassName1} />
<Tooltip popupContent="Tooltip content 2" triggerMethod="hover" triggerClassName={triggerClassName2} />
</div>,
)
const trigger1 = container.querySelector(`.${triggerClassName1}`)
const trigger2 = container.querySelector(`.${triggerClassName2}`)
expect(trigger2).not.toBeNull()
// Open first tooltip
act(() => {
fireEvent.mouseEnter(trigger1!)
})
expect(screen.queryByText('Tooltip content 1')).toBeInTheDocument()
// TooltipManager should keep track of it
// Next, immediately open the second one without leaving first (e.g., via TooltipManager)
// TooltipManager registers the newest one and closes the old one when doing full external operations, but internally the manager allows direct closing
act(() => {
tooltipManager.closeActiveTooltip()
})
expect(screen.queryByText('Tooltip content 1')).not.toBeInTheDocument()
// Safe to call again
expect(() => tooltipManager.closeActiveTooltip()).not.toThrow()
})
})
@@ -88,6 +299,11 @@ describe('Tooltip', () => {
expect(trigger?.className).toContain('custom-trigger')
})
it('should pass triggerTestId to the fallback icon wrapper', () => {
render(<Tooltip popupContent="Tooltip content" triggerTestId="test-tooltip-icon" />)
expect(screen.getByTestId('test-tooltip-icon')).toBeInTheDocument()
})
it('should apply custom popup className', async () => {
const triggerClassName = 'custom-trigger'
const { container } = render(<Tooltip popupContent="Tooltip content" triggerClassName={triggerClassName} popupClassName="custom-popup" />)

View File

@@ -43,24 +43,20 @@ type DialogContentProps = {
children: React.ReactNode
className?: string
overlayClassName?: string
backdropProps?: React.ComponentPropsWithoutRef<typeof BaseDialog.Backdrop>
}
export function DialogContent({
children,
className,
overlayClassName,
backdropProps,
}: DialogContentProps) {
return (
<DialogPortal>
<BaseDialog.Backdrop
{...backdropProps}
className={cn(
'fixed inset-0 z-50 bg-background-overlay',
'transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
overlayClassName,
backdropProps?.className,
)}
/>
<BaseDialog.Popup

View File

@@ -1,10 +1,9 @@
import { render, screen, waitFor } from '@testing-library/react'
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { audioToText } from '@/service/share'
import VoiceInput from '../index'
const { mockState, MockRecorder } = vi.hoisted(() => {
const { mockState, MockRecorder, rafState } = vi.hoisted(() => {
const state = {
params: {} as Record<string, string>,
pathname: '/test',
@@ -12,6 +11,9 @@ const { mockState, MockRecorder } = vi.hoisted(() => {
startOverride: null as (() => Promise<void>) | null,
analyseData: new Uint8Array(1024).fill(150) as Uint8Array,
}
const rafStateObj = {
callback: null as (() => void) | null,
}
class MockRecorderClass {
start = vi.fn((..._args: unknown[]) => {
@@ -33,7 +35,7 @@ const { mockState, MockRecorder } = vi.hoisted(() => {
}
}
return { mockState: state, MockRecorder: MockRecorderClass }
return { mockState: state, MockRecorder: MockRecorderClass, rafState: rafStateObj }
})
vi.mock('js-audio-recorder', () => ({
@@ -54,6 +56,17 @@ vi.mock('../utils', () => ({
convertToMp3: vi.fn(() => new Blob(['test'], { type: 'audio/mp3' })),
}))
vi.mock('ahooks', async (importOriginal) => {
const actual = await importOriginal<typeof import('ahooks')>()
return {
...actual,
useRafInterval: vi.fn((fn) => {
rafState.callback = fn
return vi.fn()
}),
}
})
describe('VoiceInput', () => {
const onConverted = vi.fn()
const onCancel = vi.fn()
@@ -64,6 +77,7 @@ describe('VoiceInput', () => {
mockState.pathname = '/test'
mockState.recorderInstances = []
mockState.startOverride = null
rafState.callback = null
// Ensure canvas has non-zero dimensions for initCanvas()
HTMLCanvasElement.prototype.getBoundingClientRect = vi.fn(() => ({
@@ -257,4 +271,268 @@ describe('VoiceInput', () => {
})
})
})
it('should use fallback rect when canvas roundRect is not available', async () => {
const user = userEvent.setup()
vi.mocked(audioToText).mockResolvedValueOnce({ text: 'test' })
mockState.params = { token: 'abc' }
mockState.analyseData = new Uint8Array(1024).fill(150)
const oldGetContext = HTMLCanvasElement.prototype.getContext
HTMLCanvasElement.prototype.getContext = vi.fn(() => ({
scale: vi.fn(),
clearRect: vi.fn(),
beginPath: vi.fn(),
moveTo: vi.fn(),
rect: vi.fn(),
fill: vi.fn(),
closePath: vi.fn(),
})) as unknown as typeof HTMLCanvasElement.prototype.getContext
let rafCalls = 0
vi.spyOn(window, 'requestAnimationFrame').mockImplementation((cb) => {
rafCalls++
if (rafCalls <= 1)
cb(0)
return rafCalls
})
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await user.click(await screen.findByTestId('voice-input-stop'))
await waitFor(() => {
expect(onConverted).toHaveBeenCalled()
})
HTMLCanvasElement.prototype.getContext = oldGetContext
})
it('should display timer in MM:SS format correctly', async () => {
mockState.params = { token: 'abc' }
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const timer = await screen.findByTestId('voice-input-timer')
expect(timer).toHaveTextContent('00:00')
await act(async () => {
if (rafState.callback)
rafState.callback()
})
expect(timer).toHaveTextContent('00:01')
for (let i = 0; i < 9; i++) {
await act(async () => {
if (rafState.callback)
rafState.callback()
})
}
expect(timer).toHaveTextContent('00:10')
})
it('should show timer element with formatted time', async () => {
mockState.params = { token: 'abc' }
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const timer = screen.getByTestId('voice-input-timer')
expect(timer).toBeInTheDocument()
// Initial state should show 00:00
expect(timer.textContent).toMatch(/0\d:\d{2}/)
})
it('should handle data values in normal range (between 128 and 178)', async () => {
mockState.analyseData = new Uint8Array(1024).fill(150)
let rafCalls = 0
vi.spyOn(window, 'requestAnimationFrame').mockImplementation((cb) => {
rafCalls++
if (rafCalls <= 2)
cb(0)
return rafCalls
})
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await screen.findByTestId('voice-input-stop')
// eslint-disable-next-line ts/no-explicit-any
const recorder = mockState.recorderInstances[0] as any
expect(recorder.getRecordAnalyseData).toHaveBeenCalled()
})
it('should handle canvas context and device pixel ratio', async () => {
const dprSpy = vi.spyOn(window, 'devicePixelRatio', 'get')
dprSpy.mockReturnValue(2)
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await screen.findByTestId('voice-input-stop')
expect(screen.getByTestId('voice-input-stop')).toBeInTheDocument()
dprSpy.mockRestore()
})
it('should handle empty params with no token or appId', async () => {
const user = userEvent.setup()
vi.mocked(audioToText).mockResolvedValueOnce({ text: 'test' })
mockState.params = {}
mockState.pathname = '/test'
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const stopBtn = await screen.findByTestId('voice-input-stop')
await user.click(stopBtn)
await waitFor(() => {
// Should call audioToText with empty URL when neither token nor appId is present
expect(audioToText).toHaveBeenCalledWith('', 'installedApp', expect.any(FormData))
})
})
it('should render speaking state indicator', async () => {
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
expect(await screen.findByText('common.voiceInput.speaking')).toBeInTheDocument()
})
it('should cleanup on unmount', () => {
const { unmount } = render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
// eslint-disable-next-line ts/no-explicit-any
const recorder = mockState.recorderInstances[0] as any
unmount()
expect(recorder.stop).toHaveBeenCalled()
})
it('should handle all data in recordAnalyseData for canvas drawing', async () => {
const allDataValues = []
for (let i = 0; i < 256; i++) {
allDataValues.push(i)
}
mockState.analyseData = new Uint8Array(allDataValues)
let rafCalls = 0
vi.spyOn(window, 'requestAnimationFrame').mockImplementation((cb) => {
rafCalls++
if (rafCalls <= 2)
cb(0)
return rafCalls
})
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await screen.findByTestId('voice-input-stop')
// eslint-disable-next-line ts/no-explicit-any
const recorder = mockState.recorderInstances[0] as any
expect(recorder.getRecordAnalyseData).toHaveBeenCalled()
})
it('should pass multiple props correctly', async () => {
const user = userEvent.setup()
vi.mocked(audioToText).mockResolvedValueOnce({ text: 'test' })
mockState.params = { token: 'token123' }
render(
<VoiceInput
onConverted={onConverted}
onCancel={onCancel}
wordTimestamps="enabled"
/>,
)
const stopBtn = await screen.findByTestId('voice-input-stop')
await user.click(stopBtn)
await waitFor(() => {
const calls = vi.mocked(audioToText).mock.calls
expect(calls.length).toBeGreaterThan(0)
const [url, sourceType, formData] = calls[0]
expect(url).toBe('/audio-to-text')
expect(sourceType).toBe('webApp')
expect(formData.get('word_timestamps')).toBe('enabled')
})
})
it('should handle pathname with explore/installed correctly when appId exists', async () => {
const user = userEvent.setup()
vi.mocked(audioToText).mockResolvedValueOnce({ text: 'test' })
mockState.params = { appId: 'app-id-123' }
mockState.pathname = '/explore/installed/app-details'
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const stopBtn = await screen.findByTestId('voice-input-stop')
await user.click(stopBtn)
await waitFor(() => {
expect(audioToText).toHaveBeenCalledWith(
'/installed-apps/app-id-123/audio-to-text',
'installedApp',
expect.any(FormData),
)
})
})
it('should render timer with initial 00:00 value', () => {
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const timer = screen.getByTestId('voice-input-timer')
expect(timer).toHaveTextContent('00:00')
})
it('should render stop button during recording', async () => {
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
expect(await screen.findByTestId('voice-input-stop')).toBeInTheDocument()
})
it('should render converting UI after stopping', async () => {
const user = userEvent.setup()
vi.mocked(audioToText).mockImplementation(() => new Promise(() => { }))
mockState.params = { token: 'abc' }
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
const stopBtn = await screen.findByTestId('voice-input-stop')
await user.click(stopBtn)
await screen.findByTestId('voice-input-loader')
expect(screen.getByTestId('voice-input-converting-text')).toBeInTheDocument()
expect(screen.getByTestId('voice-input-cancel')).toBeInTheDocument()
})
it('should auto-stop recording and convert audio when duration reaches 10 minutes (600s)', async () => {
vi.mocked(audioToText).mockResolvedValueOnce({ text: 'auto-stopped text' })
mockState.params = { token: 'abc' }
render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
expect(await screen.findByTestId('voice-input-stop')).toBeInTheDocument()
for (let i = 0; i < 601; i++) {
await act(async () => {
if (rafState.callback)
rafState.callback()
})
}
expect(await screen.findByTestId('voice-input-converting-text')).toBeInTheDocument()
await waitFor(() => {
expect(onConverted).toHaveBeenCalledWith('auto-stopped text')
})
}, 10000)
it('should handle null canvas element gracefully during initialization', async () => {
const getElementByIdMock = vi.spyOn(document, 'getElementById').mockReturnValue(null)
const { unmount } = render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await screen.findByTestId('voice-input-stop')
unmount()
getElementByIdMock.mockRestore()
})
it('should handle getContext returning null gracefully during initialization', async () => {
const oldGetContext = HTMLCanvasElement.prototype.getContext
HTMLCanvasElement.prototype.getContext = vi.fn().mockReturnValue(null)
const { unmount } = render(<VoiceInput onConverted={onConverted} onCancel={onCancel} />)
await screen.findByTestId('voice-input-stop')
unmount()
HTMLCanvasElement.prototype.getContext = oldGetContext
})
})

View File

@@ -0,0 +1,196 @@
import { convertToMp3 } from '../utils'
// ── Hoisted mocks ──
const mocks = vi.hoisted(() => {
const readHeader = vi.fn()
const encodeBuffer = vi.fn()
const flush = vi.fn()
return { readHeader, encodeBuffer, flush }
})
vi.mock('lamejs', () => ({
default: {
WavHeader: {
readHeader: mocks.readHeader,
},
Mp3Encoder: class MockMp3Encoder {
encodeBuffer = mocks.encodeBuffer
flush = mocks.flush
},
},
}))
vi.mock('lamejs/src/js/BitStream', () => ({ default: {} }))
vi.mock('lamejs/src/js/Lame', () => ({ default: {} }))
vi.mock('lamejs/src/js/MPEGMode', () => ({ default: {} }))
// ── helpers ──
/** Build a fake recorder whose getChannelData returns DataView-like objects with .buffer and .byteLength. */
function createMockRecorder(opts: {
channels: number
sampleRate: number
leftSamples: number[]
rightSamples?: number[]
}) {
const toDataView = (samples: number[]) => {
const buf = new ArrayBuffer(samples.length * 2)
const view = new DataView(buf)
samples.forEach((v, i) => {
view.setInt16(i * 2, v, true)
})
return view
}
const leftView = toDataView(opts.leftSamples)
const rightView = opts.rightSamples ? toDataView(opts.rightSamples) : null
mocks.readHeader.mockReturnValue({
channels: opts.channels,
sampleRate: opts.sampleRate,
})
return {
getWAV: vi.fn(() => new ArrayBuffer(44)),
getChannelData: vi.fn(() => ({
left: leftView,
right: rightView,
})),
}
}
describe('convertToMp3', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should convert mono WAV data to an MP3 blob', () => {
const recorder = createMockRecorder({
channels: 1,
sampleRate: 44100,
leftSamples: [100, 200, 300, 400],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array([1, 2, 3]))
mocks.flush.mockReturnValue(new Int8Array([4, 5]))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
expect(result.type).toBe('audio/mp3')
expect(mocks.encodeBuffer).toHaveBeenCalled()
// Mono: encodeBuffer called with only left data
const firstCall = mocks.encodeBuffer.mock.calls[0]
expect(firstCall).toHaveLength(1)
expect(mocks.flush).toHaveBeenCalled()
})
it('should convert stereo WAV data to an MP3 blob', () => {
const recorder = createMockRecorder({
channels: 2,
sampleRate: 48000,
leftSamples: [100, 200],
rightSamples: [300, 400],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array([10, 20]))
mocks.flush.mockReturnValue(new Int8Array([30]))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
expect(result.type).toBe('audio/mp3')
// Stereo: encodeBuffer called with left AND right
const firstCall = mocks.encodeBuffer.mock.calls[0]
expect(firstCall).toHaveLength(2)
})
it('should skip empty encoded buffers', () => {
const recorder = createMockRecorder({
channels: 1,
sampleRate: 44100,
leftSamples: [100, 200],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array(0))
mocks.flush.mockReturnValue(new Int8Array(0))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
expect(result.type).toBe('audio/mp3')
expect(result.size).toBe(0)
})
it('should include flush data when flush returns non-empty buffer', () => {
const recorder = createMockRecorder({
channels: 1,
sampleRate: 22050,
leftSamples: [1],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array(0))
mocks.flush.mockReturnValue(new Int8Array([99, 98, 97]))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
expect(result.size).toBe(3)
})
it('should omit flush data when flush returns empty buffer', () => {
const recorder = createMockRecorder({
channels: 1,
sampleRate: 44100,
leftSamples: [10, 20],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array([1, 2]))
mocks.flush.mockReturnValue(new Int8Array(0))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
expect(result.size).toBe(2)
})
it('should process multiple chunks when sample count exceeds maxSamples (1152)', () => {
const samples = Array.from({ length: 2400 }, (_, i) => i % 32767)
const recorder = createMockRecorder({
channels: 1,
sampleRate: 44100,
leftSamples: samples,
})
mocks.encodeBuffer.mockReturnValue(new Int8Array([1]))
mocks.flush.mockReturnValue(new Int8Array(0))
const result = convertToMp3(recorder)
expect(mocks.encodeBuffer.mock.calls.length).toBeGreaterThan(1)
expect(result).toBeInstanceOf(Blob)
})
it('should encode stereo with right channel subarray', () => {
const recorder = createMockRecorder({
channels: 2,
sampleRate: 44100,
leftSamples: [100, 200, 300],
rightSamples: [400, 500, 600],
})
mocks.encodeBuffer.mockReturnValue(new Int8Array([5, 6, 7]))
mocks.flush.mockReturnValue(new Int8Array([8]))
const result = convertToMp3(recorder)
expect(result).toBeInstanceOf(Blob)
for (const call of mocks.encodeBuffer.mock.calls) {
expect(call).toHaveLength(2)
expect(call[0]).toBeInstanceOf(Int16Array)
expect(call[1]).toBeInstanceOf(Int16Array)
}
})
})

View File

@@ -3,10 +3,11 @@ import BitStream from 'lamejs/src/js/BitStream'
import Lame from 'lamejs/src/js/Lame'
import MPEGMode from 'lamejs/src/js/MPEGMode'
/* v8 ignore next - @preserve */
if (globalThis) {
(globalThis as any).MPEGMode = MPEGMode
;(globalThis as any).Lame = Lame
;(globalThis as any).BitStream = BitStream
; (globalThis as any).Lame = Lame
; (globalThis as any).BitStream = BitStream
}
export const convertToMp3 = (recorder: any) => {

View File

@@ -0,0 +1,123 @@
describe('zendesk/utils', () => {
// Create mock for window.zE
const mockZE = vi.fn()
beforeEach(() => {
vi.resetModules()
vi.clearAllMocks()
// Set up window.zE mock before each test
window.zE = mockZE
})
afterEach(() => {
// Clean up window.zE after each test
window.zE = mockZE
})
describe('setZendeskConversationFields', () => {
it('should call window.zE with correct arguments when not CE edition and zE exists', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { setZendeskConversationFields } = await import('../utils')
const fields = [
{ id: 'field1', value: 'value1' },
{ id: 'field2', value: 'value2' },
]
const callback = vi.fn()
setZendeskConversationFields(fields, callback)
expect(window.zE).toHaveBeenCalledWith(
'messenger:set',
'conversationFields',
fields,
callback,
)
})
it('should not call window.zE when IS_CE_EDITION is true', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: true }))
const { setZendeskConversationFields } = await import('../utils')
const fields = [{ id: 'field1', value: 'value1' }]
setZendeskConversationFields(fields)
expect(window.zE).not.toHaveBeenCalled()
})
it('should work without callback', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { setZendeskConversationFields } = await import('../utils')
const fields = [{ id: 'field1', value: 'value1' }]
setZendeskConversationFields(fields)
expect(window.zE).toHaveBeenCalledWith(
'messenger:set',
'conversationFields',
fields,
undefined,
)
})
})
describe('setZendeskWidgetVisibility', () => {
it('should call window.zE to show widget when visible is true', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { setZendeskWidgetVisibility } = await import('../utils')
setZendeskWidgetVisibility(true)
expect(window.zE).toHaveBeenCalledWith('messenger', 'show')
})
it('should call window.zE to hide widget when visible is false', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { setZendeskWidgetVisibility } = await import('../utils')
setZendeskWidgetVisibility(false)
expect(window.zE).toHaveBeenCalledWith('messenger', 'hide')
})
it('should not call window.zE when IS_CE_EDITION is true', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: true }))
const { setZendeskWidgetVisibility } = await import('../utils')
setZendeskWidgetVisibility(true)
expect(window.zE).not.toHaveBeenCalled()
})
})
describe('toggleZendeskWindow', () => {
it('should call window.zE to open messenger when open is true', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { toggleZendeskWindow } = await import('../utils')
toggleZendeskWindow(true)
expect(window.zE).toHaveBeenCalledWith('messenger', 'open')
})
it('should call window.zE to close messenger when open is false', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: false }))
const { toggleZendeskWindow } = await import('../utils')
toggleZendeskWindow(false)
expect(window.zE).toHaveBeenCalledWith('messenger', 'close')
})
it('should not call window.zE when IS_CE_EDITION is true', async () => {
vi.doMock('@/config', () => ({ IS_CE_EDITION: true }))
const { toggleZendeskWindow } = await import('../utils')
toggleZendeskWindow(true)
expect(window.zE).not.toHaveBeenCalled()
})
})
})

View File

@@ -1,7 +1,7 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { CategoryEnum } from '..'
import Footer from '../footer'
import { CategoryEnum } from '../types'
vi.mock('next/link', () => ({
default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (

View File

@@ -1,16 +1,7 @@
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { Dialog } from '@/app/components/base/ui/dialog'
import Header from '../header'
function renderHeader(onClose: () => void) {
return render(
<Dialog open>
<Header onClose={onClose} />
</Dialog>,
)
}
describe('Header', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -20,7 +11,7 @@ describe('Header', () => {
it('should render title and description translations', () => {
const handleClose = vi.fn()
renderHeader(handleClose)
render(<Header onClose={handleClose} />)
expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument()
expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument()
@@ -31,7 +22,7 @@ describe('Header', () => {
describe('Props', () => {
it('should invoke onClose when close button is clicked', () => {
const handleClose = vi.fn()
renderHeader(handleClose)
render(<Header onClose={handleClose} />)
fireEvent.click(screen.getByRole('button'))
@@ -41,7 +32,7 @@ describe('Header', () => {
describe('Edge Cases', () => {
it('should render structural elements with translation keys', () => {
const { container } = renderHeader(vi.fn())
const { container } = render(<Header onClose={vi.fn()} />)
expect(container.querySelector('span')).toBeInTheDocument()
expect(container.querySelector('p')).toBeInTheDocument()

View File

@@ -74,11 +74,15 @@ describe('Pricing', () => {
})
describe('Props', () => {
it('should allow switching categories', () => {
render(<Pricing onCancel={vi.fn()} />)
it('should allow switching categories and handle esc key', () => {
const handleCancel = vi.fn()
render(<Pricing onCancel={handleCancel} />)
fireEvent.click(screen.getByText('billing.plansCommon.self'))
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
fireEvent.keyDown(window, { key: 'Escape', keyCode: 27 })
expect(handleCancel).toHaveBeenCalled()
})
})

View File

@@ -1,9 +1,10 @@
import type { Category } from './types'
import type { Category } from '.'
import { RiArrowRightUpLine } from '@remixicon/react'
import Link from 'next/link'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
import { CategoryEnum } from './types'
import { CategoryEnum } from '.'
type FooterProps = {
pricingPageURL: string
@@ -33,7 +34,7 @@ const Footer = ({
>
{t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })}
</Link>
<span aria-hidden="true" className="i-ri-arrow-right-up-line size-4" />
<RiArrowRightUpLine className="size-4" />
</span>
</div>
</div>

Some files were not shown because too many files have changed in this diff Show More