mirror of
https://github.com/langgenius/dify.git
synced 2026-03-05 23:25:11 +00:00
Compare commits
2 Commits
deploy/dev
...
feat/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d612c0909 | ||
|
|
56e0dc0ae6 |
@@ -2668,77 +2668,3 @@ def clean_expired_messages(
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@@ -39,7 +39,6 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -185,7 +184,6 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -8,7 +6,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -18,7 +16,6 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -280,168 +277,3 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en"
|
||||
|
||||
# Maps dify interface_language prefixes to notification lang tags.
|
||||
# Any unrecognised prefix falls back to _FALLBACK_LANG.
|
||||
_LANG_MAP: dict[str, str] = {
|
||||
"zh": "zh",
|
||||
"ja": "jp",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_lang(interface_language: str | None) -> str:
|
||||
"""Derive the notification lang tag from the user's interface_language.
|
||||
|
||||
e.g. "zh-Hans" → "zh", "ja-JP" → "jp", "en-US" / None → "en"
|
||||
"""
|
||||
if not interface_language:
|
||||
return _FALLBACK_LANG
|
||||
prefix = interface_language.split("-")[0].lower()
|
||||
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"The notification is NOT marked as seen here; call POST /notification/dismiss "
|
||||
"when the user explicitly closes the modal."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = _resolve_lang(current_user.interface_language)
|
||||
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
class NotificationDismissApi(Resource):
|
||||
@console_ns.doc("dismiss_notification")
|
||||
@console_ns.doc(
|
||||
description="Mark a notification as dismissed for the current user.",
|
||||
responses={200: "Success", 401: "Unauthorized"},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
account_id=str(current_user.id),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -393,78 +393,3 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
calls will return should_show=false when frequency='once'.
|
||||
|
||||
Response shape (mirrors GetAccountNotificationReply):
|
||||
{
|
||||
"should_show": bool,
|
||||
"notification": { # present only when should_show=true
|
||||
"notification_id": str,
|
||||
"contents": { # lang -> LangContent
|
||||
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
|
||||
...
|
||||
},
|
||||
"frequency": "once" | "every_page_load"
|
||||
}
|
||||
}
|
||||
"""
|
||||
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
}
|
||||
if notification_id:
|
||||
payload["notification_id"] = notification_id
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/accounts",
|
||||
json={"account_ids": account_ids},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
|
||||
"""Mark a notification as dismissed for an account.
|
||||
|
||||
Returns {"success": bool}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/dismiss",
|
||||
json={"account_id": account_id},
|
||||
)
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@@ -1,233 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
@@ -1,43 +0,0 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
|
||||
def test_validate_export_filename_accepts_relative_path():
|
||||
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"test01.jsonl.gz",
|
||||
"test01.jsonl",
|
||||
"test01.gz",
|
||||
"/tmp/test01",
|
||||
"exports/../test01",
|
||||
"bad\x00name",
|
||||
"bad\tname",
|
||||
"a" * 1025,
|
||||
],
|
||||
)
|
||||
def test_validate_export_filename_rejects_invalid_values(filename: str):
|
||||
with pytest.raises(ValueError):
|
||||
AppMessageExportService.validate_export_filename(filename)
|
||||
|
||||
|
||||
def test_service_derives_output_names_from_filename_base():
|
||||
service = AppMessageExportService(
|
||||
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
|
||||
start_from=None,
|
||||
end_before=datetime.datetime(2026, 3, 1),
|
||||
filename="exports/2026/test01",
|
||||
batch_size=1000,
|
||||
use_cloud_storage=True,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert service._filename_base == "exports/2026/test01"
|
||||
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
|
||||
assert service.output_jsonl_name == "exports/2026/test01.jsonl"
|
||||
@@ -0,0 +1,399 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { Provider } from 'jotai'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
import {
|
||||
useExpandModelProviderList,
|
||||
useModelProviderListExpanded,
|
||||
useResetModelProviderListExpanded,
|
||||
useSetModelProviderListExpanded,
|
||||
} from './atoms'
|
||||
|
||||
const createWrapper = () => {
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<Provider>{children}</Provider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('atoms', () => {
|
||||
let wrapper: ReturnType<typeof createWrapper>
|
||||
|
||||
beforeEach(() => {
|
||||
wrapper = createWrapper()
|
||||
})
|
||||
|
||||
// Read hook: returns whether a specific provider is expanded
|
||||
describe('useModelProviderListExpanded', () => {
|
||||
it('should return false when provider has not been expanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => useModelProviderListExpanded('openai'),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for any unknown provider name', () => {
|
||||
const { result } = renderHook(
|
||||
() => useModelProviderListExpanded('nonexistent-provider'),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when provider has been expanded via setter', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Setter hook: toggles expanded state for a specific provider
|
||||
describe('useSetModelProviderListExpanded', () => {
|
||||
it('should expand a provider when called with true', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
setExpanded: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should collapse a provider when called with false', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
setExpanded: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setExpanded(false)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should not affect other providers when setting one', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
expect(result.current.anthropicExpanded).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// Expand hook: expands any provider by name
|
||||
describe('useExpandModelProviderList', () => {
|
||||
it('should expand the specified provider', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('google'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('google')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should expand multiple providers independently', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
expect(result.current.anthropicExpanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should not collapse already expanded providers when expanding another', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Reset hook: clears all expanded state back to empty
|
||||
describe('useResetModelProviderListExpanded', () => {
|
||||
it('should reset all expanded providers to false', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(false)
|
||||
expect(result.current.anthropicExpanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should be safe to call when no providers are expanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should allow re-expanding providers after reset', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
expand: useExpandModelProviderList(),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Cross-hook interaction: verify hooks cooperate through the shared atom
|
||||
describe('Cross-hook interaction', () => {
|
||||
it('should reflect state set by useSetModelProviderListExpanded in useModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should reflect state set by useExpandModelProviderList in useModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should allow useSetModelProviderListExpanded to collapse a provider expanded by useExpandModelProviderList', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
expand: useExpandModelProviderList(),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
expect(result.current.expanded).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(false)
|
||||
})
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should reset state set by useSetModelProviderListExpanded via useResetModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// selectAtom granularity: changing one provider should not affect unrelated reads
|
||||
describe('selectAtom granularity', () => {
|
||||
it('should not cause unrelated provider reads to change when one provider is toggled', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openai: useModelProviderListExpanded('openai'),
|
||||
anthropic: useModelProviderListExpanded('anthropic'),
|
||||
google: useModelProviderListExpanded('google'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
const anthropicBefore = result.current.anthropic
|
||||
const googleBefore = result.current.google
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
|
||||
expect(result.current.openai).toBe(true)
|
||||
expect(result.current.anthropic).toBe(anthropicBefore)
|
||||
expect(result.current.google).toBe(googleBefore)
|
||||
})
|
||||
|
||||
it('should keep individual provider states independent across multiple expansions and collapses', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openai: useModelProviderListExpanded('openai'),
|
||||
anthropic: useModelProviderListExpanded('anthropic'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
setAnthropic: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setAnthropic(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setOpenai(false)
|
||||
})
|
||||
|
||||
expect(result.current.openai).toBe(false)
|
||||
expect(result.current.anthropic).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Isolation: separate Provider instances have independent state
|
||||
describe('Provider isolation', () => {
|
||||
it('should have independent state across different Provider instances', () => {
|
||||
const wrapper1 = createWrapper()
|
||||
const wrapper2 = createWrapper()
|
||||
|
||||
const { result: result1 } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper: wrapper1 },
|
||||
)
|
||||
|
||||
const { result: result2 } = renderHook(
|
||||
() => useModelProviderListExpanded('openai'),
|
||||
{ wrapper: wrapper2 },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result1.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result1.current.expanded).toBe(true)
|
||||
expect(result2.current).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,6 +3,16 @@ import { act, renderHook } from '@testing-library/react'
|
||||
import { Provider as JotaiProvider } from 'jotai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createNuqsTestWrapper } from '@/test/nuqs-testing'
|
||||
import {
|
||||
useActivePluginType,
|
||||
useFilterPluginTags,
|
||||
useMarketplaceMoreClick,
|
||||
useMarketplaceSearchMode,
|
||||
useMarketplaceSort,
|
||||
useMarketplaceSortValue,
|
||||
useSearchPluginText,
|
||||
useSetMarketplaceSort,
|
||||
} from '../atoms'
|
||||
import { DEFAULT_SORT } from '../constants'
|
||||
|
||||
const createWrapper = (searchParams = '') => {
|
||||
@@ -22,8 +32,7 @@ describe('Marketplace sort atoms', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return default sort value from useMarketplaceSort', async () => {
|
||||
const { useMarketplaceSort } = await import('../atoms')
|
||||
it('should return default sort value from useMarketplaceSort', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceSort(), { wrapper })
|
||||
|
||||
@@ -31,24 +40,28 @@ describe('Marketplace sort atoms', () => {
|
||||
expect(typeof result.current[1]).toBe('function')
|
||||
})
|
||||
|
||||
it('should return default sort value from useMarketplaceSortValue', async () => {
|
||||
const { useMarketplaceSortValue } = await import('../atoms')
|
||||
it('should return default sort value from useMarketplaceSortValue', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceSortValue(), { wrapper })
|
||||
|
||||
expect(result.current).toEqual(DEFAULT_SORT)
|
||||
})
|
||||
|
||||
it('should return setter from useSetMarketplaceSort', async () => {
|
||||
const { useSetMarketplaceSort } = await import('../atoms')
|
||||
it('should return setter from useSetMarketplaceSort', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useSetMarketplaceSort(), { wrapper })
|
||||
const { result } = renderHook(() => ({
|
||||
setSort: useSetMarketplaceSort(),
|
||||
sortValue: useMarketplaceSortValue(),
|
||||
}), { wrapper })
|
||||
|
||||
expect(typeof result.current).toBe('function')
|
||||
act(() => {
|
||||
result.current.setSort({ sortBy: 'created_at', sortOrder: 'ASC' })
|
||||
})
|
||||
|
||||
expect(result.current.sortValue).toEqual({ sortBy: 'created_at', sortOrder: 'ASC' })
|
||||
})
|
||||
|
||||
it('should update sort value via useMarketplaceSort setter', async () => {
|
||||
const { useMarketplaceSort } = await import('../atoms')
|
||||
it('should update sort value via useMarketplaceSort setter', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceSort(), { wrapper })
|
||||
|
||||
@@ -65,8 +78,7 @@ describe('useSearchPluginText', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return empty string as default', async () => {
|
||||
const { useSearchPluginText } = await import('../atoms')
|
||||
it('should return empty string as default', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
|
||||
|
||||
@@ -74,8 +86,7 @@ describe('useSearchPluginText', () => {
|
||||
expect(typeof result.current[1]).toBe('function')
|
||||
})
|
||||
|
||||
it('should parse q from search params', async () => {
|
||||
const { useSearchPluginText } = await import('../atoms')
|
||||
it('should parse q from search params', () => {
|
||||
const { wrapper } = createWrapper('?q=hello')
|
||||
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
|
||||
|
||||
@@ -83,16 +94,14 @@ describe('useSearchPluginText', () => {
|
||||
})
|
||||
|
||||
it('should expose a setter function for search text', async () => {
|
||||
const { useSearchPluginText } = await import('../atoms')
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useSearchPluginText(), { wrapper })
|
||||
|
||||
expect(typeof result.current[1]).toBe('function')
|
||||
|
||||
// Calling the setter should not throw
|
||||
await act(async () => {
|
||||
result.current[1]('search term')
|
||||
})
|
||||
|
||||
expect(result.current[0]).toBe('search term')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -101,16 +110,14 @@ describe('useActivePluginType', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return "all" as default category', async () => {
|
||||
const { useActivePluginType } = await import('../atoms')
|
||||
it('should return "all" as default category', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useActivePluginType(), { wrapper })
|
||||
|
||||
expect(result.current[0]).toBe('all')
|
||||
})
|
||||
|
||||
it('should parse category from search params', async () => {
|
||||
const { useActivePluginType } = await import('../atoms')
|
||||
it('should parse category from search params', () => {
|
||||
const { wrapper } = createWrapper('?category=tool')
|
||||
const { result } = renderHook(() => useActivePluginType(), { wrapper })
|
||||
|
||||
@@ -123,16 +130,14 @@ describe('useFilterPluginTags', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return empty array as default', async () => {
|
||||
const { useFilterPluginTags } = await import('../atoms')
|
||||
it('should return empty array as default', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useFilterPluginTags(), { wrapper })
|
||||
|
||||
expect(result.current[0]).toEqual([])
|
||||
})
|
||||
|
||||
it('should parse tags from search params', async () => {
|
||||
const { useFilterPluginTags } = await import('../atoms')
|
||||
it('should parse tags from search params', () => {
|
||||
const { wrapper } = createWrapper('?tags=search')
|
||||
const { result } = renderHook(() => useFilterPluginTags(), { wrapper })
|
||||
|
||||
@@ -145,42 +150,35 @@ describe('useMarketplaceSearchMode', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return false when no search text, no tags, and category has collections (all)', async () => {
|
||||
const { useMarketplaceSearchMode } = await import('../atoms')
|
||||
it('should return false when no search text, no tags, and category has collections (all)', () => {
|
||||
const { wrapper } = createWrapper('?category=all')
|
||||
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
|
||||
|
||||
// "all" is in PLUGIN_CATEGORY_WITH_COLLECTIONS, so search mode should be false
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when search text is present', async () => {
|
||||
const { useMarketplaceSearchMode } = await import('../atoms')
|
||||
it('should return true when search text is present', () => {
|
||||
const { wrapper } = createWrapper('?q=test&category=all')
|
||||
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
|
||||
|
||||
expect(result.current).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when tags are present', async () => {
|
||||
const { useMarketplaceSearchMode } = await import('../atoms')
|
||||
it('should return true when tags are present', () => {
|
||||
const { wrapper } = createWrapper('?tags=search&category=all')
|
||||
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
|
||||
|
||||
expect(result.current).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when category does not have collections (e.g. model)', async () => {
|
||||
const { useMarketplaceSearchMode } = await import('../atoms')
|
||||
it('should return true when category does not have collections (e.g. model)', () => {
|
||||
const { wrapper } = createWrapper('?category=model')
|
||||
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
|
||||
|
||||
// "model" is NOT in PLUGIN_CATEGORY_WITH_COLLECTIONS, so search mode = true
|
||||
expect(result.current).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when category has collections (tool) and no search/tags', async () => {
|
||||
const { useMarketplaceSearchMode } = await import('../atoms')
|
||||
it('should return false when category has collections (tool) and no search/tags', () => {
|
||||
const { wrapper } = createWrapper('?category=tool')
|
||||
const { result } = renderHook(() => useMarketplaceSearchMode(), { wrapper })
|
||||
|
||||
@@ -193,27 +191,33 @@ describe('useMarketplaceMoreClick', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return a callback function', async () => {
|
||||
const { useMarketplaceMoreClick } = await import('../atoms')
|
||||
it('should return a callback function', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
|
||||
|
||||
expect(typeof result.current).toBe('function')
|
||||
})
|
||||
|
||||
it('should do nothing when called with no params', async () => {
|
||||
const { useMarketplaceMoreClick } = await import('../atoms')
|
||||
it('should do nothing when called with no params', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
|
||||
const { result } = renderHook(() => ({
|
||||
handleMoreClick: useMarketplaceMoreClick(),
|
||||
sort: useMarketplaceSortValue(),
|
||||
searchText: useSearchPluginText()[0],
|
||||
}), { wrapper })
|
||||
|
||||
const sortBefore = result.current.sort
|
||||
const searchTextBefore = result.current.searchText
|
||||
|
||||
// Should not throw when called with undefined
|
||||
act(() => {
|
||||
result.current(undefined)
|
||||
result.current.handleMoreClick(undefined)
|
||||
})
|
||||
|
||||
expect(result.current.sort).toEqual(sortBefore)
|
||||
expect(result.current.searchText).toBe(searchTextBefore)
|
||||
})
|
||||
|
||||
it('should update search state when called with search params', async () => {
|
||||
const { useMarketplaceMoreClick, useMarketplaceSortValue } = await import('../atoms')
|
||||
it('should update search state when called with search params', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
|
||||
const { result } = renderHook(() => ({
|
||||
@@ -229,17 +233,20 @@ describe('useMarketplaceMoreClick', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// Sort should be updated via the jotai atom
|
||||
expect(result.current.sort).toEqual({ sortBy: 'created_at', sortOrder: 'ASC' })
|
||||
})
|
||||
|
||||
it('should use defaults when search params fields are missing', async () => {
|
||||
const { useMarketplaceMoreClick } = await import('../atoms')
|
||||
it('should use defaults when search params fields are missing', () => {
|
||||
const { wrapper } = createWrapper()
|
||||
const { result } = renderHook(() => useMarketplaceMoreClick(), { wrapper })
|
||||
const { result } = renderHook(() => ({
|
||||
handleMoreClick: useMarketplaceMoreClick(),
|
||||
sort: useMarketplaceSortValue(),
|
||||
}), { wrapper })
|
||||
|
||||
act(() => {
|
||||
result.current({})
|
||||
result.current.handleMoreClick({})
|
||||
})
|
||||
|
||||
expect(result.current.sort).toEqual(DEFAULT_SORT)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -74,31 +74,40 @@ describe('PluginTypeSwitch', () => {
|
||||
const { Wrapper } = createWrapper('?category=all')
|
||||
render(<PluginTypeSwitch />, { wrapper: Wrapper })
|
||||
|
||||
// Click on Models option — should not throw
|
||||
expect(() => fireEvent.click(screen.getByText('Models'))).not.toThrow()
|
||||
fireEvent.click(screen.getByText('Models'))
|
||||
|
||||
const modelsButton = screen.getByText('Models').closest('div')
|
||||
expect(modelsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
|
||||
})
|
||||
|
||||
it('should handle clicking on category with collections (Tools)', () => {
|
||||
const { Wrapper } = createWrapper('?category=model')
|
||||
render(<PluginTypeSwitch />, { wrapper: Wrapper })
|
||||
|
||||
// Click on "Tools" which has collections → setSearchMode(null)
|
||||
expect(() => fireEvent.click(screen.getByText('Tools'))).not.toThrow()
|
||||
fireEvent.click(screen.getByText('Tools'))
|
||||
|
||||
const toolsButton = screen.getByText('Tools').closest('div')
|
||||
expect(toolsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
|
||||
})
|
||||
|
||||
it('should handle clicking on category without collections (Models)', () => {
|
||||
const { Wrapper } = createWrapper('?category=all')
|
||||
render(<PluginTypeSwitch />, { wrapper: Wrapper })
|
||||
|
||||
// Click on "Models" which does NOT have collections → no setSearchMode call
|
||||
expect(() => fireEvent.click(screen.getByText('Models'))).not.toThrow()
|
||||
fireEvent.click(screen.getByText('Models'))
|
||||
|
||||
const modelsButton = screen.getByText('Models').closest('div')
|
||||
expect(modelsButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
|
||||
})
|
||||
|
||||
it('should handle clicking on bundles', () => {
|
||||
const { Wrapper } = createWrapper('?category=all')
|
||||
render(<PluginTypeSwitch />, { wrapper: Wrapper })
|
||||
|
||||
expect(() => fireEvent.click(screen.getByText('Bundles'))).not.toThrow()
|
||||
fireEvent.click(screen.getByText('Bundles'))
|
||||
|
||||
const bundlesButton = screen.getByText('Bundles').closest('div')
|
||||
expect(bundlesButton?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
|
||||
})
|
||||
|
||||
it('should handle clicking on each category', () => {
|
||||
@@ -107,7 +116,10 @@ describe('PluginTypeSwitch', () => {
|
||||
|
||||
const categories = ['All', 'Models', 'Tools', 'Data Sources', 'Triggers', 'Agents', 'Extensions', 'Bundles']
|
||||
categories.forEach((category) => {
|
||||
expect(() => fireEvent.click(screen.getByText(category))).not.toThrow()
|
||||
fireEvent.click(screen.getByText(category))
|
||||
|
||||
const button = screen.getByText(category).closest('div')
|
||||
expect(button?.className).toContain('!bg-components-main-nav-nav-button-bg-active')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user