mirror of
https://github.com/langgenius/dify.git
synced 2026-03-09 17:25:10 +00:00
Compare commits
29 Commits
3-6-type-c
...
deploy/cle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
232b8eb248 | ||
|
|
4c0d81029f | ||
|
|
bacf70c00a | ||
|
|
942729ba48 | ||
|
|
4f5af0b43c | ||
|
|
a28cb993b8 | ||
|
|
0aef09d630 | ||
|
|
c6dd2ef25a | ||
|
|
d2208ad43e | ||
|
|
4a2ba058bb | ||
|
|
654e41d47f | ||
|
|
ec5409756e | ||
|
|
b0e8becd14 | ||
|
|
8b1ea3a8f5 | ||
|
|
f2d3feca66 | ||
|
|
0590b09958 | ||
|
|
66f9fde2fe | ||
|
|
1811a855ab | ||
|
|
322cd37de1 | ||
|
|
f90e0d781a | ||
|
|
2cc0de9c1b | ||
|
|
46098b2be6 | ||
|
|
7dcf94f48f | ||
|
|
7869551afd | ||
|
|
c925d17e8f | ||
|
|
dc2a53d834 | ||
|
|
05ab107e73 | ||
|
|
c016793efb | ||
|
|
a5bcbaebb7 |
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -25,6 +25,10 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "storybook"
|
||||
@@ -33,5 +37,7 @@ updates:
|
||||
patterns:
|
||||
- "*"
|
||||
exclude-patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
|
||||
@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
|
||||
206
api/commands.py
206
api/commands.py
@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
@@ -936,6 +937,12 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
is_flag=True,
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
before_days: int,
|
||||
batch_size: int,
|
||||
@@ -944,10 +951,13 @@ def clean_workflow_runs(
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@@ -967,13 +977,17 @@ def clean_workflow_runs(
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@@ -2598,15 +2612,29 @@ def migrate_oss(
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
|
||||
)
|
||||
@click.option(
|
||||
"--before-days",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
|
||||
)
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
|
||||
@click.option(
|
||||
"--graceful-period",
|
||||
@@ -2615,33 +2643,99 @@ def migrate_oss(
|
||||
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
|
||||
)
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
|
||||
@click.option(
|
||||
"--task-label",
|
||||
default="daily",
|
||||
show_default=True,
|
||||
help="Stable label value used to distinguish multiple cleanup CronJobs in metrics.",
|
||||
)
|
||||
def clean_expired_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
from_days_ago: int | None,
|
||||
before_days: int | None,
|
||||
dry_run: bool,
|
||||
task_label: str,
|
||||
):
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
abs_mode = start_from is not None and end_before is not None
|
||||
rel_mode = before_days is not None
|
||||
|
||||
if abs_mode and rel_mode:
|
||||
raise click.UsageError(
|
||||
"Options are mutually exclusive: use either (--start-from,--end-before) "
|
||||
"or (--from-days-ago,--before-days)."
|
||||
)
|
||||
|
||||
if from_days_ago is not None and before_days is None:
|
||||
raise click.UsageError("--from-days-ago must be used together with --before-days.")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
|
||||
|
||||
if not abs_mode and not rel_mode:
|
||||
raise click.UsageError(
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
|
||||
)
|
||||
|
||||
if rel_mode:
|
||||
assert before_days is not None
|
||||
if before_days < 0:
|
||||
raise click.UsageError("--before-days must be >= 0.")
|
||||
if from_days_ago is not None:
|
||||
if from_days_ago < 0:
|
||||
raise click.UsageError("--from-days-ago must be >= 0.")
|
||||
if from_days_ago <= before_days:
|
||||
raise click.UsageError("--from-days-ago must be greater than --before-days.")
|
||||
|
||||
# Create policy based on billing configuration
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
# Create and run the cleanup service
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
service = MessagesCleanService.from_days(
|
||||
policy=policy,
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
assert from_days_ago is not None
|
||||
now = naive_utc_now()
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=now - datetime.timedelta(days=from_days_ago),
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
@@ -2666,5 +2760,81 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_file_utils import prepare_file_dict
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
|
||||
# Fetch files associated with this message
|
||||
files = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
|
||||
if message_files:
|
||||
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
|
||||
upload_file_ids = list(
|
||||
dict.fromkeys(
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
)
|
||||
)
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
files_list = []
|
||||
for message_file in message_files:
|
||||
file_dict = prepare_file_dict(message_file, upload_files_map)
|
||||
files_list.append(file_dict)
|
||||
|
||||
files = files_list or None
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
from threading import Thread, Timer
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -96,9 +95,9 @@ class MessageCycleManager:
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
time.sleep(1)
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
thread = Timer(
|
||||
1,
|
||||
self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation_id,
|
||||
|
||||
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
:param message_file: MessageFile instance
|
||||
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
|
||||
:return: Dictionary containing file information
|
||||
"""
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
return {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
@@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
|
||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -82,8 +83,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
|
||||
value = variable.value
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
if isinstance(value, list):
|
||||
value = list(filter(lambda x: x, value))
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
if not value:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": ArrayStringSegment(value=[])},
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = [
|
||||
@@ -111,6 +122,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
else:
|
||||
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
|
||||
except DocumentExtractorError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
@@ -385,6 +397,32 @@ def parser_docx_part(block, doc: Document, content_items, i):
|
||||
content_items.append((i, "table", Table(block, doc)))
|
||||
|
||||
|
||||
def _normalize_docx_zip(file_content: bytes) -> bytes:
|
||||
"""
|
||||
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
|
||||
ZIP entry names use backslash (\\) as path separator instead of the forward
|
||||
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
|
||||
"word\\document.xml" is never found when python-docx looks for
|
||||
"word/document.xml", which triggers a KeyError about a missing relationship.
|
||||
|
||||
This function rewrites the ZIP in-memory, normalizing all entry names to
|
||||
use forward slashes without touching any actual document content.
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
|
||||
out_buf = io.BytesIO()
|
||||
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
|
||||
for item in zin.infolist():
|
||||
data = zin.read(item.filename)
|
||||
# Normalize backslash path separators to forward slash
|
||||
item.filename = item.filename.replace("\\", "/")
|
||||
zout.writestr(item, data)
|
||||
return out_buf.getvalue()
|
||||
except zipfile.BadZipFile:
|
||||
# Not a valid zip — return as-is and let python-docx report the real error
|
||||
return file_content
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
@@ -392,7 +430,15 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
try:
|
||||
doc = docx.Document(doc_file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
|
||||
# Some DOCX files exported by tools like Evernote on Windows use
|
||||
# backslash path separators in ZIP entries and/or single-quoted XML
|
||||
# attributes, both of which break python-docx on Linux. Normalize and retry.
|
||||
file_content = _normalize_docx_zip(file_content)
|
||||
doc = docx.Document(io.BytesIO(file_content))
|
||||
text = []
|
||||
|
||||
# Keep track of paragraph and table positions
|
||||
|
||||
@@ -23,7 +23,11 @@ from dify_graph.variables import (
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayObjectSegment
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .entities import (
|
||||
Condition,
|
||||
KnowledgeRetrievalNodeData,
|
||||
MetadataFilteringCondition,
|
||||
)
|
||||
from .exc import (
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
@@ -171,6 +175,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
|
||||
resolved_metadata_conditions = (
|
||||
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
|
||||
if node_data.metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
@@ -189,7 +199,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
)
|
||||
@@ -247,7 +257,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
@@ -256,6 +266,48 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _resolve_metadata_filtering_conditions(
|
||||
self, conditions: MetadataFilteringCondition
|
||||
) -> MetadataFilteringCondition:
|
||||
if conditions.conditions is None:
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator,
|
||||
conditions=None,
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
else:
|
||||
resolved_value = value
|
||||
resolved_conditions.append(
|
||||
Condition(
|
||||
name=cond.name,
|
||||
comparison_operator=cond.comparison_operator,
|
||||
value=resolved_value,
|
||||
)
|
||||
)
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator or "and",
|
||||
conditions=resolved_conditions,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import trace
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3Format
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@@ -31,9 +31,29 @@ def setup_context_propagation() -> None:
|
||||
|
||||
|
||||
def shutdown_tracer() -> None:
|
||||
flush_telemetry()
|
||||
|
||||
|
||||
def flush_telemetry() -> None:
|
||||
"""
|
||||
Best-effort flush for telemetry providers.
|
||||
|
||||
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
|
||||
so counters/histograms are exported before the process exits.
|
||||
"""
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush()
|
||||
try:
|
||||
provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush trace provider")
|
||||
|
||||
metric_provider = metrics.get_meter_provider()
|
||||
if hasattr(metric_provider, "force_flush"):
|
||||
try:
|
||||
metric_provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush metric provider")
|
||||
|
||||
|
||||
def is_celery_worker():
|
||||
|
||||
@@ -63,7 +63,12 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
304
api/services/retention/conversation/message_export_service.py
Normal file
304
api/services/retention/conversation/message_export_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@@ -1,17 +1,18 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@@ -32,6 +33,128 @@ from services.retention.conversation.messages_clean_policy import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class MessagesCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
|
||||
|
||||
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
|
||||
dashboard-friendly for long-running CronJob executions.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_messages_scanned_total: "Counter | None"
|
||||
_messages_filtered_total: "Counter | None"
|
||||
_messages_deleted_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._messages_scanned_total = None
|
||||
self._messages_filtered_total = None
|
||||
self._messages_deleted_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "messages_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("messages_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"messages_cleanup_jobs_total",
|
||||
description="Total number of expired message cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"messages_cleanup_batches_total",
|
||||
description="Total number of message cleanup batches processed.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._messages_scanned_total = meter.create_counter(
|
||||
"messages_cleanup_scanned_messages_total",
|
||||
description="Total messages scanned by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_filtered_total = meter.create_counter(
|
||||
"messages_cleanup_filtered_messages_total",
|
||||
description="Total messages selected by cleanup policy.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_deleted_total = meter.create_counter(
|
||||
"messages_cleanup_deleted_messages_total",
|
||||
description="Total messages deleted by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_job_duration_seconds",
|
||||
description="Duration of expired message cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_batch_duration_seconds",
|
||||
description="Duration of expired message cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
scanned_messages: int,
|
||||
filtered_messages: int,
|
||||
deleted_messages: int,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._messages_scanned_total, scanned_messages, attributes)
|
||||
self._add(self._messages_filtered_total, filtered_messages, attributes)
|
||||
self._add(self._messages_deleted_total, deleted_messages, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
@@ -47,6 +170,7 @@ class MessagesCleanService:
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the service with cleanup parameters.
|
||||
@@ -57,12 +181,20 @@ class MessagesCleanService:
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
"""
|
||||
self._policy = policy
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._batch_size = batch_size
|
||||
self._dry_run = dry_run
|
||||
normalized_task_label = task_label.strip()
|
||||
self._task_label = normalized_task_label or "daily"
|
||||
self._metrics = MessagesCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=self._task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_time_range(
|
||||
@@ -72,6 +204,7 @@ class MessagesCleanService:
|
||||
end_before: datetime.datetime,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages within a specific time range.
|
||||
@@ -84,6 +217,7 @@ class MessagesCleanService:
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@@ -111,6 +245,7 @@ class MessagesCleanService:
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -120,6 +255,7 @@ class MessagesCleanService:
|
||||
days: int = 30,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages older than specified days.
|
||||
@@ -129,6 +265,7 @@ class MessagesCleanService:
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Stable task label to distinguish multiple cleanup CronJobs
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@@ -142,7 +279,7 @@ class MessagesCleanService:
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
end_before = naive_utc_now() - datetime.timedelta(days=days)
|
||||
|
||||
logger.info(
|
||||
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
|
||||
@@ -152,7 +289,14 @@ class MessagesCleanService:
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
|
||||
return cls(
|
||||
policy=policy,
|
||||
end_before=end_before,
|
||||
start_from=None,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
"""
|
||||
@@ -161,7 +305,18 @@ class MessagesCleanService:
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
return self._clean_messages_by_time_range()
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
try:
|
||||
return self._clean_messages_by_time_range()
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
"""
|
||||
@@ -196,11 +351,14 @@ class MessagesCleanService:
|
||||
self._end_before,
|
||||
)
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
batch_start = time.monotonic()
|
||||
batch_scanned_messages = 0
|
||||
batch_filtered_messages = 0
|
||||
batch_deleted_messages = 0
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -239,9 +397,16 @@ class MessagesCleanService:
|
||||
|
||||
# Track total messages fetched across all batches
|
||||
stats["total_messages"] += len(messages)
|
||||
batch_scanned_messages = len(messages)
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
@@ -267,6 +432,12 @@ class MessagesCleanService:
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
# Build app_id -> tenant_id mapping
|
||||
@@ -285,9 +456,16 @@ class MessagesCleanService:
|
||||
|
||||
if not message_ids_to_delete:
|
||||
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
stats["filtered_messages"] += len(message_ids_to_delete)
|
||||
batch_filtered_messages = len(message_ids_to_delete)
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
@@ -308,6 +486,7 @@ class MessagesCleanService:
|
||||
commit_ms = int((time.monotonic() - commit_start) * 1000)
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
batch_deleted_messages = messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s messages",
|
||||
@@ -342,6 +521,13 @@ class MessagesCleanService:
|
||||
for msg_id in sampled_ids:
|
||||
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
|
||||
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -20,6 +20,156 @@ from services.billing_service import BillingService, SubscriptionPlan
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class WorkflowRunCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
|
||||
|
||||
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
|
||||
to keep dashboard and alert cardinality predictable in production clusters.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_runs_scanned_total: "Counter | None"
|
||||
_runs_targeted_total: "Counter | None"
|
||||
_runs_deleted_total: "Counter | None"
|
||||
_runs_skipped_total: "Counter | None"
|
||||
_related_records_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._runs_scanned_total = None
|
||||
self._runs_targeted_total = None
|
||||
self._runs_deleted_total = None
|
||||
self._runs_skipped_total = None
|
||||
self._related_records_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "workflow_run_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"workflow_run_cleanup_jobs_total",
|
||||
description="Total number of workflow run cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"workflow_run_cleanup_batches_total",
|
||||
description="Total number of processed cleanup batches.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._runs_scanned_total = meter.create_counter(
|
||||
"workflow_run_cleanup_scanned_runs_total",
|
||||
description="Total workflow runs scanned by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_targeted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_targeted_runs_total",
|
||||
description="Total workflow runs targeted by cleanup policy.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_deleted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_deleted_runs_total",
|
||||
description="Total workflow runs deleted by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_skipped_total = meter.create_counter(
|
||||
"workflow_run_cleanup_skipped_runs_total",
|
||||
description="Total workflow runs skipped because tenant is paid/unknown.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._related_records_total = meter.create_counter(
|
||||
"workflow_run_cleanup_related_records_total",
|
||||
description="Total related records processed by cleanup jobs.",
|
||||
unit="{record}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_job_duration_seconds",
|
||||
description="Duration of workflow run cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_batch_duration_seconds",
|
||||
description="Duration of workflow run cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
batch_rows: int,
|
||||
targeted_runs: int,
|
||||
skipped_runs: int,
|
||||
deleted_runs: int,
|
||||
related_counts: dict[str, int] | None,
|
||||
related_action: str | None,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._runs_scanned_total, batch_rows, attributes)
|
||||
self._add(self._runs_targeted_total, targeted_runs, attributes)
|
||||
self._add(self._runs_skipped_total, skipped_runs, attributes)
|
||||
self._add(self._runs_deleted_total, deleted_runs, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
if not related_counts or not related_action:
|
||||
return
|
||||
|
||||
for record_type, count in related_counts.items():
|
||||
self._add(
|
||||
self._related_records_total,
|
||||
count,
|
||||
self._attrs(action=related_action, record_type=record_type),
|
||||
)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,6 +179,7 @@ class WorkflowRunCleanup:
|
||||
end_before: datetime.datetime | None = None,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "daily",
|
||||
):
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise ValueError("start_from and end_before must be both set or both omitted.")
|
||||
@@ -46,6 +197,13 @@ class WorkflowRunCleanup:
|
||||
self.batch_size = batch_size
|
||||
self._cleanup_whitelist: set[str] | None = None
|
||||
self.dry_run = dry_run
|
||||
normalized_task_label = task_label.strip()
|
||||
self.task_label = normalized_task_label or "daily"
|
||||
self._metrics = WorkflowRunCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=self.task_label,
|
||||
)
|
||||
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
|
||||
self.workflow_run_repo: APIWorkflowRunRepository
|
||||
if workflow_run_repo:
|
||||
@@ -74,153 +232,193 @@ class WorkflowRunCleanup:
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
try:
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=0,
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts=None,
|
||||
related_action=None,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="would_delete",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=counts["runs"],
|
||||
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="deleted",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = (
|
||||
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
)
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
188
api/tests/unit_tests/commands/test_clean_expired_messages.py
Normal file
188
api/tests/unit_tests/commands/test_clean_expired_messages.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import datetime
|
||||
import re
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from commands import clean_expired_messages
|
||||
|
||||
|
||||
def _mock_service() -> MagicMock:
|
||||
service = MagicMock()
|
||||
service.run.return_value = {
|
||||
"batches": 1,
|
||||
"total_messages": 10,
|
||||
"filtered_messages": 5,
|
||||
"total_deleted": 5,
|
||||
}
|
||||
return service
|
||||
|
||||
|
||||
def test_absolute_mode_calls_from_time_range():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=200,
|
||||
graceful_period=21,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
from_days_ago=None,
|
||||
before_days=None,
|
||||
dry_run=True,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=200,
|
||||
dry_run=True,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
|
||||
def test_relative_mode_before_days_only_calls_from_days():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
|
||||
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=500,
|
||||
graceful_period=14,
|
||||
start_from=None,
|
||||
end_before=None,
|
||||
from_days_ago=None,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_days.assert_called_once_with(
|
||||
policy=policy,
|
||||
days=30,
|
||||
batch_size=500,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_time_range.assert_not_called()
|
||||
|
||||
|
||||
def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
patch("commands.naive_utc_now", return_value=fixed_now),
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=1000,
|
||||
graceful_period=21,
|
||||
start_from=None,
|
||||
end_before=None,
|
||||
from_days_ago=60,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
policy=policy,
|
||||
start_from=fixed_now - datetime.timedelta(days=60),
|
||||
end_before=fixed_now - datetime.timedelta(days=30),
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("kwargs", "message"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"start_from": datetime.datetime(2024, 1, 1),
|
||||
"end_before": datetime.datetime(2024, 2, 1),
|
||||
"from_days_ago": None,
|
||||
"before_days": 30,
|
||||
},
|
||||
"mutually exclusive",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": datetime.datetime(2024, 1, 1),
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": None,
|
||||
},
|
||||
"Both --start-from and --end-before are required",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": 10,
|
||||
"before_days": None,
|
||||
},
|
||||
"--from-days-ago must be used together with --before-days",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": -1,
|
||||
},
|
||||
"--before-days must be >= 0",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": 30,
|
||||
"before_days": 30,
|
||||
},
|
||||
"--from-days-ago must be greater than --before-days",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": None,
|
||||
},
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
|
||||
with pytest.raises(click.UsageError, match=re.escape(message)):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=1000,
|
||||
graceful_period=21,
|
||||
start_from=kwargs["start_from"],
|
||||
end_before=kwargs["end_before"],
|
||||
from_days_ago=kwargs["from_days_ago"],
|
||||
before_days=kwargs["before_days"],
|
||||
dry_run=False,
|
||||
task_label="daily",
|
||||
)
|
||||
70
api/tests/unit_tests/controllers/common/test_errors.py
Normal file
70
api/tests/unit_tests/controllers/common/test_errors.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
class TestFilenameNotExistsError:
|
||||
def test_defaults(self):
|
||||
error = FilenameNotExistsError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "The specified filename does not exist."
|
||||
|
||||
|
||||
class TestRemoteFileUploadError:
|
||||
def test_defaults(self):
|
||||
error = RemoteFileUploadError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "Error uploading remote file."
|
||||
|
||||
|
||||
class TestFileTooLargeError:
|
||||
def test_defaults(self):
|
||||
error = FileTooLargeError()
|
||||
|
||||
assert error.code == 413
|
||||
assert error.error_code == "file_too_large"
|
||||
assert error.description == "File size exceeded. {message}"
|
||||
|
||||
|
||||
class TestUnsupportedFileTypeError:
|
||||
def test_defaults(self):
|
||||
error = UnsupportedFileTypeError()
|
||||
|
||||
assert error.code == 415
|
||||
assert error.error_code == "unsupported_file_type"
|
||||
assert error.description == "File type not allowed."
|
||||
|
||||
|
||||
class TestBlockedFileExtensionError:
|
||||
def test_defaults(self):
|
||||
error = BlockedFileExtensionError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "file_extension_blocked"
|
||||
assert error.description == "The file extension is blocked for security reasons."
|
||||
|
||||
|
||||
class TestTooManyFilesError:
|
||||
def test_defaults(self):
|
||||
error = TooManyFilesError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "too_many_files"
|
||||
assert error.description == "Only one file is allowed."
|
||||
|
||||
|
||||
class TestNoFileUploadedError:
|
||||
def test_defaults(self):
|
||||
error = NoFileUploadedError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "no_file_uploaded"
|
||||
assert error.description == "Please upload your file."
|
||||
@@ -1,22 +1,95 @@
|
||||
from flask import Response
|
||||
|
||||
from controllers.common.file_response import enforce_download_for_html, is_html_content
|
||||
from controllers.common.file_response import (
|
||||
_normalize_mime_type,
|
||||
enforce_download_for_html,
|
||||
is_html_content,
|
||||
)
|
||||
|
||||
|
||||
class TestFileResponseHelpers:
|
||||
def test_is_html_content_detects_mime_type(self):
|
||||
class TestNormalizeMimeType:
|
||||
def test_returns_empty_string_for_none(self):
|
||||
assert _normalize_mime_type(None) == ""
|
||||
|
||||
def test_returns_empty_string_for_empty_string(self):
|
||||
assert _normalize_mime_type("") == ""
|
||||
|
||||
def test_normalizes_mime_type(self):
|
||||
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
|
||||
|
||||
|
||||
class TestIsHtmlContent:
|
||||
def test_detects_html_via_mime_type(self):
|
||||
mime_type = "text/html; charset=UTF-8"
|
||||
|
||||
result = is_html_content(mime_type, filename="file.txt", extension="txt")
|
||||
result = is_html_content(
|
||||
mime_type=mime_type,
|
||||
filename="file.txt",
|
||||
extension="txt",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_html_content_detects_extension(self):
|
||||
result = is_html_content("text/plain", filename="report.html", extension=None)
|
||||
def test_detects_html_via_extension_argument(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_enforce_download_for_html_sets_headers(self):
|
||||
def test_detects_html_via_filename_extension(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename="report.html",
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_no_html_detected_anywhere(self):
|
||||
"""
|
||||
Missing negative test:
|
||||
- MIME type is not HTML
|
||||
- filename has no HTML extension
|
||||
- extension argument is not HTML
|
||||
"""
|
||||
result = is_html_content(
|
||||
mime_type="application/json",
|
||||
filename="data.json",
|
||||
extension="json",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_all_inputs_are_none(self):
|
||||
result = is_html_content(
|
||||
mime_type=None,
|
||||
filename=None,
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestEnforceDownloadForHtml:
|
||||
def test_sets_attachment_when_filename_missing(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
response,
|
||||
mime_type="text/html",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert response.headers["Content-Disposition"] == "attachment"
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_sets_headers_when_filename_present(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
@@ -27,11 +100,12 @@ class TestFileResponseHelpers:
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Disposition"].startswith("attachment")
|
||||
assert "unsafe.html" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_enforce_download_for_html_no_change_for_non_html(self):
|
||||
def test_does_not_modify_response_for_non_html_content(self):
|
||||
response = Response("payload", mimetype="text/plain")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
|
||||
188
api/tests/unit_tests/controllers/common/test_helpers.py
Normal file
188
api/tests/unit_tests/controllers/common/test_helpers.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common import helpers
|
||||
from controllers.common.helpers import FileInfo, guess_file_info_from_response
|
||||
|
||||
|
||||
def make_response(
|
||||
url="https://example.com/file.txt",
|
||||
headers=None,
|
||||
content=None,
|
||||
):
|
||||
return httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers=headers or {},
|
||||
content=content or b"",
|
||||
)
|
||||
|
||||
|
||||
class TestGuessFileInfoFromResponse:
|
||||
def test_filename_from_url(self):
|
||||
response = make_response(
|
||||
url="https://example.com/test.pdf",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "test.pdf"
|
||||
assert info.extension == ".pdf"
|
||||
assert info.mimetype == "application/pdf"
|
||||
|
||||
def test_filename_from_content_disposition(self):
|
||||
headers = {
|
||||
"Content-Disposition": "attachment; filename=myfile.csv",
|
||||
"Content-Type": "text/csv",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
headers=headers,
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "myfile.csv"
|
||||
assert info.extension == ".csv"
|
||||
assert info.mimetype == "text/csv"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("magic_available", "expected_ext"),
|
||||
[
|
||||
(True, "txt"),
|
||||
(False, "bin"),
|
||||
],
|
||||
)
|
||||
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
|
||||
if magic_available:
|
||||
if helpers.magic is None:
|
||||
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
|
||||
else:
|
||||
monkeypatch.setattr(helpers, "magic", None)
|
||||
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
name, ext = info.filename.split(".")
|
||||
UUID(name)
|
||||
assert ext == expected_ext
|
||||
|
||||
def test_mimetype_from_header_when_unknown(self):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = make_response(
|
||||
url="https://example.com/file.unknown",
|
||||
headers=headers,
|
||||
content=b'{"a": 1}',
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.mimetype == "application/json"
|
||||
|
||||
def test_extension_added_when_missing(self):
|
||||
headers = {"Content-Type": "image/png"}
|
||||
response = make_response(
|
||||
url="https://example.com/image",
|
||||
headers=headers,
|
||||
content=b"fakepngdata",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".png"
|
||||
assert info.filename.endswith(".png")
|
||||
|
||||
def test_content_length_used_as_size(self):
|
||||
headers = {
|
||||
"Content-Length": "1234",
|
||||
"Content-Type": "text/plain",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/a.txt",
|
||||
headers=headers,
|
||||
content=b"a" * 1234,
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == 1234
|
||||
|
||||
def test_size_minus_one_when_header_missing(self):
|
||||
response = make_response(url="https://example.com/a.txt")
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == -1
|
||||
|
||||
def test_fallback_to_bin_extension(self):
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
response = make_response(
|
||||
url="https://example.com/download",
|
||||
headers=headers,
|
||||
content=b"\x00\x01\x02\x03",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".bin"
|
||||
assert info.filename.endswith(".bin")
|
||||
|
||||
def test_return_type(self):
|
||||
response = make_response()
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert isinstance(info, FileInfo)
|
||||
|
||||
|
||||
class TestMagicImportWarnings:
|
||||
@pytest.mark.parametrize(
|
||||
("platform_name", "expected_message"),
|
||||
[
|
||||
("Windows", "pip install python-magic-bin"),
|
||||
("Darwin", "brew install libmagic"),
|
||||
("Linux", "sudo apt-get install libmagic1"),
|
||||
("Other", "install `libmagic`"),
|
||||
],
|
||||
)
|
||||
def test_magic_import_warning_per_platform(
|
||||
self,
|
||||
monkeypatch,
|
||||
platform_name,
|
||||
expected_message,
|
||||
):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
# Force ImportError when "magic" is imported
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "magic":
|
||||
raise ImportError("No module named magic")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
monkeypatch.setattr("platform.system", lambda: platform_name)
|
||||
|
||||
# Remove helpers so it imports fresh
|
||||
import sys
|
||||
|
||||
original_helpers = sys.modules.get(helpers.__name__)
|
||||
sys.modules.pop(helpers.__name__, None)
|
||||
|
||||
try:
|
||||
with pytest.warns(UserWarning, match="To use python-magic") as warning:
|
||||
imported_helpers = importlib.import_module(helpers.__name__)
|
||||
assert expected_message in str(warning[0].message)
|
||||
finally:
|
||||
if original_helpers is not None:
|
||||
sys.modules[helpers.__name__] = original_helpers
|
||||
189
api/tests/unit_tests/controllers/common/test_schema.py
Normal file
189
api/tests/unit_tests/controllers/common/test_schema.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import sys
|
||||
from enum import StrEnum
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ProductModel(BaseModel):
|
||||
id: int
|
||||
price: float
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_console_ns():
|
||||
"""Mock the console_ns to avoid circular imports during test collection."""
|
||||
mock_ns = MagicMock(spec=Namespace)
|
||||
mock_ns.models = {}
|
||||
|
||||
# Inject mock before importing schema module
|
||||
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
|
||||
yield mock_ns
|
||||
|
||||
|
||||
def test_default_ref_template_value():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
|
||||
|
||||
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
|
||||
|
||||
|
||||
def test_register_schema_model_calls_namespace_schema_model():
|
||||
from controllers.common.schema import register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "UserModel"
|
||||
assert isinstance(schema, dict)
|
||||
assert "properties" in schema
|
||||
|
||||
|
||||
def test_register_schema_model_passes_schema_from_pydantic():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
|
||||
assert schema == expected_schema
|
||||
|
||||
|
||||
def test_register_schema_models_registers_multiple_models():
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["UserModel", "ProductModel"]
|
||||
|
||||
|
||||
def test_register_schema_models_calls_register_schema_model(monkeypatch):
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_register(ns, model):
|
||||
calls.append((ns, model))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.common.schema.register_schema_model",
|
||||
fake_register,
|
||||
)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert calls == [
|
||||
(namespace, UserModel),
|
||||
(namespace, ProductModel),
|
||||
]
|
||||
|
||||
|
||||
class StatusEnum(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class PriorityEnum(StrEnum):
|
||||
HIGH = "high"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
def test_get_or_create_model_returns_existing_model(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"TestModel": existing_model}
|
||||
|
||||
result = get_or_create_model("TestModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
mock_console_ns.models = {}
|
||||
new_model = MagicMock()
|
||||
mock_console_ns.model.return_value = new_model
|
||||
field_def = {"name": {"type": "string"}}
|
||||
|
||||
result = get_or_create_model("NewModel", field_def)
|
||||
|
||||
assert result == new_model
|
||||
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
|
||||
|
||||
|
||||
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"ExistingModel": existing_model}
|
||||
|
||||
result = get_or_create_model("ExistingModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_register_enum_models_registers_single_enum():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "StatusEnum"
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
|
||||
def test_register_enum_models_registers_multiple_enums():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum, PriorityEnum)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["StatusEnum", "PriorityEnum"]
|
||||
|
||||
|
||||
def test_register_enum_models_uses_correct_ref_template():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
# Verify the schema contains enum values
|
||||
assert "enum" in schema or "anyOf" in schema
|
||||
@@ -124,12 +124,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
def fake_thread(*args, **kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
|
||||
|
||||
This test suite ensures that the files array is correctly populated in the message_end
|
||||
SSE event, which is critical for vision/image chat responses to render correctly.
|
||||
|
||||
Test Coverage:
|
||||
- Files array populated when MessageFile records exist
|
||||
- Files array is None when no MessageFile records exist
|
||||
- Correct signed URL generation for LOCAL_FILE transfer method
|
||||
- Correct URL handling for REMOTE_URL transfer method
|
||||
- Correct URL handling for TOOL_FILE transfer method
|
||||
- Proper file metadata formatting (filename, mime_type, size, extension)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestMessageEndStreamResponseFiles:
|
||||
"""Test suite for files array population in message_end SSE event."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline(self):
|
||||
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
|
||||
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
|
||||
pipeline._message_id = str(uuid.uuid4())
|
||||
pipeline._task_state = Mock()
|
||||
pipeline._task_state.metadata = Mock()
|
||||
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
|
||||
pipeline._task_state.llm_result = Mock()
|
||||
pipeline._task_state.llm_result.usage = Mock()
|
||||
pipeline._application_generate_entity = Mock()
|
||||
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_local(self):
|
||||
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_remote(self):
|
||||
"""Create a mock MessageFile with REMOTE_URL transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_tool(self):
|
||||
"""Create a mock MessageFile with TOOL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self, mock_message_file_local):
|
||||
"""Create a mock UploadFile."""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = mock_message_file_local.upload_file_id
|
||||
upload_file.name = "test_image.png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.size = 1024
|
||||
upload_file.extension = "png"
|
||||
return upload_file
|
||||
|
||||
def test_message_end_with_no_files(self, mock_pipeline):
|
||||
"""Test that files array is None when no MessageFile records exist."""
|
||||
# Arrange
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is None
|
||||
assert result.id == mock_pipeline._message_id
|
||||
assert result.metadata == {"test": "metadata"}
|
||||
|
||||
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
|
||||
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_local.id
|
||||
assert file_dict["filename"] == "test_image.png"
|
||||
assert file_dict["mime_type"] == "image/png"
|
||||
assert file_dict["size"] == 1024
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
|
||||
assert "https://example.com/signed-url" in file_dict["url"]
|
||||
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
|
||||
assert file_dict["remote_url"] == ""
|
||||
|
||||
# Verify database queries
|
||||
# Should be called twice: once for MessageFile, once for UploadFile
|
||||
assert mock_session.scalars.call_count == 2
|
||||
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
|
||||
|
||||
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
|
||||
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_remote]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_remote.id
|
||||
assert file_dict["filename"] == "image.jpg"
|
||||
assert file_dict["url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["extension"] == ".jpg"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
|
||||
assert file_dict["remote_url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["upload_file_id"] == mock_message_file_remote.id
|
||||
|
||||
# Verify only one query for message_files is made
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "https://example.com/tool_file.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["url"] == "https://example.com/tool_file.png"
|
||||
assert file_dict["filename"] == "tool_file.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with local path."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_123.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
|
||||
assert file_dict["filename"] == "tool_file_123.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
# Verify tool file signing was called
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
|
||||
|
||||
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_abc.verylongextension"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
mock_sign_tool.return_value = "https://example.com/signed.bin"
|
||||
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
assert result.files is not None
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["extension"] == ".bin"
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
|
||||
|
||||
def test_message_end_with_multiple_files(
|
||||
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
|
||||
):
|
||||
"""Test that files array contains all MessageFile records when multiple exist."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 2
|
||||
|
||||
# Verify both files are present
|
||||
file_ids = [f["related_id"] for f in result.files]
|
||||
assert mock_message_file_local.id in file_ids
|
||||
assert mock_message_file_remote.id in file_ids
|
||||
|
||||
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
|
||||
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query) - returns empty list (not found)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [] # UploadFile not found
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/fallback-url" in file_dict["url"]
|
||||
# Verify fallback URL was generated using upload_file_id from message_file
|
||||
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))
|
||||
@@ -8,7 +8,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.nodes.knowledge_retrieval.entities import (
|
||||
Condition,
|
||||
KnowledgeRetrievalNodeData,
|
||||
MetadataFilteringCondition,
|
||||
MultipleRetrievalConfig,
|
||||
RerankingModelConfig,
|
||||
SingleRetrievalConfig,
|
||||
@@ -593,3 +595,106 @@ class TestFetchDatasetRetriever:
|
||||
|
||||
# Assert
|
||||
assert version == "1"
|
||||
|
||||
def test_resolve_metadata_filtering_conditions_templates(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"title": "Knowledge Retrieval",
|
||||
"type": "knowledge-retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
},
|
||||
}
|
||||
# Variable in pool used by template
|
||||
mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme"))
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
conditions = MetadataFilteringCondition(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"),
|
||||
Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]),
|
||||
Condition(name="year", comparison_operator="=", value=2025),
|
||||
],
|
||||
)
|
||||
|
||||
# Act
|
||||
resolved = node._resolve_metadata_filtering_conditions(conditions)
|
||||
|
||||
# Assert
|
||||
assert resolved.logical_operator == "and"
|
||||
assert resolved.conditions[0].value == "readme"
|
||||
assert isinstance(resolved.conditions[1].value, list)
|
||||
assert resolved.conditions[1].value[1] == "readme"
|
||||
assert resolved.conditions[2].value == 2025
|
||||
|
||||
def test_fetch_passes_resolved_metadata_conditions(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""_fetch_dataset_retriever should pass resolved metadata conditions into request."""
|
||||
# Arrange
|
||||
query = "hi"
|
||||
variables = {"query": query}
|
||||
mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme"))
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_enable=True,
|
||||
reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"),
|
||||
),
|
||||
metadata_filtering_mode="manual",
|
||||
metadata_filtering_conditions=MetadataFilteringCondition(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {"id": node_id, "data": node_data.model_dump()}
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
# Act
|
||||
node._fetch_dataset_retriever(node_data=node_data, variables=variables)
|
||||
|
||||
# Assert the passed request has resolved value
|
||||
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
|
||||
request = call_args[1]["request"]
|
||||
assert request.metadata_filtering_conditions is not None
|
||||
assert request.metadata_filtering_conditions.conditions[0].value == "readme"
|
||||
|
||||
@@ -16,6 +16,7 @@ from dify_graph.nodes.document_extractor.node import (
|
||||
_extract_text_from_excel,
|
||||
_extract_text_from_pdf,
|
||||
_extract_text_from_plain_text,
|
||||
_normalize_docx_zip,
|
||||
)
|
||||
from dify_graph.variables import ArrayFileSegment
|
||||
from dify_graph.variables.segments import ArrayStringSegment
|
||||
@@ -86,6 +87,38 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
|
||||
assert "is not an ArrayFileSegment" in result.error
|
||||
|
||||
|
||||
def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state):
|
||||
"""Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([])."""
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
|
||||
# Provide an actual ArrayFileSegment with an empty list
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[])
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
|
||||
assert result.process_data.get("documents") == []
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=[])
|
||||
|
||||
|
||||
def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state):
|
||||
"""A file list containing only None (e.g., [None]) should be filtered to [] and succeed."""
|
||||
document_extractor_node.graph_runtime_state = mock_graph_runtime_state
|
||||
|
||||
# Use a Mock to bypass type validation for None entries in the list
|
||||
afs = Mock(spec=ArrayFileSegment)
|
||||
afs.value = [None]
|
||||
mock_graph_runtime_state.variable_pool.get.return_value = afs
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
|
||||
assert result.process_data.get("documents") == []
|
||||
assert result.outputs["text"] == ArrayStringSegment(value=[])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
|
||||
[
|
||||
@@ -385,3 +418,58 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file):
|
||||
expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n"
|
||||
|
||||
assert expected_manual == result
|
||||
|
||||
|
||||
def _make_docx_zip(use_backslash: bool) -> bytes:
|
||||
"""Helper to build a minimal in-memory DOCX zip.
|
||||
|
||||
When use_backslash=True the ZIP entry names use backslash separators
|
||||
(as produced by Evernote on Windows), otherwise forward slashes are used.
|
||||
"""
|
||||
import zipfile
|
||||
|
||||
sep = "\\" if use_backslash else "/"
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("[Content_Types].xml", b"<Types/>")
|
||||
zf.writestr(f"_rels{sep}.rels", b"<Relationships/>")
|
||||
zf.writestr(f"word{sep}document.xml", b"<w:document/>")
|
||||
zf.writestr(f"word{sep}_rels{sep}document.xml.rels", b"<Relationships/>")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def test_normalize_docx_zip_replaces_backslashes():
|
||||
"""ZIP entries with backslash separators must be rewritten to forward slashes."""
|
||||
import zipfile
|
||||
|
||||
malformed = _make_docx_zip(use_backslash=True)
|
||||
fixed = _normalize_docx_zip(malformed)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(fixed)) as zf:
|
||||
names = zf.namelist()
|
||||
|
||||
assert "word/document.xml" in names
|
||||
assert "word/_rels/document.xml.rels" in names
|
||||
# No entry should contain a backslash after normalization
|
||||
assert all("\\" not in name for name in names)
|
||||
|
||||
|
||||
def test_normalize_docx_zip_leaves_forward_slash_unchanged():
|
||||
"""ZIP entries that already use forward slashes must not be modified."""
|
||||
import zipfile
|
||||
|
||||
normal = _make_docx_zip(use_backslash=False)
|
||||
fixed = _normalize_docx_zip(normal)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(fixed)) as zf:
|
||||
names = zf.namelist()
|
||||
|
||||
assert "word/document.xml" in names
|
||||
assert "word/_rels/document.xml.rels" in names
|
||||
|
||||
|
||||
def test_normalize_docx_zip_returns_original_on_bad_zip():
|
||||
"""Non-zip bytes must be returned as-is without raising."""
|
||||
garbage = b"not a zip file at all"
|
||||
result = _normalize_docx_zip(garbage)
|
||||
assert result == garbage
|
||||
|
||||
@@ -265,6 +265,61 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup.run()
|
||||
|
||||
|
||||
def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
batches=[[FakeRun("run-free", "t_free", cutoff)]],
|
||||
delete_result={
|
||||
"runs": 0,
|
||||
"node_executions": 2,
|
||||
"offloads": 1,
|
||||
"app_logs": 3,
|
||||
"trigger_logs": 4,
|
||||
"pauses": 5,
|
||||
"pause_reasons": 6,
|
||||
},
|
||||
)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
batch_calls: list[dict[str, object]] = []
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs))
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert len(batch_calls) == 1
|
||||
assert batch_calls[0]["batch_rows"] == 1
|
||||
assert batch_calls[0]["targeted_runs"] == 1
|
||||
assert batch_calls[0]["deleted_runs"] == 1
|
||||
assert batch_calls[0]["related_action"] == "deleted"
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FailingRepo(FakeRepo):
|
||||
def delete_runs_with_related(
|
||||
self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None
|
||||
) -> dict[str, int]:
|
||||
raise RuntimeError("delete failed")
|
||||
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]])
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
with pytest.raises(RuntimeError, match="delete failed"):
|
||||
cleanup.run()
|
||||
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
|
||||
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
|
||||
def test_validate_export_filename_accepts_relative_path():
|
||||
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"test01.jsonl.gz",
|
||||
"test01.jsonl",
|
||||
"test01.gz",
|
||||
"/tmp/test01",
|
||||
"exports/../test01",
|
||||
"bad\x00name",
|
||||
"bad\tname",
|
||||
"a" * 1025,
|
||||
],
|
||||
)
|
||||
def test_validate_export_filename_rejects_invalid_values(filename: str):
|
||||
with pytest.raises(ValueError):
|
||||
AppMessageExportService.validate_export_filename(filename)
|
||||
|
||||
|
||||
def test_service_derives_output_names_from_filename_base():
|
||||
service = AppMessageExportService(
|
||||
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
|
||||
start_from=None,
|
||||
end_before=datetime.datetime(2026, 3, 1),
|
||||
filename="exports/2026/test01",
|
||||
batch_size=1000,
|
||||
use_cloud_storage=True,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert service._filename_base == "exports/2026/test01"
|
||||
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
|
||||
assert service.output_jsonl_name == "exports/2026/test01.jsonl"
|
||||
@@ -554,11 +554,9 @@ class TestMessagesCleanServiceFromDays:
|
||||
MessagesCleanService.from_days(policy=policy, days=-1)
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
|
||||
mock_now.return_value = fixed_now
|
||||
service = MessagesCleanService.from_days(policy=policy, days=0)
|
||||
|
||||
# Assert
|
||||
@@ -586,11 +584,9 @@ class TestMessagesCleanServiceFromDays:
|
||||
dry_run = True
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
|
||||
mock_now.return_value = fixed_now
|
||||
service = MessagesCleanService.from_days(
|
||||
policy=policy,
|
||||
days=days,
|
||||
@@ -613,11 +609,9 @@ class TestMessagesCleanServiceFromDays:
|
||||
policy = BillingDisabledPolicy()
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
|
||||
mock_now.return_value = fixed_now
|
||||
service = MessagesCleanService.from_days(policy=policy)
|
||||
|
||||
# Assert
|
||||
@@ -625,3 +619,53 @@ class TestMessagesCleanServiceFromDays:
|
||||
assert service._end_before == expected_end_before
|
||||
assert service._batch_size == 1000 # default
|
||||
assert service._dry_run is False # default
|
||||
|
||||
|
||||
class TestMessagesCleanServiceRun:
|
||||
"""Unit tests for MessagesCleanService.run instrumentation behavior."""
|
||||
|
||||
def test_run_records_completion_metrics_on_success(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
expected_stats = {
|
||||
"batches": 1,
|
||||
"total_messages": 10,
|
||||
"filtered_messages": 5,
|
||||
"total_deleted": 5,
|
||||
}
|
||||
service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act
|
||||
result = service.run()
|
||||
|
||||
# Assert
|
||||
assert result == expected_stats
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
def test_run_records_completion_metrics_on_failure(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="clean failed"):
|
||||
service.run()
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any
|
||||
|
||||
|
||||
class ConfigHelper:
|
||||
_LEGACY_SECTION_MAP = {
|
||||
"admin_config": "admin",
|
||||
"token_config": "auth",
|
||||
"app_config": "app",
|
||||
"api_key_config": "api_key",
|
||||
}
|
||||
|
||||
"""Helper class for reading and writing configuration files."""
|
||||
|
||||
def __init__(self, base_dir: Path | None = None):
|
||||
@@ -50,14 +57,8 @@ class ConfigHelper:
|
||||
Dictionary containing config data, or None if file doesn't exist
|
||||
"""
|
||||
# Provide backward compatibility for old config names
|
||||
if filename in ["admin_config", "token_config", "app_config", "api_key_config"]:
|
||||
section_map = {
|
||||
"admin_config": "admin",
|
||||
"token_config": "auth",
|
||||
"app_config": "app",
|
||||
"api_key_config": "api_key",
|
||||
}
|
||||
return self.get_state_section(section_map[filename])
|
||||
if filename in self._LEGACY_SECTION_MAP:
|
||||
return self.get_state_section(self._LEGACY_SECTION_MAP[filename])
|
||||
|
||||
config_path = self.get_config_path(filename)
|
||||
|
||||
@@ -85,14 +86,11 @@ class ConfigHelper:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
# Provide backward compatibility for old config names
|
||||
if filename in ["admin_config", "token_config", "app_config", "api_key_config"]:
|
||||
section_map = {
|
||||
"admin_config": "admin",
|
||||
"token_config": "auth",
|
||||
"app_config": "app",
|
||||
"api_key_config": "api_key",
|
||||
}
|
||||
return self.update_state_section(section_map[filename], data)
|
||||
if filename in self._LEGACY_SECTION_MAP:
|
||||
return self.update_state_section(
|
||||
self._LEGACY_SECTION_MAP[filename],
|
||||
data,
|
||||
)
|
||||
|
||||
self.ensure_config_dir()
|
||||
config_path = self.get_config_path(filename)
|
||||
|
||||
@@ -2,6 +2,12 @@
|
||||
|
||||
- Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions.
|
||||
|
||||
## Overlay Components (Mandatory)
|
||||
|
||||
- `./docs/overlay-migration.md` is the source of truth for overlay-related work.
|
||||
- In new or modified code, use only overlay primitives from `@/app/components/base/ui/*`.
|
||||
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding).
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests.
|
||||
|
||||
@@ -503,7 +503,7 @@ describe('TimePicker', () => {
|
||||
const emitted = onChange.mock.calls[0][0]
|
||||
expect(isDayjsObject(emitted)).toBe(true)
|
||||
// 10:30 UTC converted to America/New_York (UTC-5 in Jan) = 05:30
|
||||
expect(emitted.utcOffset()).toBe(dayjs().tz('America/New_York').utcOffset())
|
||||
expect(emitted.utcOffset()).toBe(dayjs.tz('2024-01-01', 'America/New_York').utcOffset())
|
||||
expect(emitted.hour()).toBe(5)
|
||||
expect(emitted.minute()).toBe(30)
|
||||
})
|
||||
|
||||
@@ -20,7 +20,7 @@ describe('dayjs utilities', () => {
|
||||
const result = toDayjs('07:15 PM', { timezone: tz })
|
||||
expect(result).toBeDefined()
|
||||
expect(result?.format('HH:mm')).toBe('19:15')
|
||||
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).utcOffset())
|
||||
expect(result?.utcOffset()).toBe(getDateWithTimezone({ timezone: tz }).startOf('day').utcOffset())
|
||||
})
|
||||
|
||||
it('isDayjsObject detects dayjs instances', () => {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
'use client'
|
||||
import type { ReactNode } from 'react'
|
||||
import type { IToastProps } from './context'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import type { IToastProps } from './context'
|
||||
import { ToastContext, useToastContext } from './context'
|
||||
|
||||
export type ToastHandle = {
|
||||
|
||||
257
web/app/components/base/ui/context-menu/__tests__/index.spec.tsx
Normal file
257
web/app/components/base/ui/context-menu/__tests__/index.spec.tsx
Normal file
@@ -0,0 +1,257 @@
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
ContextMenu,
|
||||
ContextMenuContent,
|
||||
ContextMenuItem,
|
||||
ContextMenuLinkItem,
|
||||
ContextMenuSeparator,
|
||||
ContextMenuSub,
|
||||
ContextMenuSubContent,
|
||||
ContextMenuSubTrigger,
|
||||
ContextMenuTrigger,
|
||||
} from '../index'
|
||||
|
||||
describe('context-menu wrapper', () => {
|
||||
describe('ContextMenuContent', () => {
|
||||
it('should position content at bottom-start with default placement when props are omitted', () => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent positionerProps={{ 'role': 'group', 'aria-label': 'content positioner' }}>
|
||||
<ContextMenuItem>Content action</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const positioner = screen.getByRole('group', { name: 'content positioner' })
|
||||
const popup = screen.getByRole('menu')
|
||||
expect(positioner).toHaveAttribute('data-side', 'bottom')
|
||||
expect(positioner).toHaveAttribute('data-align', 'start')
|
||||
expect(within(popup).getByRole('menuitem', { name: 'Content action' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply custom placement when custom positioning props are provided', () => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent
|
||||
placement="top-end"
|
||||
sideOffset={12}
|
||||
alignOffset={-3}
|
||||
positionerProps={{ 'role': 'group', 'aria-label': 'custom content positioner' }}
|
||||
>
|
||||
<ContextMenuItem>Custom content</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const positioner = screen.getByRole('group', { name: 'custom content positioner' })
|
||||
const popup = screen.getByRole('menu')
|
||||
expect(positioner).toHaveAttribute('data-side', 'top')
|
||||
expect(positioner).toHaveAttribute('data-align', 'end')
|
||||
expect(within(popup).getByRole('menuitem', { name: 'Custom content' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should forward passthrough attributes and handlers when positionerProps and popupProps are provided', () => {
|
||||
const handlePositionerMouseEnter = vi.fn()
|
||||
const handlePopupClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent
|
||||
positionerProps={{
|
||||
'role': 'group',
|
||||
'aria-label': 'context content positioner',
|
||||
'id': 'context-content-positioner',
|
||||
'onMouseEnter': handlePositionerMouseEnter,
|
||||
}}
|
||||
popupProps={{
|
||||
role: 'menu',
|
||||
id: 'context-content-popup',
|
||||
onClick: handlePopupClick,
|
||||
}}
|
||||
>
|
||||
<ContextMenuItem>Passthrough content</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const positioner = screen.getByRole('group', { name: 'context content positioner' })
|
||||
const popup = screen.getByRole('menu')
|
||||
fireEvent.mouseEnter(positioner)
|
||||
fireEvent.click(popup)
|
||||
expect(positioner).toHaveAttribute('id', 'context-content-positioner')
|
||||
expect(popup).toHaveAttribute('id', 'context-content-popup')
|
||||
expect(handlePositionerMouseEnter).toHaveBeenCalledTimes(1)
|
||||
expect(handlePopupClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('ContextMenuSubContent', () => {
|
||||
it('should position sub-content at right-start with default placement when props are omitted', () => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuSub open>
|
||||
<ContextMenuSubTrigger>More actions</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent positionerProps={{ 'role': 'group', 'aria-label': 'sub positioner' }}>
|
||||
<ContextMenuItem>Sub action</ContextMenuItem>
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const positioner = screen.getByRole('group', { name: 'sub positioner' })
|
||||
expect(positioner).toHaveAttribute('data-side', 'right')
|
||||
expect(positioner).toHaveAttribute('data-align', 'start')
|
||||
expect(screen.getByRole('menuitem', { name: 'Sub action' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('destructive prop behavior', () => {
|
||||
it.each([true, false])('should remain interactive and not leak destructive prop on item when destructive is %s', (destructive) => {
|
||||
const handleClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem
|
||||
destructive={destructive}
|
||||
aria-label="menu action"
|
||||
id={`context-item-${String(destructive)}`}
|
||||
onClick={handleClick}
|
||||
>
|
||||
Item label
|
||||
</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const item = screen.getByRole('menuitem', { name: 'menu action' })
|
||||
fireEvent.click(item)
|
||||
expect(item).toHaveAttribute('id', `context-item-${String(destructive)}`)
|
||||
expect(item).not.toHaveAttribute('destructive')
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it.each([true, false])('should remain interactive and not leak destructive prop on submenu trigger when destructive is %s', (destructive) => {
|
||||
const handleClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuSub open>
|
||||
<ContextMenuSubTrigger
|
||||
destructive={destructive}
|
||||
aria-label="submenu action"
|
||||
id={`context-sub-${String(destructive)}`}
|
||||
onClick={handleClick}
|
||||
>
|
||||
Trigger item
|
||||
</ContextMenuSubTrigger>
|
||||
</ContextMenuSub>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const trigger = screen.getByRole('menuitem', { name: 'submenu action' })
|
||||
fireEvent.click(trigger)
|
||||
expect(trigger).toHaveAttribute('id', `context-sub-${String(destructive)}`)
|
||||
expect(trigger).not.toHaveAttribute('destructive')
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it.each([true, false])('should remain interactive and not leak destructive prop on link item when destructive is %s', (destructive) => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuLinkItem
|
||||
destructive={destructive}
|
||||
href="https://example.com/docs"
|
||||
aria-label="context docs link"
|
||||
id={`context-link-${String(destructive)}`}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Docs
|
||||
</ContextMenuLinkItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'context docs link' })
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('id', `context-link-${String(destructive)}`)
|
||||
expect(link).not.toHaveAttribute('destructive')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ContextMenuLinkItem close behavior', () => {
|
||||
it('should keep link semantics and not leak closeOnClick prop when closeOnClick is false', () => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuLinkItem
|
||||
href="https://example.com/docs"
|
||||
closeOnClick={false}
|
||||
aria-label="docs link"
|
||||
>
|
||||
Docs
|
||||
</ContextMenuLinkItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'docs link' })
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('href', 'https://example.com/docs')
|
||||
expect(link).not.toHaveAttribute('closeOnClick')
|
||||
})
|
||||
})
|
||||
|
||||
describe('ContextMenuTrigger interaction', () => {
|
||||
it('should open menu when right-clicking trigger area', () => {
|
||||
render(
|
||||
<ContextMenu>
|
||||
<ContextMenuTrigger aria-label="context trigger area">
|
||||
Trigger area
|
||||
</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem>Open on right click</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
const trigger = screen.getByLabelText('context trigger area')
|
||||
fireEvent.contextMenu(trigger)
|
||||
expect(screen.getByRole('menuitem', { name: 'Open on right click' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('ContextMenuSeparator', () => {
|
||||
it('should render separator and keep surrounding rows when separator is between items', () => {
|
||||
render(
|
||||
<ContextMenu open>
|
||||
<ContextMenuTrigger aria-label="context trigger">Open</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem>First action</ContextMenuItem>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuItem>Second action</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('menuitem', { name: 'First action' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('menuitem', { name: 'Second action' })).toBeInTheDocument()
|
||||
expect(screen.getAllByRole('separator')).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
215
web/app/components/base/ui/context-menu/index.stories.tsx
Normal file
215
web/app/components/base/ui/context-menu/index.stories.tsx
Normal file
@@ -0,0 +1,215 @@
|
||||
import type { Meta, StoryObj } from '@storybook/nextjs-vite'
|
||||
import { useState } from 'react'
|
||||
import {
|
||||
ContextMenu,
|
||||
ContextMenuCheckboxItem,
|
||||
ContextMenuCheckboxItemIndicator,
|
||||
ContextMenuContent,
|
||||
ContextMenuGroup,
|
||||
ContextMenuGroupLabel,
|
||||
ContextMenuItem,
|
||||
ContextMenuLinkItem,
|
||||
ContextMenuRadioGroup,
|
||||
ContextMenuRadioItem,
|
||||
ContextMenuRadioItemIndicator,
|
||||
ContextMenuSeparator,
|
||||
ContextMenuSub,
|
||||
ContextMenuSubContent,
|
||||
ContextMenuSubTrigger,
|
||||
ContextMenuTrigger,
|
||||
} from '.'
|
||||
|
||||
const TriggerArea = ({ label = 'Right-click inside this area' }: { label?: string }) => (
|
||||
<ContextMenuTrigger
|
||||
aria-label="context menu trigger area"
|
||||
render={<button type="button" className="flex h-44 w-80 select-none items-center justify-center rounded-xl border border-divider-subtle bg-background-default-subtle px-6 text-center text-sm text-text-tertiary" />}
|
||||
>
|
||||
{label}
|
||||
</ContextMenuTrigger>
|
||||
)
|
||||
|
||||
const meta = {
|
||||
title: 'Base/Navigation/ContextMenu',
|
||||
component: ContextMenu,
|
||||
parameters: {
|
||||
layout: 'centered',
|
||||
docs: {
|
||||
description: {
|
||||
component: 'Compound context menu built on Base UI ContextMenu. Open by right-clicking the trigger area.',
|
||||
},
|
||||
},
|
||||
},
|
||||
tags: ['autodocs'],
|
||||
} satisfies Meta<typeof ContextMenu>
|
||||
|
||||
export default meta
|
||||
type Story = StoryObj<typeof meta>
|
||||
|
||||
export const Default: Story = {
|
||||
render: () => (
|
||||
<ContextMenu>
|
||||
<TriggerArea />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem>Edit</ContextMenuItem>
|
||||
<ContextMenuItem>Duplicate</ContextMenuItem>
|
||||
<ContextMenuItem>Archive</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
),
|
||||
}
|
||||
|
||||
export const WithSubmenu: Story = {
|
||||
render: () => (
|
||||
<ContextMenu>
|
||||
<TriggerArea />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem>Copy</ContextMenuItem>
|
||||
<ContextMenuItem>Paste</ContextMenuItem>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuSub>
|
||||
<ContextMenuSubTrigger>Share</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent>
|
||||
<ContextMenuItem>Email</ContextMenuItem>
|
||||
<ContextMenuItem>Slack</ContextMenuItem>
|
||||
<ContextMenuItem>Copy link</ContextMenuItem>
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
),
|
||||
}
|
||||
|
||||
export const WithGroupLabel: Story = {
|
||||
render: () => (
|
||||
<ContextMenu>
|
||||
<TriggerArea />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuGroup>
|
||||
<ContextMenuGroupLabel>Actions</ContextMenuGroupLabel>
|
||||
<ContextMenuItem>Rename</ContextMenuItem>
|
||||
<ContextMenuItem>Duplicate</ContextMenuItem>
|
||||
</ContextMenuGroup>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuGroup>
|
||||
<ContextMenuGroupLabel>Danger Zone</ContextMenuGroupLabel>
|
||||
<ContextMenuItem destructive>Delete</ContextMenuItem>
|
||||
</ContextMenuGroup>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
),
|
||||
}
|
||||
|
||||
const WithRadioItemsDemo = () => {
|
||||
const [value, setValue] = useState('comfortable')
|
||||
|
||||
return (
|
||||
<ContextMenu>
|
||||
<TriggerArea label={`Right-click to set density: ${value}`} />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuRadioGroup value={value} onValueChange={setValue}>
|
||||
<ContextMenuRadioItem value="compact">
|
||||
Compact
|
||||
<ContextMenuRadioItemIndicator />
|
||||
</ContextMenuRadioItem>
|
||||
<ContextMenuRadioItem value="comfortable">
|
||||
Comfortable
|
||||
<ContextMenuRadioItemIndicator />
|
||||
</ContextMenuRadioItem>
|
||||
<ContextMenuRadioItem value="spacious">
|
||||
Spacious
|
||||
<ContextMenuRadioItemIndicator />
|
||||
</ContextMenuRadioItem>
|
||||
</ContextMenuRadioGroup>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
)
|
||||
}
|
||||
|
||||
export const WithRadioItems: Story = {
|
||||
render: () => <WithRadioItemsDemo />,
|
||||
}
|
||||
|
||||
const WithCheckboxItemsDemo = () => {
|
||||
const [showToolbar, setShowToolbar] = useState(true)
|
||||
const [showSidebar, setShowSidebar] = useState(false)
|
||||
const [showStatusBar, setShowStatusBar] = useState(true)
|
||||
|
||||
return (
|
||||
<ContextMenu>
|
||||
<TriggerArea label="Right-click to configure panel visibility" />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuCheckboxItem checked={showToolbar} onCheckedChange={setShowToolbar}>
|
||||
Toolbar
|
||||
<ContextMenuCheckboxItemIndicator />
|
||||
</ContextMenuCheckboxItem>
|
||||
<ContextMenuCheckboxItem checked={showSidebar} onCheckedChange={setShowSidebar}>
|
||||
Sidebar
|
||||
<ContextMenuCheckboxItemIndicator />
|
||||
</ContextMenuCheckboxItem>
|
||||
<ContextMenuCheckboxItem checked={showStatusBar} onCheckedChange={setShowStatusBar}>
|
||||
Status bar
|
||||
<ContextMenuCheckboxItemIndicator />
|
||||
</ContextMenuCheckboxItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
)
|
||||
}
|
||||
|
||||
export const WithCheckboxItems: Story = {
|
||||
render: () => <WithCheckboxItemsDemo />,
|
||||
}
|
||||
|
||||
export const WithLinkItems: Story = {
|
||||
render: () => (
|
||||
<ContextMenu>
|
||||
<TriggerArea label="Right-click to open links" />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuLinkItem href="https://docs.dify.ai" rel="noopener noreferrer" target="_blank">
|
||||
Dify Docs
|
||||
</ContextMenuLinkItem>
|
||||
<ContextMenuLinkItem href="https://roadmap.dify.ai" rel="noopener noreferrer" target="_blank">
|
||||
Product Roadmap
|
||||
</ContextMenuLinkItem>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuLinkItem destructive href="https://example.com/delete" rel="noopener noreferrer" target="_blank">
|
||||
Dangerous External Action
|
||||
</ContextMenuLinkItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
),
|
||||
}
|
||||
|
||||
export const Complex: Story = {
|
||||
render: () => (
|
||||
<ContextMenu>
|
||||
<TriggerArea label="Right-click to inspect all menu capabilities" />
|
||||
<ContextMenuContent>
|
||||
<ContextMenuItem>
|
||||
<span aria-hidden className="i-ri-pencil-line size-4 shrink-0 text-text-tertiary" />
|
||||
Rename
|
||||
</ContextMenuItem>
|
||||
<ContextMenuItem>
|
||||
<span aria-hidden className="i-ri-file-copy-line size-4 shrink-0 text-text-tertiary" />
|
||||
Duplicate
|
||||
</ContextMenuItem>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuSub>
|
||||
<ContextMenuSubTrigger>
|
||||
<span aria-hidden className="i-ri-share-line size-4 shrink-0 text-text-tertiary" />
|
||||
Share
|
||||
</ContextMenuSubTrigger>
|
||||
<ContextMenuSubContent>
|
||||
<ContextMenuItem>Email</ContextMenuItem>
|
||||
<ContextMenuItem>Slack</ContextMenuItem>
|
||||
<ContextMenuItem>Copy Link</ContextMenuItem>
|
||||
</ContextMenuSubContent>
|
||||
</ContextMenuSub>
|
||||
<ContextMenuSeparator />
|
||||
<ContextMenuItem destructive>
|
||||
<span aria-hidden className="i-ri-delete-bin-line size-4 shrink-0" />
|
||||
Delete
|
||||
</ContextMenuItem>
|
||||
</ContextMenuContent>
|
||||
</ContextMenu>
|
||||
),
|
||||
}
|
||||
302
web/app/components/base/ui/context-menu/index.tsx
Normal file
302
web/app/components/base/ui/context-menu/index.tsx
Normal file
@@ -0,0 +1,302 @@
|
||||
'use client'
|
||||
|
||||
import type { Placement } from '@/app/components/base/ui/placement'
|
||||
import { ContextMenu as BaseContextMenu } from '@base-ui/react/context-menu'
|
||||
import * as React from 'react'
|
||||
import {
|
||||
menuBackdropClassName,
|
||||
menuGroupLabelClassName,
|
||||
menuIndicatorClassName,
|
||||
menuPopupAnimationClassName,
|
||||
menuPopupBaseClassName,
|
||||
menuRowClassName,
|
||||
menuSeparatorClassName,
|
||||
} from '@/app/components/base/ui/menu-shared'
|
||||
import { parsePlacement } from '@/app/components/base/ui/placement'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export const ContextMenu = BaseContextMenu.Root
|
||||
export const ContextMenuTrigger = BaseContextMenu.Trigger
|
||||
export const ContextMenuPortal = BaseContextMenu.Portal
|
||||
export const ContextMenuBackdrop = BaseContextMenu.Backdrop
|
||||
export const ContextMenuSub = BaseContextMenu.SubmenuRoot
|
||||
export const ContextMenuGroup = BaseContextMenu.Group
|
||||
export const ContextMenuRadioGroup = BaseContextMenu.RadioGroup
|
||||
|
||||
type ContextMenuContentProps = {
|
||||
children: React.ReactNode
|
||||
placement?: Placement
|
||||
sideOffset?: number
|
||||
alignOffset?: number
|
||||
className?: string
|
||||
popupClassName?: string
|
||||
positionerProps?: Omit<
|
||||
React.ComponentPropsWithoutRef<typeof BaseContextMenu.Positioner>,
|
||||
'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset'
|
||||
>
|
||||
popupProps?: Omit<
|
||||
React.ComponentPropsWithoutRef<typeof BaseContextMenu.Popup>,
|
||||
'children' | 'className'
|
||||
>
|
||||
}
|
||||
|
||||
type ContextMenuPopupRenderProps = Required<Pick<ContextMenuContentProps, 'children'>> & {
|
||||
placement: Placement
|
||||
sideOffset: number
|
||||
alignOffset: number
|
||||
className?: string
|
||||
popupClassName?: string
|
||||
positionerProps?: ContextMenuContentProps['positionerProps']
|
||||
popupProps?: ContextMenuContentProps['popupProps']
|
||||
withBackdrop?: boolean
|
||||
}
|
||||
|
||||
function renderContextMenuPopup({
|
||||
children,
|
||||
placement,
|
||||
sideOffset,
|
||||
alignOffset,
|
||||
className,
|
||||
popupClassName,
|
||||
positionerProps,
|
||||
popupProps,
|
||||
withBackdrop = false,
|
||||
}: ContextMenuPopupRenderProps) {
|
||||
const { side, align } = parsePlacement(placement)
|
||||
|
||||
return (
|
||||
<BaseContextMenu.Portal>
|
||||
{withBackdrop && (
|
||||
<BaseContextMenu.Backdrop className={menuBackdropClassName} />
|
||||
)}
|
||||
<BaseContextMenu.Positioner
|
||||
side={side}
|
||||
align={align}
|
||||
sideOffset={sideOffset}
|
||||
alignOffset={alignOffset}
|
||||
className={cn('z-50 outline-none', className)}
|
||||
{...positionerProps}
|
||||
>
|
||||
<BaseContextMenu.Popup
|
||||
className={cn(
|
||||
menuPopupBaseClassName,
|
||||
menuPopupAnimationClassName,
|
||||
popupClassName,
|
||||
)}
|
||||
{...popupProps}
|
||||
>
|
||||
{children}
|
||||
</BaseContextMenu.Popup>
|
||||
</BaseContextMenu.Positioner>
|
||||
</BaseContextMenu.Portal>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuContent({
|
||||
children,
|
||||
placement = 'bottom-start',
|
||||
sideOffset = 0,
|
||||
alignOffset = 0,
|
||||
className,
|
||||
popupClassName,
|
||||
positionerProps,
|
||||
popupProps,
|
||||
}: ContextMenuContentProps) {
|
||||
return renderContextMenuPopup({
|
||||
children,
|
||||
placement,
|
||||
sideOffset,
|
||||
alignOffset,
|
||||
className,
|
||||
popupClassName,
|
||||
positionerProps,
|
||||
popupProps,
|
||||
withBackdrop: true,
|
||||
})
|
||||
}
|
||||
|
||||
type ContextMenuItemProps = React.ComponentPropsWithoutRef<typeof BaseContextMenu.Item> & {
|
||||
destructive?: boolean
|
||||
}
|
||||
|
||||
export function ContextMenuItem({
|
||||
className,
|
||||
destructive,
|
||||
...props
|
||||
}: ContextMenuItemProps) {
|
||||
return (
|
||||
<BaseContextMenu.Item
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
type ContextMenuLinkItemProps = React.ComponentPropsWithoutRef<typeof BaseContextMenu.LinkItem> & {
|
||||
destructive?: boolean
|
||||
}
|
||||
|
||||
export function ContextMenuLinkItem({
|
||||
className,
|
||||
destructive,
|
||||
closeOnClick = true,
|
||||
...props
|
||||
}: ContextMenuLinkItemProps) {
|
||||
return (
|
||||
<BaseContextMenu.LinkItem
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
closeOnClick={closeOnClick}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuRadioItem({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentPropsWithoutRef<typeof BaseContextMenu.RadioItem>) {
|
||||
return (
|
||||
<BaseContextMenu.RadioItem
|
||||
className={cn(menuRowClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuCheckboxItem({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentPropsWithoutRef<typeof BaseContextMenu.CheckboxItem>) {
|
||||
return (
|
||||
<BaseContextMenu.CheckboxItem
|
||||
className={cn(menuRowClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
type ContextMenuIndicatorProps = Omit<React.ComponentPropsWithoutRef<'span'>, 'children'> & {
|
||||
children?: React.ReactNode
|
||||
}
|
||||
|
||||
export function ContextMenuItemIndicator({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: ContextMenuIndicatorProps) {
|
||||
return (
|
||||
<span
|
||||
aria-hidden
|
||||
className={cn(menuIndicatorClassName, className)}
|
||||
{...props}
|
||||
>
|
||||
{children ?? <span aria-hidden className="i-ri-check-line h-4 w-4" />}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuCheckboxItemIndicator({
|
||||
className,
|
||||
...props
|
||||
}: Omit<React.ComponentPropsWithoutRef<typeof BaseContextMenu.CheckboxItemIndicator>, 'children'>) {
|
||||
return (
|
||||
<BaseContextMenu.CheckboxItemIndicator
|
||||
className={cn(menuIndicatorClassName, className)}
|
||||
{...props}
|
||||
>
|
||||
<span aria-hidden className="i-ri-check-line h-4 w-4" />
|
||||
</BaseContextMenu.CheckboxItemIndicator>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuRadioItemIndicator({
|
||||
className,
|
||||
...props
|
||||
}: Omit<React.ComponentPropsWithoutRef<typeof BaseContextMenu.RadioItemIndicator>, 'children'>) {
|
||||
return (
|
||||
<BaseContextMenu.RadioItemIndicator
|
||||
className={cn(menuIndicatorClassName, className)}
|
||||
{...props}
|
||||
>
|
||||
<span aria-hidden className="i-ri-check-line h-4 w-4" />
|
||||
</BaseContextMenu.RadioItemIndicator>
|
||||
)
|
||||
}
|
||||
|
||||
type ContextMenuSubTriggerProps = React.ComponentPropsWithoutRef<typeof BaseContextMenu.SubmenuTrigger> & {
|
||||
destructive?: boolean
|
||||
}
|
||||
|
||||
export function ContextMenuSubTrigger({
|
||||
className,
|
||||
destructive,
|
||||
children,
|
||||
...props
|
||||
}: ContextMenuSubTriggerProps) {
|
||||
return (
|
||||
<BaseContextMenu.SubmenuTrigger
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<span aria-hidden className="i-ri-arrow-right-s-line ml-auto size-4 shrink-0 text-text-tertiary" />
|
||||
</BaseContextMenu.SubmenuTrigger>
|
||||
)
|
||||
}
|
||||
|
||||
type ContextMenuSubContentProps = {
|
||||
children: React.ReactNode
|
||||
placement?: Placement
|
||||
sideOffset?: number
|
||||
alignOffset?: number
|
||||
className?: string
|
||||
popupClassName?: string
|
||||
positionerProps?: ContextMenuContentProps['positionerProps']
|
||||
popupProps?: ContextMenuContentProps['popupProps']
|
||||
}
|
||||
|
||||
export function ContextMenuSubContent({
|
||||
children,
|
||||
placement = 'right-start',
|
||||
sideOffset = 4,
|
||||
alignOffset = 0,
|
||||
className,
|
||||
popupClassName,
|
||||
positionerProps,
|
||||
popupProps,
|
||||
}: ContextMenuSubContentProps) {
|
||||
return renderContextMenuPopup({
|
||||
children,
|
||||
placement,
|
||||
sideOffset,
|
||||
alignOffset,
|
||||
className,
|
||||
popupClassName,
|
||||
positionerProps,
|
||||
popupProps,
|
||||
})
|
||||
}
|
||||
|
||||
export function ContextMenuGroupLabel({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentPropsWithoutRef<typeof BaseContextMenu.GroupLabel>) {
|
||||
return (
|
||||
<BaseContextMenu.GroupLabel
|
||||
className={cn(menuGroupLabelClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export function ContextMenuSeparator({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentPropsWithoutRef<typeof BaseContextMenu.Separator>) {
|
||||
return (
|
||||
<BaseContextMenu.Separator
|
||||
className={cn(menuSeparatorClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
import { Menu } from '@base-ui/react/menu'
|
||||
import type { ComponentPropsWithoutRef, ReactNode } from 'react'
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react'
|
||||
import Link from 'next/link'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuGroup,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuPortal,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuLinkItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubContent,
|
||||
@@ -15,18 +14,22 @@ import {
|
||||
DropdownMenuTrigger,
|
||||
} from '../index'
|
||||
|
||||
describe('dropdown-menu wrapper', () => {
|
||||
describe('alias exports', () => {
|
||||
it('should map direct aliases to the corresponding Menu primitive when importing menu roots', () => {
|
||||
expect(DropdownMenu).toBe(Menu.Root)
|
||||
expect(DropdownMenuPortal).toBe(Menu.Portal)
|
||||
expect(DropdownMenuTrigger).toBe(Menu.Trigger)
|
||||
expect(DropdownMenuSub).toBe(Menu.SubmenuRoot)
|
||||
expect(DropdownMenuGroup).toBe(Menu.Group)
|
||||
expect(DropdownMenuRadioGroup).toBe(Menu.RadioGroup)
|
||||
})
|
||||
})
|
||||
vi.mock('next/link', () => ({
|
||||
default: ({
|
||||
href,
|
||||
children,
|
||||
...props
|
||||
}: {
|
||||
href: string
|
||||
children?: ReactNode
|
||||
} & Omit<ComponentPropsWithoutRef<'a'>, 'href'>) => (
|
||||
<a href={href} {...props}>
|
||||
{children}
|
||||
</a>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('dropdown-menu wrapper', () => {
|
||||
describe('DropdownMenuContent', () => {
|
||||
it('should position content at bottom-end with default placement when props are omitted', () => {
|
||||
render(
|
||||
@@ -250,6 +253,99 @@ describe('dropdown-menu wrapper', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropdownMenuLinkItem', () => {
|
||||
it('should render as anchor and keep href/target attributes when link props are provided', () => {
|
||||
render(
|
||||
<DropdownMenu open>
|
||||
<DropdownMenuTrigger aria-label="menu trigger">Open</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLinkItem href="https://example.com/docs" target="_blank" rel="noopener noreferrer">
|
||||
Docs
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'Docs' })
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('href', 'https://example.com/docs')
|
||||
expect(link).toHaveAttribute('target', '_blank')
|
||||
expect(link).toHaveAttribute('rel', 'noopener noreferrer')
|
||||
})
|
||||
|
||||
it('should keep link semantics and not leak closeOnClick prop when closeOnClick is false', () => {
|
||||
render(
|
||||
<DropdownMenu open>
|
||||
<DropdownMenuTrigger aria-label="menu trigger">Open</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLinkItem
|
||||
href="https://example.com/docs"
|
||||
closeOnClick={false}
|
||||
aria-label="docs link"
|
||||
>
|
||||
Docs
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'docs link' })
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('href', 'https://example.com/docs')
|
||||
expect(link).not.toHaveAttribute('closeOnClick')
|
||||
})
|
||||
|
||||
it('should preserve link semantics when render prop uses a custom link component', () => {
|
||||
render(
|
||||
<DropdownMenu open>
|
||||
<DropdownMenuTrigger aria-label="menu trigger">Open</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLinkItem
|
||||
render={<Link href="/account" />}
|
||||
aria-label="account link"
|
||||
>
|
||||
Account settings
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'account link' })
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('href', '/account')
|
||||
expect(link).toHaveTextContent('Account settings')
|
||||
})
|
||||
|
||||
it.each([true, false])('should remain interactive and not leak destructive prop when destructive is %s', (destructive) => {
|
||||
const handleClick = vi.fn()
|
||||
|
||||
render(
|
||||
<DropdownMenu open>
|
||||
<DropdownMenuTrigger aria-label="menu trigger">Open</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLinkItem
|
||||
destructive={destructive}
|
||||
href="https://example.com/docs"
|
||||
aria-label="docs link"
|
||||
id={`menu-link-${String(destructive)}`}
|
||||
onClick={handleClick}
|
||||
>
|
||||
Docs
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>,
|
||||
)
|
||||
|
||||
const link = screen.getByRole('menuitem', { name: 'docs link' })
|
||||
fireEvent.click(link)
|
||||
|
||||
expect(link.tagName.toLowerCase()).toBe('a')
|
||||
expect(link).toHaveAttribute('id', `menu-link-${String(destructive)}`)
|
||||
expect(link).not.toHaveAttribute('destructive')
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('DropdownMenuSeparator', () => {
|
||||
it('should forward passthrough props and handlers when separator props are provided', () => {
|
||||
const handleMouseEnter = vi.fn()
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
DropdownMenuGroup,
|
||||
DropdownMenuGroupLabel,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLinkItem,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuRadioItem,
|
||||
DropdownMenuRadioItemIndicator,
|
||||
@@ -234,6 +235,22 @@ export const WithIcons: Story = {
|
||||
),
|
||||
}
|
||||
|
||||
export const WithLinkItems: Story = {
|
||||
render: () => (
|
||||
<DropdownMenu>
|
||||
<TriggerButton label="Open links" />
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLinkItem href="https://docs.dify.ai" rel="noopener noreferrer" target="_blank">
|
||||
Dify Docs
|
||||
</DropdownMenuLinkItem>
|
||||
<DropdownMenuLinkItem href="https://roadmap.dify.ai" rel="noopener noreferrer" target="_blank">
|
||||
Product Roadmap
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
),
|
||||
}
|
||||
|
||||
const ComplexDemo = () => {
|
||||
const [sortOrder, setSortOrder] = useState('newest')
|
||||
const [showArchived, setShowArchived] = useState(false)
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
import type { Placement } from '@/app/components/base/ui/placement'
|
||||
import { Menu } from '@base-ui/react/menu'
|
||||
import * as React from 'react'
|
||||
import {
|
||||
menuGroupLabelClassName,
|
||||
menuIndicatorClassName,
|
||||
menuPopupAnimationClassName,
|
||||
menuPopupBaseClassName,
|
||||
menuRowClassName,
|
||||
menuSeparatorClassName,
|
||||
} from '@/app/components/base/ui/menu-shared'
|
||||
import { parsePlacement } from '@/app/components/base/ui/placement'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@@ -13,20 +21,13 @@ export const DropdownMenuSub = Menu.SubmenuRoot
|
||||
export const DropdownMenuGroup = Menu.Group
|
||||
export const DropdownMenuRadioGroup = Menu.RadioGroup
|
||||
|
||||
const menuRowBaseClassName = 'mx-1 flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg px-2 outline-none'
|
||||
const menuRowStateClassName = 'data-[highlighted]:bg-state-base-hover data-[disabled]:cursor-not-allowed data-[disabled]:opacity-30'
|
||||
|
||||
export function DropdownMenuRadioItem({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentPropsWithoutRef<typeof Menu.RadioItem>) {
|
||||
return (
|
||||
<Menu.RadioItem
|
||||
className={cn(
|
||||
menuRowBaseClassName,
|
||||
menuRowStateClassName,
|
||||
className,
|
||||
)}
|
||||
className={cn(menuRowClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@@ -38,10 +39,7 @@ export function DropdownMenuRadioItemIndicator({
|
||||
}: Omit<React.ComponentPropsWithoutRef<typeof Menu.RadioItemIndicator>, 'children'>) {
|
||||
return (
|
||||
<Menu.RadioItemIndicator
|
||||
className={cn(
|
||||
'ml-auto flex shrink-0 items-center text-text-accent',
|
||||
className,
|
||||
)}
|
||||
className={cn(menuIndicatorClassName, className)}
|
||||
{...props}
|
||||
>
|
||||
<span aria-hidden className="i-ri-check-line h-4 w-4" />
|
||||
@@ -55,11 +53,7 @@ export function DropdownMenuCheckboxItem({
|
||||
}: React.ComponentPropsWithoutRef<typeof Menu.CheckboxItem>) {
|
||||
return (
|
||||
<Menu.CheckboxItem
|
||||
className={cn(
|
||||
menuRowBaseClassName,
|
||||
menuRowStateClassName,
|
||||
className,
|
||||
)}
|
||||
className={cn(menuRowClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@@ -71,10 +65,7 @@ export function DropdownMenuCheckboxItemIndicator({
|
||||
}: Omit<React.ComponentPropsWithoutRef<typeof Menu.CheckboxItemIndicator>, 'children'>) {
|
||||
return (
|
||||
<Menu.CheckboxItemIndicator
|
||||
className={cn(
|
||||
'ml-auto flex shrink-0 items-center text-text-accent',
|
||||
className,
|
||||
)}
|
||||
className={cn(menuIndicatorClassName, className)}
|
||||
{...props}
|
||||
>
|
||||
<span aria-hidden className="i-ri-check-line h-4 w-4" />
|
||||
@@ -88,10 +79,7 @@ export function DropdownMenuGroupLabel({
|
||||
}: React.ComponentPropsWithoutRef<typeof Menu.GroupLabel>) {
|
||||
return (
|
||||
<Menu.GroupLabel
|
||||
className={cn(
|
||||
'px-3 pb-0.5 pt-1 text-text-tertiary system-xs-medium-uppercase',
|
||||
className,
|
||||
)}
|
||||
className={cn(menuGroupLabelClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@@ -148,8 +136,8 @@ function renderDropdownMenuPopup({
|
||||
>
|
||||
<Menu.Popup
|
||||
className={cn(
|
||||
'max-h-[var(--available-height)] overflow-y-auto overflow-x-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur py-1 text-sm text-text-secondary shadow-lg backdrop-blur-[5px]',
|
||||
'origin-[var(--transform-origin)] transition-[transform,scale,opacity] data-[ending-style]:scale-95 data-[starting-style]:scale-95 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
|
||||
menuPopupBaseClassName,
|
||||
menuPopupAnimationClassName,
|
||||
popupClassName,
|
||||
)}
|
||||
{...popupProps}
|
||||
@@ -195,12 +183,7 @@ export function DropdownMenuSubTrigger({
|
||||
}: DropdownMenuSubTriggerProps) {
|
||||
return (
|
||||
<Menu.SubmenuTrigger
|
||||
className={cn(
|
||||
menuRowBaseClassName,
|
||||
menuRowStateClassName,
|
||||
destructive && 'text-text-destructive',
|
||||
className,
|
||||
)}
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -253,12 +236,26 @@ export function DropdownMenuItem({
|
||||
}: DropdownMenuItemProps) {
|
||||
return (
|
||||
<Menu.Item
|
||||
className={cn(
|
||||
menuRowBaseClassName,
|
||||
menuRowStateClassName,
|
||||
destructive && 'text-text-destructive',
|
||||
className,
|
||||
)}
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
type DropdownMenuLinkItemProps = React.ComponentPropsWithoutRef<typeof Menu.LinkItem> & {
|
||||
destructive?: boolean
|
||||
}
|
||||
|
||||
export function DropdownMenuLinkItem({
|
||||
className,
|
||||
destructive,
|
||||
closeOnClick = true,
|
||||
...props
|
||||
}: DropdownMenuLinkItemProps) {
|
||||
return (
|
||||
<Menu.LinkItem
|
||||
className={cn(menuRowClassName, destructive && 'text-text-destructive', className)}
|
||||
closeOnClick={closeOnClick}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
@@ -270,7 +267,7 @@ export function DropdownMenuSeparator({
|
||||
}: React.ComponentPropsWithoutRef<typeof Menu.Separator>) {
|
||||
return (
|
||||
<Menu.Separator
|
||||
className={cn('my-1 h-px bg-divider-subtle', className)}
|
||||
className={cn(menuSeparatorClassName, className)}
|
||||
{...props}
|
||||
/>
|
||||
)
|
||||
|
||||
7
web/app/components/base/ui/menu-shared.ts
Normal file
7
web/app/components/base/ui/menu-shared.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export const menuRowClassName = 'mx-1 flex h-8 cursor-pointer select-none items-center gap-1 rounded-lg px-2 outline-none data-[highlighted]:bg-state-base-hover data-[disabled]:cursor-not-allowed data-[disabled]:opacity-30'
|
||||
export const menuIndicatorClassName = 'ml-auto flex shrink-0 items-center text-text-accent'
|
||||
export const menuGroupLabelClassName = 'px-3 pb-0.5 pt-1 text-text-tertiary system-xs-medium-uppercase'
|
||||
export const menuSeparatorClassName = 'my-1 h-px bg-divider-subtle'
|
||||
export const menuPopupBaseClassName = 'max-h-[var(--available-height)] overflow-y-auto overflow-x-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur py-1 text-sm text-text-secondary shadow-lg outline-none focus:outline-none focus-visible:outline-none backdrop-blur-[5px]'
|
||||
export const menuPopupAnimationClassName = 'origin-[var(--transform-origin)] transition-[transform,scale,opacity] data-[ending-style]:scale-95 data-[starting-style]:scale-95 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none'
|
||||
export const menuBackdropClassName = 'fixed inset-0 z-50 bg-transparent transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none'
|
||||
@@ -225,5 +225,97 @@ describe('Compliance', () => {
|
||||
payload: ACCOUNT_SETTING_TAB.BILLING,
|
||||
})
|
||||
})
|
||||
|
||||
// isPending branches: spinner visible, disabled class, guard blocks second call
|
||||
it('should show spinner and guard against duplicate download when isPending is true', async () => {
|
||||
// Arrange
|
||||
let resolveDownload: (value: { url: string }) => void
|
||||
vi.mocked(getDocDownloadUrl).mockImplementation(() => new Promise((resolve) => {
|
||||
resolveDownload = resolve
|
||||
}))
|
||||
vi.mocked(useProviderContext).mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.team,
|
||||
},
|
||||
})
|
||||
|
||||
// Act
|
||||
openMenuAndRender()
|
||||
const downloadButtons = screen.getAllByText('common.operation.download')
|
||||
fireEvent.click(downloadButtons[0])
|
||||
|
||||
// Assert - btn-disabled class and spinner should appear while mutation is pending
|
||||
await waitFor(() => {
|
||||
const menuItem = screen.getByText('common.compliance.soc2Type1').closest('[role="menuitem"]')
|
||||
expect(menuItem).not.toBeNull()
|
||||
const disabledBtn = menuItem!.querySelector('.cursor-not-allowed')
|
||||
expect(disabledBtn).not.toBeNull()
|
||||
}, { timeout: 10000 })
|
||||
|
||||
// Cleanup: resolve the pending promise
|
||||
resolveDownload!({ url: 'http://example.com/doc.pdf' })
|
||||
await waitFor(() => {
|
||||
expect(downloadUrl).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call downloadCompliance again while pending', async () => {
|
||||
let resolveDownload: (value: { url: string }) => void
|
||||
vi.mocked(getDocDownloadUrl).mockImplementation(() => new Promise((resolve) => {
|
||||
resolveDownload = resolve
|
||||
}))
|
||||
vi.mocked(useProviderContext).mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.team,
|
||||
},
|
||||
})
|
||||
|
||||
openMenuAndRender()
|
||||
const downloadButtons = screen.getAllByText('common.operation.download')
|
||||
|
||||
// First click starts download
|
||||
fireEvent.click(downloadButtons[0])
|
||||
|
||||
// Wait for mutation to start and React to re-render (isPending=true)
|
||||
await waitFor(() => {
|
||||
const menuItem = screen.getByText('common.compliance.soc2Type1').closest('[role="menuitem"]')
|
||||
const el = menuItem!.querySelector('.cursor-not-allowed')
|
||||
expect(el).not.toBeNull()
|
||||
expect(getDocDownloadUrl).toHaveBeenCalledTimes(1)
|
||||
}, { timeout: 10000 })
|
||||
|
||||
// Second click while pending - should be guarded by isPending check
|
||||
fireEvent.click(downloadButtons[0])
|
||||
|
||||
resolveDownload!({ url: 'http://example.com/doc.pdf' })
|
||||
await waitFor(() => {
|
||||
expect(downloadUrl).toHaveBeenCalledTimes(1)
|
||||
}, { timeout: 10000 })
|
||||
// getDocDownloadUrl should still have only been called once
|
||||
expect(getDocDownloadUrl).toHaveBeenCalledTimes(1)
|
||||
}, 20000)
|
||||
|
||||
// canShowUpgradeTooltip=false: enterprise plan has empty tooltip text → no TooltipContent
|
||||
it('should show upgrade badge with empty tooltip for enterprise plan', () => {
|
||||
// Arrange
|
||||
vi.mocked(useProviderContext).mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.enterprise,
|
||||
},
|
||||
})
|
||||
|
||||
// Act
|
||||
openMenuAndRender()
|
||||
|
||||
// Assert - enterprise is not in any download list, so upgrade badges should appear
|
||||
// The key branch: upgradeTooltip[Plan.enterprise] = '' → canShowUpgradeTooltip=false
|
||||
expect(screen.getAllByText('billing.upgradeBtn.encourageShort').length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -184,7 +184,7 @@ export default function Compliance() {
|
||||
<DropdownMenuSubContent
|
||||
popupClassName="w-[337px] divide-y divide-divider-subtle !bg-components-panel-bg-blur !py-0 backdrop-blur-sm"
|
||||
>
|
||||
<DropdownMenuGroup className="p-1">
|
||||
<DropdownMenuGroup className="py-1">
|
||||
<ComplianceDocRowItem
|
||||
icon={<Soc2 aria-hidden className="size-7 shrink-0" />}
|
||||
label={t('compliance.soc2Type1', { ns: 'common' })}
|
||||
|
||||
@@ -247,6 +247,23 @@ describe('AccountDropdown', () => {
|
||||
// Assert
|
||||
expect(screen.getByText('common.userProfile.compliance')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Compound AND middle-false: IS_CLOUD_EDITION=true but isCurrentWorkspaceOwner=false
|
||||
it('should hide Compliance in Cloud Edition when user is not workspace owner', () => {
|
||||
// Arrange
|
||||
mockConfig.IS_CLOUD_EDITION = true
|
||||
vi.mocked(useAppContext).mockReturnValue({
|
||||
...baseAppContextValue,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
})
|
||||
|
||||
// Act
|
||||
renderWithRouter(<AppSelector />)
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.userProfile.compliance')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Actions', () => {
|
||||
|
||||
@@ -9,7 +9,7 @@ import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import ThemeSwitcher from '@/app/components/base/theme-switcher'
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { DropdownMenu, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSeparator, DropdownMenuTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@@ -41,12 +41,12 @@ function AccountMenuRouteItem({
|
||||
trailing,
|
||||
}: AccountMenuRouteItemProps) {
|
||||
return (
|
||||
<DropdownMenuItem
|
||||
<DropdownMenuLinkItem
|
||||
className="justify-between"
|
||||
render={<Link href={href} />}
|
||||
>
|
||||
<MenuItemContent iconClassName={iconClassName} label={label} trailing={trailing} />
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuLinkItem>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -64,12 +64,14 @@ function AccountMenuExternalItem({
|
||||
trailing,
|
||||
}: AccountMenuExternalItemProps) {
|
||||
return (
|
||||
<DropdownMenuItem
|
||||
<DropdownMenuLinkItem
|
||||
className="justify-between"
|
||||
render={<a href={href} rel="noopener noreferrer" target="_blank" />}
|
||||
href={href}
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
>
|
||||
<MenuItemContent iconClassName={iconClassName} label={label} trailing={trailing} />
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuLinkItem>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -101,7 +103,7 @@ type AccountMenuSectionProps = {
|
||||
}
|
||||
|
||||
function AccountMenuSection({ children }: AccountMenuSectionProps) {
|
||||
return <DropdownMenuGroup className="p-1">{children}</DropdownMenuGroup>
|
||||
return <DropdownMenuGroup className="py-1">{children}</DropdownMenuGroup>
|
||||
}
|
||||
|
||||
export default function AppSelector() {
|
||||
@@ -144,8 +146,8 @@ export default function AppSelector() {
|
||||
sideOffset={6}
|
||||
popupClassName="w-60 max-w-80 !bg-components-panel-bg-blur !py-0 backdrop-blur-sm"
|
||||
>
|
||||
<DropdownMenuGroup className="px-1 py-1">
|
||||
<div className="flex flex-nowrap items-center py-2 pl-3 pr-2">
|
||||
<DropdownMenuGroup className="py-1">
|
||||
<div className="mx-1 flex flex-nowrap items-center py-2 pl-3 pr-2">
|
||||
<div className="grow">
|
||||
<div className="break-all text-text-primary system-md-medium">
|
||||
{userProfile.name}
|
||||
|
||||
@@ -36,8 +36,8 @@ vi.mock('@/config', async (importOriginal) => {
|
||||
return {
|
||||
...actual,
|
||||
IS_CE_EDITION: false,
|
||||
get ZENDESK_WIDGET_KEY() { return mockZendeskKey.value },
|
||||
get SUPPORT_EMAIL_ADDRESS() { return mockSupportEmailKey.value },
|
||||
get ZENDESK_WIDGET_KEY() { return mockZendeskKey.value || '' },
|
||||
get SUPPORT_EMAIL_ADDRESS() { return mockSupportEmailKey.value || '' },
|
||||
}
|
||||
})
|
||||
|
||||
@@ -173,25 +173,18 @@ describe('Support', () => {
|
||||
expect(screen.queryByText('common.userProfile.contactUs')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show email support if specified in the config', () => {
|
||||
// Optional chain null guard: ZENDESK_WIDGET_KEY is null
|
||||
it('should show Email Support when ZENDESK_WIDGET_KEY is null', () => {
|
||||
// Arrange
|
||||
mockZendeskKey.value = ''
|
||||
mockSupportEmailKey.value = 'support@example.com'
|
||||
vi.mocked(useProviderContext).mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.sandbox,
|
||||
},
|
||||
})
|
||||
mockZendeskKey.value = null as unknown as string
|
||||
|
||||
// Act
|
||||
renderSupport()
|
||||
fireEvent.click(screen.getByText('common.userProfile.support'))
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.userProfile.emailSupport')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.userProfile.emailSupport')?.closest('a')?.getAttribute('href')?.startsWith(`mailto:${mockSupportEmailKey.value}`)).toBe(true)
|
||||
expect(screen.getByText('common.userProfile.emailSupport')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.userProfile.contactUs')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { DropdownMenuGroup, DropdownMenuItem, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { DropdownMenuGroup, DropdownMenuItem, DropdownMenuLinkItem, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger } from '@/app/components/base/ui/dropdown-menu'
|
||||
import { toggleZendeskWindow } from '@/app/components/base/zendesk/utils'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { SUPPORT_EMAIL_ADDRESS, ZENDESK_WIDGET_KEY } from '@/config'
|
||||
@@ -31,7 +31,7 @@ export default function Support({ closeAccountDropdown }: SupportProps) {
|
||||
<DropdownMenuSubContent
|
||||
popupClassName="w-[216px] divide-y divide-divider-subtle !bg-components-panel-bg-blur !py-0 backdrop-blur-sm"
|
||||
>
|
||||
<DropdownMenuGroup className="p-1">
|
||||
<DropdownMenuGroup className="py-1">
|
||||
{hasDedicatedChannel && hasZendeskWidget && (
|
||||
<DropdownMenuItem
|
||||
className="justify-between"
|
||||
@@ -47,37 +47,43 @@ export default function Support({ closeAccountDropdown }: SupportProps) {
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{hasDedicatedChannel && !hasZendeskWidget && (
|
||||
<DropdownMenuItem
|
||||
<DropdownMenuLinkItem
|
||||
className="justify-between"
|
||||
render={<a href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo?.current_version, SUPPORT_EMAIL_ADDRESS)} rel="noopener noreferrer" target="_blank" />}
|
||||
href={mailToSupport(userProfile.email, plan.type, langGeniusVersionInfo?.current_version, SUPPORT_EMAIL_ADDRESS)}
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
>
|
||||
<MenuItemContent
|
||||
iconClassName="i-ri-mail-send-line"
|
||||
label={t('userProfile.emailSupport', { ns: 'common' })}
|
||||
trailing={<ExternalLinkIndicator />}
|
||||
/>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuLinkItem>
|
||||
)}
|
||||
<DropdownMenuItem
|
||||
<DropdownMenuLinkItem
|
||||
className="justify-between"
|
||||
render={<a href="https://forum.dify.ai/" rel="noopener noreferrer" target="_blank" />}
|
||||
href="https://forum.dify.ai/"
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
>
|
||||
<MenuItemContent
|
||||
iconClassName="i-ri-discuss-line"
|
||||
label={t('userProfile.forum', { ns: 'common' })}
|
||||
trailing={<ExternalLinkIndicator />}
|
||||
/>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem
|
||||
</DropdownMenuLinkItem>
|
||||
<DropdownMenuLinkItem
|
||||
className="justify-between"
|
||||
render={<a href="https://discord.gg/5AEfbxcd9k" rel="noopener noreferrer" target="_blank" />}
|
||||
href="https://discord.gg/5AEfbxcd9k"
|
||||
rel="noopener noreferrer"
|
||||
target="_blank"
|
||||
>
|
||||
<MenuItemContent
|
||||
iconClassName="i-ri-discord-line"
|
||||
label={t('userProfile.community', { ns: 'common' })}
|
||||
trailing={<ExternalLinkIndicator />}
|
||||
/>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuLinkItem>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuSub>
|
||||
|
||||
@@ -136,4 +136,32 @@ describe('WorkplaceSelector', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
// find() returns undefined: no workspace with current: true
|
||||
it('should not crash when no workspace has current: true', () => {
|
||||
// Arrange
|
||||
vi.mocked(useWorkspacesContext).mockReturnValue({
|
||||
workspaces: [
|
||||
{ id: '1', name: 'Workspace 1', current: false, plan: 'professional', status: 'normal', created_at: Date.now() },
|
||||
],
|
||||
})
|
||||
|
||||
// Act & Assert - should not throw
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
|
||||
// name[0]?.toLocaleUpperCase() undefined: workspace with empty name
|
||||
it('should not crash when workspace name is empty string', () => {
|
||||
// Arrange
|
||||
vi.mocked(useWorkspacesContext).mockReturnValue({
|
||||
workspaces: [
|
||||
{ id: '1', name: '', current: true, plan: 'sandbox', status: 'normal', created_at: Date.now() },
|
||||
],
|
||||
})
|
||||
|
||||
// Act & Assert - should not throw
|
||||
expect(() => renderComponent()).not.toThrow()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -388,37 +388,33 @@ describe('DataSourceNotion Component', () => {
|
||||
})
|
||||
|
||||
describe('Additional Action Edge Cases', () => {
|
||||
it('should cover all possible falsy/nullish branches for connection data in handleAuthAgain and useEffect', async () => {
|
||||
it.each([
|
||||
undefined,
|
||||
null,
|
||||
{},
|
||||
{ data: undefined },
|
||||
{ data: null },
|
||||
{ data: '' },
|
||||
{ data: 0 },
|
||||
{ data: false },
|
||||
{ data: 'http' },
|
||||
{ data: 'internal' },
|
||||
{ data: 'unknown' },
|
||||
])('should cover connection data branch: %s', async (val) => {
|
||||
vi.mocked(useDataSourceIntegrates).mockReturnValue(mockQuerySuccess({ data: mockWorkspaces }))
|
||||
/* eslint-disable-next-line ts/no-explicit-any */
|
||||
vi.mocked(useNotionConnection).mockReturnValue({ data: val, isSuccess: true } as any)
|
||||
|
||||
render(<DataSourceNotion />)
|
||||
|
||||
const connectionCases = [
|
||||
undefined,
|
||||
null,
|
||||
{},
|
||||
{ data: undefined },
|
||||
{ data: null },
|
||||
{ data: '' },
|
||||
{ data: 0 },
|
||||
{ data: false },
|
||||
{ data: 'http' },
|
||||
{ data: 'internal' },
|
||||
{ data: 'unknown' },
|
||||
]
|
||||
// Trigger handleAuthAgain with these values
|
||||
const workspaceItem = getWorkspaceItem('Workspace 1')
|
||||
const actionBtn = within(workspaceItem).getByRole('button')
|
||||
fireEvent.click(actionBtn)
|
||||
const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages')
|
||||
fireEvent.click(authAgainBtn)
|
||||
|
||||
for (const val of connectionCases) {
|
||||
/* eslint-disable-next-line ts/no-explicit-any */
|
||||
vi.mocked(useNotionConnection).mockReturnValue({ data: val, isSuccess: true } as any)
|
||||
|
||||
// Trigger handleAuthAgain with these values
|
||||
const workspaceItem = getWorkspaceItem('Workspace 1')
|
||||
const actionBtn = within(workspaceItem).getByRole('button')
|
||||
fireEvent.click(actionBtn)
|
||||
const authAgainBtn = await screen.findByText('common.dataSource.notion.changeAuthorizedPages')
|
||||
fireEvent.click(authAgainBtn)
|
||||
}
|
||||
|
||||
await waitFor(() => expect(useNotionConnection).toHaveBeenCalled())
|
||||
expect(useNotionConnection).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -134,5 +134,46 @@ describe('ConfigJinaReaderModal Component', () => {
|
||||
resolveSave!({ result: 'success' })
|
||||
await waitFor(() => expect(mockOnSaved).toHaveBeenCalledTimes(1))
|
||||
})
|
||||
|
||||
it('should show encryption info and external link in the modal', async () => {
|
||||
render(<ConfigJinaReaderModal onCancel={mockOnCancel} onSaved={mockOnSaved} />)
|
||||
|
||||
// Verify PKCS1_OAEP link exists
|
||||
const pkcsLink = screen.getByText('PKCS1_OAEP')
|
||||
expect(pkcsLink.closest('a')).toHaveAttribute('href', 'https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html')
|
||||
|
||||
// Verify the Jina Reader external link
|
||||
const jinaLink = screen.getByRole('link', { name: /datasetCreation\.jinaReader\.getApiKeyLinkText/i })
|
||||
expect(jinaLink).toHaveAttribute('target', '_blank')
|
||||
})
|
||||
|
||||
it('should return early when save is clicked while already saving (isSaving guard)', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange - a save that never resolves so isSaving stays true
|
||||
let resolveFirst: (value: { result: 'success' }) => void
|
||||
const neverResolves = new Promise<{ result: 'success' }>((resolve) => {
|
||||
resolveFirst = resolve
|
||||
})
|
||||
vi.mocked(createDataSourceApiKeyBinding).mockReturnValue(neverResolves)
|
||||
render(<ConfigJinaReaderModal onCancel={mockOnCancel} onSaved={mockOnSaved} />)
|
||||
|
||||
const apiKeyInput = screen.getByPlaceholderText('datasetCreation.jinaReader.apiKeyPlaceholder')
|
||||
await user.type(apiKeyInput, 'valid-key')
|
||||
|
||||
const saveBtn = screen.getByRole('button', { name: /common\.operation\.save/i })
|
||||
// First click - starts saving, isSaving becomes true
|
||||
await user.click(saveBtn)
|
||||
expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Second click using fireEvent bypasses disabled check - hits isSaving guard
|
||||
const { fireEvent: fe } = await import('@testing-library/react')
|
||||
fe.click(saveBtn)
|
||||
// Still only called once because isSaving=true returns early
|
||||
expect(createDataSourceApiKeyBinding).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Cleanup
|
||||
resolveFirst!({ result: 'success' })
|
||||
await waitFor(() => expect(mockOnSaved).toHaveBeenCalled())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -195,4 +195,57 @@ describe('DataSourceWebsite Component', () => {
|
||||
expect(removeDataSourceApiKeyBinding).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Firecrawl Save Flow', () => {
|
||||
it('should re-fetch sources after saving Firecrawl configuration', async () => {
|
||||
// Arrange
|
||||
await renderAndWait(DataSourceProvider.fireCrawl)
|
||||
fireEvent.click(screen.getByText('common.dataSource.configure'))
|
||||
expect(screen.getByText('datasetCreation.firecrawl.configFirecrawl')).toBeInTheDocument()
|
||||
vi.mocked(fetchDataSources).mockClear()
|
||||
|
||||
// Act - fill in required API key field and save
|
||||
const apiKeyInput = screen.getByPlaceholderText('datasetCreation.firecrawl.apiKeyPlaceholder')
|
||||
fireEvent.change(apiKeyInput, { target: { value: 'test-key' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/i }))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(fetchDataSources).toHaveBeenCalled()
|
||||
expect(screen.queryByText('datasetCreation.firecrawl.configFirecrawl')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Cancel Flow', () => {
|
||||
it('should close watercrawl modal when cancel is clicked', async () => {
|
||||
// Arrange
|
||||
await renderAndWait(DataSourceProvider.waterCrawl)
|
||||
fireEvent.click(screen.getByText('common.dataSource.configure'))
|
||||
expect(screen.getByText('datasetCreation.watercrawl.configWatercrawl')).toBeInTheDocument()
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i }))
|
||||
|
||||
// Assert - modal closed
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('datasetCreation.watercrawl.configWatercrawl')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should close jina reader modal when cancel is clicked', async () => {
|
||||
// Arrange
|
||||
await renderAndWait(DataSourceProvider.jinaReader)
|
||||
fireEvent.click(screen.getByText('common.dataSource.configure'))
|
||||
expect(screen.getByText('datasetCreation.jinaReader.configJinaReader')).toBeInTheDocument()
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.operation\.cancel/i }))
|
||||
|
||||
// Assert - modal closed
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('datasetCreation.jinaReader.configJinaReader')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import Operate from './Operate'
|
||||
|
||||
describe('Operate', () => {
|
||||
it('renders cancel and save when editing', () => {
|
||||
it('should render cancel and save when editing is open', () => {
|
||||
render(
|
||||
<Operate
|
||||
isOpen
|
||||
@@ -18,7 +19,7 @@ describe('Operate', () => {
|
||||
expect(screen.getByText('common.operation.save')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows add key prompt when closed', () => {
|
||||
it('should show add-key prompt when closed', () => {
|
||||
render(
|
||||
<Operate
|
||||
isOpen={false}
|
||||
@@ -33,7 +34,7 @@ describe('Operate', () => {
|
||||
expect(screen.getByText('common.provider.addKey')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows invalid state indicator and edit prompt when status is fail', () => {
|
||||
it('should show invalid state and edit prompt when status is fail', () => {
|
||||
render(
|
||||
<Operate
|
||||
isOpen={false}
|
||||
@@ -49,7 +50,7 @@ describe('Operate', () => {
|
||||
expect(screen.getByText('common.provider.editKey')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows edit prompt without error text when status is success', () => {
|
||||
it('should show edit prompt without error text when status is success', () => {
|
||||
render(
|
||||
<Operate
|
||||
isOpen={false}
|
||||
@@ -65,11 +66,30 @@ describe('Operate', () => {
|
||||
expect(screen.queryByText('common.provider.invalidApiKey')).toBeNull()
|
||||
})
|
||||
|
||||
it('shows no actions for unsupported status', () => {
|
||||
it('should not call onAdd when disabled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onAdd = vi.fn()
|
||||
render(
|
||||
<Operate
|
||||
isOpen={false}
|
||||
status={'unknown' as never}
|
||||
status="add"
|
||||
disabled
|
||||
onAdd={onAdd}
|
||||
onCancel={vi.fn()}
|
||||
onEdit={vi.fn()}
|
||||
onSave={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
await user.click(screen.getByText('common.provider.addKey'))
|
||||
expect(onAdd).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show no actions when status is unsupported', () => {
|
||||
render(
|
||||
<Operate
|
||||
isOpen={false}
|
||||
// @ts-expect-error intentional invalid status for runtime fallback coverage
|
||||
status="unknown"
|
||||
onAdd={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
onEdit={vi.fn()}
|
||||
|
||||
@@ -267,6 +267,99 @@ describe('MembersPage', () => {
|
||||
expect(screen.getByText(/plansCommon\.unlimited/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show non-billing member format for team plan even when billing is enabled', () => {
|
||||
vi.mocked(useProviderContext).mockReturnValue(createMockProviderContextValue({
|
||||
enableBilling: true,
|
||||
plan: {
|
||||
type: Plan.team,
|
||||
total: { teamMembers: 50 } as unknown as ReturnType<typeof useProviderContext>['plan']['total'],
|
||||
} as unknown as ReturnType<typeof useProviderContext>['plan'],
|
||||
}))
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
// Plan.team is an unlimited member plan → isNotUnlimitedMemberPlan=false → non-billing layout
|
||||
expect(screen.getByText(/plansCommon\.memberAfter/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show invite button when user is manager but not owner', () => {
|
||||
vi.mocked(useAppContext).mockReturnValue({
|
||||
userProfile: { email: 'admin@example.com' },
|
||||
currentWorkspace: { name: 'Test Workspace', role: 'admin' } as ICurrentWorkspace,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
isCurrentWorkspaceManager: true,
|
||||
} as unknown as AppContextValue)
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
expect(screen.getByRole('button', { name: /invite/i })).toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: /transfer ownership/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use created_at as fallback when last_active_at is empty', () => {
|
||||
const memberNoLastActive: Member = {
|
||||
...mockAccounts[1],
|
||||
last_active_at: '',
|
||||
created_at: '1700000000',
|
||||
}
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: { accounts: [memberNoLastActive] },
|
||||
refetch: mockRefetch,
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(1700000000000)
|
||||
})
|
||||
|
||||
it('should not show plural s when only one account in billing layout', () => {
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: { accounts: [mockAccounts[0]] },
|
||||
refetch: mockRefetch,
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
vi.mocked(useProviderContext).mockReturnValue(createMockProviderContextValue({
|
||||
enableBilling: true,
|
||||
plan: {
|
||||
type: Plan.sandbox,
|
||||
total: { teamMembers: 5 } as unknown as ReturnType<typeof useProviderContext>['plan']['total'],
|
||||
} as unknown as ReturnType<typeof useProviderContext>['plan'],
|
||||
}))
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
expect(screen.getByText(/plansCommon\.member/i)).toBeInTheDocument()
|
||||
expect(screen.getByText('1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show plural s when only one account in non-billing layout', () => {
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: { accounts: [mockAccounts[0]] },
|
||||
refetch: mockRefetch,
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
expect(screen.getByText(/plansCommon\.memberAfter/i)).toBeInTheDocument()
|
||||
expect(screen.getByText('1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show normal role as fallback for unknown role', () => {
|
||||
vi.mocked(useAppContext).mockReturnValue({
|
||||
userProfile: { email: 'admin@example.com' },
|
||||
currentWorkspace: { name: 'Test Workspace', role: 'admin' } as ICurrentWorkspace,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
isCurrentWorkspaceManager: false,
|
||||
} as unknown as AppContextValue)
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: { accounts: [{ ...mockAccounts[1], role: 'unknown_role' as Member['role'] }] },
|
||||
refetch: mockRefetch,
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
|
||||
render(<MembersPage />)
|
||||
|
||||
expect(screen.getByText('common.members.normal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show upgrade button when member limit is full', () => {
|
||||
vi.mocked(useProviderContext).mockReturnValue(createMockProviderContextValue({
|
||||
enableBilling: true,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { InvitationResponse } from '@/models/common'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { vi } from 'vitest'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
@@ -171,6 +171,66 @@ describe('InviteModal', () => {
|
||||
expect(screen.queryByText('user@example.com')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show unlimited label when workspace member limit is zero', async () => {
|
||||
vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({
|
||||
licenseLimit: { workspace_members: { size: 5, limit: 0 } },
|
||||
refreshLicenseLimit: mockRefreshLicenseLimit,
|
||||
} as unknown as Parameters<typeof selector>[0]))
|
||||
|
||||
renderModal()
|
||||
|
||||
expect(await screen.findByText(/license\.unlimited/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should initialize usedSize to zero when workspace_members.size is null', async () => {
|
||||
vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({
|
||||
licenseLimit: { workspace_members: { size: null, limit: 10 } },
|
||||
refreshLicenseLimit: mockRefreshLicenseLimit,
|
||||
} as unknown as Parameters<typeof selector>[0]))
|
||||
|
||||
renderModal()
|
||||
|
||||
// usedSize starts at 0 (via ?? 0 fallback), no emails added → counter shows 0
|
||||
expect(await screen.findByText('0')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not call onSend when invite result is not success', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.mocked(inviteMember).mockResolvedValue({
|
||||
result: 'error',
|
||||
invitation_results: [],
|
||||
} as unknown as InvitationResponse)
|
||||
|
||||
renderModal()
|
||||
|
||||
await user.type(screen.getByTestId('mock-email-input'), 'user@example.com')
|
||||
await user.click(screen.getByRole('button', { name: /members\.sendInvite/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(inviteMember).toHaveBeenCalled()
|
||||
expect(mockOnSend).not.toHaveBeenCalled()
|
||||
expect(mockOnCancel).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show destructive text color when used size exceeds limit', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({
|
||||
licenseLimit: { workspace_members: { size: 10, limit: 10 } },
|
||||
refreshLicenseLimit: mockRefreshLicenseLimit,
|
||||
} as unknown as Parameters<typeof selector>[0]))
|
||||
|
||||
renderModal()
|
||||
|
||||
const input = screen.getByTestId('mock-email-input')
|
||||
await user.type(input, 'user@example.com')
|
||||
|
||||
// usedSize = 10 + 1 = 11 > limit 10 → destructive color
|
||||
const counter = screen.getByText('11')
|
||||
expect(counter.closest('div')).toHaveClass('text-text-destructive')
|
||||
})
|
||||
|
||||
it('should not submit if already submitting', async () => {
|
||||
const user = userEvent.setup()
|
||||
let resolveInvite: (value: InvitationResponse) => void
|
||||
@@ -202,4 +262,72 @@ describe('InviteModal', () => {
|
||||
expect(mockOnCancel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show destructive color and disable send button when limit is exactly met with one email', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
// size=10, limit=10 - adding 1 email makes usedSize=11 > limit=10
|
||||
vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({
|
||||
licenseLimit: { workspace_members: { size: 10, limit: 10 } },
|
||||
refreshLicenseLimit: mockRefreshLicenseLimit,
|
||||
} as unknown as Parameters<typeof selector>[0]))
|
||||
|
||||
renderModal()
|
||||
|
||||
const input = screen.getByTestId('mock-email-input')
|
||||
await user.type(input, 'user@example.com')
|
||||
|
||||
// isLimitExceeded=true → button is disabled, cannot submit
|
||||
const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i })
|
||||
expect(sendBtn).toBeDisabled()
|
||||
expect(inviteMember).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should hit isSubmitting guard inside handleSend when button is force-clicked during submission', async () => {
|
||||
const user = userEvent.setup()
|
||||
let resolveInvite: (value: InvitationResponse) => void
|
||||
const invitePromise = new Promise<InvitationResponse>((resolve) => {
|
||||
resolveInvite = resolve
|
||||
})
|
||||
vi.mocked(inviteMember).mockReturnValue(invitePromise)
|
||||
|
||||
renderModal()
|
||||
|
||||
const input = screen.getByTestId('mock-email-input')
|
||||
await user.type(input, 'user@example.com')
|
||||
|
||||
const sendBtn = screen.getByRole('button', { name: /members\.sendInvite/i })
|
||||
|
||||
// First click starts submission
|
||||
await user.click(sendBtn)
|
||||
expect(inviteMember).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Force-click bypasses disabled attribute → hits isSubmitting guard in handleSend
|
||||
fireEvent.click(sendBtn)
|
||||
expect(inviteMember).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Cleanup
|
||||
resolveInvite!({ result: 'success', invitation_results: [] })
|
||||
await waitFor(() => {
|
||||
expect(mockOnCancel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not show error text color when isLimited is false even with many emails', async () => {
|
||||
// size=0, limit=0 → isLimited=false, usedSize=emails.length
|
||||
vi.mocked(useProviderContextSelector).mockImplementation(selector => selector({
|
||||
licenseLimit: { workspace_members: { size: 0, limit: 0 } },
|
||||
refreshLicenseLimit: mockRefreshLicenseLimit,
|
||||
} as unknown as Parameters<typeof selector>[0]))
|
||||
|
||||
const user = userEvent.setup()
|
||||
renderModal()
|
||||
|
||||
const input = screen.getByTestId('mock-email-input')
|
||||
await user.type(input, 'user@example.com')
|
||||
|
||||
// isLimited=false → no destructive color
|
||||
const counter = screen.getByText('1')
|
||||
expect(counter.closest('div')).not.toHaveClass('text-text-destructive')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,8 +2,12 @@ import type { InvitationResult } from '@/models/common'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import InvitedModal from './index'
|
||||
|
||||
const mockConfigState = vi.hoisted(() => ({ isCeEdition: true }))
|
||||
|
||||
vi.mock('@/config', () => ({
|
||||
IS_CE_EDITION: true,
|
||||
get IS_CE_EDITION() {
|
||||
return mockConfigState.isCeEdition
|
||||
},
|
||||
}))
|
||||
|
||||
describe('InvitedModal', () => {
|
||||
@@ -13,6 +17,11 @@ describe('InvitedModal', () => {
|
||||
{ email: 'failed@example.com', status: 'failed', message: 'Error msg' },
|
||||
]
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockConfigState.isCeEdition = true
|
||||
})
|
||||
|
||||
it('should show success and failed invitation sections', async () => {
|
||||
render(<InvitedModal invitationResults={results} onCancel={mockOnCancel} />)
|
||||
|
||||
@@ -21,4 +30,59 @@ describe('InvitedModal', () => {
|
||||
expect(screen.getByText('http://invite.com/1')).toBeInTheDocument()
|
||||
expect(screen.getByText('failed@example.com')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide invitation link section when there are no successes', () => {
|
||||
const failedOnly: InvitationResult[] = [
|
||||
{ email: 'fail@example.com', status: 'failed', message: 'Quota exceeded' },
|
||||
]
|
||||
|
||||
render(<InvitedModal invitationResults={failedOnly} onCancel={mockOnCancel} />)
|
||||
|
||||
expect(screen.queryByText(/members\.invitationLink/i)).not.toBeInTheDocument()
|
||||
expect(screen.getByText(/members\.failedInvitationEmails/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide failed section when there are only successes', () => {
|
||||
const successOnly: InvitationResult[] = [
|
||||
{ email: 'ok@example.com', status: 'success', url: 'http://invite.com/2' },
|
||||
]
|
||||
|
||||
render(<InvitedModal invitationResults={successOnly} onCancel={mockOnCancel} />)
|
||||
|
||||
expect(screen.getByText(/members\.invitationLink/i)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/members\.failedInvitationEmails/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide both sections when results are empty', () => {
|
||||
render(<InvitedModal invitationResults={[]} onCancel={mockOnCancel} />)
|
||||
|
||||
expect(screen.queryByText(/members\.invitationLink/i)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/members\.failedInvitationEmails/i)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('InvitedModal (non-CE edition)', () => {
|
||||
const mockOnCancel = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockConfigState.isCeEdition = false
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
mockConfigState.isCeEdition = true
|
||||
})
|
||||
|
||||
it('should render invitationSentTip without CE edition content when IS_CE_EDITION is false', async () => {
|
||||
const results: InvitationResult[] = [
|
||||
{ email: 'success@example.com', status: 'success', url: 'http://invite.com/1' },
|
||||
]
|
||||
|
||||
render(<InvitedModal invitationResults={results} onCancel={mockOnCancel} />)
|
||||
|
||||
// The !IS_CE_EDITION branch - should show the tip text
|
||||
expect(await screen.findByText(/members\.invitationSentTip/i)).toBeInTheDocument()
|
||||
// CE-only content should not be shown
|
||||
expect(screen.queryByText(/members\.invitationLink/i)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -49,13 +49,13 @@ describe('Operation', () => {
|
||||
mockUseProviderContext.mockReturnValue({ datasetOperatorEnabled: false })
|
||||
})
|
||||
|
||||
it('renders the current role label', () => {
|
||||
it('should render the current role label when member has editor role', () => {
|
||||
renderOperation()
|
||||
|
||||
expect(screen.getByText('common.members.editor')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows dataset operator option when the feature flag is enabled', async () => {
|
||||
it('should show dataset operator option when feature flag is enabled', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
mockUseProviderContext.mockReturnValue({ datasetOperatorEnabled: true })
|
||||
@@ -66,7 +66,7 @@ describe('Operation', () => {
|
||||
expect(await screen.findByText('common.members.datasetOperator')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows owner-allowed role options for admin operators', async () => {
|
||||
it('should show owner-allowed role options when operator role is admin', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
renderOperation({}, 'admin')
|
||||
@@ -77,7 +77,7 @@ describe('Operation', () => {
|
||||
expect(screen.getByText('common.members.normal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('does not show role options for unsupported operators', async () => {
|
||||
it('should not show role options when operator role is unsupported', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
renderOperation({}, 'normal')
|
||||
@@ -88,7 +88,7 @@ describe('Operation', () => {
|
||||
expect(screen.getByText('common.members.removeFromTeam')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calls updateMemberRole and onOperate when selecting another role', async () => {
|
||||
it('should call updateMemberRole and onOperate when selecting another role', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
renderOperation({}, 'owner', onOperate)
|
||||
@@ -102,7 +102,24 @@ describe('Operation', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('calls deleteMemberOrCancelInvitation when removing the member', async () => {
|
||||
it('should show dataset operator option when operator is admin and feature flag is enabled', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockUseProviderContext.mockReturnValue({ datasetOperatorEnabled: true })
|
||||
renderOperation({}, 'admin')
|
||||
|
||||
await user.click(screen.getByText('common.members.editor'))
|
||||
|
||||
expect(await screen.findByText('common.members.datasetOperator')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.members.admin')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fall back to normal role label when member role is unknown', () => {
|
||||
renderOperation({ role: 'unknown_role' as Member['role'] })
|
||||
|
||||
expect(screen.getByText('common.members.normal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call deleteMemberOrCancelInvitation when removing the member', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onOperate = vi.fn()
|
||||
renderOperation({}, 'owner', onOperate)
|
||||
|
||||
@@ -13,11 +13,6 @@ vi.mock('@/context/app-context')
|
||||
vi.mock('@/service/common')
|
||||
vi.mock('@/service/use-common')
|
||||
|
||||
// Mock Modal directly to avoid transition/portal issues in tests
|
||||
vi.mock('@/app/components/base/modal', () => ({
|
||||
default: ({ children, isShow }: { children: React.ReactNode, isShow: boolean }) => isShow ? <div data-testid="mock-modal">{children}</div> : null,
|
||||
}))
|
||||
|
||||
vi.mock('./member-selector', () => ({
|
||||
default: ({ onSelect }: { onSelect: (id: string) => void }) => (
|
||||
<button onClick={() => onSelect('new-owner-id')}>Select member</button>
|
||||
@@ -40,11 +35,13 @@ describe('TransferOwnershipModal', () => {
|
||||
data: { accounts: [] },
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
|
||||
// Fix Location stubbing for reload
|
||||
// Stub globalThis.location.reload (component calls globalThis.location.reload())
|
||||
const mockReload = vi.fn()
|
||||
vi.stubGlobal('location', {
|
||||
...window.location,
|
||||
reload: mockReload,
|
||||
href: '',
|
||||
assign: vi.fn(),
|
||||
replace: vi.fn(),
|
||||
} as unknown as Location)
|
||||
})
|
||||
|
||||
@@ -105,8 +102,8 @@ describe('TransferOwnershipModal', () => {
|
||||
await waitFor(() => {
|
||||
expect(ownershipTransfer).toHaveBeenCalledWith('new-owner-id', { token: 'final-token' })
|
||||
expect(window.location.reload).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
}, { timeout: 10000 })
|
||||
}, 15000)
|
||||
|
||||
it('should handle timer countdown and resend', async () => {
|
||||
vi.useFakeTimers()
|
||||
@@ -202,6 +199,70 @@ describe('TransferOwnershipModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle sendOwnerEmail returning null data', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.mocked(sendOwnerEmail).mockResolvedValue({
|
||||
data: null,
|
||||
result: 'success',
|
||||
} as unknown as Awaited<ReturnType<typeof sendOwnerEmail>>)
|
||||
|
||||
renderModal()
|
||||
await user.click(screen.getByTestId('transfer-modal-send-code'))
|
||||
|
||||
// Should advance to verify step even with null data
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/members\.transferModal\.verifyEmail/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show fallback error prefix when sendOwnerEmail throws null', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.mocked(sendOwnerEmail).mockRejectedValue(null)
|
||||
|
||||
renderModal()
|
||||
await user.click(screen.getByTestId('transfer-modal-send-code'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: expect.stringContaining('Error sending verification code:'),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('should show fallback error prefix when verifyOwnerEmail throws null', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockEmailVerification()
|
||||
vi.mocked(verifyOwnerEmail).mockRejectedValue(null)
|
||||
|
||||
renderModal()
|
||||
await goToTransferStep(user)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: expect.stringContaining('Error verifying email:'),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('should show fallback error prefix when ownershipTransfer throws null', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockEmailVerification()
|
||||
vi.mocked(ownershipTransfer).mockRejectedValue(null)
|
||||
|
||||
renderModal()
|
||||
await goToTransferStep(user)
|
||||
await selectNewOwnerAndSubmit(user)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: expect.stringContaining('Error ownership transfer:'),
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
it('should close when close button is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderModal()
|
||||
|
||||
@@ -71,9 +71,80 @@ describe('MemberSelector', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should filter list by email when name does not match', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
|
||||
await user.click(screen.getByTestId('member-selector-trigger'))
|
||||
await user.type(screen.getByTestId('member-selector-search'), 'john@')
|
||||
|
||||
const items = screen.getAllByTestId('member-selector-item')
|
||||
expect(items).toHaveLength(1)
|
||||
expect(screen.getByText('John Doe')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Jane Smith')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show placeholder when value does not match any account', () => {
|
||||
render(<MemberSelector value="nonexistent-id" onSelect={mockOnSelect} />)
|
||||
|
||||
expect(screen.getByText(/members\.transferModal\.transferPlaceholder/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle missing data gracefully', () => {
|
||||
vi.mocked(useMembers).mockReturnValue({ data: undefined } as unknown as ReturnType<typeof useMembers>)
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
expect(screen.getByText(/members\.transferModal\.transferPlaceholder/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter by email when account name is empty', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: { accounts: [...mockAccounts, { id: '4', name: '', email: 'noname@example.com', avatar_url: '' }] },
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
|
||||
await user.click(screen.getByTestId('member-selector-trigger'))
|
||||
await user.type(screen.getByTestId('member-selector-search'), 'noname@')
|
||||
|
||||
const items = screen.getAllByTestId('member-selector-item')
|
||||
expect(items).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should apply hover background class when dropdown is open', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
|
||||
const trigger = screen.getByTestId('member-selector-trigger')
|
||||
await user.click(trigger)
|
||||
|
||||
expect(trigger).toHaveClass('bg-state-base-hover-alt')
|
||||
})
|
||||
|
||||
it('should not match account when neither name nor email contains search value', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
|
||||
await user.click(screen.getByTestId('member-selector-trigger'))
|
||||
await user.type(screen.getByTestId('member-selector-search'), 'xyz-no-match-xyz')
|
||||
|
||||
expect(screen.queryByTestId('member-selector-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fall back to empty string for account with undefined email when searching', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.mocked(useMembers).mockReturnValue({
|
||||
data: {
|
||||
accounts: [
|
||||
{ id: '1', name: 'John', email: undefined as unknown as string, avatar_url: '' },
|
||||
],
|
||||
},
|
||||
} as unknown as ReturnType<typeof useMembers>)
|
||||
render(<MemberSelector onSelect={mockOnSelect} />)
|
||||
|
||||
await user.click(screen.getByTestId('member-selector-trigger'))
|
||||
await user.type(screen.getByTestId('member-selector-search'), 'john')
|
||||
|
||||
const items = screen.getAllByTestId('member-selector-item')
|
||||
expect(items).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -433,6 +433,55 @@ describe('hooks', () => {
|
||||
|
||||
expect(result.current.credentials).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should not call invalidateQueries when neither predefined nor custom is enabled', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useQuery as Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
// Both predefinedEnabled and customEnabled are false (no credentialId)
|
||||
const { result } = renderHook(() => useProviderCredentialsAndLoadBalancing(
|
||||
'openai',
|
||||
ConfigurationMethodEnum.predefinedModel,
|
||||
false,
|
||||
undefined,
|
||||
undefined,
|
||||
))
|
||||
|
||||
act(() => {
|
||||
result.current.mutate()
|
||||
})
|
||||
|
||||
expect(invalidateQueries).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should build URL without credentialId when not provided in predefined queryFn', async () => {
|
||||
// Trigger the queryFn when credentialId is undefined but predefinedEnabled is true
|
||||
; (useQuery as Mock).mockReturnValue({
|
||||
data: { credentials: { api_key: 'k' } },
|
||||
isPending: false,
|
||||
})
|
||||
|
||||
const { result: _result } = renderHook(() => useProviderCredentialsAndLoadBalancing(
|
||||
'openai',
|
||||
ConfigurationMethodEnum.predefinedModel,
|
||||
true,
|
||||
undefined,
|
||||
undefined,
|
||||
))
|
||||
|
||||
// Find and invoke the predefined queryFn
|
||||
const queryCall = (useQuery as Mock).mock.calls.find(
|
||||
call => call[0].queryKey?.[1] === 'credentials',
|
||||
)
|
||||
if (queryCall) {
|
||||
await queryCall[0].queryFn()
|
||||
expect(fetchModelProviderCredentials).toHaveBeenCalled()
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('useModelList', () => {
|
||||
@@ -1111,6 +1160,26 @@ describe('hooks', () => {
|
||||
expect(result.current.plugins![0].plugin_id).toBe('plugin1')
|
||||
})
|
||||
|
||||
it('should deduplicate plugins that exist in both collections and regular plugins', () => {
|
||||
const duplicatePlugin = { plugin_id: 'shared-plugin', type: 'plugin' }
|
||||
|
||||
; (useMarketplacePluginsByCollectionId as Mock).mockReturnValue({
|
||||
plugins: [duplicatePlugin],
|
||||
isLoading: false,
|
||||
})
|
||||
; (useMarketplacePlugins as Mock).mockReturnValue({
|
||||
plugins: [{ ...duplicatePlugin }, { plugin_id: 'unique-plugin', type: 'plugin' }],
|
||||
queryPlugins: vi.fn(),
|
||||
queryPluginsWithDebounced: vi.fn(),
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useMarketplaceAllPlugins([], ''))
|
||||
|
||||
expect(result.current.plugins).toHaveLength(2)
|
||||
expect(result.current.plugins!.filter(p => p.plugin_id === 'shared-plugin')).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should handle loading states', () => {
|
||||
; (useMarketplacePluginsByCollectionId as Mock).mockReturnValue({
|
||||
plugins: [],
|
||||
@@ -1127,6 +1196,45 @@ describe('hooks', () => {
|
||||
|
||||
expect(result.current.isLoading).toBe(true)
|
||||
})
|
||||
|
||||
it('should not crash when plugins is undefined', () => {
|
||||
; (useMarketplacePluginsByCollectionId as Mock).mockReturnValue({
|
||||
plugins: [],
|
||||
isLoading: false,
|
||||
})
|
||||
; (useMarketplacePlugins as Mock).mockReturnValue({
|
||||
plugins: undefined,
|
||||
queryPlugins: vi.fn(),
|
||||
queryPluginsWithDebounced: vi.fn(),
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useMarketplaceAllPlugins([], ''))
|
||||
|
||||
expect(result.current.plugins).toBeDefined()
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
})
|
||||
|
||||
it('should return search plugins (not allPlugins) when searchText is truthy', () => {
|
||||
const searchPlugins = [{ plugin_id: 'search-result', type: 'plugin' }]
|
||||
const collectionPlugins = [{ plugin_id: 'collection-only', type: 'plugin' }]
|
||||
|
||||
; (useMarketplacePluginsByCollectionId as Mock).mockReturnValue({
|
||||
plugins: collectionPlugins,
|
||||
isLoading: false,
|
||||
})
|
||||
; (useMarketplacePlugins as Mock).mockReturnValue({
|
||||
plugins: searchPlugins,
|
||||
queryPlugins: vi.fn(),
|
||||
queryPluginsWithDebounced: vi.fn(),
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useMarketplaceAllPlugins([], 'openai'))
|
||||
|
||||
expect(result.current.plugins).toEqual(searchPlugins)
|
||||
expect(result.current.plugins?.some(p => p.plugin_id === 'collection-only')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useRefreshModel', () => {
|
||||
@@ -1234,6 +1342,35 @@ describe('hooks', () => {
|
||||
expect(emit).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should emit event and invalidate all supported model types when __model_type is undefined', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
|
||||
const provider = createMockProvider()
|
||||
const customFields = { __model_name: 'my-model', __model_type: undefined } as unknown as CustomConfigurationModelFixedFields
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider, customFields, true)
|
||||
})
|
||||
|
||||
expect(emit).toHaveBeenCalledWith({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: 'openai',
|
||||
})
|
||||
// When __model_type is undefined, all supported model types are invalidated
|
||||
const modelListCalls = invalidateQueries.mock.calls.filter(
|
||||
call => call[0]?.queryKey?.[0] === 'model-list',
|
||||
)
|
||||
expect(modelListCalls).toHaveLength(provider.supported_model_types.length)
|
||||
})
|
||||
|
||||
it('should handle provider with single model type', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
|
||||
|
||||
@@ -60,7 +60,15 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockDefaultModelState = {
|
||||
type MockDefaultModelData = {
|
||||
model: string
|
||||
provider?: { provider: string }
|
||||
} | null
|
||||
|
||||
const mockDefaultModelState: {
|
||||
data: MockDefaultModelData
|
||||
isLoading: boolean
|
||||
} = {
|
||||
data: null,
|
||||
isLoading: false,
|
||||
}
|
||||
@@ -196,4 +204,129 @@ describe('ModelProviderPage', () => {
|
||||
])
|
||||
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show not configured alert when all default models are absent', () => {
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when default model is loading', () => {
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = true
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter providers by label text', () => {
|
||||
render(<ModelProviderPage searchText="OpenAI" />)
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(600)
|
||||
})
|
||||
expect(screen.getByText('openai')).toBeInTheDocument()
|
||||
expect(screen.queryByText('anthropic')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should classify system-enabled providers with matching quota as configured', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'sys-provider',
|
||||
label: { en_US: 'System Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: true,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('sys-provider')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should classify system-enabled provider with no matching quota as not configured', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'sys-no-quota',
|
||||
label: { en_US: 'System No Quota' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: true,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('sys-no-quota')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.toBeConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should preserve order of two non-fixed providers (sort returns 0)', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'alpha-provider',
|
||||
label: { en_US: 'Alpha Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.active },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
}, {
|
||||
provider: 'beta-provider',
|
||||
label: { en_US: 'Beta Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.active },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
const renderedProviders = screen.getAllByTestId('provider-card').map(item => item.textContent)
|
||||
expect(renderedProviders).toEqual(['alpha-provider', 'beta-provider'])
|
||||
})
|
||||
|
||||
it('should not show not configured alert when shared default model mock has data', () => {
|
||||
mockDefaultModelState.data = { model: 'embed-model' }
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when rerankDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'rerank-model', provider: { provider: 'cohere' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when ttsDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'tts-model', provider: { provider: 'openai' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when speech2textDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'whisper', provider: { provider: 'openai' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -96,4 +96,97 @@ describe('AddCredentialInLoadBalancing', () => {
|
||||
|
||||
expect(onSelectCredential).toHaveBeenCalledWith(modelCredential.available_credentials[0])
|
||||
})
|
||||
|
||||
// renderTrigger with open=true: bg-state-base-hover style applied
|
||||
it('should apply hover background when trigger is rendered with open=true', async () => {
|
||||
vi.doMock('@/app/components/header/account-setting/model-provider-page/model-auth', () => ({
|
||||
Authorized: ({
|
||||
renderTrigger,
|
||||
}: {
|
||||
renderTrigger: (open?: boolean) => React.ReactNode
|
||||
}) => (
|
||||
<div data-testid="open-trigger">{renderTrigger(true)}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
// Must invalidate module cache so the component picks up the new mock
|
||||
vi.resetModules()
|
||||
try {
|
||||
const { default: AddCredentialLB } = await import('./add-credential-in-load-balancing')
|
||||
|
||||
const { container } = render(
|
||||
<AddCredentialLB
|
||||
provider={provider}
|
||||
model={model}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={modelCredential}
|
||||
onSelectCredential={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// The trigger div rendered by renderTrigger(true) should have bg-state-base-hover
|
||||
// (the static class applied when open=true via cn())
|
||||
const triggerDiv = container.querySelector('[data-testid="open-trigger"] > div')
|
||||
expect(triggerDiv).toBeInTheDocument()
|
||||
expect(triggerDiv!.className).toContain('bg-state-base-hover')
|
||||
}
|
||||
finally {
|
||||
vi.doUnmock('@/app/components/header/account-setting/model-provider-page/model-auth')
|
||||
vi.resetModules()
|
||||
}
|
||||
})
|
||||
|
||||
// customizableModel configuration method: component renders the add credential label
|
||||
it('should render correctly with customizableModel configuration method', () => {
|
||||
render(
|
||||
<AddCredentialInLoadBalancing
|
||||
provider={provider}
|
||||
model={model}
|
||||
configurationMethod={ConfigurationMethodEnum.customizableModel}
|
||||
modelCredential={modelCredential}
|
||||
onSelectCredential={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/modelProvider.auth.addCredential/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle undefined available_credentials gracefully using nullish coalescing', () => {
|
||||
const credentialWithNoAvailable = {
|
||||
available_credentials: undefined,
|
||||
credentials: {},
|
||||
load_balancing: { enabled: false, configs: [] },
|
||||
} as unknown as typeof modelCredential
|
||||
|
||||
render(
|
||||
<AddCredentialInLoadBalancing
|
||||
provider={provider}
|
||||
model={model}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={credentialWithNoAvailable}
|
||||
onSelectCredential={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Component should render without error - the ?? [] fallback is used
|
||||
expect(screen.getByText(/modelProvider.auth.addCredential/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not throw when update action fires without onUpdate prop', () => {
|
||||
// Arrange - no onUpdate prop
|
||||
render(
|
||||
<AddCredentialInLoadBalancing
|
||||
provider={provider}
|
||||
model={model}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={modelCredential}
|
||||
onSelectCredential={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act - trigger the update without onUpdate being set (should not throw)
|
||||
expect(() => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'Run update' }))
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -85,4 +85,69 @@ describe('CredentialItem', () => {
|
||||
|
||||
expect(onDelete).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// All disable flags true → no action buttons rendered
|
||||
it('should hide all action buttons when disableRename, disableEdit, and disableDelete are all true', () => {
|
||||
// Act
|
||||
render(
|
||||
<CredentialItem
|
||||
credential={credential}
|
||||
onEdit={vi.fn()}
|
||||
onDelete={vi.fn()}
|
||||
disableRename
|
||||
disableEdit
|
||||
disableDelete
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('edit-icon')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('delete-icon')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// disabled=true guards: clicks on the item row and on delete should both be no-ops
|
||||
it('should not call onItemClick when disabled=true and item is clicked', () => {
|
||||
const onItemClick = vi.fn()
|
||||
|
||||
render(<CredentialItem credential={credential} disabled onItemClick={onItemClick} />)
|
||||
|
||||
fireEvent.click(screen.getByText('Test API Key'))
|
||||
|
||||
expect(onItemClick).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not call onDelete when disabled=true and delete button is clicked', () => {
|
||||
const onDelete = vi.fn()
|
||||
|
||||
render(<CredentialItem credential={credential} disabled onDelete={onDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
|
||||
expect(onDelete).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// showSelectedIcon=true: check icon area is always rendered; check icon only appears when IDs match
|
||||
it('should render check icon area when showSelectedIcon=true and selectedCredentialId matches', () => {
|
||||
render(
|
||||
<CredentialItem
|
||||
credential={credential}
|
||||
showSelectedIcon
|
||||
selectedCredentialId="cred-1"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('check-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render check icon when showSelectedIcon=true but selectedCredentialId does not match', () => {
|
||||
render(
|
||||
<CredentialItem
|
||||
credential={credential}
|
||||
showSelectedIcon
|
||||
selectedCredentialId="other-cred"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('check-icon')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -24,36 +24,6 @@ vi.mock('../hooks', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
let mockPortalOpen = false
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => {
|
||||
mockPortalOpen = open
|
||||
return <div data-testid="portal" data-open={open}>{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
|
||||
if (!mockPortalOpen)
|
||||
return null
|
||||
return <div data-testid="portal-content">{children}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, onCancel, onConfirm }: { isShow: boolean, onCancel: () => void, onConfirm: () => void }) => {
|
||||
if (!isShow)
|
||||
return null
|
||||
return (
|
||||
<div data-testid="confirm-dialog">
|
||||
<button onClick={onCancel}>Cancel</button>
|
||||
<button onClick={onConfirm}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('./authorized-item', () => ({
|
||||
default: ({ credentials, model, onEdit, onDelete, onItemClick }: {
|
||||
credentials: Credential[]
|
||||
@@ -105,382 +75,127 @@ describe('Authorized', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpen = false
|
||||
mockDeleteCredentialId = null
|
||||
mockDoingAction = false
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render trigger button', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
it('should render trigger and open popup when trigger is clicked', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/Trigger/)).toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
expect(screen.getByTestId('authorized-item')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /addApiKey/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render portal content when open', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
it('should call handleOpenModal when triggerOnlyOpenModal is true', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
triggerOnlyOpenModal
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('authorized-item')).toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
expect(mockHandleOpenModal).toHaveBeenCalled()
|
||||
expect(screen.queryByTestId('authorized-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render portal content when closed', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
it('should call onItemClick when credential is selected', () => {
|
||||
const onItemClick = vi.fn()
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
onItemClick={onItemClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'Select' })[0])
|
||||
|
||||
it('should render Add API Key button when not model credential', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
expect(onItemClick).toHaveBeenCalledWith(mockCredentials[0], mockItems[0].model)
|
||||
})
|
||||
|
||||
expect(screen.getByText(/addApiKey/)).toBeInTheDocument()
|
||||
})
|
||||
it('should call handleActiveCredential when onItemClick is not provided', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
it('should render Add Model Credential button when is model credential', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
authParams={{ isModelCredential: true }}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'Select' })[0])
|
||||
|
||||
expect(screen.getByText(/addModelCredential/)).toBeInTheDocument()
|
||||
})
|
||||
expect(mockHandleActiveCredential).toHaveBeenCalledWith(mockCredentials[0], mockItems[0].model)
|
||||
})
|
||||
|
||||
it('should not render add action when hideAddAction is true', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
hideAddAction
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
it('should call handleOpenModal with fixed model fields when adding model credential', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.customizableModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
authParams={{ isModelCredential: true }}
|
||||
currentCustomConfigurationModelFixedFields={{
|
||||
__model_name: 'gpt-4',
|
||||
__model_type: ModelTypeEnum.textGeneration,
|
||||
}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText(/addApiKey/)).not.toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
fireEvent.click(screen.getByText(/addModelCredential/))
|
||||
|
||||
it('should render popup title when provided', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
popupTitle="Select Credential"
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Select Credential')).toBeInTheDocument()
|
||||
expect(mockHandleOpenModal).toHaveBeenCalledWith(undefined, {
|
||||
model: 'gpt-4',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should call onOpenChange when trigger is clicked in controlled mode', () => {
|
||||
const onOpenChange = vi.fn()
|
||||
it('should not render add action when hideAddAction is true', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
hideAddAction
|
||||
/>,
|
||||
)
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen={false}
|
||||
onOpenChange={onOpenChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
|
||||
expect(onOpenChange).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should toggle portal on trigger click', () => {
|
||||
const { rerender } = render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
|
||||
rerender(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open modal when triggerOnlyOpenModal is true', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
triggerOnlyOpenModal
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call handleOpenModal when Add API Key is clicked', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText(/addApiKey/))
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call handleOpenModal with credential and model when edit is clicked', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('Edit')[0])
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalledWith(
|
||||
mockCredentials[0],
|
||||
mockItems[0].model,
|
||||
)
|
||||
})
|
||||
|
||||
it('should pass current model fields when adding model credential', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.customizableModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
authParams={{ isModelCredential: true }}
|
||||
currentCustomConfigurationModelFixedFields={{
|
||||
__model_name: 'gpt-4',
|
||||
__model_type: ModelTypeEnum.textGeneration,
|
||||
}}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText(/addModelCredential/))
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalledWith(undefined, {
|
||||
model: 'gpt-4',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onItemClick when credential is selected', () => {
|
||||
const onItemClick = vi.fn()
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
onItemClick={onItemClick}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('Select')[0])
|
||||
|
||||
expect(onItemClick).toHaveBeenCalledWith(mockCredentials[0], mockItems[0].model)
|
||||
})
|
||||
|
||||
it('should call handleActiveCredential when onItemClick is not provided', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('Select')[0])
|
||||
|
||||
expect(mockHandleActiveCredential).toHaveBeenCalledWith(mockCredentials[0], mockItems[0].model)
|
||||
})
|
||||
|
||||
it('should not call onItemClick when disableItemClick is true', () => {
|
||||
const onItemClick = vi.fn()
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
onItemClick={onItemClick}
|
||||
disableItemClick
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getAllByText('Select')[0])
|
||||
|
||||
expect(onItemClick).not.toHaveBeenCalled()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /trigger\s*closed/i }))
|
||||
expect(screen.queryByRole('button', { name: /addApiKey/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('Delete Confirmation', () => {
|
||||
it('should show confirm dialog when deleteCredentialId is set', () => {
|
||||
mockDeleteCredentialId = 'cred-1'
|
||||
it('should show confirm dialog and call confirm handler when delete is confirmed', () => {
|
||||
mockDeleteCredentialId = 'cred-1'
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show confirm dialog when deleteCredentialId is null', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call closeConfirmDelete when cancel is clicked', () => {
|
||||
mockDeleteCredentialId = 'cred-1'
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Cancel'))
|
||||
|
||||
expect(mockCloseConfirmDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call handleConfirmDelete when confirm is clicked', () => {
|
||||
mockDeleteCredentialId = 'cred-1'
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Confirm'))
|
||||
|
||||
expect(mockHandleConfirmDelete).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty items array', () => {
|
||||
render(
|
||||
<Authorized
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={[]}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('authorized-item')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render add action when provider does not allow custom token', () => {
|
||||
const restrictedProvider = { ...mockProvider, allow_custom_token: false }
|
||||
|
||||
render(
|
||||
<Authorized
|
||||
provider={restrictedProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
items={mockItems}
|
||||
renderTrigger={mockRenderTrigger}
|
||||
isOpen
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText(/addApiKey/)).not.toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: /common.operation.confirm/i }))
|
||||
expect(mockHandleConfirmDelete).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ModelProvider } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import ConfigProvider from './config-provider'
|
||||
|
||||
const mockUseCredentialStatus = vi.fn()
|
||||
@@ -54,7 +55,8 @@ describe('ConfigProvider', () => {
|
||||
expect(screen.getByText(/operation.config/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should still render setup label when custom credentials are not allowed', () => {
|
||||
it('should show setup label and unavailable tooltip when custom credentials are not allowed and no credential exists', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockUseCredentialStatus.mockReturnValue({
|
||||
hasCredential: false,
|
||||
authorized: false,
|
||||
@@ -65,6 +67,50 @@ describe('ConfigProvider', () => {
|
||||
|
||||
render(<ConfigProvider provider={{ ...baseProvider, allow_custom_token: false }} />)
|
||||
|
||||
expect(screen.getByText(/operation.setup/i)).toBeInTheDocument()
|
||||
await user.hover(screen.getByText(/operation.setup/i))
|
||||
expect(await screen.findByText(/auth\.credentialUnavailable/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show config label when hasCredential but not authorized', () => {
|
||||
mockUseCredentialStatus.mockReturnValue({
|
||||
hasCredential: true,
|
||||
authorized: false,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'Key 1',
|
||||
available_credentials: [],
|
||||
})
|
||||
|
||||
render(<ConfigProvider provider={baseProvider} />)
|
||||
|
||||
expect(screen.getByText(/operation.config/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show config label when custom credentials are not allowed but credential exists', () => {
|
||||
mockUseCredentialStatus.mockReturnValue({
|
||||
hasCredential: true,
|
||||
authorized: true,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'Key 1',
|
||||
available_credentials: [],
|
||||
})
|
||||
|
||||
render(<ConfigProvider provider={{ ...baseProvider, allow_custom_token: false }} />)
|
||||
|
||||
expect(screen.getByText(/operation.config/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle nullish credential values with fallbacks', () => {
|
||||
mockUseCredentialStatus.mockReturnValue({
|
||||
hasCredential: false,
|
||||
authorized: false,
|
||||
current_credential_id: null,
|
||||
current_credential_name: null,
|
||||
available_credentials: null,
|
||||
})
|
||||
|
||||
render(<ConfigProvider provider={baseProvider} />)
|
||||
|
||||
expect(screen.getByText(/operation.setup/i)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import CredentialSelector from './credential-selector'
|
||||
|
||||
// Mock components
|
||||
vi.mock('./authorized/credential-item', () => ({
|
||||
default: ({ credential, onItemClick }: { credential: { credential_name: string }, onItemClick: (c: unknown) => void }) => (
|
||||
<div data-testid="credential-item" onClick={() => onItemClick(credential)}>
|
||||
default: ({ credential, onItemClick }: { credential: { credential_name: string }, onItemClick?: (c: unknown) => void }) => (
|
||||
<button type="button" onClick={() => onItemClick?.(credential)}>
|
||||
{credential.credential_name}
|
||||
</div>
|
||||
</button>
|
||||
),
|
||||
}))
|
||||
|
||||
@@ -19,22 +19,6 @@ vi.mock('@remixicon/react', () => ({
|
||||
RiArrowDownSLine: () => <div data-testid="arrow-icon" />,
|
||||
}))
|
||||
|
||||
// Mock portal components
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
|
||||
<div data-testid="portal" data-open={open}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children }: { children: React.ReactNode, open?: boolean }) => {
|
||||
// We should only render children if open or if we want to test they are hidden
|
||||
// The real component might handle this with CSS or conditional rendering.
|
||||
// Let's use conditional rendering in the mock to avoid "multiple elements" errors.
|
||||
return <div data-testid="portal-content">{children}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
describe('CredentialSelector', () => {
|
||||
const mockCredentials = [
|
||||
{ credential_id: 'cred-1', credential_name: 'Key 1' },
|
||||
@@ -46,7 +30,7 @@ describe('CredentialSelector', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render selected credential name', () => {
|
||||
it('should render selected credential name when selectedCredential is provided', () => {
|
||||
render(
|
||||
<CredentialSelector
|
||||
selectedCredential={mockCredentials[0]}
|
||||
@@ -55,12 +39,11 @@ describe('CredentialSelector', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
// Use getAllByText and take the first one (the one in the trigger)
|
||||
expect(screen.getAllByText('Key 1')[0]).toBeInTheDocument()
|
||||
expect(screen.getByText('Key 1')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('indicator')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render placeholder when no credential selected', () => {
|
||||
it('should render placeholder when selectedCredential is missing', () => {
|
||||
render(
|
||||
<CredentialSelector
|
||||
credentials={mockCredentials}
|
||||
@@ -71,7 +54,8 @@ describe('CredentialSelector', () => {
|
||||
expect(screen.getByText(/modelProvider.auth.selectModelCredential/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open portal on click', () => {
|
||||
it('should call onSelect when a credential item is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<CredentialSelector
|
||||
credentials={mockCredentials}
|
||||
@@ -79,26 +63,14 @@ describe('CredentialSelector', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
expect(screen.getByTestId('portal')).toHaveAttribute('data-open', 'true')
|
||||
expect(screen.getAllByTestId('credential-item')).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should call onSelect when a credential is clicked', () => {
|
||||
render(
|
||||
<CredentialSelector
|
||||
credentials={mockCredentials}
|
||||
onSelect={mockOnSelect}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
fireEvent.click(screen.getByText('Key 2'))
|
||||
await user.click(screen.getByText(/modelProvider.auth.selectModelCredential/))
|
||||
await user.click(screen.getByRole('button', { name: 'Key 2' }))
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(mockCredentials[1])
|
||||
})
|
||||
|
||||
it('should call onSelect with add new credential data when clicking add button', () => {
|
||||
it('should call onSelect with add-new payload when add action is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<CredentialSelector
|
||||
credentials={mockCredentials}
|
||||
@@ -106,8 +78,8 @@ describe('CredentialSelector', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
fireEvent.click(screen.getByText(/modelProvider.auth.addNewModelCredential/))
|
||||
await user.click(screen.getByText(/modelProvider.auth.selectModelCredential/))
|
||||
await user.click(screen.getByText(/modelProvider.auth.addNewModelCredential/))
|
||||
|
||||
expect(mockOnSelect).toHaveBeenCalledWith(expect.objectContaining({
|
||||
credential_id: '__add_new_credential',
|
||||
@@ -115,7 +87,8 @@ describe('CredentialSelector', () => {
|
||||
}))
|
||||
})
|
||||
|
||||
it('should not open portal when disabled', () => {
|
||||
it('should not open options when disabled is true', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<CredentialSelector
|
||||
disabled
|
||||
@@ -124,7 +97,7 @@ describe('CredentialSelector', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
expect(screen.getByTestId('portal')).toHaveAttribute('data-open', 'false')
|
||||
await user.click(screen.getByText(/modelProvider.auth.selectModelCredential/))
|
||||
expect(screen.queryByRole('button', { name: 'Key 1' })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import type {
|
||||
Credential,
|
||||
CustomModel,
|
||||
ModelProvider,
|
||||
} from '../../declarations'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { ConfigurationMethodEnum, ModelModalModeEnum, ModelTypeEnum } from '../../declarations'
|
||||
import { useAuth } from './use-auth'
|
||||
|
||||
@@ -20,9 +22,13 @@ const mockAddModelCredential = vi.fn()
|
||||
const mockEditProviderCredential = vi.fn()
|
||||
const mockEditModelCredential = vi.fn()
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({ notify: mockNotify }),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({ notify: mockNotify }),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
|
||||
useModelModalHandler: () => mockOpenModelModal,
|
||||
@@ -66,6 +72,12 @@ describe('useAuth', () => {
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
}
|
||||
|
||||
const createWrapper = ({ children }: { children: ReactNode }) => (
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
{children}
|
||||
</ToastContext.Provider>
|
||||
)
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockDeleteModelService.mockResolvedValue({ result: 'success' })
|
||||
@@ -80,7 +92,7 @@ describe('useAuth', () => {
|
||||
})
|
||||
|
||||
it('should open and close delete confirmation state', () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel))
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.openConfirmDelete(credential, model)
|
||||
@@ -100,7 +112,7 @@ describe('useAuth', () => {
|
||||
})
|
||||
|
||||
it('should activate credential, notify success, and refresh models', async () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.customizableModel))
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.customizableModel), { wrapper: createWrapper })
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleActiveCredential(credential, model)
|
||||
@@ -120,7 +132,7 @@ describe('useAuth', () => {
|
||||
})
|
||||
|
||||
it('should close delete dialog without calling services when nothing is pending', async () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel))
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleConfirmDelete()
|
||||
@@ -137,7 +149,7 @@ describe('useAuth', () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel, undefined, {
|
||||
isModelCredential: false,
|
||||
onRemove,
|
||||
}))
|
||||
}), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.openConfirmDelete(credential, model)
|
||||
@@ -161,7 +173,7 @@ describe('useAuth', () => {
|
||||
const onRemove = vi.fn()
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.customizableModel, undefined, {
|
||||
onRemove,
|
||||
}))
|
||||
}), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.openConfirmDelete(undefined, model)
|
||||
@@ -179,7 +191,7 @@ describe('useAuth', () => {
|
||||
})
|
||||
|
||||
it('should add or edit credentials and refresh on successful save', async () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel))
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSaveCredential({ api_key: 'new-key' })
|
||||
@@ -200,7 +212,7 @@ describe('useAuth', () => {
|
||||
const deferred = createDeferred<{ result: string }>()
|
||||
mockAddProviderCredential.mockReturnValueOnce(deferred.promise)
|
||||
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel))
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
let first!: Promise<void>
|
||||
let second!: Promise<void>
|
||||
@@ -226,7 +238,7 @@ describe('useAuth', () => {
|
||||
isModelCredential: true,
|
||||
onUpdate,
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
}))
|
||||
}), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.handleOpenModal(credential, model)
|
||||
@@ -244,4 +256,90 @@ describe('useAuth', () => {
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should not notify or refresh when handleSaveCredential returns non-success result', async () => {
|
||||
mockAddProviderCredential.mockResolvedValue({ result: 'error' })
|
||||
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleSaveCredential({ api_key: 'some-key' })
|
||||
})
|
||||
|
||||
expect(mockAddProviderCredential).toHaveBeenCalledWith({ api_key: 'some-key' })
|
||||
expect(mockNotify).not.toHaveBeenCalled()
|
||||
expect(mockHandleRefreshModel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should pass undefined model and model_type when handleActiveCredential is called without a model parameter', async () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleActiveCredential(credential)
|
||||
})
|
||||
|
||||
expect(mockActiveProviderCredential).toHaveBeenCalledWith({
|
||||
credential_id: 'cred-1',
|
||||
model: undefined,
|
||||
model_type: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
// openConfirmDelete with credential only (no model): deleteCredentialId set, deleteModel stays null
|
||||
it('should only set deleteCredentialId when openConfirmDelete is called without a model', () => {
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.openConfirmDelete(credential, undefined)
|
||||
})
|
||||
|
||||
expect(result.current.deleteCredentialId).toBe('cred-1')
|
||||
expect(result.current.deleteModel).toBeNull()
|
||||
expect(result.current.pendingOperationCredentialId.current).toBe('cred-1')
|
||||
expect(result.current.pendingOperationModel.current).toBeNull()
|
||||
})
|
||||
|
||||
// doingActionRef guard: second handleConfirmDelete call while first is in progress is a no-op
|
||||
it('should ignore a second handleConfirmDelete call while the first is still in progress', async () => {
|
||||
const deferred = createDeferred<{ result: string }>()
|
||||
mockDeleteProviderCredential.mockReturnValueOnce(deferred.promise)
|
||||
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
act(() => {
|
||||
result.current.openConfirmDelete(credential, model)
|
||||
})
|
||||
|
||||
let first!: Promise<void>
|
||||
let second!: Promise<void>
|
||||
|
||||
await act(async () => {
|
||||
first = result.current.handleConfirmDelete()
|
||||
second = result.current.handleConfirmDelete()
|
||||
deferred.resolve({ result: 'success' })
|
||||
await Promise.all([first, second])
|
||||
})
|
||||
|
||||
expect(mockDeleteProviderCredential).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// doingActionRef guard: second handleActiveCredential call while first is in progress is a no-op
|
||||
it('should ignore a second handleActiveCredential call while the first is still in progress', async () => {
|
||||
const deferred = createDeferred<{ result: string }>()
|
||||
mockActiveProviderCredential.mockReturnValueOnce(deferred.promise)
|
||||
|
||||
const { result } = renderHook(() => useAuth(provider, ConfigurationMethodEnum.predefinedModel), { wrapper: createWrapper })
|
||||
|
||||
let first!: Promise<void>
|
||||
let second!: Promise<void>
|
||||
|
||||
await act(async () => {
|
||||
first = result.current.handleActiveCredential(credential)
|
||||
second = result.current.handleActiveCredential(credential)
|
||||
deferred.resolve({ result: 'success' })
|
||||
await Promise.all([first, second])
|
||||
})
|
||||
|
||||
expect(mockActiveProviderCredential).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -13,11 +13,13 @@ vi.mock('./hooks', () => ({
|
||||
|
||||
// Mock Authorized
|
||||
vi.mock('./authorized', () => ({
|
||||
default: ({ renderTrigger, items, popupTitle }: { renderTrigger: (o?: boolean) => React.ReactNode, items: { length: number }, popupTitle: string }) => (
|
||||
default: ({ renderTrigger, items, popupTitle }: { renderTrigger: (o?: boolean) => React.ReactNode, items: Array<{ selectedCredential?: unknown }>, popupTitle: string }) => (
|
||||
<div data-testid="authorized-mock">
|
||||
<div data-testid="trigger-container">{renderTrigger()}</div>
|
||||
<div data-testid="trigger-closed">{renderTrigger()}</div>
|
||||
<div data-testid="trigger-open">{renderTrigger(true)}</div>
|
||||
<div data-testid="popup-title">{popupTitle}</div>
|
||||
<div data-testid="items-count">{items.length}</div>
|
||||
<div data-testid="items-selected">{items.map((it, i) => <span key={i} data-testid={`selected-${i}`}>{it.selectedCredential ? 'has-cred' : 'no-cred'}</span>)}</div>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
@@ -55,8 +57,41 @@ describe('ManageCustomModelCredentials', () => {
|
||||
render(<ManageCustomModelCredentials provider={mockProvider} />)
|
||||
|
||||
expect(screen.getByTestId('authorized-mock')).toBeInTheDocument()
|
||||
expect(screen.getByText(/modelProvider.auth.manageCredentials/)).toBeInTheDocument()
|
||||
expect(screen.getAllByText(/modelProvider.auth.manageCredentials/).length).toBeGreaterThan(0)
|
||||
expect(screen.getByTestId('items-count')).toHaveTextContent('2')
|
||||
expect(screen.getByTestId('popup-title')).toHaveTextContent('modelProvider.auth.customModelCredentials')
|
||||
})
|
||||
|
||||
it('should render trigger in both open and closed states', () => {
|
||||
const mockModels = [
|
||||
{
|
||||
model: 'gpt-4',
|
||||
available_model_credentials: [{ credential_id: 'c1', credential_name: 'Key 1' }],
|
||||
current_credential_id: 'c1',
|
||||
current_credential_name: 'Key 1',
|
||||
},
|
||||
]
|
||||
mockUseCustomModels.mockReturnValue(mockModels)
|
||||
|
||||
render(<ManageCustomModelCredentials provider={mockProvider} />)
|
||||
|
||||
expect(screen.getByTestId('trigger-closed')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('trigger-open')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass undefined selectedCredential when model has no current_credential_id', () => {
|
||||
const mockModels = [
|
||||
{
|
||||
model: 'gpt-3.5',
|
||||
available_model_credentials: [{ credential_id: 'c1', credential_name: 'Key 1' }],
|
||||
current_credential_id: '',
|
||||
current_credential_name: '',
|
||||
},
|
||||
]
|
||||
mockUseCustomModels.mockReturnValue(mockModels)
|
||||
|
||||
render(<ManageCustomModelCredentials provider={mockProvider} />)
|
||||
|
||||
expect(screen.getByTestId('selected-0')).toHaveTextContent('no-cred')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,15 +18,6 @@ vi.mock('@/app/components/header/indicator', () => ({
|
||||
default: ({ color }: { color: string }) => <div data-testid={`indicator-${color}`} />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip-mock">
|
||||
{children}
|
||||
<div>{popupContent}</div>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiArrowDownSLine: () => <div data-testid="arrow-icon" />,
|
||||
}))
|
||||
@@ -125,6 +116,131 @@ describe('SwitchCredentialInLoadBalancing', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.mouseEnter(screen.getByText(/auth.credentialUnavailableInButton/))
|
||||
expect(screen.getByText('plugin.auth.credentialUnavailable')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Empty credentials with allowed custom: no tooltip but still shows unavailable text
|
||||
it('should show unavailable status without tooltip when custom credentials are allowed', () => {
|
||||
// Act
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
model={mockModel}
|
||||
credentials={[]}
|
||||
customModelCredential={undefined}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(/auth.credentialUnavailableInButton/)).toBeInTheDocument()
|
||||
expect(screen.queryByText('plugin.auth.credentialUnavailable')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// not_allowed_to_use=true: indicator is red and destructive button text is shown
|
||||
it('should show red indicator and unavailable button text when credential has not_allowed_to_use=true', () => {
|
||||
const unavailableCredential = { credential_id: 'cred-1', credential_name: 'Key 1', not_allowed_to_use: true }
|
||||
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
model={mockModel}
|
||||
credentials={[unavailableCredential]}
|
||||
customModelCredential={unavailableCredential}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('indicator-red')).toBeInTheDocument()
|
||||
expect(screen.getByText(/auth.credentialUnavailableInButton/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// from_enterprise=true on the selected credential: Enterprise badge appears in the trigger
|
||||
it('should show Enterprise badge when selected credential has from_enterprise=true', () => {
|
||||
const enterpriseCredential = { credential_id: 'cred-1', credential_name: 'Enterprise Key', from_enterprise: true }
|
||||
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
model={mockModel}
|
||||
credentials={[enterpriseCredential]}
|
||||
customModelCredential={enterpriseCredential}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Enterprise')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// non-empty credentials with allow_custom_token=false: no tooltip (tooltip only for empty+notAllowCustom)
|
||||
it('should not show unavailable tooltip when credentials are non-empty and allow_custom_token=false', () => {
|
||||
const restrictedProvider = { ...mockProvider, allow_custom_token: false }
|
||||
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={restrictedProvider}
|
||||
model={mockModel}
|
||||
credentials={mockCredentials}
|
||||
customModelCredential={mockCredentials[0]}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.mouseEnter(screen.getByText('Key 1'))
|
||||
expect(screen.queryByText('plugin.auth.credentialUnavailable')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Key 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass undefined currentCustomConfigurationModelFixedFields when model is undefined', () => {
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
// @ts-expect-error testing runtime handling when model is omitted
|
||||
model={undefined}
|
||||
credentials={mockCredentials}
|
||||
customModelCredential={mockCredentials[0]}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Component still renders (Authorized receives undefined currentCustomConfigurationModelFixedFields)
|
||||
expect(screen.getByTestId('authorized-mock')).toBeInTheDocument()
|
||||
expect(screen.getByText('Key 1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should treat undefined credentials as empty list', () => {
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
model={mockModel}
|
||||
credentials={undefined}
|
||||
customModelCredential={undefined}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// credentials is undefined → empty=true → unavailable text shown
|
||||
expect(screen.getByText(/auth.credentialUnavailableInButton/)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId(/indicator-/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render nothing for credential_name when it is empty string', () => {
|
||||
const credWithEmptyName = { credential_id: 'cred-1', credential_name: '' }
|
||||
|
||||
render(
|
||||
<SwitchCredentialInLoadBalancing
|
||||
provider={mockProvider}
|
||||
model={mockModel}
|
||||
credentials={[credWithEmptyName]}
|
||||
customModelCredential={credWithEmptyName}
|
||||
setCustomModelCredential={mockSetCustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// indicator-green shown (not authRemoved, not unavailable, not empty)
|
||||
expect(screen.getByTestId('indicator-green')).toBeInTheDocument()
|
||||
// credential_name is empty so nothing printed for name
|
||||
expect(screen.queryByText('Key 1')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -24,10 +24,6 @@ vi.mock('../hooks', () => ({
|
||||
useLanguage: () => mockLanguage,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/icons/src/public/llm', () => ({
|
||||
OpenaiYellow: () => <svg data-testid="openai-yellow-icon" />,
|
||||
}))
|
||||
|
||||
const createI18nText = (value: string): I18nText => ({
|
||||
en_US: value,
|
||||
zh_Hans: value,
|
||||
@@ -92,10 +88,10 @@ describe('ModelIcon', () => {
|
||||
icon_small: createI18nText('openai.png'),
|
||||
})
|
||||
|
||||
render(<ModelIcon provider={provider} modelName="o1" />)
|
||||
const { container } = render(<ModelIcon provider={provider} modelName="o1" />)
|
||||
|
||||
expect(screen.queryByRole('img', { name: /model-icon/i })).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('openai-yellow-icon')).toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Edge case
|
||||
@@ -105,4 +101,25 @@ describe('ModelIcon', () => {
|
||||
expect(screen.queryByRole('img', { name: /model-icon/i })).not.toBeInTheDocument()
|
||||
expect(container.firstChild).not.toBeNull()
|
||||
})
|
||||
|
||||
it('should render OpenAI Yellow icon for langgenius/openai/openai provider with model starting with o', () => {
|
||||
const provider = createModel({
|
||||
provider: 'langgenius/openai/openai',
|
||||
icon_small: createI18nText('openai.png'),
|
||||
})
|
||||
|
||||
const { container } = render(<ModelIcon provider={provider} modelName="o3" />)
|
||||
|
||||
expect(screen.queryByRole('img', { name: /model-icon/i })).not.toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply opacity-50 when isDeprecated is true', () => {
|
||||
const provider = createModel()
|
||||
|
||||
const { container } = render(<ModelIcon provider={provider} isDeprecated={true} />)
|
||||
|
||||
const wrapper = container.querySelector('.opacity-50')
|
||||
expect(wrapper).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -161,7 +161,7 @@ function Form<
|
||||
const disabled = readonly || (isEditMode && (variable === '__model_type' || variable === '__model_name'))
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
@@ -204,13 +204,14 @@ function Form<
|
||||
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
)}
|
||||
{tooltipContent}
|
||||
</div>
|
||||
{/* eslint-disable-next-line tailwindcss/no-unknown-classes */}
|
||||
<div className={cn('grid gap-3', `grid-cols-${options?.length}`)}>
|
||||
{options.filter((option) => {
|
||||
if (option.show_on.length)
|
||||
@@ -229,7 +230,7 @@ function Form<
|
||||
>
|
||||
<RadioE isChecked={value[variable] === option.value} />
|
||||
|
||||
<div className="system-sm-regular text-text-secondary">{option.label[language] || option.label.en_US}</div>
|
||||
<div className="text-text-secondary system-sm-regular">{option.label[language] || option.label.en_US}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -254,7 +255,7 @@ function Form<
|
||||
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
|
||||
{required && (
|
||||
@@ -295,9 +296,9 @@ function Form<
|
||||
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className="system-sm-semibold flex items-center justify-between py-2 text-text-secondary">
|
||||
<div className="flex items-center justify-between py-2 text-text-secondary system-sm-semibold">
|
||||
<div className="flex items-center space-x-2">
|
||||
<span className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>{label[language] || label.en_US}</span>
|
||||
<span className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>{label[language] || label.en_US}</span>
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
)}
|
||||
@@ -326,7 +327,7 @@ function Form<
|
||||
} = formSchema as (CredentialFormSchemaTextInput | CredentialFormSchemaSecretInput)
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
@@ -358,7 +359,7 @@ function Form<
|
||||
} = formSchema as (CredentialFormSchemaTextInput | CredentialFormSchemaSecretInput)
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
@@ -422,7 +423,7 @@ function Form<
|
||||
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
@@ -451,7 +452,7 @@ function Form<
|
||||
|
||||
return (
|
||||
<div key={variable} className={cn(itemClassName, 'py-3')}>
|
||||
<div className={cn(fieldLabelClassName, 'system-sm-semibold flex items-center py-2 text-text-secondary')}>
|
||||
<div className={cn(fieldLabelClassName, 'flex items-center py-2 text-text-secondary system-sm-semibold')}>
|
||||
{label[language] || label.en_US}
|
||||
{required && (
|
||||
<span className="ml-1 text-red-500">*</span>
|
||||
|
||||
@@ -93,4 +93,88 @@ describe('Input', () => {
|
||||
expect(onChange).not.toHaveBeenCalledWith('2')
|
||||
expect(onChange).not.toHaveBeenCalledWith('6')
|
||||
})
|
||||
|
||||
it('should not clamp when min and max are not provided', () => {
|
||||
const onChange = vi.fn()
|
||||
|
||||
render(
|
||||
<Input
|
||||
placeholder="Free"
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const input = screen.getByPlaceholderText('Free')
|
||||
fireEvent.change(input, { target: { value: '999' } })
|
||||
fireEvent.blur(input)
|
||||
|
||||
// onChange only called from change event, not from blur clamping
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
expect(onChange).toHaveBeenCalledWith('999')
|
||||
})
|
||||
|
||||
it('should show check circle icon when validated is true', () => {
|
||||
const { container } = render(
|
||||
<Input
|
||||
placeholder="Key"
|
||||
onChange={vi.fn()}
|
||||
validated
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByPlaceholderText('Key')).toBeInTheDocument()
|
||||
expect(container.querySelector('.absolute.right-2\\.5.top-2\\.5')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show check circle icon when validated is false', () => {
|
||||
const { container } = render(
|
||||
<Input
|
||||
placeholder="Key"
|
||||
onChange={vi.fn()}
|
||||
validated={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByPlaceholderText('Key')).toBeInTheDocument()
|
||||
expect(container.querySelector('.absolute.right-2\\.5.top-2\\.5')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply disabled attribute when disabled prop is true', () => {
|
||||
render(
|
||||
<Input
|
||||
placeholder="Disabled"
|
||||
onChange={vi.fn()}
|
||||
disabled
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByPlaceholderText('Disabled')).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should call onFocus when input receives focus', () => {
|
||||
const onFocus = vi.fn()
|
||||
|
||||
render(
|
||||
<Input
|
||||
placeholder="Focus"
|
||||
onChange={vi.fn()}
|
||||
onFocus={onFocus}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.focus(screen.getByPlaceholderText('Focus'))
|
||||
expect(onFocus).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render with custom className', () => {
|
||||
render(
|
||||
<Input
|
||||
placeholder="Styled"
|
||||
onChange={vi.fn()}
|
||||
className="custom-class"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByPlaceholderText('Styled')).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import type { Credential, CredentialFormSchema, ModelProvider } from '../declarations'
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { Credential, CredentialFormSchema, CustomModel, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CurrentSystemQuotaTypeEnum,
|
||||
@@ -43,15 +45,6 @@ const mockHandlers = vi.hoisted(() => ({
|
||||
handleActiveCredential: vi.fn(),
|
||||
}))
|
||||
|
||||
type FormResponse = {
|
||||
isCheckValidated: boolean
|
||||
values: Record<string, unknown>
|
||||
}
|
||||
const mockFormState = vi.hoisted(() => ({
|
||||
responses: [] as FormResponse[],
|
||||
setFieldValue: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../model-auth/hooks', () => ({
|
||||
useCredentialData: () => ({
|
||||
isLoading: mockState.isLoading,
|
||||
@@ -86,36 +79,6 @@ vi.mock('../hooks', () => ({
|
||||
useLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/form/form-scenarios/auth', async () => {
|
||||
const React = await import('react')
|
||||
const AuthForm = React.forwardRef(({
|
||||
onChange,
|
||||
}: {
|
||||
onChange?: (field: string, value: string) => void
|
||||
}, ref: React.ForwardedRef<{ getFormValues: () => FormResponse, getForm: () => { setFieldValue: (field: string, value: string) => void } }>) => {
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
getFormValues: () => mockFormState.responses.shift() || { isCheckValidated: false, values: {} },
|
||||
getForm: () => ({ setFieldValue: mockFormState.setFieldValue }),
|
||||
}))
|
||||
return (
|
||||
<div>
|
||||
<button type="button" onClick={() => onChange?.('__model_name', 'updated-model')}>Model Name Change</button>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
return { default: AuthForm }
|
||||
})
|
||||
|
||||
vi.mock('../model-auth', () => ({
|
||||
CredentialSelector: ({ onSelect }: { onSelect: (credential: Credential & { addNewCredential?: boolean }) => void }) => (
|
||||
<div>
|
||||
<button type="button" onClick={() => onSelect({ credential_id: 'existing' })}>Choose Existing</button>
|
||||
<button type="button" onClick={() => onSelect({ credential_id: 'new', addNewCredential: true })}>Add New</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createI18n = (text: string) => ({ en_US: text, zh_Hans: text })
|
||||
|
||||
const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
|
||||
@@ -158,7 +121,7 @@ const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderModal = (overrides?: Partial<React.ComponentProps<typeof ModelModal>>) => {
|
||||
const renderModal = (overrides?: Partial<ComponentProps<typeof ModelModal>>) => {
|
||||
const provider = createProvider()
|
||||
const props = {
|
||||
provider,
|
||||
@@ -168,13 +131,50 @@ const renderModal = (overrides?: Partial<React.ComponentProps<typeof ModelModal>
|
||||
onRemove: vi.fn(),
|
||||
...overrides,
|
||||
}
|
||||
const view = render(<ModelModal {...props} />)
|
||||
return {
|
||||
...props,
|
||||
unmount: view.unmount,
|
||||
}
|
||||
render(<ModelModal {...props} />)
|
||||
return props
|
||||
}
|
||||
|
||||
const mockFormRef1 = {
|
||||
getFormValues: vi.fn(),
|
||||
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
|
||||
}
|
||||
|
||||
const mockFormRef2 = {
|
||||
getFormValues: vi.fn(),
|
||||
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({
|
||||
default: React.forwardRef((props: { formSchemas: Record<string, unknown>[], onChange?: (f: string, v: string) => void }, ref: React.ForwardedRef<unknown>) => {
|
||||
React.useImperativeHandle(ref, () => {
|
||||
// Return the mock depending on schemas passed (hacky but works for refs)
|
||||
if (props.formSchemas.length > 0 && props.formSchemas[0].name === '__model_name')
|
||||
return mockFormRef1
|
||||
return mockFormRef2
|
||||
})
|
||||
return (
|
||||
<div data-testid="auth-form" onClick={() => props.onChange?.('test-field', 'val')}>
|
||||
AuthForm Mock (
|
||||
{props.formSchemas.length}
|
||||
{' '}
|
||||
fields)
|
||||
</div>
|
||||
)
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../model-auth', () => ({
|
||||
CredentialSelector: ({ onSelect }: { onSelect: (val: unknown) => void }) => (
|
||||
<button onClick={() => onSelect({ addNewCredential: true })} data-testid="credential-selector">
|
||||
Select Credential
|
||||
</button>
|
||||
),
|
||||
useAuth: vi.fn(),
|
||||
useCredentialData: vi.fn(),
|
||||
useModelFormSchemas: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('ModelModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -187,167 +187,131 @@ describe('ModelModal', () => {
|
||||
mockState.formValues = {}
|
||||
mockState.modelNameAndTypeFormSchemas = []
|
||||
mockState.modelNameAndTypeFormValues = {}
|
||||
mockFormState.responses = []
|
||||
|
||||
// reset form refs
|
||||
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __model_name: 'test', __model_type: ModelTypeEnum.textGeneration } })
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'test_auth', api_key: 'sk-test' } })
|
||||
})
|
||||
|
||||
it('should show title, description, and loading state for predefined models', () => {
|
||||
it('should render title and loading state for predefined credential modal', () => {
|
||||
mockState.isLoading = true
|
||||
|
||||
const predefined = renderModal()
|
||||
|
||||
renderModal()
|
||||
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.title')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.desc')).toBeInTheDocument()
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeDisabled()
|
||||
})
|
||||
|
||||
predefined.unmount()
|
||||
const customizable = renderModal({ configurateMethod: ConfigurationMethodEnum.customizableModel })
|
||||
expect(screen.queryByText('common.modelProvider.auth.apiKeyModal.desc')).not.toBeInTheDocument()
|
||||
customizable.unmount()
|
||||
|
||||
mockState.credentialData = { credentials: {}, available_credentials: [] }
|
||||
renderModal({ mode: ModelModalModeEnum.configModelCredential, model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration } })
|
||||
it('should render model credential title when mode is configModelCredential', () => {
|
||||
renderModal({
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration },
|
||||
})
|
||||
expect(screen.getByText('common.modelProvider.auth.addModelCredential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should reveal the credential label when adding a new credential', () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList })
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.auth.modelCredential')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('Add New'))
|
||||
|
||||
expect(screen.getByText('common.modelProvider.auth.modelCredential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onCancel when the cancel button is clicked', () => {
|
||||
const { onCancel } = renderModal()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should call onCancel when the escape key is pressed', () => {
|
||||
const { onCancel } = renderModal()
|
||||
|
||||
fireEvent.keyDown(document, { key: 'Escape' })
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should confirm deletion when a delete dialog is shown', () => {
|
||||
mockState.credentialData = { credentials: { api_key: 'secret' }, available_credentials: [] }
|
||||
mockState.deleteCredentialId = 'delete-id'
|
||||
|
||||
const credential: Credential = { credential_id: 'cred-1' }
|
||||
const { onCancel } = renderModal({ credential })
|
||||
|
||||
expect(screen.getByText('common.modelProvider.confirmDelete')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalledTimes(1)
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle save flows for different modal modes', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text-input' } as unknown as CredentialFormSchema]
|
||||
mockState.formSchemas = [{ variable: 'api_key', type: 'secret-input' } as unknown as CredentialFormSchema]
|
||||
mockFormState.responses = [
|
||||
{ isCheckValidated: true, values: { __model_name: 'custom-model', __model_type: ModelTypeEnum.textGeneration } },
|
||||
{ isCheckValidated: true, values: { __authorization_name__: 'Auth Name', api_key: 'secret' } },
|
||||
]
|
||||
const configCustomModel = renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getAllByText('Model Name Change')[0])
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
expect(mockFormState.setFieldValue).toHaveBeenCalledWith('__model_name', 'updated-model')
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'secret' },
|
||||
name: 'Auth Name',
|
||||
model: 'custom-model',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
expect(configCustomModel.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Auth Name', api_key: 'secret' })
|
||||
configCustomModel.unmount()
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Model Auth', api_key: 'abc' } }]
|
||||
const model = { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration }
|
||||
const configModelCredential = renderModal({
|
||||
it('should render edit credential title when credential exists', () => {
|
||||
renderModal({
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
model,
|
||||
credential: { credential_id: 'cred-123' },
|
||||
credential: { credential_id: '1' } as unknown as Credential,
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: 'cred-123',
|
||||
credentials: { api_key: 'abc' },
|
||||
name: 'Model Auth',
|
||||
model: 'gpt-4',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
expect(configModelCredential.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Model Auth', api_key: 'abc' })
|
||||
configModelCredential.unmount()
|
||||
expect(screen.getByText('common.modelProvider.auth.editModelCredential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should change title to Add Model when mode is configCustomModel', () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
expect(screen.getByText('common.modelProvider.auth.addModel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should validate and fail save if form is invalid in configCustomModel mode', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
|
||||
renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate and save new credential and model in configCustomModel mode', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
const props = renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Provider Auth', api_key: 'provider-key' } }]
|
||||
const configProviderCredential = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'provider-key' },
|
||||
name: 'Provider Auth',
|
||||
credentials: { api_key: 'sk-test' },
|
||||
name: 'test_auth',
|
||||
model: 'test',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
expect(props.onSave).toHaveBeenCalled()
|
||||
})
|
||||
configProviderCredential.unmount()
|
||||
})
|
||||
|
||||
const addToModelList = renderModal({
|
||||
mode: ModelModalModeEnum.addCustomModelToModelList,
|
||||
model,
|
||||
})
|
||||
fireEvent.click(screen.getByText('Choose Existing'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleActiveCredential).toHaveBeenCalledWith({ credential_id: 'existing' }, model)
|
||||
expect(addToModelList.onCancel).toHaveBeenCalled()
|
||||
addToModelList.unmount()
|
||||
it('should save credential only in standard configProviderCredential mode', async () => {
|
||||
const { onSave } = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'New Auth', api_key: 'new-key' } }]
|
||||
const addToModelListWithNew = renderModal({
|
||||
mode: ModelModalModeEnum.addCustomModelToModelList,
|
||||
model,
|
||||
})
|
||||
fireEvent.click(screen.getByText('Add New'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'new-key' },
|
||||
name: 'New Auth',
|
||||
model: 'gpt-4',
|
||||
credentials: { api_key: 'sk-test' },
|
||||
name: 'test_auth',
|
||||
})
|
||||
expect(onSave).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should save active credential and cancel when picking existing credential in addCustomModelToModelList mode', async () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm1', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
|
||||
// By default selected is undefined so button clicks form
|
||||
// Let's not click credential selector, so it evaluates without it. If selectedCredential is undefined, form validation is checked.
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should save active credential when picking existing credential in addCustomModelToModelList mode', async () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm2', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
|
||||
|
||||
// Select existing credential (addNewCredential: true simulates new but we can simulate false if we just hack the mocked state in the component, but it's internal.
|
||||
// The credential selector sets selectedCredential.
|
||||
fireEvent.click(screen.getByTestId('credential-selector')) // Sets addNewCredential = true internally, so it proceeds to form save
|
||||
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'auth', api: 'key' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api: 'key' },
|
||||
name: 'auth',
|
||||
model: 'm2',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
addToModelListWithNew.unmount()
|
||||
})
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: false, values: {} }]
|
||||
const invalidSave = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
invalidSave.unmount()
|
||||
it('should open and confirm deletion of credential', () => {
|
||||
mockState.credentialData = { credentials: { api_key: '123' }, available_credentials: [] }
|
||||
mockState.formValues = { api_key: '123' } // To trigger isEditMode = true
|
||||
const credential = { credential_id: 'c1' } as unknown as Credential
|
||||
renderModal({ credential })
|
||||
|
||||
mockState.credentialData = { credentials: { api_key: 'value' }, available_credentials: [] }
|
||||
mockState.formValues = { api_key: 'value' }
|
||||
const removable = renderModal({ credential: { credential_id: 'remove-1' } })
|
||||
// Open Delete Confirm
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.remove' }))
|
||||
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith({ credential_id: 'remove-1' }, undefined)
|
||||
removable.unmount()
|
||||
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith(credential, undefined)
|
||||
|
||||
// Simulate the dialog appearing and confirming
|
||||
mockState.deleteCredentialId = 'c1'
|
||||
renderModal({ credential }) // Re-render logic mock
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.confirm' })[0])
|
||||
|
||||
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should bind escape key to cancel', () => {
|
||||
const props = renderModal()
|
||||
fireEvent.keyDown(document, { key: 'Escape' })
|
||||
expect(props.onCancel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import ModelParameterModal from './index'
|
||||
|
||||
let isAPIKeySet = true
|
||||
let parameterRules = [
|
||||
let parameterRules: Array<Record<string, unknown>> | undefined = [
|
||||
{
|
||||
name: 'temperature',
|
||||
label: { en_US: 'Temperature' },
|
||||
@@ -62,42 +61,17 @@ vi.mock('../hooks', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock PortalToFollowElem components to control visibility and simplify testing
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => {
|
||||
return {
|
||||
PortalToFollowElem: ({ children }: { children: React.ReactNode }) => {
|
||||
return (
|
||||
<div>
|
||||
<div data-testid="portal-wrapper">
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className: string }) => (
|
||||
<div data-testid="portal-content" className={className}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('./parameter-item', () => ({
|
||||
default: ({ parameterRule, value, onChange, onSwitch }: { parameterRule: { name: string, label: { en_US: string } }, value: string | number, onChange: (v: number) => void, onSwitch: (checked: boolean, val: unknown) => void }) => (
|
||||
default: ({ parameterRule, onChange, onSwitch }: {
|
||||
parameterRule: { name: string, label: { en_US: string } }
|
||||
onChange: (v: number) => void
|
||||
onSwitch: (checked: boolean, val: unknown) => void
|
||||
}) => (
|
||||
<div data-testid={`param-${parameterRule.name}`}>
|
||||
{parameterRule.label.en_US}
|
||||
<input
|
||||
aria-label={parameterRule.name}
|
||||
value={value || ''}
|
||||
onChange={e => onChange(Number(e.target.value))}
|
||||
/>
|
||||
<button onClick={() => onSwitch?.(false, undefined)}>Remove</button>
|
||||
<button onClick={() => onSwitch?.(true, 'assigned')}>Add</button>
|
||||
<button onClick={() => onChange(0.9)}>Change</button>
|
||||
<button onClick={() => onSwitch(false, undefined)}>Remove</button>
|
||||
<button onClick={() => onSwitch(true, 'assigned')}>Add</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
@@ -105,7 +79,6 @@ vi.mock('./parameter-item', () => ({
|
||||
vi.mock('../model-selector', () => ({
|
||||
default: ({ onSelect }: { onSelect: (value: { provider: string, model: string }) => void }) => (
|
||||
<div data-testid="model-selector">
|
||||
Model Selector
|
||||
<button onClick={() => onSelect({ provider: 'openai', model: 'gpt-4.1' })}>Select GPT-4.1</button>
|
||||
</div>
|
||||
),
|
||||
@@ -121,16 +94,11 @@ vi.mock('./trigger', () => ({
|
||||
default: () => <button>Open Settings</button>,
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/classnames', () => ({
|
||||
cn: (...args: (string | undefined | null | false)[]) => args.filter(Boolean).join(' '),
|
||||
}))
|
||||
|
||||
// Mock config
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/config')>()
|
||||
return {
|
||||
...actual,
|
||||
PROVIDER_WITH_PRESET_TONE: ['openai'], // ensure presets mock renders
|
||||
PROVIDER_WITH_PRESET_TONE: ['openai'],
|
||||
}
|
||||
})
|
||||
|
||||
@@ -188,21 +156,19 @@ describe('ModelParameterModal', () => {
|
||||
]
|
||||
})
|
||||
|
||||
it('should render trigger and content', () => {
|
||||
it('should render trigger and open modal content when trigger is clicked', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('Open Settings')).toBeInTheDocument()
|
||||
expect(screen.getByText('Temperature')).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
expect(screen.getByTestId('param-temperature')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should update params when changed and handle switch add/remove', () => {
|
||||
it('should call onCompletionParamsChange when parameter changes and switch actions happen', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
|
||||
const input = screen.getByLabelText('temperature')
|
||||
fireEvent.change(input, { target: { value: '0.9' } })
|
||||
|
||||
fireEvent.click(screen.getByText('Change'))
|
||||
expect(defaultProps.onCompletionParamsChange).toHaveBeenCalledWith({
|
||||
...defaultProps.completionParams,
|
||||
temperature: 0.9,
|
||||
@@ -218,51 +184,18 @@ describe('ModelParameterModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle preset selection', () => {
|
||||
it('should call onCompletionParamsChange when preset is selected', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
fireEvent.click(screen.getByText('Preset 1'))
|
||||
expect(defaultProps.onCompletionParamsChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle debug mode toggle', () => {
|
||||
const { rerender } = render(<ModelParameterModal {...defaultProps} />)
|
||||
const toggle = screen.getByText(/debugAsMultipleModel/i)
|
||||
fireEvent.click(toggle)
|
||||
expect(defaultProps.onDebugWithMultipleModelChange).toHaveBeenCalled()
|
||||
|
||||
rerender(<ModelParameterModal {...defaultProps} debugWithMultipleModel />)
|
||||
expect(screen.getByText(/debugAsSingleModel/i)).toBeInTheDocument()
|
||||
})
|
||||
it('should handle custom renderTrigger', () => {
|
||||
const renderTrigger = vi.fn().mockReturnValue(<div>Custom Trigger</div>)
|
||||
render(<ModelParameterModal {...defaultProps} renderTrigger={renderTrigger} readonly />)
|
||||
|
||||
expect(screen.getByText('Custom Trigger')).toBeInTheDocument()
|
||||
expect(renderTrigger).toHaveBeenCalled()
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
expect(renderTrigger).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle model selection and advanced mode parameters', () => {
|
||||
parameterRules = [
|
||||
{
|
||||
name: 'temperature',
|
||||
label: { en_US: 'Temperature' },
|
||||
type: 'float',
|
||||
default: 0.7,
|
||||
min: 0,
|
||||
max: 1,
|
||||
help: { en_US: 'Control randomness' },
|
||||
},
|
||||
]
|
||||
const { rerender } = render(<ModelParameterModal {...defaultProps} />)
|
||||
expect(screen.getByTestId('param-temperature')).toBeInTheDocument()
|
||||
|
||||
rerender(<ModelParameterModal {...defaultProps} isAdvancedMode />)
|
||||
expect(screen.getByTestId('param-stop')).toBeInTheDocument()
|
||||
|
||||
it('should call setModel when model selector picks another model', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
fireEvent.click(screen.getByText('Select GPT-4.1'))
|
||||
|
||||
expect(defaultProps.setModel).toHaveBeenCalledWith({
|
||||
modelId: 'gpt-4.1',
|
||||
provider: 'openai',
|
||||
@@ -270,4 +203,32 @@ describe('ModelParameterModal', () => {
|
||||
features: ['vision', 'tool-call'],
|
||||
})
|
||||
})
|
||||
|
||||
it('should toggle debug mode when debug footer is clicked', () => {
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
fireEvent.click(screen.getByText(/debugAsMultipleModel/i))
|
||||
expect(defaultProps.onDebugWithMultipleModelChange).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should render loading state when parameter rules are loading', () => {
|
||||
isRulesLoading = true
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not open content when readonly is true', () => {
|
||||
render(<ModelParameterModal {...defaultProps} readonly />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
expect(screen.queryByTestId('model-selector')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render no parameter items when rules are undefined', () => {
|
||||
parameterRules = undefined
|
||||
render(<ModelParameterModal {...defaultProps} />)
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
expect(screen.queryByTestId('param-temperature')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,238 +1,182 @@
|
||||
import type { ModelParameterRule } from '../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import ParameterItem from './parameter-item'
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
useLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/radio', () => {
|
||||
const Radio = ({ children, value }: { children: React.ReactNode, value: boolean }) => <button data-testid={`radio-${value}`}>{children}</button>
|
||||
Radio.Group = ({ children, onChange }: { children: React.ReactNode, onChange: (value: boolean) => void }) => (
|
||||
<div>
|
||||
{children}
|
||||
<button onClick={() => onChange(true)}>Select True</button>
|
||||
<button onClick={() => onChange(false)}>Select False</button>
|
||||
</div>
|
||||
)
|
||||
return { default: Radio }
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/base/select', () => ({
|
||||
SimpleSelect: ({ onSelect, items }: { onSelect: (item: { value: string }) => void, items: { value: string, name: string }[] }) => (
|
||||
<select onChange={e => onSelect({ value: e.target.value })}>
|
||||
{items.map(item => (
|
||||
<option key={item.value} value={item.value}>{item.name}</option>
|
||||
))}
|
||||
</select>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/slider', () => ({
|
||||
default: ({ value, onChange }: { value: number, onChange: (val: number) => void }) => (
|
||||
<input type="range" value={value} onChange={e => onChange(Number(e.target.value))} />
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/switch', () => ({
|
||||
default: ({ onChange, value }: { onChange: (val: boolean) => void, value: boolean }) => (
|
||||
<button onClick={() => onChange(!value)}>Switch</button>
|
||||
default: ({ onChange }: { onChange: (v: number) => void }) => (
|
||||
<button onClick={() => onChange(2)} data-testid="slider-btn">Slide 2</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tag-input', () => ({
|
||||
default: ({ onChange }: { onChange: (val: string[]) => void }) => (
|
||||
<input onChange={e => onChange(e.target.value.split(','))} />
|
||||
default: ({ onChange }: { onChange: (v: string[]) => void }) => (
|
||||
<button onClick={() => onChange(['tag1', 'tag2'])} data-testid="tag-input">Tag</button>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ popupContent }: { popupContent: React.ReactNode }) => <div>{popupContent}</div>,
|
||||
}))
|
||||
|
||||
describe('ParameterItem', () => {
|
||||
const createRule = (overrides: Partial<ModelParameterRule> = {}): ModelParameterRule => ({
|
||||
name: 'temp',
|
||||
label: { en_US: 'Temperature', zh_Hans: 'Temperature' },
|
||||
type: 'float',
|
||||
min: 0,
|
||||
max: 1,
|
||||
help: { en_US: 'Help text', zh_Hans: 'Help text' },
|
||||
required: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createProps = (overrides: {
|
||||
parameterRule?: ModelParameterRule
|
||||
value?: number | string | boolean | string[]
|
||||
} = {}) => {
|
||||
const onChange = vi.fn()
|
||||
const onSwitch = vi.fn()
|
||||
return {
|
||||
parameterRule: createRule(),
|
||||
value: 0.7,
|
||||
onChange,
|
||||
onSwitch,
|
||||
...overrides,
|
||||
}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render float input with slider', () => {
|
||||
const props = createProps()
|
||||
const { rerender } = render(<ParameterItem {...props} />)
|
||||
|
||||
expect(screen.getByText('Temperature')).toBeInTheDocument()
|
||||
// Float tests
|
||||
it('should render float controls and clamp numeric input to max', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 1 })} value={0.7} onChange={onChange} />)
|
||||
const input = screen.getByRole('spinbutton')
|
||||
fireEvent.change(input, { target: { value: '0.8' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(0.8)
|
||||
|
||||
fireEvent.change(input, { target: { value: '1.4' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(1)
|
||||
|
||||
fireEvent.change(input, { target: { value: '-0.2' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(0)
|
||||
|
||||
const slider = screen.getByRole('slider')
|
||||
fireEvent.change(slider, { target: { value: '2' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(1)
|
||||
|
||||
fireEvent.change(slider, { target: { value: '-1' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(0)
|
||||
|
||||
fireEvent.change(slider, { target: { value: '0.4' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith(0.4)
|
||||
|
||||
fireEvent.blur(input)
|
||||
expect(input).toHaveValue(0.7)
|
||||
|
||||
const minBoundedProps = createProps({
|
||||
parameterRule: createRule({ type: 'float', min: 1, max: 2 }),
|
||||
value: 1.5,
|
||||
})
|
||||
rerender(<ParameterItem {...minBoundedProps} />)
|
||||
fireEvent.change(screen.getByRole('slider'), { target: { value: '0' } })
|
||||
expect(minBoundedProps.onChange).toHaveBeenCalledWith(1)
|
||||
expect(onChange).toHaveBeenCalledWith(1)
|
||||
expect(screen.getByTestId('slider-btn')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render boolean radio', () => {
|
||||
const props = createProps({ parameterRule: createRule({ type: 'boolean', default: false }), value: true })
|
||||
render(<ParameterItem {...props} />)
|
||||
it('should clamp float numeric input to min', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0.1, max: 1 })} value={0.7} onChange={onChange} />)
|
||||
const input = screen.getByRole('spinbutton')
|
||||
fireEvent.change(input, { target: { value: '0.05' } })
|
||||
expect(onChange).toHaveBeenCalledWith(0.1)
|
||||
})
|
||||
|
||||
// Int tests
|
||||
it('should render int controls and clamp numeric input', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'int', min: 0, max: 10 })} value={5} onChange={onChange} />)
|
||||
const input = screen.getByRole('spinbutton')
|
||||
fireEvent.change(input, { target: { value: '15' } })
|
||||
expect(onChange).toHaveBeenCalledWith(10)
|
||||
fireEvent.change(input, { target: { value: '-5' } })
|
||||
expect(onChange).toHaveBeenCalledWith(0)
|
||||
})
|
||||
|
||||
it('should adjust step based on max for int type', () => {
|
||||
const { rerender } = render(<ParameterItem parameterRule={createRule({ type: 'int', min: 0, max: 50 })} value={5} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '1')
|
||||
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'int', min: 0, max: 500 })} value={50} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '10')
|
||||
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'int', min: 0, max: 2000 })} value={50} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '100')
|
||||
})
|
||||
|
||||
it('should render int input without slider if min or max is missing', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'int', min: 0 })} value={5} />)
|
||||
expect(screen.queryByRole('slider')).not.toBeInTheDocument()
|
||||
// No max -> precision step
|
||||
expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0')
|
||||
})
|
||||
|
||||
// Slider events (uses generic value mock for slider)
|
||||
it('should handle slide change and clamp values', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 10 })} value={0.7} onChange={onChange} />)
|
||||
|
||||
// Test that the actual slider triggers the onChange logic correctly
|
||||
// The implementation of Slider uses onChange(val) directly via the mock
|
||||
fireEvent.click(screen.getByTestId('slider-btn'))
|
||||
expect(onChange).toHaveBeenCalledWith(2)
|
||||
})
|
||||
|
||||
// Text & String tests
|
||||
it('should render exact string input and propagate text changes', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'string', name: 'prompt' })} value="initial" onChange={onChange} />)
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'updated' } })
|
||||
expect(onChange).toHaveBeenCalledWith('updated')
|
||||
})
|
||||
|
||||
it('should render textarea for text type', () => {
|
||||
const onChange = vi.fn()
|
||||
const { container } = render(<ParameterItem parameterRule={createRule({ type: 'text' })} value="long text" onChange={onChange} />)
|
||||
const textarea = container.querySelector('textarea')!
|
||||
expect(textarea).toBeInTheDocument()
|
||||
fireEvent.change(textarea, { target: { value: 'new long text' } })
|
||||
expect(onChange).toHaveBeenCalledWith('new long text')
|
||||
})
|
||||
|
||||
it('should render select for string with options', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'string', options: ['a', 'b'] })} value="a" />)
|
||||
// SimpleSelect renders an element with text 'a'
|
||||
expect(screen.getByText('a')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Tag Tests
|
||||
it('should render tag input for tag type', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'tag', tagPlaceholder: { en_US: 'placeholder', zh_Hans: 'placeholder' } })} value={['a']} onChange={onChange} />)
|
||||
expect(screen.getByText('placeholder')).toBeInTheDocument()
|
||||
// Trigger mock tag input
|
||||
fireEvent.click(screen.getByTestId('tag-input'))
|
||||
expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2'])
|
||||
})
|
||||
|
||||
// Boolean tests
|
||||
it('should render boolean radios and update value on click', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'boolean', default: false })} value={true} onChange={onChange} />)
|
||||
fireEvent.click(screen.getByText('False'))
|
||||
expect(onChange).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
// Switch tests
|
||||
it('should call onSwitch with current value when optional switch is toggled off', () => {
|
||||
const onSwitch = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule()} value={0.7} onSwitch={onSwitch} />)
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
expect(onSwitch).toHaveBeenCalledWith(false, 0.7)
|
||||
})
|
||||
|
||||
it('should not render switch if required or name is stop', () => {
|
||||
const { rerender } = render(<ParameterItem parameterRule={createRule({ required: true as unknown as false })} value={1} />)
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
rerender(<ParameterItem parameterRule={createRule({ name: 'stop', required: false })} value={1} />)
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Default Value Fallbacks (rendering without value)
|
||||
it('should use default values if value is undefined', () => {
|
||||
const { rerender } = render(<ParameterItem parameterRule={createRule({ type: 'float', default: 0.5 })} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveValue(0.5)
|
||||
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'string', default: 'hello' })} />)
|
||||
expect(screen.getByRole('textbox')).toHaveValue('hello')
|
||||
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'boolean', default: true })} />)
|
||||
expect(screen.getByText('True')).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByText('Select False'))
|
||||
expect(props.onChange).toHaveBeenCalledWith(false)
|
||||
expect(screen.getByText('False')).toBeInTheDocument()
|
||||
|
||||
// Without default
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'float' })} />) // min is 0 by default in createRule
|
||||
expect(screen.getByRole('spinbutton')).toHaveValue(0)
|
||||
})
|
||||
|
||||
it('should render string input and select options', () => {
|
||||
const props = createProps({ parameterRule: createRule({ type: 'string' }), value: 'test' })
|
||||
const { rerender } = render(<ParameterItem {...props} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
fireEvent.change(input, { target: { value: 'new' } })
|
||||
expect(props.onChange).toHaveBeenCalledWith('new')
|
||||
|
||||
const selectProps = createProps({
|
||||
parameterRule: createRule({ type: 'string', options: ['opt1', 'opt2'] }),
|
||||
value: 'opt1',
|
||||
})
|
||||
rerender(<ParameterItem {...selectProps} />)
|
||||
const select = screen.getByRole('combobox')
|
||||
fireEvent.change(select, { target: { value: 'opt2' } })
|
||||
expect(selectProps.onChange).toHaveBeenCalledWith('opt2')
|
||||
// Input Blur
|
||||
it('should reset input to actual bound value on blur', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 1 })} />)
|
||||
const input = screen.getByRole('spinbutton')
|
||||
// change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state)
|
||||
// Actually our test fires a change so localValue = 1, then blur sets it
|
||||
fireEvent.change(input, { target: { value: '5' } })
|
||||
fireEvent.blur(input)
|
||||
expect(input).toHaveValue(1)
|
||||
})
|
||||
|
||||
it('should handle switch toggle', () => {
|
||||
const props = createProps()
|
||||
let view = render(<ParameterItem {...props} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(props.onSwitch).toHaveBeenCalledWith(false, 0.7)
|
||||
|
||||
const intDefaultProps = createProps({
|
||||
parameterRule: createRule({ type: 'int', min: 0, default: undefined }),
|
||||
value: undefined,
|
||||
})
|
||||
view.unmount()
|
||||
view = render(<ParameterItem {...intDefaultProps} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(intDefaultProps.onSwitch).toHaveBeenCalledWith(true, 0)
|
||||
|
||||
const stringDefaultProps = createProps({
|
||||
parameterRule: createRule({ type: 'string', default: 'preset-value' }),
|
||||
value: undefined,
|
||||
})
|
||||
view.unmount()
|
||||
view = render(<ParameterItem {...stringDefaultProps} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(stringDefaultProps.onSwitch).toHaveBeenCalledWith(true, 'preset-value')
|
||||
|
||||
const booleanDefaultProps = createProps({
|
||||
parameterRule: createRule({ type: 'boolean', default: true }),
|
||||
value: undefined,
|
||||
})
|
||||
view.unmount()
|
||||
view = render(<ParameterItem {...booleanDefaultProps} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(booleanDefaultProps.onSwitch).toHaveBeenCalledWith(true, true)
|
||||
|
||||
const tagDefaultProps = createProps({
|
||||
parameterRule: createRule({ type: 'tag', default: ['one'] }),
|
||||
value: undefined,
|
||||
})
|
||||
view.unmount()
|
||||
const tagView = render(<ParameterItem {...tagDefaultProps} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(tagDefaultProps.onSwitch).toHaveBeenCalledWith(true, ['one'])
|
||||
|
||||
const zeroValueProps = createProps({
|
||||
parameterRule: createRule({ type: 'float', default: 0.5 }),
|
||||
value: 0,
|
||||
})
|
||||
tagView.unmount()
|
||||
render(<ParameterItem {...zeroValueProps} />)
|
||||
fireEvent.click(screen.getByText('Switch'))
|
||||
expect(zeroValueProps.onSwitch).toHaveBeenCalledWith(false, 0)
|
||||
})
|
||||
|
||||
it('should support text and tag parameter interactions', () => {
|
||||
const textProps = createProps({
|
||||
parameterRule: createRule({ type: 'text', name: 'prompt' }),
|
||||
value: 'initial prompt',
|
||||
})
|
||||
const { rerender } = render(<ParameterItem {...textProps} />)
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'rewritten prompt' } })
|
||||
expect(textProps.onChange).toHaveBeenCalledWith('rewritten prompt')
|
||||
|
||||
const tagProps = createProps({
|
||||
parameterRule: createRule({
|
||||
type: 'tag',
|
||||
name: 'tags',
|
||||
tagPlaceholder: { en_US: 'Tag hint', zh_Hans: 'Tag hint' },
|
||||
}),
|
||||
value: ['alpha'],
|
||||
})
|
||||
rerender(<ParameterItem {...tagProps} />)
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'one,two' } })
|
||||
expect(tagProps.onChange).toHaveBeenCalledWith(['one', 'two'])
|
||||
})
|
||||
|
||||
it('should support int parameters and unknown type fallback', () => {
|
||||
const intProps = createProps({
|
||||
parameterRule: createRule({ type: 'int', min: 0, max: 500, default: 100 }),
|
||||
value: 100,
|
||||
})
|
||||
const { rerender } = render(<ParameterItem {...intProps} />)
|
||||
fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '350' } })
|
||||
expect(intProps.onChange).toHaveBeenCalledWith(350)
|
||||
|
||||
const unknownTypeProps = createProps({
|
||||
parameterRule: createRule({ type: 'unsupported' }),
|
||||
value: 0.7,
|
||||
})
|
||||
rerender(<ParameterItem {...unknownTypeProps} />)
|
||||
// Unsupported
|
||||
it('should render no input for unsupported parameter type', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'unsupported' as unknown as string })} value={0.7} />)
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -2,19 +2,6 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import PresetsParameter from './presets-parameter'
|
||||
|
||||
vi.mock('@/app/components/base/dropdown', () => ({
|
||||
default: ({ renderTrigger, items, onSelect }: { renderTrigger: (open: boolean) => React.ReactNode, items: { value: number, text: string }[], onSelect: (item: { value: number }) => void }) => (
|
||||
<div>
|
||||
{renderTrigger(false)}
|
||||
{items.map(item => (
|
||||
<button key={item.value} onClick={() => onSelect(item)}>
|
||||
{item.text}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
describe('PresetsParameter', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -26,7 +13,39 @@ describe('PresetsParameter', () => {
|
||||
|
||||
expect(screen.getByText('common.modelProvider.loadPresets')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i }))
|
||||
fireEvent.click(screen.getByText('common.model.tone.Creative'))
|
||||
expect(onSelect).toHaveBeenCalledWith(1)
|
||||
})
|
||||
|
||||
// open=true: trigger has bg-state-base-hover class
|
||||
it('should apply hover background class when open is true', () => {
|
||||
render(<PresetsParameter onSelect={vi.fn()} />)
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i }))
|
||||
|
||||
const button = screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i })
|
||||
expect(button).toHaveClass('bg-state-base-hover')
|
||||
})
|
||||
|
||||
// Tone map branch 2: Balanced → Scales02 icon
|
||||
it('should call onSelect with tone id 2 when Balanced is clicked', () => {
|
||||
const onSelect = vi.fn()
|
||||
render(<PresetsParameter onSelect={onSelect} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i }))
|
||||
fireEvent.click(screen.getByText('common.model.tone.Balanced'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(2)
|
||||
})
|
||||
|
||||
// Tone map branch 3: Precise → Target04 icon
|
||||
it('should call onSelect with tone id 3 when Precise is clicked', () => {
|
||||
const onSelect = vi.fn()
|
||||
render(<PresetsParameter onSelect={onSelect} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /common\.modelProvider\.loadPresets/i }))
|
||||
fireEvent.click(screen.getByText('common.model.tone.Precise'))
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(3)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { vi } from 'vitest'
|
||||
import StatusIndicators from './status-indicators'
|
||||
|
||||
@@ -8,10 +9,6 @@ vi.mock('@/service/use-plugins', () => ({
|
||||
useInstalledPluginList: () => ({ data: { plugins: installedPlugins } }),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ popupContent }: { popupContent: React.ReactNode }) => <div>{popupContent}</div>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/switch-plugin-version', () => ({
|
||||
SwitchPluginVersion: ({ uniqueIdentifier }: { uniqueIdentifier: string }) => <div>{`SwitchVersion:${uniqueIdentifier}`}</div>,
|
||||
}))
|
||||
@@ -38,57 +35,95 @@ describe('StatusIndicators', () => {
|
||||
expect(container).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('should render warning states when provider model is disabled', () => {
|
||||
const parentClick = vi.fn()
|
||||
const { rerender } = render(
|
||||
<div onClick={parentClick}>
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={true}
|
||||
disabled={true}
|
||||
pluginInfo={null}
|
||||
t={t}
|
||||
/>
|
||||
</div>,
|
||||
it('should render deprecated tooltip when provider model is disabled and in model list', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={true}
|
||||
disabled={true}
|
||||
pluginInfo={null}
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('nodes.agent.modelSelectorTooltips.deprecated')).toBeInTheDocument()
|
||||
|
||||
rerender(
|
||||
<div onClick={parentClick}>
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={false}
|
||||
disabled={true}
|
||||
pluginInfo={null}
|
||||
t={t}
|
||||
/>
|
||||
</div>,
|
||||
)
|
||||
expect(screen.getByText('nodes.agent.modelNotSupport.title')).toBeInTheDocument()
|
||||
expect(screen.getByText('nodes.agent.linkToPlugin').closest('a')).toHaveAttribute('href', '/plugins')
|
||||
fireEvent.click(screen.getByText('nodes.agent.modelNotSupport.title'))
|
||||
fireEvent.click(screen.getByText('nodes.agent.linkToPlugin'))
|
||||
expect(parentClick).not.toHaveBeenCalled()
|
||||
const trigger = container.querySelector('[data-state]')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
rerender(
|
||||
<div onClick={parentClick}>
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={false}
|
||||
disabled={true}
|
||||
pluginInfo={{ name: 'demo-plugin' }}
|
||||
t={t}
|
||||
/>
|
||||
</div>,
|
||||
expect(await screen.findByText('nodes.agent.modelSelectorTooltips.deprecated')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render model-not-support tooltip when disabled model is not in model list and has no pluginInfo', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={false}
|
||||
disabled={true}
|
||||
pluginInfo={null}
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
|
||||
const trigger = container.querySelector('[data-state]')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(await screen.findByText('nodes.agent.modelNotSupport.title')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render switch plugin version when pluginInfo exists for disabled unsupported model', () => {
|
||||
render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={false}
|
||||
disabled={true}
|
||||
pluginInfo={{ name: 'demo-plugin' }}
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('SwitchVersion:demo@1.0.0')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render marketplace warning when provider is unavailable', () => {
|
||||
it('should render nothing when needsConfiguration is true even with disabled and modelProvider', () => {
|
||||
const { container } = render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={true}
|
||||
modelProvider={true}
|
||||
inModelList={true}
|
||||
disabled={true}
|
||||
pluginInfo={null}
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
expect(container).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('should render SwitchVersion with empty identifier when plugin is not in installed list', () => {
|
||||
installedPlugins = []
|
||||
|
||||
render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={true}
|
||||
inModelList={false}
|
||||
disabled={true}
|
||||
pluginInfo={{ name: 'missing-plugin' }}
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('SwitchVersion:')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render marketplace warning tooltip when provider is unavailable', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(
|
||||
<StatusIndicators
|
||||
needsConfiguration={false}
|
||||
modelProvider={false}
|
||||
@@ -98,6 +133,11 @@ describe('StatusIndicators', () => {
|
||||
t={t}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('nodes.agent.modelNotInMarketplace.title')).toBeInTheDocument()
|
||||
|
||||
const trigger = container.querySelector('[data-state]')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(await screen.findByText('nodes.agent.modelNotInMarketplace.title')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import Trigger from './trigger'
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
@@ -24,6 +25,10 @@ describe('Trigger', () => {
|
||||
const currentProvider = { provider: 'openai', label: { en_US: 'OpenAI' } } as unknown as ComponentProps<typeof Trigger>['currentProvider']
|
||||
const currentModel = { model: 'gpt-4' } as unknown as ComponentProps<typeof Trigger>['currentModel']
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render initialized state', () => {
|
||||
render(
|
||||
<Trigger
|
||||
@@ -44,4 +49,92 @@ describe('Trigger', () => {
|
||||
)
|
||||
expect(screen.getByText('gpt-4')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// isInWorkflow=true: workflow border class + RiArrowDownSLine arrow
|
||||
it('should render workflow styles when isInWorkflow is true', () => {
|
||||
// Act
|
||||
const { container } = render(
|
||||
<Trigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
isInWorkflow
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(container.firstChild).toHaveClass('border-workflow-block-parma-bg')
|
||||
expect(container.firstChild).toHaveClass('bg-workflow-block-parma-bg')
|
||||
expect(container.querySelectorAll('svg').length).toBe(2)
|
||||
})
|
||||
|
||||
// disabled=true + hasDeprecated=true: AlertTriangle + deprecated tooltip
|
||||
it('should show deprecated warning when disabled with hasDeprecated', () => {
|
||||
// Act
|
||||
render(
|
||||
<Trigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
disabled
|
||||
hasDeprecated
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - AlertTriangle renders with warning color
|
||||
const warningIcon = document.querySelector('.text-\\[\\#F79009\\]')
|
||||
expect(warningIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// disabled=true + modelDisabled=true: status text tooltip
|
||||
it('should show model status tooltip when disabled with modelDisabled', () => {
|
||||
// Act
|
||||
render(
|
||||
<Trigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={{ ...currentModel, status: 'no-configure' } as unknown as typeof currentModel}
|
||||
disabled
|
||||
modelDisabled
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - AlertTriangle warning icon should be present
|
||||
const warningIcon = document.querySelector('.text-\\[\\#F79009\\]')
|
||||
expect(warningIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render empty tooltip content when disabled without deprecated or modelDisabled', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(
|
||||
<Trigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
disabled
|
||||
hasDeprecated={false}
|
||||
modelDisabled={false}
|
||||
/>,
|
||||
)
|
||||
const warningIcon = document.querySelector('.text-\\[\\#F79009\\]')
|
||||
expect(warningIcon).toBeInTheDocument()
|
||||
const trigger = container.querySelector('[data-state]')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
await user.hover(trigger as HTMLElement)
|
||||
const tooltip = screen.queryByRole('tooltip')
|
||||
if (tooltip)
|
||||
expect(tooltip).toBeEmptyDOMElement()
|
||||
expect(screen.queryByText('modelProvider.deprecated')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('No Configure')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// providerName not matching any provider: find() returns undefined
|
||||
it('should render without crashing when providerName does not match any provider', () => {
|
||||
// Act
|
||||
render(
|
||||
<Trigger
|
||||
modelId="gpt-4"
|
||||
providerName="unknown-provider"
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('gpt-4')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -10,4 +10,22 @@ describe('EmptyTrigger', () => {
|
||||
render(<EmptyTrigger open={false} />)
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// open=true: hover bg class present
|
||||
it('should apply hover background class when open is true', () => {
|
||||
// Act
|
||||
const { container } = render(<EmptyTrigger open={true} />)
|
||||
|
||||
// Assert
|
||||
expect(container.firstChild).toHaveClass('bg-components-input-bg-hover')
|
||||
})
|
||||
|
||||
// className prop truthy: custom className appears on root
|
||||
it('should apply custom className when provided', () => {
|
||||
// Act
|
||||
const { container } = render(<EmptyTrigger open={false} className="custom-class" />)
|
||||
|
||||
// Assert
|
||||
expect(container.firstChild).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -10,12 +10,13 @@ import PopupItem from './popup-item'
|
||||
|
||||
const mockUpdateModelList = vi.hoisted(() => vi.fn())
|
||||
const mockUpdateModelProviders = vi.hoisted(() => vi.fn())
|
||||
const mockLanguageRef = vi.hoisted(() => ({ value: 'en_US' }))
|
||||
|
||||
vi.mock('../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
|
||||
return {
|
||||
...actual,
|
||||
useLanguage: () => 'en_US',
|
||||
useLanguage: () => mockLanguageRef.value,
|
||||
useUpdateModelList: () => mockUpdateModelList,
|
||||
useUpdateModelProviders: () => mockUpdateModelProviders,
|
||||
}
|
||||
@@ -69,6 +70,7 @@ const makeModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
describe('PopupItem', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockLanguageRef.value = 'en_US'
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
modelProviders: [{ provider: 'openai' }],
|
||||
})
|
||||
@@ -144,4 +146,87 @@ describe('PopupItem', () => {
|
||||
|
||||
expect(screen.getByText('GPT-4')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show check icon when model matches but provider does not', () => {
|
||||
const defaultModel: DefaultModel = { provider: 'anthropic', model: 'gpt-4' }
|
||||
render(
|
||||
<PopupItem
|
||||
defaultModel={defaultModel}
|
||||
model={makeModel()}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
const checkIcons = document.querySelectorAll('.h-4.w-4.shrink-0.text-text-accent')
|
||||
expect(checkIcons.length).toBe(0)
|
||||
})
|
||||
|
||||
it('should not show mode badge when model_properties.mode is absent', () => {
|
||||
const modelItem = makeModelItem({ model_properties: {} })
|
||||
render(
|
||||
<PopupItem
|
||||
model={makeModel({ models: [modelItem] })}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('CHAT')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should fall back to en_US label when current locale translation is empty', () => {
|
||||
mockLanguageRef.value = 'zh_Hans'
|
||||
const model = makeModel({
|
||||
label: { en_US: 'English Label', zh_Hans: '' },
|
||||
})
|
||||
render(<PopupItem model={model} onSelect={vi.fn()} />)
|
||||
|
||||
expect(screen.getByText('English Label')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show context_size badge when absent', () => {
|
||||
const modelItem = makeModelItem({ model_properties: { mode: 'chat' } })
|
||||
render(
|
||||
<PopupItem
|
||||
model={makeModel({ models: [modelItem] })}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText(/K$/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show capabilities section when features are empty', () => {
|
||||
const modelItem = makeModelItem({ features: [] })
|
||||
render(
|
||||
<PopupItem
|
||||
model={makeModel({ models: [modelItem] })}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('common.model.capabilities')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show capabilities for non-qualifying model types', () => {
|
||||
const modelItem = makeModelItem({
|
||||
model_type: ModelTypeEnum.tts,
|
||||
features: [ModelFeatureEnum.vision],
|
||||
})
|
||||
render(
|
||||
<PopupItem
|
||||
model={makeModel({ models: [modelItem] })}
|
||||
onSelect={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('common.model.capabilities')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show en_US label when language is fr_FR and fr_FR key is absent', () => {
|
||||
mockLanguageRef.value = 'fr_FR'
|
||||
const model = makeModel({ label: { en_US: 'FallbackLabel', zh_Hans: 'FallbackLabel' } })
|
||||
render(<PopupItem model={model} onSelect={vi.fn()} />)
|
||||
|
||||
expect(screen.getByText('FallbackLabel')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { Model, ModelItem } from '../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { tooltipManager } from '@/app/components/base/tooltip/TooltipManager'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelFeatureEnum,
|
||||
@@ -22,21 +23,6 @@ vi.mock('@/utils/tool-call', () => ({
|
||||
supportFunctionCall: mockSupportFunctionCall,
|
||||
}))
|
||||
|
||||
const mockCloseActiveTooltip = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/app/components/base/tooltip/TooltipManager', () => ({
|
||||
tooltipManager: {
|
||||
closeActiveTooltip: mockCloseActiveTooltip,
|
||||
register: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/icons/src/vender/solid/general', () => ({
|
||||
XCircle: ({ onClick }: { onClick?: () => void }) => (
|
||||
<button type="button" aria-label="clear-search" onClick={onClick} />
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
|
||||
return {
|
||||
@@ -70,10 +56,13 @@ const makeModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
})
|
||||
|
||||
describe('Popup', () => {
|
||||
let closeActiveTooltipSpy: ReturnType<typeof vi.spyOn>
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockLanguage = 'en_US'
|
||||
mockSupportFunctionCall.mockReturnValue(true)
|
||||
closeActiveTooltipSpy = vi.spyOn(tooltipManager, 'closeActiveTooltip')
|
||||
})
|
||||
|
||||
it('should filter models by search and allow clearing search', () => {
|
||||
@@ -91,8 +80,9 @@ describe('Popup', () => {
|
||||
fireEvent.change(input, { target: { value: 'not-found' } })
|
||||
expect(screen.getByText('No model found for “not-found”')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'clear-search' }))
|
||||
fireEvent.change(input, { target: { value: '' } })
|
||||
expect((input as HTMLInputElement).value).toBe('')
|
||||
expect(screen.getByText('openai')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter by scope features including toolCall and non-toolCall checks', () => {
|
||||
@@ -168,6 +158,24 @@ describe('Popup', () => {
|
||||
expect(screen.getByText('openai')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter out model when features array exists but does not include required scopeFeature', () => {
|
||||
const modelWithToolCallOnly = makeModel({
|
||||
models: [makeModelItem({ features: [ModelFeatureEnum.toolCall] })],
|
||||
})
|
||||
|
||||
render(
|
||||
<Popup
|
||||
modelList={[modelWithToolCallOnly]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
scopeFeatures={[ModelFeatureEnum.vision]}
|
||||
/>,
|
||||
)
|
||||
|
||||
// The model item should be filtered out because it has toolCall but not vision
|
||||
expect(screen.queryByText('openai')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should close tooltip on scroll', () => {
|
||||
const { container } = render(
|
||||
<Popup
|
||||
@@ -178,7 +186,7 @@ describe('Popup', () => {
|
||||
)
|
||||
|
||||
fireEvent.scroll(container.firstElementChild as HTMLElement)
|
||||
expect(mockCloseActiveTooltip).toHaveBeenCalled()
|
||||
expect(closeActiveTooltipSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should open provider settings when clicking footer link', () => {
|
||||
@@ -196,4 +204,35 @@ describe('Popup', () => {
|
||||
payload: 'provider',
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onHide when footer settings link is clicked', () => {
|
||||
const mockOnHide = vi.fn()
|
||||
render(
|
||||
<Popup
|
||||
modelList={[makeModel()]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={mockOnHide}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('common.model.settingsLink'))
|
||||
|
||||
expect(mockOnHide).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should match model label when searchText is non-empty and label key exists for current language', () => {
|
||||
render(
|
||||
<Popup
|
||||
modelList={[makeModel()]}
|
||||
onSelect={vi.fn()}
|
||||
onHide={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// GPT-4 label has en_US key, so modelItem.label[language] is defined
|
||||
const input = screen.getByPlaceholderText('datasetSettings.form.searchModel')
|
||||
fireEvent.change(input, { target: { value: 'gpt' } })
|
||||
|
||||
expect(screen.getByText('openai')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { changeModelProviderPriority } from '@/service/common'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import CredentialPanel from './credential-panel'
|
||||
@@ -24,11 +25,15 @@ vi.mock('@/config', async (importOriginal) => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
@@ -93,8 +98,14 @@ describe('CredentialPanel', () => {
|
||||
})
|
||||
})
|
||||
|
||||
const renderCredentialPanel = (provider: ModelProvider) => render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<CredentialPanel provider={provider} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
|
||||
it('should show credential name and configuration actions', () => {
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.getByText('test-credential')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('config-provider')).toBeInTheDocument()
|
||||
@@ -103,7 +114,7 @@ describe('CredentialPanel', () => {
|
||||
|
||||
it('should show unauthorized status label when credential is missing', () => {
|
||||
mockCredentialStatus.hasCredential = false
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.unAuthorized/)).toBeInTheDocument()
|
||||
})
|
||||
@@ -111,7 +122,7 @@ describe('CredentialPanel', () => {
|
||||
it('should show removed credential label and priority tip for custom preference', () => {
|
||||
mockCredentialStatus.authorized = false
|
||||
mockCredentialStatus.authRemoved = true
|
||||
render(<CredentialPanel provider={{ ...mockProvider, preferred_provider_type: 'custom' } as ModelProvider} />)
|
||||
renderCredentialPanel({ ...mockProvider, preferred_provider_type: 'custom' } as ModelProvider)
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.authRemoved/)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('priority-use-tip')).toBeInTheDocument()
|
||||
@@ -120,7 +131,7 @@ describe('CredentialPanel', () => {
|
||||
it('should change priority and refresh related data after success', async () => {
|
||||
const mockChangePriority = changeModelProviderPriority as ReturnType<typeof vi.fn>
|
||||
mockChangePriority.mockResolvedValue({ result: 'success' })
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
fireEvent.click(screen.getByTestId('priority-selector'))
|
||||
|
||||
@@ -138,8 +149,70 @@ describe('CredentialPanel', () => {
|
||||
...mockProvider,
|
||||
provider_credential_schema: null,
|
||||
} as unknown as ModelProvider
|
||||
render(<CredentialPanel provider={providerNoSchema} />)
|
||||
renderCredentialPanel(providerNoSchema)
|
||||
expect(screen.getByTestId('priority-selector')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('config-provider')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show gray indicator when notAllowedToUse is true', () => {
|
||||
mockCredentialStatus.notAllowedToUse = true
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.getByTestId('indicator')).toHaveTextContent('gray')
|
||||
})
|
||||
|
||||
it('should not notify or update when priority change returns non-success', async () => {
|
||||
const mockChangePriority = changeModelProviderPriority as ReturnType<typeof vi.fn>
|
||||
mockChangePriority.mockResolvedValue({ result: 'error' })
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
fireEvent.click(screen.getByTestId('priority-selector'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockChangePriority).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockNotify).not.toHaveBeenCalled()
|
||||
expect(mockUpdateModelProviders).not.toHaveBeenCalled()
|
||||
expect(mockEventEmitter.emit).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show empty label when authorized is false and authRemoved is false', () => {
|
||||
mockCredentialStatus.authorized = false
|
||||
mockCredentialStatus.authRemoved = false
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.queryByText(/modelProvider\.auth\.unAuthorized/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/modelProvider\.auth\.authRemoved/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show PriorityUseTip when priorityUseType is system', () => {
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.queryByTestId('priority-use-tip')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not iterate configurateMethods for non-predefinedModel methods', async () => {
|
||||
const mockChangePriority = changeModelProviderPriority as ReturnType<typeof vi.fn>
|
||||
mockChangePriority.mockResolvedValue({ result: 'success' })
|
||||
const providerWithCustomMethod = {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
renderCredentialPanel(providerWithCustomMethod)
|
||||
|
||||
fireEvent.click(screen.getByTestId('priority-selector'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockChangePriority).toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockUpdateModelList).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should show red indicator when hasCredential is false', () => {
|
||||
mockCredentialStatus.hasCredential = false
|
||||
renderCredentialPanel(mockProvider)
|
||||
|
||||
expect(screen.getByTestId('indicator')).toHaveTextContent('red')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -125,6 +125,48 @@ describe('ProviderAddedCard', () => {
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show loading spinner while model list is being fetched', async () => {
|
||||
let resolvePromise: (value: unknown) => void = () => {}
|
||||
const pendingPromise = new Promise((resolve) => {
|
||||
resolvePromise = resolve
|
||||
})
|
||||
vi.mocked(fetchModelProviderModelList).mockReturnValue(pendingPromise as ReturnType<typeof fetchModelProviderModelList>)
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('show-models-button'))
|
||||
|
||||
expect(document.querySelector('.i-ri-loader-2-line.animate-spin')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
resolvePromise({ data: [] })
|
||||
})
|
||||
})
|
||||
|
||||
it('should show modelsNum text after models have loaded', async () => {
|
||||
const models = [
|
||||
{ model: 'gpt-4' },
|
||||
{ model: 'gpt-3.5' },
|
||||
]
|
||||
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: models } as unknown as { data: ModelItem[] })
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('show-models-button'))
|
||||
|
||||
await screen.findByTestId('model-list')
|
||||
|
||||
const collapseBtn = screen.getByRole('button', { name: 'collapse list' })
|
||||
fireEvent.click(collapseBtn)
|
||||
|
||||
await waitFor(() => expect(screen.queryByTestId('model-list')).not.toBeInTheDocument())
|
||||
|
||||
const numTexts = screen.getAllByText(/modelProvider\.modelsNum/)
|
||||
expect(numTexts.length).toBeGreaterThan(0)
|
||||
|
||||
expect(screen.getByText(/modelProvider\.showModelsNum/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render configure tip when provider is not in quota list and not configured', () => {
|
||||
const providerWithoutQuota = {
|
||||
...mockProvider,
|
||||
@@ -163,6 +205,16 @@ describe('ProviderAddedCard', () => {
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should apply anthropic background class for anthropic provider', () => {
|
||||
const anthropicProvider = {
|
||||
...mockProvider,
|
||||
provider: 'langgenius/anthropic/anthropic',
|
||||
} as unknown as ModelProvider
|
||||
const { container } = render(<ProviderAddedCard provider={anthropicProvider} />)
|
||||
|
||||
expect(container.querySelector('.bg-third-party-model-bg-anthropic')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render custom model actions for workspace managers', () => {
|
||||
const customConfigProvider = {
|
||||
...mockProvider,
|
||||
@@ -177,4 +229,36 @@ describe('ProviderAddedCard', () => {
|
||||
rerender(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render credential panel when showCredential is true', () => {
|
||||
// Arrange: use ConfigurationMethodEnum.predefinedModel ('predefined-model') so showCredential=true
|
||||
const predefinedProvider = {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
mockIsCurrentWorkspaceManager = true
|
||||
|
||||
// Act
|
||||
render(<ProviderAddedCard provider={predefinedProvider} />)
|
||||
|
||||
// Assert: credential-panel is rendered (showCredential = true branch)
|
||||
expect(screen.getByTestId('credential-panel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render credential panel when user is not workspace manager', () => {
|
||||
// Arrange: predefined-model but manager=false so showCredential=false
|
||||
const predefinedProvider = {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
|
||||
// Act
|
||||
render(<ProviderAddedCard provider={predefinedProvider} />)
|
||||
|
||||
// Assert: credential-panel is not rendered (showCredential = false)
|
||||
expect(screen.queryByTestId('credential-panel')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import { ModelStatusEnum } from '../declarations'
|
||||
import ModelListItem from './model-list-item'
|
||||
|
||||
let mockModelLoadBalancingEnabled = false
|
||||
let mockPlanType: string = 'pro'
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
@@ -14,7 +15,7 @@ vi.mock('@/context/app-context', () => ({
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
plan: { type: 'pro' },
|
||||
plan: { type: mockPlanType },
|
||||
}),
|
||||
useProviderContextSelector: () => mockModelLoadBalancingEnabled,
|
||||
}))
|
||||
@@ -60,6 +61,7 @@ describe('ModelListItem', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockModelLoadBalancingEnabled = false
|
||||
mockPlanType = 'pro'
|
||||
})
|
||||
|
||||
it('should render model item with icon and name', () => {
|
||||
@@ -127,4 +129,127 @@ describe('ModelListItem', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'modify load balancing' }))
|
||||
expect(onModifyLoadBalancing).toHaveBeenCalledWith(mockModel)
|
||||
})
|
||||
|
||||
// Deprecated branches: opacity-60, disabled switch, no ConfigModel
|
||||
it('should show deprecated model with opacity and disabled switch', () => {
|
||||
// Arrange
|
||||
const deprecatedModel = { ...mockModel, deprecated: true } as unknown as ModelItem
|
||||
mockModelLoadBalancingEnabled = true
|
||||
|
||||
// Act
|
||||
const { container } = render(
|
||||
<ModelListItem
|
||||
model={deprecatedModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(container.querySelector('.opacity-60')).toBeInTheDocument()
|
||||
expect(screen.queryByRole('button', { name: 'modify load balancing' })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Load balancing badge: visible when all 4 conditions met
|
||||
it('should show load balancing badge when all conditions are met', () => {
|
||||
// Arrange
|
||||
mockModelLoadBalancingEnabled = true
|
||||
const lbModel = {
|
||||
...mockModel,
|
||||
load_balancing_enabled: true,
|
||||
has_invalid_load_balancing_configs: false,
|
||||
deprecated: false,
|
||||
} as unknown as ModelItem
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelListItem
|
||||
model={lbModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - Badge component should render
|
||||
const badge = document.querySelector('.border-text-accent-secondary')
|
||||
expect(badge).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Plan.sandbox: ConfigModel shown without load balancing enabled
|
||||
it('should show ConfigModel for sandbox plan even without load balancing enabled', () => {
|
||||
// Arrange - set plan type to sandbox and keep load balancing disabled
|
||||
mockModelLoadBalancingEnabled = false
|
||||
mockPlanType = 'sandbox'
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelListItem
|
||||
model={mockModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - ConfigModel should show because plan.type === 'sandbox'
|
||||
expect(screen.getByRole('button', { name: 'modify load balancing' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Negative proof: non-sandbox plan without load balancing should NOT show ConfigModel
|
||||
it('should hide ConfigModel for non-sandbox plan without load balancing enabled', () => {
|
||||
// Arrange - set plan type to non-sandbox and keep load balancing disabled
|
||||
mockModelLoadBalancingEnabled = false
|
||||
mockPlanType = 'pro'
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelListItem
|
||||
model={mockModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - ConfigModel should NOT show because plan.type !== 'sandbox' and load balancing is disabled
|
||||
expect(screen.queryByRole('button', { name: 'modify load balancing' })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// model.status=credentialRemoved: switch disabled, no ConfigModel
|
||||
it('should disable switch and hide ConfigModel when status is credentialRemoved', () => {
|
||||
// Arrange
|
||||
const removedModel = { ...mockModel, status: ModelStatusEnum.credentialRemoved } as unknown as ModelItem
|
||||
mockModelLoadBalancingEnabled = true
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelListItem
|
||||
model={removedModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert - ConfigModel should not render because status is not active/disabled
|
||||
expect(screen.queryByRole('button', { name: 'modify load balancing' })).not.toBeInTheDocument()
|
||||
const statusSwitch = screen.getByRole('switch')
|
||||
expect(statusSwitch).toHaveClass('!cursor-not-allowed')
|
||||
fireEvent.click(statusSwitch)
|
||||
expect(statusSwitch).toHaveAttribute('aria-checked', 'false')
|
||||
expect(enableModel).not.toHaveBeenCalled()
|
||||
expect(disableModel).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// isConfigurable=true: hover class on row
|
||||
it('should apply hover class when isConfigurable is true', () => {
|
||||
// Act
|
||||
const { container } = render(
|
||||
<ModelListItem
|
||||
model={mockModel}
|
||||
provider={mockProvider}
|
||||
isConfigurable={true}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(container.querySelector('.hover\\:bg-components-panel-on-panel-item-bg-hover')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ModelList from './model-list'
|
||||
|
||||
const mockSetShowModelLoadBalancingModal = vi.fn()
|
||||
@@ -105,4 +106,120 @@ describe('ModelList', () => {
|
||||
expect(screen.queryByTestId('manage-credentials')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('add-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// isConfigurable=false: predefinedModel only provider hides custom model actions
|
||||
it('should hide custom model actions when provider uses predefinedModel only', () => {
|
||||
// Arrange
|
||||
const predefinedProvider = {
|
||||
provider: 'test-provider',
|
||||
configurate_methods: ['predefinedModel'],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelList
|
||||
provider={predefinedProvider}
|
||||
models={mockModels}
|
||||
onCollapse={mockOnCollapse}
|
||||
onChange={mockOnChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('manage-credentials')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('add-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSave (onChange) and onClose from the load balancing modal callbacks', () => {
|
||||
render(
|
||||
<ModelList
|
||||
provider={mockProvider}
|
||||
models={mockModels}
|
||||
onCollapse={mockOnCollapse}
|
||||
onChange={mockOnChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'gpt-4' }))
|
||||
expect(mockSetShowModelLoadBalancingModal).toHaveBeenCalled()
|
||||
|
||||
const callArg = mockSetShowModelLoadBalancingModal.mock.calls[0][0]
|
||||
|
||||
callArg.onSave('test-provider')
|
||||
expect(mockOnChange).toHaveBeenCalledWith('test-provider')
|
||||
|
||||
callArg.onClose()
|
||||
expect(mockSetShowModelLoadBalancingModal).toHaveBeenCalledWith(null)
|
||||
})
|
||||
|
||||
// fetchFromRemote filtered out: provider with only fetchFromRemote
|
||||
it('should hide custom model actions when provider uses fetchFromRemote only', () => {
|
||||
// Arrange
|
||||
const fetchOnlyProvider = {
|
||||
provider: 'test-provider',
|
||||
configurate_methods: ['fetchFromRemote'],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelList
|
||||
provider={fetchOnlyProvider}
|
||||
models={mockModels}
|
||||
onCollapse={mockOnCollapse}
|
||||
onChange={mockOnChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('manage-credentials')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('add-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show custom model actions when provider is configurable and user is workspace manager', () => {
|
||||
// Arrange: use ConfigurationMethodEnum.customizableModel ('customizable-model') so isConfigurable=true
|
||||
const configurableProvider = {
|
||||
provider: 'test-provider',
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
mockIsCurrentWorkspaceManager = true
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelList
|
||||
provider={configurableProvider}
|
||||
models={mockModels}
|
||||
onCollapse={mockOnCollapse}
|
||||
onChange={mockOnChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert: custom model actions are shown (isConfigurable=true && isCurrentWorkspaceManager=true)
|
||||
expect(screen.getByTestId('manage-credentials')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide custom model actions when provider is configurable but user is not workspace manager', () => {
|
||||
// Arrange: use ConfigurationMethodEnum.customizableModel ('customizable-model') so isConfigurable=true, but manager=false
|
||||
const configurableProvider = {
|
||||
provider: 'test-provider',
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelList
|
||||
provider={configurableProvider}
|
||||
models={mockModels}
|
||||
onCollapse={mockOnCollapse}
|
||||
onChange={mockOnChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert: custom model actions are hidden (isCurrentWorkspaceManager=false covers the && short-circuit)
|
||||
expect(screen.queryByTestId('manage-credentials')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('add-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,7 +5,7 @@ import type {
|
||||
ModelLoadBalancingConfig,
|
||||
ModelProvider,
|
||||
} from '../declarations'
|
||||
import { act, render, screen } from '@testing-library/react'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { useState } from 'react'
|
||||
import { AddCredentialInLoadBalancing } from '@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
@@ -261,6 +261,128 @@ describe('ModelLoadBalancingConfigs', () => {
|
||||
expect(screen.getByText('common.modelProvider.defaultConfig')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should remove credential at index 0', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onRemove = vi.fn()
|
||||
// Create config where the target credential is at index 0
|
||||
const config: ModelLoadBalancingConfig = {
|
||||
enabled: true,
|
||||
configs: [
|
||||
{ id: 'cfg-target', credential_id: 'cred-2', enabled: true, name: 'Key 2' },
|
||||
{ id: 'cfg-other', credential_id: 'cred-1', enabled: true, name: 'Key 1' },
|
||||
],
|
||||
} as ModelLoadBalancingConfig
|
||||
|
||||
render(<StatefulHarness initialConfig={config} onRemove={onRemove} />)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'trigger remove' }))
|
||||
|
||||
expect(onRemove).toHaveBeenCalledWith('cred-2')
|
||||
expect(screen.queryByText('Key 2')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not toggle load balancing when modelLoadBalancingEnabled=false and enabling via switch', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockModelLoadBalancingEnabled = false
|
||||
render(<StatefulHarness initialConfig={createDraftConfig(false)} withSwitch />)
|
||||
|
||||
const mainSwitch = screen.getByTestId('load-balancing-switch-main')
|
||||
await user.click(mainSwitch)
|
||||
|
||||
// Switch is disabled so toggling to true should not work
|
||||
expect(mainSwitch).toHaveAttribute('aria-checked', 'false')
|
||||
})
|
||||
|
||||
it('should toggle load balancing to false when modelLoadBalancingEnabled=false but enabled=true via switch', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockModelLoadBalancingEnabled = false
|
||||
// When draftConfig.enabled=true and !enabled (toggling off): condition `(modelLoadBalancingEnabled || !enabled)` = (!enabled) = true
|
||||
render(<StatefulHarness initialConfig={createDraftConfig(true)} withSwitch />)
|
||||
|
||||
const mainSwitch = screen.getByTestId('load-balancing-switch-main')
|
||||
await user.click(mainSwitch)
|
||||
|
||||
expect(mainSwitch).toHaveAttribute('aria-checked', 'false')
|
||||
expect(screen.queryByText('Key 1')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show provider badge when isProviderManaged=true but configurationMethod is customizableModel', () => {
|
||||
const inheritConfig: ModelLoadBalancingConfig = {
|
||||
enabled: true,
|
||||
configs: [
|
||||
{ id: 'cfg-inherit', credential_id: '', enabled: true, name: '__inherit__' },
|
||||
],
|
||||
} as ModelLoadBalancingConfig
|
||||
|
||||
render(
|
||||
<StatefulHarness
|
||||
initialConfig={inheritConfig}
|
||||
configurationMethod={ConfigurationMethodEnum.customizableModel}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.defaultConfig')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.providerManaged')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show upgrade panel when modelLoadBalancingEnabled=false and not CE edition', () => {
|
||||
mockModelLoadBalancingEnabled = false
|
||||
|
||||
render(<StatefulHarness initialConfig={createDraftConfig(false)} />)
|
||||
|
||||
expect(screen.getByText('upgrade')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.upgradeForLoadBalancing')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass explicit boolean state to toggleConfigEntryEnabled (typeof state === boolean branch)', async () => {
|
||||
// Arrange: render with a config entry; the Switch onChange passes explicit boolean value
|
||||
const user = userEvent.setup()
|
||||
render(<StatefulHarness initialConfig={createDraftConfig(true)} />)
|
||||
|
||||
// Act: click the switch which calls toggleConfigEntryEnabled(index, value) where value is boolean
|
||||
const entrySwitch = screen.getByTestId('load-balancing-switch-cfg-1')
|
||||
await user.click(entrySwitch)
|
||||
|
||||
// Assert: component still renders after the toggle (state = explicit boolean true/false)
|
||||
expect(screen.getByTestId('load-balancing-main-panel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with credential that has not_allowed_to_use flag (covers credential?.not_allowed_to_use ? false branch)', () => {
|
||||
// Arrange: config where the credential is not allowed to use
|
||||
const restrictedConfig: ModelLoadBalancingConfig = {
|
||||
enabled: true,
|
||||
configs: [
|
||||
{ id: 'cfg-restricted', credential_id: 'cred-restricted', enabled: true, name: 'Restricted Key' },
|
||||
],
|
||||
} as ModelLoadBalancingConfig
|
||||
|
||||
const mockModelCredentialWithRestricted = {
|
||||
available_credentials: [
|
||||
{
|
||||
credential_id: 'cred-restricted',
|
||||
credential_name: 'Restricted Key',
|
||||
not_allowed_to_use: true,
|
||||
},
|
||||
],
|
||||
} as unknown as ModelCredential
|
||||
|
||||
// Act
|
||||
render(
|
||||
<ModelLoadBalancingConfigs
|
||||
draftConfig={restrictedConfig}
|
||||
setDraftConfig={vi.fn()}
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={mockModelCredentialWithRestricted}
|
||||
model={{ model: 'gpt-4', model_type: 'llm' } as CustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert: Switch value should be false (credential?.not_allowed_to_use ? false branch)
|
||||
const entrySwitch = screen.getByTestId('load-balancing-switch-cfg-restricted')
|
||||
expect(entrySwitch).toHaveAttribute('aria-checked', 'false')
|
||||
})
|
||||
|
||||
it('should handle edge cases where draftConfig becomes null during callbacks', async () => {
|
||||
let capturedAdd: ((credential: Credential) => void) | null = null
|
||||
let capturedUpdate: ((payload?: unknown, formValues?: Record<string, unknown>) => void) | null = null
|
||||
@@ -298,4 +420,82 @@ describe('ModelLoadBalancingConfigs', () => {
|
||||
|
||||
// Should not throw and just return prev (which is undefined)
|
||||
})
|
||||
|
||||
it('should not toggle load balancing when modelLoadBalancingEnabled=false and clicking panel to enable', async () => {
|
||||
// Arrange: load balancing not enabled in context, draftConfig.enabled=false (so panel is clickable)
|
||||
const user = userEvent.setup()
|
||||
mockModelLoadBalancingEnabled = false
|
||||
render(<StatefulHarness initialConfig={createDraftConfig(false)} withSwitch={false} />)
|
||||
|
||||
// Act: clicking the panel calls toggleModalBalancing(true)
|
||||
// but (modelLoadBalancingEnabled || !enabled) = (false || false) = false → condition fails
|
||||
const panel = screen.getByTestId('load-balancing-main-panel')
|
||||
await user.click(panel)
|
||||
|
||||
expect(screen.queryByText('Key 1')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return early from addConfigEntry setDraftConfig when prev is undefined', async () => {
|
||||
// Arrange: use a controlled wrapper that exposes a way to force draftConfig to undefined
|
||||
let capturedAdd: ((credential: Credential) => void) | null = null
|
||||
const MockChild = ({ onSelectCredential }: {
|
||||
onSelectCredential: (credential: Credential) => void
|
||||
}) => {
|
||||
capturedAdd = onSelectCredential
|
||||
return null
|
||||
}
|
||||
vi.mocked(AddCredentialInLoadBalancing).mockImplementation(MockChild as unknown as typeof AddCredentialInLoadBalancing)
|
||||
|
||||
// Use a setDraftConfig spy that tracks calls and simulates null prev
|
||||
const setDraftConfigSpy = vi.fn((updater: ((prev: ModelLoadBalancingConfig | undefined) => ModelLoadBalancingConfig | undefined) | ModelLoadBalancingConfig | undefined) => {
|
||||
if (typeof updater === 'function')
|
||||
updater(undefined)
|
||||
})
|
||||
|
||||
render(
|
||||
<ModelLoadBalancingConfigs
|
||||
draftConfig={createDraftConfig(true)}
|
||||
setDraftConfig={setDraftConfigSpy}
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={mockModelCredential}
|
||||
model={{ model: 'gpt-4', model_type: 'llm' } as CustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: trigger addConfigEntry with undefined prev via the spy
|
||||
act(() => {
|
||||
if (capturedAdd)
|
||||
(capturedAdd as (credential: Credential) => void)({ credential_id: 'new', credential_name: 'New' } as Credential)
|
||||
})
|
||||
|
||||
// Assert: setDraftConfig was called and the updater returned early (prev was undefined)
|
||||
expect(setDraftConfigSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return early from updateConfigEntry setDraftConfig when prev is undefined', async () => {
|
||||
// Arrange: use setDraftConfig spy that invokes updater with undefined prev
|
||||
const setDraftConfigSpy = vi.fn((updater: ((prev: ModelLoadBalancingConfig | undefined) => ModelLoadBalancingConfig | undefined) | ModelLoadBalancingConfig | undefined) => {
|
||||
if (typeof updater === 'function')
|
||||
updater(undefined)
|
||||
})
|
||||
|
||||
render(
|
||||
<ModelLoadBalancingConfigs
|
||||
draftConfig={createDraftConfig(true)}
|
||||
setDraftConfig={setDraftConfigSpy}
|
||||
provider={mockProvider}
|
||||
configurationMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
modelCredential={mockModelCredential}
|
||||
model={{ model: 'gpt-4', model_type: 'llm' } as CustomModelCredential}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: click remove button which triggers updateConfigEntry → setDraftConfig with prev=undefined
|
||||
const removeBtn = screen.getByTestId('load-balancing-remove-cfg-1')
|
||||
fireEvent.click(removeBtn)
|
||||
|
||||
// Assert: setDraftConfig was called and handled undefined prev gracefully
|
||||
expect(setDraftConfigSpy).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -130,7 +130,7 @@ const ModelLoadBalancingConfigs = ({
|
||||
|
||||
const handleRemove = useCallback((credentialId: string) => {
|
||||
const index = draftConfig?.configs.findIndex(item => item.credential_id === credentialId && item.name !== '__inherit__')
|
||||
if (index && index > -1)
|
||||
if (typeof index === 'number' && index > -1)
|
||||
updateConfigEntry(index, () => undefined)
|
||||
onRemove?.(credentialId)
|
||||
}, [draftConfig?.configs, updateConfigEntry, onRemove])
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ModelLoadBalancingModal from './model-load-balancing-modal'
|
||||
|
||||
vi.mock('@headlessui/react', () => ({
|
||||
Transition: ({ show, children }: { show: boolean, children: React.ReactNode }) => (show ? <>{children}</> : null),
|
||||
TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
Dialog: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
DialogPanel: ({ children, className }: { children: React.ReactNode, className?: string }) => <div className={className}>{children}</div>,
|
||||
DialogTitle: ({ children, className }: { children: React.ReactNode, className?: string }) => <h3 className={className}>{children}</h3>,
|
||||
}))
|
||||
|
||||
type CredentialData = {
|
||||
load_balancing: {
|
||||
enabled: boolean
|
||||
@@ -43,11 +53,15 @@ let mockCredentialData: CredentialData | undefined = {
|
||||
current_credential_name: 'Default',
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/service/use-models', () => ({
|
||||
useGetModelCredential: () => ({
|
||||
@@ -102,6 +116,8 @@ vi.mock('../model-name', () => ({
|
||||
}))
|
||||
|
||||
describe('ModelLoadBalancingModal', () => {
|
||||
let user: ReturnType<typeof userEvent.setup>
|
||||
|
||||
const mockProvider = {
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: {
|
||||
@@ -118,8 +134,15 @@ describe('ModelLoadBalancingModal', () => {
|
||||
fetch_from: 'predefined-model',
|
||||
} as unknown as ModelItem
|
||||
|
||||
const renderModal = (node: Parameters<typeof render>[0]) => render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
{node}
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
user = userEvent.setup()
|
||||
mockDeleteModel = null
|
||||
mockCredentialData = {
|
||||
load_balancing: {
|
||||
@@ -143,7 +166,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
it('should show loading area while draft config is not ready', () => {
|
||||
mockCredentialData = undefined
|
||||
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
@@ -156,7 +179,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
})
|
||||
|
||||
it('should render predefined model content', () => {
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
@@ -173,7 +196,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
it('should render custom model actions and close when update has no credentials', async () => {
|
||||
const onClose = vi.fn()
|
||||
mockRefetch.mockResolvedValue({ data: { available_credentials: [] } })
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
@@ -185,7 +208,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.removeModel/)).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'switch credential' })).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
await user.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
await waitFor(() => {
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
@@ -195,7 +218,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
const onSave = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
@@ -206,9 +229,9 @@ describe('ModelLoadBalancingModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'config rename credential' }))
|
||||
fireEvent.click(screen.getByText(/operation\.save/))
|
||||
await user.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
await user.click(screen.getByRole('button', { name: 'config rename credential' }))
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
@@ -226,7 +249,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
const onClose = vi.fn()
|
||||
mockRefetch.mockResolvedValue({ data: { available_credentials: [] } })
|
||||
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
@@ -236,7 +259,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'switch credential' }))
|
||||
await user.click(screen.getByRole('button', { name: 'switch credential' }))
|
||||
await waitFor(() => {
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
@@ -246,7 +269,7 @@ describe('ModelLoadBalancingModal', () => {
|
||||
const onClose = vi.fn()
|
||||
mockDeleteModel = { model: 'gpt-4' }
|
||||
|
||||
render(
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
@@ -256,8 +279,8 @@ describe('ModelLoadBalancingModal', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText(/modelProvider\.auth\.removeModel/))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
await user.click(screen.getByText(/modelProvider\.auth\.removeModel/))
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOpenConfirmDelete).toHaveBeenCalled()
|
||||
@@ -265,4 +288,479 @@ describe('ModelLoadBalancingModal', () => {
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// Disabled load balancing: title shows configModel text
|
||||
it('should show configModel title when load balancing is disabled', () => {
|
||||
mockCredentialData = {
|
||||
...mockCredentialData!,
|
||||
load_balancing: {
|
||||
enabled: false,
|
||||
configs: mockCredentialData!.load_balancing.configs,
|
||||
},
|
||||
}
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.configModel/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Modal hidden when open=false
|
||||
it('should not render modal content when open is false', () => {
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText(/modelProvider\.auth\.configLoadBalancing/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Config rename: updates name in draft config
|
||||
it('should rename credential in draft config', async () => {
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'config rename credential' }))
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// Config remove: removes credential from draft
|
||||
it('should remove credential from draft config', async () => {
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'config remove' }))
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// Save error: shows error toast
|
||||
it('should show error toast when save fails', async () => {
|
||||
mockMutateAsync.mockResolvedValue({ result: 'error' })
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// No current_credential_id: modelCredential is undefined
|
||||
it('should handle missing current_credential_id', () => {
|
||||
mockCredentialData = {
|
||||
...mockCredentialData!,
|
||||
current_credential_id: '',
|
||||
}
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: 'switch credential' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should disable save button when less than 2 configs are enabled', () => {
|
||||
mockCredentialData = {
|
||||
...mockCredentialData!,
|
||||
load_balancing: {
|
||||
enabled: true,
|
||||
configs: [
|
||||
{ id: 'cfg-1', credential_id: 'cred-1', enabled: true, name: 'Only One', credentials: { api_key: 'key' } },
|
||||
{ id: 'cfg-2', credential_id: 'cred-2', enabled: false, name: 'Disabled', credentials: { api_key: 'key2' } },
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/operation\.save/)).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should encode config entry without id as non-hidden value', async () => {
|
||||
mockCredentialData = {
|
||||
...mockCredentialData!,
|
||||
load_balancing: {
|
||||
enabled: true,
|
||||
configs: [
|
||||
{ id: '', credential_id: 'cred-new', enabled: true, name: 'New Entry', credentials: { api_key: 'new-key' } },
|
||||
{ id: 'cfg-2', credential_id: 'cred-2', enabled: true, name: 'Backup', credentials: { api_key: 'backup-key' } },
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
const payload = mockMutateAsync.mock.calls[0][0] as { load_balancing: { configs: Array<{ credentials: { api_key: string } }> } }
|
||||
// Entry without id should NOT be encoded as hidden
|
||||
expect(payload.load_balancing.configs[0].credentials.api_key).toBe('new-key')
|
||||
})
|
||||
})
|
||||
|
||||
it('should add new credential to draft config when update finds matching credential', async () => {
|
||||
mockRefetch.mockResolvedValue({
|
||||
data: {
|
||||
available_credentials: [
|
||||
{ credential_id: 'cred-new', credential_name: 'New Key' },
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Save after adding credential to verify it was added to draft
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not update draft config when handleUpdate credential name does not match any available credential', async () => {
|
||||
mockRefetch.mockResolvedValue({
|
||||
data: {
|
||||
available_credentials: [
|
||||
{ credential_id: 'cred-other', credential_name: 'Other Key' },
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// "config add credential" triggers onUpdate(undefined, { __authorization_name__: 'New Key' })
|
||||
// But refetch returns 'Other Key' not 'New Key', so find() returns undefined → no config update
|
||||
await user.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
// The payload configs should only have the original 2 entries (no new one added)
|
||||
const payload = mockMutateAsync.mock.calls[0][0] as { load_balancing: { configs: unknown[] } }
|
||||
expect(payload.load_balancing.configs).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
it('should toggle modal from enabled to disabled when clicking the card', async () => {
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
// draftConfig.enabled=true → title shows configLoadBalancing
|
||||
expect(screen.getByText(/modelProvider\.auth\.configLoadBalancing/)).toBeInTheDocument()
|
||||
|
||||
// Clicking the card when enabled=true toggles to disabled
|
||||
const card = screen.getByText(/modelProvider\.auth\.providerManaged$/).closest('div[class]')!.closest('div[class]')!
|
||||
await user.click(card)
|
||||
|
||||
// After toggling, title should show configModel (disabled state)
|
||||
expect(screen.getByText(/modelProvider\.auth\.configModel/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use customModelCredential credential_id when present in handleSave', async () => {
|
||||
// Arrange: set up credential data so customModelCredential is initialized from current_credential_id
|
||||
mockCredentialData = {
|
||||
...mockCredentialData!,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'Default',
|
||||
}
|
||||
const onSave = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={onSave}
|
||||
onClose={onClose}
|
||||
credential={{ credential_id: 'cred-1', credential_name: 'Default' } as unknown as Parameters<typeof ModelLoadBalancingModal>[0]['credential']}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: save triggers handleSave which uses customModelCredential?.credential_id
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
const payload = mockMutateAsync.mock.calls[0][0] as { credential_id: string }
|
||||
// credential_id should come from customModelCredential
|
||||
expect(payload.credential_id).toBe('cred-1')
|
||||
})
|
||||
})
|
||||
|
||||
it('should use null fallback for available_credentials when result.data is missing in handleUpdate', async () => {
|
||||
// Arrange: refetch returns data without available_credentials
|
||||
const onClose = vi.fn()
|
||||
mockRefetch.mockResolvedValue({ data: undefined })
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: trigger handleUpdate which does `result.data?.available_credentials || []`
|
||||
await user.click(screen.getByRole('button', { name: 'config add credential' }))
|
||||
|
||||
// Assert: available_credentials falls back to [], so onClose is called
|
||||
await waitFor(() => {
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should use null fallback for available_credentials in handleUpdateWhenSwitchCredential when result.data is missing', async () => {
|
||||
// Arrange: refetch returns data without available_credentials
|
||||
const onClose = vi.fn()
|
||||
mockRefetch.mockResolvedValue({ data: undefined })
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: trigger handleUpdateWhenSwitchCredential which does `result.data?.available_credentials || []`
|
||||
await user.click(screen.getByRole('button', { name: 'switch credential' }))
|
||||
|
||||
// Assert: available_credentials falls back to [], onClose is called
|
||||
await waitFor(() => {
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should use predefined provider schema without fallback when credential_form_schemas is undefined', () => {
|
||||
// Arrange: provider with no credential_form_schemas → triggers ?? [] fallback
|
||||
const providerWithoutSchemas = {
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: {
|
||||
credential_form_schemas: undefined,
|
||||
},
|
||||
model_credential_schema: {
|
||||
credential_form_schemas: undefined,
|
||||
},
|
||||
} as unknown as ModelProvider
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={providerWithoutSchemas}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert: component renders without error (extendedSecretFormSchemas = [])
|
||||
expect(screen.getByText(/modelProvider\.auth\.configLoadBalancing/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use custom model credential schema without fallback when credential_form_schemas is undefined', () => {
|
||||
// Arrange: provider with no model credential schemas → triggers ?? [] fallback for custom model path
|
||||
const providerWithoutModelSchemas = {
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: {
|
||||
credential_form_schemas: undefined,
|
||||
},
|
||||
model_credential_schema: {
|
||||
credential_form_schemas: undefined,
|
||||
},
|
||||
} as unknown as ModelProvider
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={providerWithoutModelSchemas}
|
||||
configurateMethod={ConfigurationMethodEnum.customizableModel}
|
||||
model={mockModel}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert: component renders without error (extendedSecretFormSchemas = [])
|
||||
expect(screen.getAllByText(/modelProvider\.auth\.specifyModelCredential/).length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should not update draft config when rename finds no matching index in prevIndex', async () => {
|
||||
// Arrange: credential in payload does not match any config (prevIndex = -1)
|
||||
mockRefetch.mockResolvedValue({
|
||||
data: {
|
||||
available_credentials: [
|
||||
{ credential_id: 'cred-99', credential_name: 'Unknown' },
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: "config rename credential" triggers onUpdate with credential: { credential_id: 'cred-1' }
|
||||
// but refetch returns cred-99, so newIndex for cred-1 is -1
|
||||
await user.click(screen.getByRole('button', { name: 'config rename credential' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Save to verify the config was not changed
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
const payload = mockMutateAsync.mock.calls[0][0] as { load_balancing: { configs: unknown[] } }
|
||||
// Config count unchanged (still 2 from original)
|
||||
expect(payload.load_balancing.configs).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
it('should encode credential_name as empty string when available_credentials has no name', async () => {
|
||||
// Arrange: available_credentials has a credential with no credential_name
|
||||
mockRefetch.mockResolvedValue({
|
||||
data: {
|
||||
available_credentials: [
|
||||
{ credential_id: 'cred-1', credential_name: '' },
|
||||
{ credential_id: 'cred-2', credential_name: 'Backup' },
|
||||
],
|
||||
},
|
||||
})
|
||||
|
||||
renderModal(
|
||||
<ModelLoadBalancingModal
|
||||
provider={mockProvider}
|
||||
configurateMethod={ConfigurationMethodEnum.predefinedModel}
|
||||
model={mockModel}
|
||||
open
|
||||
onSave={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Act: rename cred-1 which now has empty credential_name
|
||||
await user.click(screen.getByRole('button', { name: 'config rename credential' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefetch).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText(/operation\.save/))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -163,6 +163,18 @@ const ModelLoadBalancingModal = ({
|
||||
onSave?.(provider.provider)
|
||||
onClose?.()
|
||||
}
|
||||
else {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: (res as { error?: string })?.error || t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }),
|
||||
})
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }),
|
||||
})
|
||||
}
|
||||
finally {
|
||||
setLoading(false)
|
||||
@@ -218,7 +230,7 @@ const ModelLoadBalancingModal = ({
|
||||
}
|
||||
})
|
||||
}
|
||||
}, [refetch, credential])
|
||||
}, [refetch, onClose])
|
||||
|
||||
const handleUpdateWhenSwitchCredential = useCallback(async () => {
|
||||
const result = await refetch()
|
||||
@@ -250,7 +262,7 @@ const ModelLoadBalancingModal = ({
|
||||
modelName={model!.model}
|
||||
/>
|
||||
<ModelName
|
||||
className="system-md-regular grow text-text-secondary"
|
||||
className="grow text-text-secondary system-md-regular"
|
||||
modelItem={model!}
|
||||
showModelType
|
||||
showMode
|
||||
|
||||
@@ -1,14 +1,45 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import type { i18n } from 'i18next'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as reactI18next from 'react-i18next'
|
||||
import PriorityUseTip from './priority-use-tip'
|
||||
|
||||
describe('PriorityUseTip', () => {
|
||||
it('should render tooltip with icon content', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('should render tooltip with icon content', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(<PriorityUseTip />)
|
||||
expect(container.querySelector('[data-state]')).toBeInTheDocument()
|
||||
const trigger = container.querySelector('.cursor-pointer')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(await screen.findByText('common.modelProvider.priorityUsing')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render the component without crashing', () => {
|
||||
const { container } = render(<PriorityUseTip />)
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should exercise || fallback when t() returns empty string', async () => {
|
||||
const user = userEvent.setup()
|
||||
vi.spyOn(reactI18next, 'useTranslation').mockReturnValue({
|
||||
t: () => '',
|
||||
i18n: {} as unknown as i18n,
|
||||
ready: true,
|
||||
} as unknown as ReturnType<typeof reactI18next.useTranslation>)
|
||||
const { container } = render(<PriorityUseTip />)
|
||||
const trigger = container.querySelector('.cursor-pointer')
|
||||
expect(trigger).toBeInTheDocument()
|
||||
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.priorityUsing')).not.toBeInTheDocument()
|
||||
expect(document.querySelector('.rounded-md.bg-components-panel-bg')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import QuotaPanel from './quota-panel'
|
||||
|
||||
let mockWorkspace = {
|
||||
@@ -13,18 +14,6 @@ let mockPlugins = [{
|
||||
latest_package_identifier: 'openai@1.0.0',
|
||||
}]
|
||||
|
||||
vi.mock('@/app/components/base/icons/src/public/llm', () => {
|
||||
const Icon = ({ label }: { label: string }) => <span>{label}</span>
|
||||
return {
|
||||
OpenaiSmall: () => <Icon label="openai" />,
|
||||
AnthropicShortLight: () => <Icon label="anthropic" />,
|
||||
Gemini: () => <Icon label="gemini" />,
|
||||
Grok: () => <Icon label="x" />,
|
||||
Deepseek: () => <Icon label="deepseek" />,
|
||||
Tongyi: () => <Icon label="tongyi" />,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
currentWorkspace: mockWorkspace,
|
||||
@@ -80,6 +69,18 @@ describe('QuotaPanel', () => {
|
||||
mockPlugins = [{ plugin_id: 'langgenius/openai', latest_package_identifier: 'openai@1.0.0' }]
|
||||
})
|
||||
|
||||
const getTrialProviderIconTrigger = (container: HTMLElement) => {
|
||||
const providerIcon = container.querySelector('svg.h-6.w-6.rounded-lg')
|
||||
expect(providerIcon).toBeInTheDocument()
|
||||
const trigger = providerIcon?.closest('[data-state]') as HTMLDivElement | null
|
||||
expect(trigger).toBeInTheDocument()
|
||||
return trigger as HTMLDivElement
|
||||
}
|
||||
|
||||
const clickFirstTrialProviderIcon = (container: HTMLElement) => {
|
||||
fireEvent.click(getTrialProviderIconTrigger(container))
|
||||
}
|
||||
|
||||
it('should render loading state', () => {
|
||||
render(
|
||||
<QuotaPanel
|
||||
@@ -116,17 +117,17 @@ describe('QuotaPanel', () => {
|
||||
})
|
||||
|
||||
it('should open install modal when clicking an unsupported trial provider', () => {
|
||||
render(<QuotaPanel providers={[]} />)
|
||||
const { container } = render(<QuotaPanel providers={[]} />)
|
||||
|
||||
fireEvent.click(screen.getByText('openai'))
|
||||
clickFirstTrialProviderIcon(container)
|
||||
|
||||
expect(screen.getByText('install modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should close install modal when provider becomes installed', async () => {
|
||||
const { rerender } = render(<QuotaPanel providers={[]} />)
|
||||
const { rerender, container } = render(<QuotaPanel providers={[]} />)
|
||||
|
||||
fireEvent.click(screen.getByText('openai'))
|
||||
clickFirstTrialProviderIcon(container)
|
||||
expect(screen.getByText('install modal')).toBeInTheDocument()
|
||||
|
||||
rerender(<QuotaPanel providers={mockProviders} />)
|
||||
@@ -135,4 +136,61 @@ describe('QuotaPanel', () => {
|
||||
expect(screen.queryByText('install modal')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not open install modal when clicking an already installed provider', () => {
|
||||
const { container } = render(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
clickFirstTrialProviderIcon(container)
|
||||
|
||||
expect(screen.queryByText('install modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not open install modal when plugin is not found in marketplace', () => {
|
||||
mockPlugins = []
|
||||
const { container } = render(<QuotaPanel providers={[]} />)
|
||||
|
||||
clickFirstTrialProviderIcon(container)
|
||||
|
||||
expect(screen.queryByText('install modal')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show destructive border when credits are zero or negative', () => {
|
||||
mockWorkspace = {
|
||||
trial_credits: 0,
|
||||
trial_credits_used: 0,
|
||||
next_credit_reset_date: '',
|
||||
}
|
||||
|
||||
const { container } = render(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
expect(container.querySelector('.border-state-destructive-border')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show modelAPI tooltip for configured provider with custom preference', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
const trigger = getTrialProviderIconTrigger(container)
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(await screen.findByText(/common\.modelProvider\.card\.modelAPI/)).toHaveTextContent('OpenAI')
|
||||
})
|
||||
|
||||
it('should show modelSupported tooltip for installed provider without custom config', async () => {
|
||||
const user = userEvent.setup()
|
||||
const systemProviders = [
|
||||
{
|
||||
provider: 'langgenius/openai/openai',
|
||||
preferred_provider_type: 'system',
|
||||
custom_configuration: { available_credentials: [] },
|
||||
},
|
||||
] as unknown as ModelProvider[]
|
||||
|
||||
const { container } = render(<QuotaPanel providers={systemProviders} />)
|
||||
|
||||
const trigger = getTrialProviderIconTrigger(container)
|
||||
await user.hover(trigger as HTMLElement)
|
||||
|
||||
expect(await screen.findByText(/common\.modelProvider\.card\.modelSupported/)).toHaveTextContent('OpenAI')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { DefaultModelResponse } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { vi } from 'vitest'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { ModelTypeEnum } from '../declarations'
|
||||
import SystemModel from './index'
|
||||
|
||||
@@ -42,11 +43,15 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
useModelList: () => ({
|
||||
@@ -89,18 +94,24 @@ const defaultProps = {
|
||||
}
|
||||
|
||||
describe('SystemModel', () => {
|
||||
const renderSystemModel = (props: typeof defaultProps) => render(
|
||||
<ToastContext.Provider value={{ notify: mockNotify, close: vi.fn() }}>
|
||||
<SystemModel {...props} />
|
||||
</ToastContext.Provider>,
|
||||
)
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsCurrentWorkspaceManager = true
|
||||
})
|
||||
|
||||
it('should render settings button', () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
renderSystemModel(defaultProps)
|
||||
expect(screen.getByRole('button', { name: /system model settings/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open modal when button is clicked', async () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
renderSystemModel(defaultProps)
|
||||
const button = screen.getByRole('button', { name: /system model settings/i })
|
||||
fireEvent.click(button)
|
||||
await waitFor(() => {
|
||||
@@ -109,12 +120,12 @@ describe('SystemModel', () => {
|
||||
})
|
||||
|
||||
it('should disable button when loading', () => {
|
||||
render(<SystemModel {...defaultProps} isLoading />)
|
||||
renderSystemModel({ ...defaultProps, isLoading: true })
|
||||
expect(screen.getByRole('button', { name: /system model settings/i })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should close modal when cancel is clicked', async () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
renderSystemModel(defaultProps)
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /cancel/i })).toBeInTheDocument()
|
||||
@@ -126,7 +137,7 @@ describe('SystemModel', () => {
|
||||
})
|
||||
|
||||
it('should save selected models and show success feedback', async () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
@@ -150,11 +161,103 @@ describe('SystemModel', () => {
|
||||
|
||||
it('should disable save when user is not workspace manager', async () => {
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render primary variant button when notConfigured is true', () => {
|
||||
renderSystemModel({ ...defaultProps, notConfigured: true })
|
||||
const button = screen.getByRole('button', { name: /system model settings/i })
|
||||
expect(button.className).toContain('btn-primary')
|
||||
})
|
||||
|
||||
it('should keep modal open when save returns non-success result', async () => {
|
||||
mockUpdateDefaultModel.mockResolvedValueOnce({ result: 'error' })
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const selectorButtons = screen.getAllByRole('button', { name: 'Mock Model Selector' })
|
||||
selectorButtons.forEach(button => fireEvent.click(button))
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateDefaultModel).toHaveBeenCalledTimes(1)
|
||||
expect(mockNotify).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Modal should still be open after failed save
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: /cancel/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not add duplicate model type to changedModelTypes when same type is selected twice', async () => {
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Click the first selector twice (textGeneration type)
|
||||
const selectorButtons = screen.getAllByRole('button', { name: 'Mock Model Selector' })
|
||||
fireEvent.click(selectorButtons[0])
|
||||
fireEvent.click(selectorButtons[0])
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateDefaultModel).toHaveBeenCalledTimes(1)
|
||||
// textGeneration was changed, so updateModelList is called once for it
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should call updateModelList for speech2text and tts types on save', async () => {
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Click speech2text (index 3) and tts (index 4) selectors
|
||||
const selectorButtons = screen.getAllByRole('button', { name: 'Mock Model Selector' })
|
||||
fireEvent.click(selectorButtons[3])
|
||||
fireEvent.click(selectorButtons[4])
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
it('should call updateModelList for each unique changed model type on save', async () => {
|
||||
renderSystemModel(defaultProps)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('button', { name: /save/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Click embedding and rerank selectors (indices 1 and 2)
|
||||
const selectorButtons = screen.getAllByRole('button', { name: 'Mock Model Selector' })
|
||||
fireEvent.click(selectorButtons[1])
|
||||
fireEvent.click(selectorButtons[2])
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /save/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -33,7 +33,7 @@ vi.mock('@/service/common', () => ({
|
||||
}))
|
||||
|
||||
describe('utils', () => {
|
||||
afterEach(() => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
@@ -97,6 +97,18 @@ describe('utils', () => {
|
||||
const result = await validateCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'network error' })
|
||||
})
|
||||
|
||||
it('should return Unknown error when non-Error is thrown', async () => {
|
||||
(validateModelProvider as unknown as Mock).mockRejectedValue('string error')
|
||||
const result = await validateCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'Unknown error' })
|
||||
})
|
||||
|
||||
it('should return default error message when error field is empty', async () => {
|
||||
(validateModelProvider as unknown as Mock).mockResolvedValue({ result: 'error', error: '' })
|
||||
const result = await validateCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'error' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('validateLoadBalancingCredentials', () => {
|
||||
@@ -140,6 +152,24 @@ describe('utils', () => {
|
||||
const result = await validateLoadBalancingCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'failed' })
|
||||
})
|
||||
|
||||
it('should return Unknown error when non-Error is thrown', async () => {
|
||||
(validateModelLoadBalancingCredentials as unknown as Mock).mockRejectedValue(42)
|
||||
const result = await validateLoadBalancingCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'Unknown error' })
|
||||
})
|
||||
|
||||
it('should handle exception with Error', async () => {
|
||||
(validateModelLoadBalancingCredentials as unknown as Mock).mockRejectedValue(new Error('Timeout'))
|
||||
const result = await validateLoadBalancingCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'Timeout' })
|
||||
})
|
||||
|
||||
it('should return default error message when error field is empty', async () => {
|
||||
(validateModelLoadBalancingCredentials as unknown as Mock).mockResolvedValue({ result: 'error', error: '' })
|
||||
const result = await validateLoadBalancingCredentials(true, 'provider', {})
|
||||
expect(result).toEqual({ status: ValidatedStatus.Error, message: 'error' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('saveCredentials', () => {
|
||||
@@ -216,6 +246,19 @@ describe('utils', () => {
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should remove predefined credentials without credentialId', async () => {
|
||||
await removeCredentials(true, 'provider', {})
|
||||
expect(deleteModelProvider).toHaveBeenCalledWith({
|
||||
url: '/workspaces/current/model-providers/provider/credentials',
|
||||
body: undefined,
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call delete endpoint when non-predefined payload is falsy', async () => {
|
||||
await removeCredentials(false, 'provider', null as unknown as Record<string, unknown>)
|
||||
expect(deleteModelProvider).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('genModelTypeFormSchema', () => {
|
||||
@@ -228,11 +271,22 @@ describe('utils', () => {
|
||||
})
|
||||
|
||||
describe('genModelNameFormSchema', () => {
|
||||
it('should generate form schema', () => {
|
||||
it('should generate default form schema when no model provided', () => {
|
||||
const schema = genModelNameFormSchema()
|
||||
expect(schema.type).toBe(FormTypeEnum.textInput)
|
||||
expect(schema.variable).toBe('__model_name')
|
||||
expect(schema.required).toBe(true)
|
||||
expect(schema.label.en_US).toBe('Model Name')
|
||||
expect(schema.placeholder!.en_US).toBe('Please enter model name')
|
||||
})
|
||||
|
||||
it('should use provided label and placeholder when model is given', () => {
|
||||
const schema = genModelNameFormSchema({
|
||||
label: { en_US: 'Custom', zh_Hans: 'Custom' },
|
||||
placeholder: { en_US: 'Enter custom', zh_Hans: 'Enter custom' },
|
||||
})
|
||||
expect(schema.label.en_US).toBe('Custom')
|
||||
expect(schema.placeholder!.en_US).toBe('Enter custom')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -146,14 +146,15 @@ export const removeCredentials = async (predefined: boolean, provider: string, v
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (v) {
|
||||
const { __model_name, __model_type } = v
|
||||
body = {
|
||||
model: __model_name,
|
||||
model_type: __model_type,
|
||||
}
|
||||
url = `/workspaces/current/model-providers/${provider}/models`
|
||||
if (!v)
|
||||
return
|
||||
|
||||
const { __model_name, __model_type } = v
|
||||
body = {
|
||||
model: __model_name,
|
||||
model_type: __model_type,
|
||||
}
|
||||
url = `/workspaces/current/model-providers/${provider}/models`
|
||||
}
|
||||
|
||||
return deleteModelProvider({ url, body })
|
||||
|
||||
@@ -20,9 +20,13 @@ const mockEventEmitter = vi.hoisted(() => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: vi.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
|
||||
@@ -14,11 +14,15 @@ vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast/context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/app/components/base/toast/context')>()
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: vi.fn(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
|
||||
@@ -264,4 +264,78 @@ describe('AppNav', () => {
|
||||
await user.click(screen.getByTestId('load-more'))
|
||||
expect(fetchNextPage).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Non-editor link path: isCurrentWorkspaceEditor=false → link ends with /overview
|
||||
it('should build overview links when user is not editor', () => {
|
||||
// Arrange
|
||||
setupDefaultMocks({ isEditor: false })
|
||||
|
||||
// Act
|
||||
render(<AppNav />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('App 1 -> /app/app-1/overview')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// !!appId false: query disabled, no nav items
|
||||
it('should render no nav items when appId is undefined', () => {
|
||||
// Arrange
|
||||
setupDefaultMocks()
|
||||
mockUseParams.mockReturnValue({} as ReturnType<typeof useParams>)
|
||||
mockUseInfiniteAppList.mockReturnValue({
|
||||
data: undefined,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
isFetchingNextPage: false,
|
||||
refetch: vi.fn(),
|
||||
} as unknown as ReturnType<typeof useInfiniteAppList>)
|
||||
|
||||
// Act
|
||||
render(<AppNav />)
|
||||
|
||||
// Assert
|
||||
const navItems = screen.getByTestId('nav-items')
|
||||
expect(navItems.children).toHaveLength(0)
|
||||
})
|
||||
|
||||
// ADVANCED_CHAT OR branch: editor + ADVANCED_CHAT mode → link ends with /workflow
|
||||
it('should build workflow link for ADVANCED_CHAT mode when user is editor', () => {
|
||||
// Arrange
|
||||
setupDefaultMocks({
|
||||
isEditor: true,
|
||||
appData: [
|
||||
{
|
||||
id: 'app-3',
|
||||
name: 'Chat App',
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
icon_type: 'emoji',
|
||||
icon: '💬',
|
||||
icon_background: null,
|
||||
icon_url: null,
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<AppNav />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Chat App -> /app/app-3/workflow')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// No-match update path: appDetail.id doesn't match any nav item
|
||||
it('should not change nav item names when appDetail id does not match any item', async () => {
|
||||
// Arrange
|
||||
setupDefaultMocks({ isEditor: true })
|
||||
const { rerender } = render(<AppNav />)
|
||||
|
||||
// Act - set appDetail to a non-matching id
|
||||
mockAppDetail = { id: 'non-existent-id', name: 'Unknown' }
|
||||
rerender(<AppNav />)
|
||||
|
||||
// Assert - original name should be unchanged
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('App 1 -> /app/app-1/configuration')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,10 +6,6 @@ function createMockComponent(testId: string) {
|
||||
return () => <div data-testid={testId} />
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/base/logo/dify-logo', () => ({
|
||||
default: createMockComponent('dify-logo'),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-dropdown/workplace-selector', () => ({
|
||||
default: createMockComponent('workplace-selector'),
|
||||
}))
|
||||
@@ -129,7 +125,7 @@ describe('Header', () => {
|
||||
it('should render header with main nav components', () => {
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByTestId('dify-logo')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
expect(screen.getByTestId('workplace-selector')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('app-nav')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('account-dropdown')).toBeInTheDocument()
|
||||
@@ -173,7 +169,7 @@ describe('Header', () => {
|
||||
mockMedia = 'mobile'
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByTestId('dify-logo')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('env-nav')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -186,6 +182,70 @@ describe('Header', () => {
|
||||
|
||||
expect(screen.getByText('Acme Workspace')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /logo/i })).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('dify-logo')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('img', { name: /dify logo/i })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show default Dify logo when branding is enabled but no workspace_logo', () => {
|
||||
mockBrandingEnabled = true
|
||||
mockBrandingTitle = 'Custom Title'
|
||||
mockBrandingLogo = null
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByText('Custom Title')).toBeInTheDocument()
|
||||
expect(screen.getByRole('img', { name: /dify logo/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show default Dify text when branding enabled but no application_title', () => {
|
||||
mockBrandingEnabled = true
|
||||
mockBrandingTitle = null
|
||||
mockBrandingLogo = null
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByText('Dify')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show dataset nav for editor who is not dataset operator', () => {
|
||||
mockIsWorkspaceEditor = true
|
||||
mockIsDatasetOperator = false
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByTestId('dataset-nav')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('explore-nav')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('app-nav')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide dataset nav when neither editor nor dataset operator', () => {
|
||||
mockIsWorkspaceEditor = false
|
||||
mockIsDatasetOperator = false
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.queryByTestId('dataset-nav')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render mobile layout with dataset operator nav restrictions', () => {
|
||||
mockMedia = 'mobile'
|
||||
mockIsDatasetOperator = true
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.queryByTestId('explore-nav')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('app-nav')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('tools-nav')).not.toBeInTheDocument()
|
||||
expect(screen.getByTestId('dataset-nav')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render mobile layout with billing enabled', () => {
|
||||
mockMedia = 'mobile'
|
||||
mockEnableBilling = true
|
||||
mockPlanType = 'sandbox'
|
||||
|
||||
render(<Header />)
|
||||
|
||||
expect(screen.getByTestId('plan-badge')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('license-nav')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
61
web/app/components/header/utils/util.spec.ts
Normal file
61
web/app/components/header/utils/util.spec.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { generateMailToLink, mailToSupport } from './util'
|
||||
|
||||
describe('generateMailToLink', () => {
|
||||
// Email-only: both subject and body branches false
|
||||
it('should return mailto link with email only when no subject or body provided', () => {
|
||||
// Act
|
||||
const result = generateMailToLink('test@example.com')
|
||||
|
||||
// Assert
|
||||
expect(result).toBe('mailto:test@example.com')
|
||||
})
|
||||
|
||||
// Subject provided, body not: subject branch true, body branch false
|
||||
it('should append subject when subject is provided without body', () => {
|
||||
// Act
|
||||
const result = generateMailToLink('test@example.com', 'Hello World')
|
||||
|
||||
// Assert
|
||||
expect(result).toBe('mailto:test@example.com?subject=Hello%20World')
|
||||
})
|
||||
|
||||
// Body provided, no subject: subject branch false, body branch true
|
||||
it('should append body with question mark when body is provided without subject', () => {
|
||||
// Act
|
||||
const result = generateMailToLink('test@example.com', undefined, 'Some body text')
|
||||
|
||||
// Assert
|
||||
expect(result).toBe('mailto:test@example.com&body=Some%20body%20text')
|
||||
})
|
||||
|
||||
// Both subject and body provided: both branches true
|
||||
it('should append both subject and body when both are provided', () => {
|
||||
// Act
|
||||
const result = generateMailToLink('test@example.com', 'Subject', 'Body text')
|
||||
|
||||
// Assert
|
||||
expect(result).toBe('mailto:test@example.com?subject=Subject&body=Body%20text')
|
||||
})
|
||||
})
|
||||
|
||||
describe('mailToSupport', () => {
|
||||
// Transitive coverage: exercises generateMailToLink with all params
|
||||
it('should generate a mailto link with support recipient, plan, account, and version info', () => {
|
||||
// Act
|
||||
const result = mailToSupport('user@test.com', 'Pro', '1.0.0')
|
||||
|
||||
// Assert
|
||||
expect(result.startsWith('mailto:support@dify.ai?')).toBe(true)
|
||||
|
||||
const query = result.split('?')[1]
|
||||
expect(query).toBeDefined()
|
||||
|
||||
const params = new URLSearchParams(query)
|
||||
expect(params.get('subject')).toBe('Technical Support Request Pro user@test.com')
|
||||
|
||||
const body = params.get('body')
|
||||
expect(body).toContain('Current Plan: Pro')
|
||||
expect(body).toContain('Account: user@test.com')
|
||||
expect(body).toContain('Version: 1.0.0')
|
||||
})
|
||||
})
|
||||
@@ -111,11 +111,11 @@ const ToolItem: FC<Props> = ({
|
||||
})
|
||||
}}
|
||||
>
|
||||
<div className={cn('system-sm-medium h-8 truncate border-l-2 border-divider-subtle pl-4 leading-8 text-text-secondary')}>
|
||||
<div className={cn('truncate border-l-2 border-divider-subtle py-2 pl-4 text-text-secondary system-sm-medium')}>
|
||||
<span className={cn(disabled && 'opacity-30')}>{payload.label[language]}</span>
|
||||
</div>
|
||||
{isAdded && (
|
||||
<div className="system-xs-regular mr-4 text-text-tertiary">{t('addToolModal.added', { ns: 'tools' })}</div>
|
||||
<div className="mr-4 text-text-tertiary system-xs-regular">{t('addToolModal.added', { ns: 'tools' })}</div>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip>
|
||||
|
||||
@@ -77,11 +77,11 @@ const TriggerPluginActionItem: FC<Props> = ({
|
||||
})
|
||||
}}
|
||||
>
|
||||
<div className={cn('system-sm-medium h-8 truncate border-l-2 border-divider-subtle pl-4 leading-8 text-text-secondary')}>
|
||||
<div className={cn('truncate border-l-2 border-divider-subtle py-2 pl-4 text-text-secondary system-sm-medium')}>
|
||||
<span className={cn(disabled && 'opacity-30')}>{payload.label[language]}</span>
|
||||
</div>
|
||||
{isAdded && (
|
||||
<div className="system-xs-regular mr-4 text-text-tertiary">{t('addToolModal.added', { ns: 'tools' })}</div>
|
||||
<div className="mr-4 text-text-tertiary system-xs-regular">{t('addToolModal.added', { ns: 'tools' })}</div>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user