Compare commits

..

2 Commits

Author SHA1 Message Date
yyh
6d612c0909 test: improve Jotai atom test quality and add model-provider atoms tests
Replace dynamic imports with static imports in marketplace atom tests.
Convert type-only and not-toThrow assertions into proper state-change
verifications. Add comprehensive test suite for model-provider-page
atoms covering all four hooks, cross-hook interaction, selectAtom
granularity, and Provider isolation.
2026-03-05 22:49:09 +08:00
yyh
56e0dc0ae6 trigger ci
Signed-off-by: yyh <yuanyouhuilyz@gmail.com>
2026-03-05 21:22:03 +08:00
12 changed files with 479 additions and 1070 deletions

View File

@@ -2668,77 +2668,3 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -39,7 +39,6 @@ from . import (
feature,
human_input_form,
init_validate,
notification,
ping,
setup,
spec,
@@ -185,7 +184,6 @@ __all__ = [
"model_config",
"model_providers",
"models",
"notification",
"oauth",
"oauth_server",
"ops_trace",

View File

@@ -1,5 +1,3 @@
import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
@@ -8,7 +6,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
@@ -18,7 +16,6 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
@@ -280,168 +277,3 @@ class DeleteExploreBannerApi(Resource):
db.session.commit()
return {"result": "success"}, 204
class LangContentPayload(BaseModel):
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
title: str = Field(...)
subtitle: str | None = Field(default=None)
body: str = Field(...)
title_pic_url: str | None = Field(default=None)
class UpsertNotificationPayload(BaseModel):
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
contents: list[LangContentPayload] = Field(..., min_length=1)
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
status: str = Field(default="active", description="'active' | 'inactive'")
class BatchAddNotificationAccountsPayload(BaseModel):
notification_id: str = Field(...)
user_email: list[str] = Field(..., description="List of account email addresses")
console_ns.schema_model(
UpsertNotificationPayload.__name__,
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
BatchAddNotificationAccountsPayload.__name__,
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/admin/upsert_notification")
class UpsertNotificationApi(Resource):
@console_ns.doc("upsert_notification")
@console_ns.doc(
description=(
"Create or update an in-product notification. "
"Supply notification_id to update an existing one; omit it to create a new one. "
"Pass at least one language variant in contents (zh / en / jp)."
)
)
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
@console_ns.response(200, "Notification upserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,
start_time=payload.start_time,
end_time=payload.end_time,
)
return {"result": "success", "notification_id": result.get("notificationId")}, 200
@console_ns.route("/admin/batch_add_notification_accounts")
class BatchAddNotificationAccountsApi(Resource):
@console_ns.doc("batch_add_notification_accounts")
@console_ns.doc(
description=(
"Register target accounts for a notification by email address. "
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
"plus a 'notification_id' field. "
"Emails that do not match any account are silently skipped."
)
)
@console_ns.response(200, "Accounts added successfully")
@only_edition_cloud
@admin_required
def post(self):
from models.account import Account
if "file" in request.files:
notification_id = request.form.get("notification_id", "").strip()
if not notification_id:
raise BadRequest("notification_id is required.")
emails = self._parse_emails_from_file()
else:
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
notification_id = payload.notification_id
emails = payload.user_email
if not emails:
raise BadRequest("No valid email addresses provided.")
# Resolve emails → account IDs in chunks to avoid large IN-clause
account_ids: list[str] = []
chunk_size = 500
for i in range(0, len(emails), chunk_size):
chunk = emails[i : i + chunk_size]
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
account_ids.extend(str(row.id) for row in rows)
if not account_ids:
raise BadRequest("None of the provided emails matched an existing account.")
# Send to dify-saas in batches of 1000
total_count = 0
batch_size = 1000
for i in range(0, len(account_ids), batch_size):
batch = account_ids[i : i + batch_size]
result = BillingService.batch_add_notification_accounts(
notification_id=notification_id,
account_ids=batch,
)
total_count += result.get("count", 0)
return {
"result": "success",
"emails_provided": len(emails),
"accounts_matched": len(account_ids),
"count": total_count,
}, 200
@staticmethod
def _parse_emails_from_file() -> list[str]:
"""Parse email addresses from an uploaded CSV or TXT file."""
file = request.files["file"]
if not file.filename:
raise BadRequest("Uploaded file has no filename.")
filename_lower = file.filename.lower()
if not filename_lower.endswith((".csv", ".txt")):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
try:
content = file.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.seek(0)
content = file.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
emails: list[str] = []
if filename_lower.endswith(".csv"):
reader = csv.reader(io.StringIO(content))
for row in reader:
for cell in row:
cell = cell.strip()
if cell:
emails.append(cell)
else:
for line in content.splitlines():
line = line.strip()
if line:
emails.append(line)
# Deduplicate while preserving order
seen: set[str] = set()
unique_emails: list[str] = []
for email in emails:
if email.lower() not in seen:
seen.add(email.lower())
unique_emails.append(email)
return unique_emails

View File

@@ -1,108 +0,0 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
_FALLBACK_LANG = "en"
# Maps dify interface_language prefixes to notification lang tags.
# Any unrecognised prefix falls back to _FALLBACK_LANG.
_LANG_MAP: dict[str, str] = {
"zh": "zh",
"ja": "jp",
}
def _resolve_lang(interface_language: str | None) -> str:
"""Derive the notification lang tag from the user's interface_language.
e.g. "zh-Hans""zh", "ja-JP""jp", "en-US" / None → "en"
"""
if not interface_language:
return _FALLBACK_LANG
prefix = interface_language.split("-")[0].lower()
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
class DismissNotificationPayload(BaseModel):
notification_id: str = Field(...)
@console_ns.route("/notification")
class NotificationApi(Resource):
@console_ns.doc("get_notification")
@console_ns.doc(
description=(
"Return the active in-product notification for the current user "
"in their interface language (falls back to English if unavailable). "
"The notification is NOT marked as seen here; call POST /notification/dismiss "
"when the user explicitly closes the modal."
),
responses={
200: "Success — inspect should_show to decide whether to render the modal",
401: "Unauthorized",
},
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
if not result.get("shouldShow"):
return {"should_show": False, "notifications": []}, 200
lang = _resolve_lang(current_user.interface_language)
notifications = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
return {"should_show": bool(notifications), "notifications": notifications}, 200
@console_ns.route("/notification/dismiss")
class NotificationDismissApi(Resource):
@console_ns.doc("dismiss_notification")
@console_ns.doc(
description="Mark a notification as dismissed for the current user.",
responses={200: "Success", 401: "Unauthorized"},
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def post(self):
current_user, _ = current_account_with_tenant()
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,
account_id=str(current_user.id),
)
return {"result": "success"}, 200

View File

@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -393,78 +393,3 @@ class BillingService:
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist
@classmethod
def get_account_notification(cls, account_id: str) -> dict:
"""Return the active in-product notification for account_id, if any.
Calling this endpoint also marks the notification as seen; subsequent
calls will return should_show=false when frequency='once'.
Response shape (mirrors GetAccountNotificationReply):
{
"should_show": bool,
"notification": { # present only when should_show=true
"notification_id": str,
"contents": { # lang -> LangContent
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
...
},
"frequency": "once" | "every_page_load"
}
}
"""
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
@classmethod
def upsert_notification(
cls,
contents: list[dict],
frequency: str = "once",
status: str = "active",
notification_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
) -> dict:
"""Create or update a notification.
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
Returns {"notification_id": str}.
"""
payload: dict = {
"contents": contents,
"frequency": frequency,
"status": status,
}
if notification_id:
payload["notification_id"] = notification_id
if start_time:
payload["start_time"] = start_time
if end_time:
payload["end_time"] = end_time
return cls._send_request("POST", "/notifications", json=payload)
@classmethod
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
"""Register target account IDs for a notification (max 1000 per call).
Returns {"count": int}.
"""
return cls._send_request(
"POST",
f"/notifications/{notification_id}/accounts",
json={"account_ids": account_ids},
)
@classmethod
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
"""Mark a notification as dismissed for an account.
Returns {"success": bool}.
"""
return cls._send_request(
"POST",
f"/notifications/{notification_id}/dismiss",
json={"account_id": account_id},
)

View File

@@ -1,304 +0,0 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -1,233 +0,0 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -1,43 +0,0 @@
import datetime
import pytest
from services.retention.conversation.message_export_service import AppMessageExportService
def test_validate_export_filename_accepts_relative_path():
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
@pytest.mark.parametrize(
"filename",
[
"test01.jsonl.gz",
"test01.jsonl",
"test01.gz",
"/tmp/test01",
"exports/../test01",
"bad\x00name",
"bad\tname",
"a" * 1025,
],
)
def test_validate_export_filename_rejects_invalid_values(filename: str):
with pytest.raises(ValueError):
AppMessageExportService.validate_export_filename(filename)
def test_service_derives_output_names_from_filename_base():
service = AppMessageExportService(
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
start_from=None,
end_before=datetime.datetime(2026, 3, 1),
filename="exports/2026/test01",
batch_size=1000,
use_cloud_storage=True,
dry_run=True,
)
assert service._filename_base == "exports/2026/test01"
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
assert service.output_jsonl_name == "exports/2026/test01.jsonl"

View File

@@ -0,0 +1,399 @@
import type { ReactNode } from 'react'
import { act, renderHook } from '@testing-library/react'
import { Provider } from 'jotai'
import { beforeEach, describe, expect, it } from 'vitest'
import {
useExpandModelProviderList,
useModelProviderListExpanded,
useResetModelProviderListExpanded,
useSetModelProviderListExpanded,
} from './atoms'
const createWrapper = () => {
return ({ children }: { children: ReactNode }) => (
<Provider>{children}</Provider>
)
}
describe('atoms', () => {
let wrapper: ReturnType<typeof createWrapper>
beforeEach(() => {
wrapper = createWrapper()
})
// Read hook: returns whether a specific provider is expanded
describe('useModelProviderListExpanded', () => {
it('should return false when provider has not been expanded', () => {
const { result } = renderHook(
() => useModelProviderListExpanded('openai'),
{ wrapper },
)
expect(result.current).toBe(false)
})
it('should return false for any unknown provider name', () => {
const { result } = renderHook(
() => useModelProviderListExpanded('nonexistent-provider'),
{ wrapper },
)
expect(result.current).toBe(false)
})
it('should return true when provider has been expanded via setter', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
})
// Setter hook: toggles expanded state for a specific provider
describe('useSetModelProviderListExpanded', () => {
it('should expand a provider when called with true', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
setExpanded: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
it('should collapse a provider when called with false', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
setExpanded: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
act(() => {
result.current.setExpanded(false)
})
expect(result.current.expanded).toBe(false)
})
it('should not affect other providers when setting one', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
setOpenai: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setOpenai(true)
})
expect(result.current.openaiExpanded).toBe(true)
expect(result.current.anthropicExpanded).toBe(false)
})
})
// Expand hook: expands any provider by name
describe('useExpandModelProviderList', () => {
it('should expand the specified provider', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('google'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('google')
})
expect(result.current.expanded).toBe(true)
})
it('should expand multiple providers independently', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
expect(result.current.openaiExpanded).toBe(true)
expect(result.current.anthropicExpanded).toBe(true)
})
it('should not collapse already expanded providers when expanding another', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
expect(result.current.openaiExpanded).toBe(true)
})
})
// Reset hook: clears all expanded state back to empty
describe('useResetModelProviderListExpanded', () => {
it('should reset all expanded providers to false', () => {
const { result } = renderHook(
() => ({
openaiExpanded: useModelProviderListExpanded('openai'),
anthropicExpanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.expand('anthropic')
})
act(() => {
result.current.reset()
})
expect(result.current.openaiExpanded).toBe(false)
expect(result.current.anthropicExpanded).toBe(false)
})
it('should be safe to call when no providers are expanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.reset()
})
expect(result.current.expanded).toBe(false)
})
it('should allow re-expanding providers after reset', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
expand: useExpandModelProviderList(),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
act(() => {
result.current.reset()
})
act(() => {
result.current.expand('openai')
})
expect(result.current.expanded).toBe(true)
})
})
// Cross-hook interaction: verify hooks cooperate through the shared atom
describe('Cross-hook interaction', () => {
it('should reflect state set by useSetModelProviderListExpanded in useModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
expect(result.current.expanded).toBe(true)
})
it('should reflect state set by useExpandModelProviderList in useModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('anthropic'),
expand: useExpandModelProviderList(),
}),
{ wrapper },
)
act(() => {
result.current.expand('anthropic')
})
expect(result.current.expanded).toBe(true)
})
it('should allow useSetModelProviderListExpanded to collapse a provider expanded by useExpandModelProviderList', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
expand: useExpandModelProviderList(),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
act(() => {
result.current.expand('openai')
})
expect(result.current.expanded).toBe(true)
act(() => {
result.current.setExpanded(false)
})
expect(result.current.expanded).toBe(false)
})
it('should reset state set by useSetModelProviderListExpanded via useResetModelProviderListExpanded', () => {
const { result } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
reset: useResetModelProviderListExpanded(),
}),
{ wrapper },
)
act(() => {
result.current.setExpanded(true)
})
act(() => {
result.current.reset()
})
expect(result.current.expanded).toBe(false)
})
})
// selectAtom granularity: changing one provider should not affect unrelated reads
describe('selectAtom granularity', () => {
it('should not cause unrelated provider reads to change when one provider is toggled', () => {
const { result } = renderHook(
() => ({
openai: useModelProviderListExpanded('openai'),
anthropic: useModelProviderListExpanded('anthropic'),
google: useModelProviderListExpanded('google'),
setOpenai: useSetModelProviderListExpanded('openai'),
}),
{ wrapper },
)
const anthropicBefore = result.current.anthropic
const googleBefore = result.current.google
act(() => {
result.current.setOpenai(true)
})
expect(result.current.openai).toBe(true)
expect(result.current.anthropic).toBe(anthropicBefore)
expect(result.current.google).toBe(googleBefore)
})
it('should keep individual provider states independent across multiple expansions and collapses', () => {
const { result } = renderHook(
() => ({
openai: useModelProviderListExpanded('openai'),
anthropic: useModelProviderListExpanded('anthropic'),
setOpenai: useSetModelProviderListExpanded('openai'),
setAnthropic: useSetModelProviderListExpanded('anthropic'),
}),
{ wrapper },
)
act(() => {
result.current.setOpenai(true)
})
act(() => {
result.current.setAnthropic(true)
})
act(() => {
result.current.setOpenai(false)
})
expect(result.current.openai).toBe(false)
expect(result.current.anthropic).toBe(true)
})
})
// Isolation: separate Provider instances have independent state
describe('Provider isolation', () => {
it('should have independent state across different Provider instances', () => {
const wrapper1 = createWrapper()
const wrapper2 = createWrapper()
const { result: result1 } = renderHook(
() => ({
expanded: useModelProviderListExpanded('openai'),
setExpanded: useSetModelProviderListExpanded('openai'),
}),
{ wrapper: wrapper1 },
)
const { result: result2 } = renderHook(
() => useModelProviderListExpanded('openai'),
{ wrapper: wrapper2 },
)
act(() => {
result1.current.setExpanded(true)
})
expect(result1.current.expanded).toBe(true)
expect(result2.current).toBe(false)
})
})
})

View File

@@ -3,6 +3,16 @@ import { act, renderHook } from '@testing-library/react'
import { Provider as JotaiProvider } from 'jotai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
import {
useActivePluginType,
useFilterPluginTags,
useMarketplaceMoreClick,
useMarketplaceSearchMode,
useMarketplaceSort,
useMarketplaceSortValue,
useSearchPluginText,
useSetMarketplaceSort,
} from '../atoms'
import { DEFAULT_SORT } from '../constants'
const createWrapper = (searchParams = '') => {
@@ -22,8 +32,7 @@ describe('Marketplace sort atoms', () => {
vi.clearAllMocks()
})
it('should return default sort value from useMarketplaceSort', async () => {
const { useMarketplaceSort } = await import('../atoms')
it('should return default sort value from useMarketplaceSort', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceSort(), { wrapper })
@@ -31,24 +40,28 @@ describe('Marketplace sort atoms', () => {
expect(typeof result.current[1]).toBe('function')
})
it('should return default sort value from useMarketplaceSortValue', async () => {
const { useMarketplaceSortValue } = await import('../atoms')
it('should return default sort value from useMarketplaceSortValue', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceSortValue(), { wrapper })
expect(result.current).toEqual(DEFAULT_SORT)
})
it('should return setter from useSetMarketplaceSort', async () => {
const { useSetMarketplaceSort } = await import('../atoms')
it('should return setter from useSetMarketplaceSort', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useSetMarketplaceSort(), { wrapper })
const { result } = renderHook(() => ({
setSort: useSetMarketplaceSort(),
sortValue: useMarketplaceSortValue(),
}), { wrapper })
expect(typeof result.current).toBe('function')
act(() => {
result.current.setSort({ sortBy: 'created_at', sortOrder: 'ASC' })
})
expect(result.current.sortValue).toEqual({ sortBy: 'created_at', sortOrder: 'ASC' })
})
it('should update sort value via useMarketplaceSort setter', async () => {
const { useMarketplaceSort } = await import('../atoms')
it('should update sort value via useMarketplaceSort setter', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceSort(), { wrapper })
@@ -65,8 +78,7 @@ describe('useSearchPluginText', () => {
vi.clearAllMocks()
})
it('should return empty string as default', async () => {
const { useSearchPluginText } = await import('../atoms')
it('should return empty string as default', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
@@ -74,8 +86,7 @@ describe('useSearchPluginText', () => {
expect(typeof result.current[1]).toBe('function')
})
it('should parse q from search params', async () => {
const { useSearchPluginText } = await import('../atoms')
it('should parse q from search params', () => {
const { wrapper } = createWrapper('?q=hello')
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
@@ -83,16 +94,14 @@ describe('useSearchPluginText', () => {
})
it('should expose a setter function for search text', async () => {
const { useSearchPluginText } = await import('../atoms')
const { wrapper } = createWrapper()
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
expect(typeof result.current[1]).toBe('function')
// Calling the setter should not throw
await act(async () => {
result.current[1]('search term')
})
expect(result.current[0]).toBe('search term')
})
})
@@ -101,16 +110,14 @@ describe('useActivePluginType', () => {
vi.clearAllMocks()
})
it('should return "all" as default category', async () => {
const { useActivePluginType } = await import('../atoms')
it('should return "all" as default category', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useActivePluginType(), { wrapper })
expect(result.current[0]).toBe('all')
})
it('should parse category from search params', async () => {
const { useActivePluginType } = await import('../atoms')
it('should parse category from search params', () => {
const { wrapper } = createWrapper('?category=tool')
const { result } = renderHook(() => useActivePluginType(), { wrapper })
@@ -123,16 +130,14 @@ describe('useFilterPluginTags', () => {
vi.clearAllMocks()
})
it('should return empty array as default', async () => {
const { useFilterPluginTags } = await import('../atoms')
it('should return empty array as default', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useFilterPluginTags(), { wrapper })
expect(result.current[0]).toEqual([])
})
it('should parse tags from search params', async () => {
const { useFilterPluginTags } = await import('../atoms')
it('should parse tags from search params', () => {
const { wrapper } = createWrapper('?tags=search')
const { result } = renderHook(() => useFilterPluginTags(), { wrapper })
@@ -145,42 +150,35 @@ describe('useMarketplaceSearchMode', () => {
vi.clearAllMocks()
})
it('should return false when no search text, no tags, and category has collections (all)', async () => {
const { useMarketplaceSearchMode } = await import('../atoms')
it('should return false when no search text, no tags, and category has collections (all)', () => {
const { wrapper } = createWrapper('?category=all')
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
// "all" is in PLUGIN_CATEGORY_WITH_COLLECTIONS, so search mode should be false
expect(result.current).toBe(false)
})
it('should return true when search text is present', async () => {
const { useMarketplaceSearchMode } = await import('../atoms')
it('should return true when search text is present', () => {
const { wrapper } = createWrapper('?q=test&category=all')
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
expect(result.current).toBe(true)
})
it('should return true when tags are present', async () => {
const { useMarketplaceSearchMode } = await import('../atoms')
it('should return true when tags are present', () => {
const { wrapper } = createWrapper('?tags=search&category=all')
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
expect(result.current).toBe(true)
})
it('should return true when category does not have collections (e.g. model)', async () => {
const { useMarketplaceSearchMode } = await import('../atoms')
it('should return true when category does not have collections (e.g. model)', () => {
const { wrapper } = createWrapper('?category=model')
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
// "model" is NOT in PLUGIN_CATEGORY_WITH_COLLECTIONS, so search mode = true
expect(result.current).toBe(true)
})
it('should return false when category has collections (tool) and no search/tags', async () => {
const { useMarketplaceSearchMode } = await import('../atoms')
it('should return false when category has collections (tool) and no search/tags', () => {
const { wrapper } = createWrapper('?category=tool')
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
@@ -193,27 +191,33 @@ describe('useMarketplaceMoreClick', () => {
vi.clearAllMocks()
})
it('should return a callback function', async () => {
const { useMarketplaceMoreClick } = await import('../atoms')
it('should return a callback function', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
expect(typeof result.current).toBe('function')
})
it('should do nothing when called with no params', async () => {
const { useMarketplaceMoreClick } = await import('../atoms')
it('should do nothing when called with no params', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
const { result } = renderHook(() => ({
handleMoreClick: useMarketplaceMoreClick(),
sort: useMarketplaceSortValue(),
searchText: useSearchPluginText()[0],
}), { wrapper })
const sortBefore = result.current.sort
const searchTextBefore = result.current.searchText
// Should not throw when called with undefined
act(() => {
result.current(undefined)
result.current.handleMoreClick(undefined)
})
expect(result.current.sort).toEqual(sortBefore)
expect(result.current.searchText).toBe(searchTextBefore)
})
it('should update search state when called with search params', async () => {
const { useMarketplaceMoreClick, useMarketplaceSortValue } = await import('../atoms')
it('should update search state when called with search params', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => ({
@@ -229,17 +233,20 @@ describe('useMarketplaceMoreClick', () => {
})
})
// Sort should be updated via the jotai atom
expect(result.current.sort).toEqual({ sortBy: 'created_at', sortOrder: 'ASC' })
})
it('should use defaults when search params fields are missing', async () => {
const { useMarketplaceMoreClick } = await import('../atoms')
it('should use defaults when search params fields are missing', () => {
const { wrapper } = createWrapper()
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
const { result } = renderHook(() => ({
handleMoreClick: useMarketplaceMoreClick(),
sort: useMarketplaceSortValue(),
}), { wrapper })
act(() => {
result.current({})
result.current.handleMoreClick({})
})
expect(result.current.sort).toEqual(DEFAULT_SORT)
})
})

View File

@@ -74,31 +74,40 @@ describe('PluginTypeSwitch', () => {
const { Wrapper } = createWrapper('?category=all')
render(<PluginTypeSwitch />, { wrapper: Wrapper })
// Click on Models option — should not throw
expect(() => fireEvent.click(screen.getByText('Models'))).not.toThrow()
fireEvent.click(screen.getByText('Models'))
const modelsButton = screen.getByText('Models').closest('div')
expect(modelsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
})
it('should handle clicking on category with collections (Tools)', () => {
const { Wrapper } = createWrapper('?category=model')
render(<PluginTypeSwitch />, { wrapper: Wrapper })
// Click on "Tools" which has collections → setSearchMode(null)
expect(() => fireEvent.click(screen.getByText('Tools'))).not.toThrow()
fireEvent.click(screen.getByText('Tools'))
const toolsButton = screen.getByText('Tools').closest('div')
expect(toolsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
})
it('should handle clicking on category without collections (Models)', () => {
const { Wrapper } = createWrapper('?category=all')
render(<PluginTypeSwitch />, { wrapper: Wrapper })
// Click on "Models" which does NOT have collections → no setSearchMode call
expect(() => fireEvent.click(screen.getByText('Models'))).not.toThrow()
fireEvent.click(screen.getByText('Models'))
const modelsButton = screen.getByText('Models').closest('div')
expect(modelsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
})
it('should handle clicking on bundles', () => {
const { Wrapper } = createWrapper('?category=all')
render(<PluginTypeSwitch />, { wrapper: Wrapper })
expect(() => fireEvent.click(screen.getByText('Bundles'))).not.toThrow()
fireEvent.click(screen.getByText('Bundles'))
const bundlesButton = screen.getByText('Bundles').closest('div')
expect(bundlesButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
})
it('should handle clicking on each category', () => {
@@ -107,7 +116,10 @@ describe('PluginTypeSwitch', () => {
const categories = ['All', 'Models', 'Tools', 'Data Sources', 'Triggers', 'Agents', 'Extensions', 'Bundles']
categories.forEach((category) => {
expect(() => fireEvent.click(screen.getByText(category))).not.toThrow()
fireEvent.click(screen.getByText(category))
const button = screen.getByText(category).closest('div')
expect(button?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
})
})