mirror of
https://github.com/langgenius/dify.git
synced 2026-03-05 07:05:14 +00:00
Compare commits
88 Commits
3-4-modern
...
deploy/dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00a79a3e26 | ||
|
|
7471c32612 | ||
|
|
88d87d6053 | ||
|
|
2d333bbbe5 | ||
|
|
4af6788ce0 | ||
|
|
24b072def9 | ||
|
|
909c8c3350 | ||
|
|
89a859ae32 | ||
|
|
80e9c8bee0 | ||
|
|
15b7b304d2 | ||
|
|
61e2672b59 | ||
|
|
5f4ed4c6f6 | ||
|
|
4a1032c628 | ||
|
|
423c97a47e | ||
|
|
a7e3fb2e33 | ||
|
|
ce34937a1c | ||
|
|
ad9ac6978e | ||
|
|
e2b247d762 | ||
|
|
57c1ba3543 | ||
|
|
37b15acd0d | ||
|
|
d7a5af2b9a | ||
|
|
a4373d8b7b | ||
|
|
0ce9ebed63 | ||
|
|
d45edffaa3 | ||
|
|
530515b6ef | ||
|
|
1fa68d0863 | ||
|
|
f13f0d1f9a | ||
|
|
b597d52c11 | ||
|
|
34c42fe666 | ||
|
|
dc109c99f0 | ||
|
|
4a0770192e | ||
|
|
223b9d89c1 | ||
|
|
dd119eb44f | ||
|
|
164ccb7c48 | ||
|
|
2b5ce196ad | ||
|
|
2977a4d2a4 | ||
|
|
a0331b8b45 | ||
|
|
914bd4d00d | ||
|
|
9c9cb50981 | ||
|
|
970493fa85 | ||
|
|
ab87ac333a | ||
|
|
b8b70da9ad | ||
|
|
df3c66a8ac | ||
|
|
7252ce6f26 | ||
|
|
26d96f97a7 | ||
|
|
77d81aebe8 | ||
|
|
deb4cd3ece | ||
|
|
648d9ef1f9 | ||
|
|
5ed4797078 | ||
|
|
62631658e9 | ||
|
|
22a4100dd7 | ||
|
|
0f7ed6f67e | ||
|
|
4d9fcbec57 | ||
|
|
4d7a9bc798 | ||
|
|
d6d04ed657 | ||
|
|
f594a71dae | ||
|
|
04e0ab7eda | ||
|
|
784bda9c86 | ||
|
|
1af1fb6913 | ||
|
|
1f0c36e9f7 | ||
|
|
455ae65025 | ||
|
|
d44682e957 | ||
|
|
8c4afc0c18 | ||
|
|
539cbcae6a | ||
|
|
8d257fea7c | ||
|
|
c3364ac350 | ||
|
|
f991644989 | ||
|
|
29e344ac8b | ||
|
|
1ad9305732 | ||
|
|
17f38f171d | ||
|
|
802088c8eb | ||
|
|
cad6d94491 | ||
|
|
621d0fb2c9 | ||
|
|
efdd88f78a | ||
|
|
a92fb3244b | ||
|
|
97508f8d7b | ||
|
|
b2f84bf081 | ||
|
|
70e677a6ac | ||
|
|
2330aac623 | ||
|
|
97769c5c7a | ||
|
|
09ae3a9b52 | ||
|
|
9589bba713 | ||
|
|
6473c1419b | ||
|
|
d1a0b9695c | ||
|
|
3147e44a0b | ||
|
|
c243e91668 | ||
|
|
004fbbe52b | ||
|
|
63fb0ddde5 |
@@ -29,7 +29,7 @@ The codebase is split into:
|
||||
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
@@ -2668,3 +2668,68 @@ def clean_expired_messages(
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option("--filename", required=True, help="Output filename (local path or cloud storage key).")
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@@ -39,6 +39,7 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -184,6 +185,7 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -6,7 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
108
api/controllers/console/notification.py
Normal file
108
api/controllers/console/notification.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en"
|
||||
|
||||
# Maps dify interface_language prefixes to notification lang tags.
|
||||
# Any unrecognised prefix falls back to _FALLBACK_LANG.
|
||||
_LANG_MAP: dict[str, str] = {
|
||||
"zh": "zh",
|
||||
"ja": "jp",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_lang(interface_language: str | None) -> str:
|
||||
"""Derive the notification lang tag from the user's interface_language.
|
||||
|
||||
e.g. "zh-Hans" → "zh", "ja-JP" → "jp", "en-US" / None → "en"
|
||||
"""
|
||||
if not interface_language:
|
||||
return _FALLBACK_LANG
|
||||
prefix = interface_language.split("-")[0].lower()
|
||||
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"The notification is NOT marked as seen here; call POST /notification/dismiss "
|
||||
"when the user explicitly closes the modal."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = _resolve_lang(current_user.interface_language)
|
||||
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
class NotificationDismissApi(Resource):
|
||||
@console_ns.doc("dismiss_notification")
|
||||
@console_ns.doc(
|
||||
description="Mark a notification as dismissed for the current user.",
|
||||
responses={200: "Success", 401: "Unauthorized"},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
account_id=str(current_user.id),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -393,3 +393,78 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
calls will return should_show=false when frequency='once'.
|
||||
|
||||
Response shape (mirrors GetAccountNotificationReply):
|
||||
{
|
||||
"should_show": bool,
|
||||
"notification": { # present only when should_show=true
|
||||
"notification_id": str,
|
||||
"contents": { # lang -> LangContent
|
||||
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
|
||||
...
|
||||
},
|
||||
"frequency": "once" | "every_page_load"
|
||||
}
|
||||
}
|
||||
"""
|
||||
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
}
|
||||
if notification_id:
|
||||
payload["notification_id"] = notification_id
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/accounts",
|
||||
json={"account_ids": account_ids},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
|
||||
"""Mark a notification as dismissed for an account.
|
||||
|
||||
Returns {"success": bool}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/dismiss",
|
||||
json={"account_id": account_id},
|
||||
)
|
||||
|
||||
258
api/services/retention/conversation/message_export_service.py
Normal file
258
api/services/retention/conversation/message_export_service.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename = filename
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
with open(self._filename, "wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self._filename, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self._filename)
|
||||
|
||||
def _iter_records_with_stats(
|
||||
self, stats: AppMessageExportStats
|
||||
) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(
|
||||
row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]
|
||||
) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Container-backed integration tests for DocumentService.rename_document real SQL paths."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from unittest.mock import create_autospec, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Document
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
"""Patch only non-SQL dependency used by rename_document: current_user context."""
|
||||
with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user:
|
||||
current_user.current_tenant_id = str(uuid4())
|
||||
current_user.id = str(uuid4())
|
||||
yield {"current_user": current_user}
|
||||
|
||||
|
||||
def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, built_in_field_enabled=False):
|
||||
"""Persist a dataset row for rename_document integration scenarios."""
|
||||
dataset_id = dataset_id or str(uuid4())
|
||||
tenant_id = tenant_id or str(uuid4())
|
||||
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
data_source_type="upload_file",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
dataset.id = dataset_id
|
||||
dataset.built_in_field_enabled = built_in_field_enabled
|
||||
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
|
||||
def make_document(
|
||||
db_session_with_containers,
|
||||
document_id=None,
|
||||
dataset_id=None,
|
||||
tenant_id=None,
|
||||
name="Old Name",
|
||||
data_source_info=None,
|
||||
doc_metadata=None,
|
||||
):
|
||||
"""Persist a document row used by rename_document integration scenarios."""
|
||||
document_id = document_id or str(uuid4())
|
||||
dataset_id = dataset_id or str(uuid4())
|
||||
tenant_id = tenant_id or str(uuid4())
|
||||
|
||||
doc = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_info=json.dumps(data_source_info or {}),
|
||||
batch=f"batch-{uuid4()}",
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_by=str(uuid4()),
|
||||
doc_form="text_model",
|
||||
)
|
||||
doc.id = document_id
|
||||
doc.indexing_status = "completed"
|
||||
doc.doc_metadata = dict(doc_metadata or {})
|
||||
|
||||
db_session_with_containers.add(doc)
|
||||
db_session_with_containers.commit()
|
||||
return doc
|
||||
|
||||
|
||||
def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, name: str):
|
||||
"""Persist an upload file row referenced by document.data_source_info."""
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key=f"uploads/{uuid4()}",
|
||||
name=name,
|
||||
size=128,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid4()),
|
||||
created_at=FIXED_UPLOAD_CREATED_AT,
|
||||
used=False,
|
||||
)
|
||||
upload_file.id = file_id
|
||||
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
return upload_file
|
||||
|
||||
|
||||
def test_rename_document_success(db_session_with_containers, mock_env):
|
||||
"""Rename succeeds and returns the renamed document identity by id."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
new_name = "New Document Name"
|
||||
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
|
||||
document = make_document(
|
||||
db_session_with_containers,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DocumentService.rename_document(dataset.id, document_id, new_name)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
assert result.id == document.id
|
||||
assert document.name == new_name
|
||||
|
||||
|
||||
def test_rename_document_with_built_in_fields(db_session_with_containers, mock_env):
|
||||
"""Built-in document_name metadata is updated while existing metadata keys are preserved."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
new_name = "Renamed"
|
||||
dataset = make_dataset(
|
||||
db_session_with_containers,
|
||||
dataset_id,
|
||||
mock_env["current_user"].current_tenant_id,
|
||||
built_in_field_enabled=True,
|
||||
)
|
||||
document = make_document(
|
||||
db_session_with_containers,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
doc_metadata={"foo": "bar"},
|
||||
)
|
||||
|
||||
# Act
|
||||
DocumentService.rename_document(dataset.id, document.id, new_name)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.name == new_name
|
||||
assert document.doc_metadata["document_name"] == new_name
|
||||
assert document.doc_metadata["foo"] == "bar"
|
||||
|
||||
|
||||
def test_rename_document_updates_upload_file_when_present(db_session_with_containers, mock_env):
|
||||
"""Rename propagates to UploadFile.name when upload_file_id is present in data_source_info."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
file_id = str(uuid4())
|
||||
new_name = "Renamed"
|
||||
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
|
||||
document = make_document(
|
||||
db_session_with_containers,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
data_source_info={"upload_file_id": file_id},
|
||||
)
|
||||
upload_file = make_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
file_id=file_id,
|
||||
name="old.pdf",
|
||||
)
|
||||
|
||||
# Act
|
||||
DocumentService.rename_document(dataset.id, document.id, new_name)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
db_session_with_containers.refresh(upload_file)
|
||||
assert document.name == new_name
|
||||
assert upload_file.name == new_name
|
||||
|
||||
|
||||
def test_rename_document_does_not_update_upload_file_when_missing_id(db_session_with_containers, mock_env):
|
||||
"""Rename does not update UploadFile when data_source_info lacks upload_file_id."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
new_name = "Another Name"
|
||||
dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id)
|
||||
document = make_document(
|
||||
db_session_with_containers,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
data_source_info={"url": "https://example.com"},
|
||||
)
|
||||
untouched_file = make_upload_file(
|
||||
db_session_with_containers,
|
||||
tenant_id=mock_env["current_user"].current_tenant_id,
|
||||
file_id=str(uuid4()),
|
||||
name="untouched.pdf",
|
||||
)
|
||||
|
||||
# Act
|
||||
DocumentService.rename_document(dataset.id, document.id, new_name)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.refresh(document)
|
||||
db_session_with_containers.refresh(untouched_file)
|
||||
assert document.name == new_name
|
||||
assert untouched_file.name == "untouched.pdf"
|
||||
|
||||
|
||||
def test_rename_document_dataset_not_found(db_session_with_containers, mock_env):
|
||||
"""Rename raises Dataset not found when dataset id does not exist."""
|
||||
# Arrange
|
||||
missing_dataset_id = str(uuid4())
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x")
|
||||
|
||||
|
||||
def test_rename_document_not_found(db_session_with_containers, mock_env):
|
||||
"""Rename raises Document not found when document id is absent in the dataset."""
|
||||
# Arrange
|
||||
dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
DocumentService.rename_document(dataset.id, str(uuid4()), "x")
|
||||
|
||||
|
||||
def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_containers, mock_env):
|
||||
"""Rename raises No permission when document tenant differs from current_user tenant."""
|
||||
# Arrange
|
||||
dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id)
|
||||
document = make_document(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
tenant_id=str(uuid4()),
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="No permission"):
|
||||
DocumentService.rename_document(dataset.id, document.id, "x")
|
||||
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused.jsonl.gz",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@@ -282,7 +283,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
|
||||
return dataset, documents
|
||||
|
||||
def test_duplicate_document_indexing_task_success(
|
||||
def _test_duplicate_document_indexing_task_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -324,7 +325,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 3
|
||||
|
||||
def test_duplicate_document_indexing_task_with_segment_cleanup(
|
||||
def _test_duplicate_document_indexing_task_with_segment_cleanup(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -374,7 +375,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner"].assert_called_once()
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_task_dataset_not_found(
|
||||
def _test_duplicate_document_indexing_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -445,7 +446,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
processed_documents = call_args[0][0] # First argument should be documents list
|
||||
assert len(processed_documents) == 2 # Only existing documents
|
||||
|
||||
def test_duplicate_document_indexing_task_indexing_runner_exception(
|
||||
def _test_duplicate_document_indexing_task_indexing_runner_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -486,7 +487,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -549,7 +550,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
# Verify indexing runner was not called due to early validation error
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called()
|
||||
|
||||
def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
|
||||
def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
@@ -783,3 +784,90 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
document_ids=document_ids,
|
||||
)
|
||||
mock_queue.delete_task_key.assert_not_called()
|
||||
|
||||
def test_successful_duplicate_document_indexing(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
self._test_duplicate_document_indexing_task_success(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def test_duplicate_document_indexing_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing when dataset is not found."""
|
||||
self._test_duplicate_document_indexing_task_dataset_not_found(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing with billing enabled and sandbox plan."""
|
||||
self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def test_duplicate_document_indexing_runner_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
self._test_duplicate_document_indexing_task_indexing_runner_exception(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def _test_duplicate_document_indexing_task_document_is_paused(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
dataset, documents = self._create_test_dataset_and_documents(
|
||||
db_session_with_containers, mock_external_service_dependencies, document_count=2
|
||||
)
|
||||
for document in documents:
|
||||
document.is_paused = True
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
document_ids = [doc.id for doc in documents]
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError(
|
||||
"Document paused"
|
||||
)
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset.id, document_ids)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
# Assert
|
||||
for doc_id in document_ids:
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.is_paused is True
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.display_status == "paused"
|
||||
assert updated_document.processing_started_at is not None
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_document_is_paused(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
self._test_duplicate_document_indexing_task_document_is_paused(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
def test_duplicate_document_indexing_cleans_old_segments(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
self._test_duplicate_document_indexing_task_with_segment_cleanup(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Account
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
"""Patch dependencies used by DocumentService.rename_document.
|
||||
|
||||
Mocks:
|
||||
- DatasetService.get_dataset
|
||||
- DocumentService.get_document
|
||||
- current_user (with current_tenant_id)
|
||||
- db.session
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as get_dataset,
|
||||
patch("services.dataset_service.DocumentService.get_document") as get_document,
|
||||
patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user,
|
||||
patch("extensions.ext_database.db.session") as db_session,
|
||||
):
|
||||
current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
"get_dataset": get_dataset,
|
||||
"get_document": get_document,
|
||||
"current_user": current_user,
|
||||
"db_session": db_session,
|
||||
}
|
||||
|
||||
|
||||
def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False):
|
||||
return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled)
|
||||
|
||||
|
||||
def make_document(
|
||||
document_id="document-123",
|
||||
dataset_id="dataset-123",
|
||||
tenant_id="tenant-123",
|
||||
name="Old Name",
|
||||
data_source_info=None,
|
||||
doc_metadata=None,
|
||||
):
|
||||
doc = Mock()
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.name = name
|
||||
doc.data_source_info = data_source_info or {}
|
||||
# property-like usage in code relies on a dict
|
||||
doc.data_source_info_dict = dict(doc.data_source_info)
|
||||
doc.doc_metadata = dict(doc_metadata or {})
|
||||
return doc
|
||||
|
||||
|
||||
def test_rename_document_success(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "New Document Name"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
document = make_document(document_id=document_id, dataset_id=dataset_id)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
result = DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert result is document
|
||||
assert document.name == new_name
|
||||
mock_env["db_session"].add.assert_called_once_with(document)
|
||||
mock_env["db_session"].commit.assert_called_once()
|
||||
|
||||
|
||||
def test_rename_document_with_built_in_fields(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Renamed"
|
||||
|
||||
dataset = make_dataset(dataset_id, built_in_field_enabled=True)
|
||||
document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"})
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
# BuiltInField.document_name == "document_name" in service code
|
||||
assert document.doc_metadata["document_name"] == new_name
|
||||
assert document.doc_metadata["foo"] == "bar"
|
||||
|
||||
|
||||
def test_rename_document_updates_upload_file_when_present(mock_env):
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Renamed"
|
||||
file_id = "file-123"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
document = make_document(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
data_source_info={"upload_file_id": file_id},
|
||||
)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
# Intercept UploadFile rename UPDATE chain
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_env["db_session"].query.return_value = mock_query
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
mock_env["db_session"].query.assert_called() # update executed
|
||||
|
||||
|
||||
def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env):
|
||||
"""
|
||||
When data_source_info_dict exists but does not contain "upload_file_id",
|
||||
UploadFile should not be updated.
|
||||
"""
|
||||
dataset_id = "dataset-123"
|
||||
document_id = "document-123"
|
||||
new_name = "Another Name"
|
||||
|
||||
dataset = make_dataset(dataset_id)
|
||||
# Ensure data_source_info_dict is truthy but lacks the key
|
||||
document = make_document(
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
data_source_info={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
DocumentService.rename_document(dataset_id, document_id, new_name)
|
||||
|
||||
assert document.name == new_name
|
||||
# Should NOT attempt to update UploadFile
|
||||
mock_env["db_session"].query.assert_not_called()
|
||||
|
||||
|
||||
def test_rename_document_dataset_not_found(mock_env):
|
||||
mock_env["get_dataset"].return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Dataset not found"):
|
||||
DocumentService.rename_document("missing", "doc", "x")
|
||||
|
||||
|
||||
def test_rename_document_not_found(mock_env):
|
||||
dataset = make_dataset("dataset-123")
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
DocumentService.rename_document(dataset.id, "missing", "x")
|
||||
|
||||
|
||||
def test_rename_document_permission_denied_when_tenant_mismatch(mock_env):
|
||||
dataset = make_dataset("dataset-123")
|
||||
# different tenant than current_user.current_tenant_id
|
||||
document = make_document(dataset_id=dataset.id, tenant_id="tenant-other")
|
||||
|
||||
mock_env["get_dataset"].return_value = dataset
|
||||
mock_env["get_document"].return_value = document
|
||||
|
||||
with pytest.raises(ValueError, match="No permission"):
|
||||
DocumentService.rename_document(dataset.id, document.id, "x")
|
||||
@@ -1,158 +1,38 @@
|
||||
"""
|
||||
Unit tests for duplicate document indexing tasks.
|
||||
|
||||
This module tests the duplicate document indexing task functionality including:
|
||||
- Task enqueuing to different queues (normal, priority, tenant-isolated)
|
||||
- Batch processing of multiple duplicate documents
|
||||
- Progress tracking through task lifecycle
|
||||
- Error handling and retry mechanisms
|
||||
- Cleanup of old document data before re-indexing
|
||||
"""
|
||||
"""Unit tests for queue/wrapper behaviors in duplicate document indexing tasks (non-database logic)."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.duplicate_document_indexing_task import (
|
||||
_duplicate_document_indexing_task,
|
||||
_duplicate_document_indexing_task_with_tenant_queue,
|
||||
duplicate_document_indexing_task,
|
||||
normal_duplicate_document_indexing_task,
|
||||
priority_duplicate_document_indexing_task,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_ids():
|
||||
"""Generate a list of document IDs for testing."""
|
||||
return [str(uuid.uuid4()) for _ in range(3)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents(document_ids, dataset_id):
|
||||
"""Create mock Document objects."""
|
||||
documents = []
|
||||
for doc_id in document_ids:
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = doc_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.indexing_status = "waiting"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document_segments(document_ids):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for doc_id in document_ids:
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = doc_id
|
||||
segment.index_node_id = f"node-{doc_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf:
|
||||
session = MagicMock()
|
||||
# Allow tests to observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(*args, **kwargs):
|
||||
session.close()
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
session.scalars.return_value = MagicMock()
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_feature_service():
|
||||
"""Mock FeatureService."""
|
||||
with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service:
|
||||
mock_features = Mock()
|
||||
mock_features.billing = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_features.vector_space = Mock()
|
||||
mock_features.vector_space.size = 0
|
||||
mock_features.vector_space.limit = 1000
|
||||
mock_service.get_features.return_value = mock_features
|
||||
yield mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_isolated_queue():
|
||||
"""Mock TenantIsolatedTaskQueue."""
|
||||
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class:
|
||||
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
|
||||
mock_queue = Mock(spec=TenantIsolatedTaskQueue)
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue.delete_task_key = Mock()
|
||||
mock_queue.set_task_waiting_time = Mock()
|
||||
@@ -160,11 +40,6 @@ def mock_tenant_isolated_queue():
|
||||
yield mock_queue
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for deprecated duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTask:
|
||||
"""Tests for the deprecated duplicate_document_indexing_task function."""
|
||||
|
||||
@@ -190,258 +65,6 @@ class TestDuplicateDocumentIndexingTask:
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for _duplicate_document_indexing_task core function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskCore:
|
||||
"""Tests for the _duplicate_document_indexing_task core function."""
|
||||
|
||||
def test_successful_duplicate_document_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test successful duplicate document indexing flow."""
|
||||
# Arrange
|
||||
# Dataset via query.first()
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# scalars() call sequence:
|
||||
# 1) documents list
|
||||
# 2..N) segments per document
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
# First call returns documents; subsequent calls return segments
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify IndexingRunner was called
|
||||
mock_indexing_runner.run.assert_called_once()
|
||||
|
||||
# Verify all documents were set to parsing status
|
||||
for doc in mock_documents:
|
||||
assert doc.indexing_status == "parsing"
|
||||
assert doc.processing_started_at is not None
|
||||
|
||||
# Verify session operations
|
||||
assert mock_db_session.commit.called
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids):
|
||||
"""Test duplicate document indexing when dataset is not found."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session at least once
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing with billing enabled and sandbox plan."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# For sandbox plan with multiple documents, should fail
|
||||
mock_db_session.commit.assert_called()
|
||||
|
||||
def test_duplicate_document_indexing_with_billing_limit_exceeded(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_feature_service,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when billing limit is exceeded."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
# First scalars() -> documents; subsequent -> empty segments
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_features = mock_feature_service.get_features.return_value
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.billing.subscription.plan = CloudPlan.TEAM
|
||||
mock_features.vector_space.size = 990
|
||||
mock_features.vector_space.limit = 1000
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should commit the session
|
||||
assert mock_db_session.commit.called
|
||||
# Should close the session
|
||||
assert mock_db_session.close.called
|
||||
|
||||
def test_duplicate_document_indexing_runner_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when IndexingRunner raises an error."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should close the session even after error
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_document_is_paused(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test duplicate document indexing when document is paused."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = []
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Should handle DocumentIsPausedError gracefully
|
||||
mock_db_session.close.assert_called_once()
|
||||
|
||||
def test_duplicate_document_indexing_cleans_old_segments(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_indexing_runner,
|
||||
mock_feature_service,
|
||||
mock_index_processor_factory,
|
||||
mock_dataset,
|
||||
mock_documents,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_ids,
|
||||
):
|
||||
"""Test that duplicate document indexing cleans old segments."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
||||
|
||||
def _scalars_side_effect(*args, **kwargs):
|
||||
m = MagicMock()
|
||||
if not hasattr(_scalars_side_effect, "_calls"):
|
||||
_scalars_side_effect._calls = 0
|
||||
if _scalars_side_effect._calls == 0:
|
||||
m.all.return_value = mock_documents
|
||||
else:
|
||||
m.all.return_value = mock_document_segments
|
||||
_scalars_side_effect._calls += 1
|
||||
return m
|
||||
|
||||
mock_db_session.scalars.side_effect = _scalars_side_effect
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
|
||||
# Act
|
||||
_duplicate_document_indexing_task(dataset_id, document_ids)
|
||||
|
||||
# Assert
|
||||
# Verify clean was called for each document
|
||||
assert mock_processor.clean.call_count == len(mock_documents)
|
||||
|
||||
# Verify segments were deleted in batch (DELETE FROM document_segments)
|
||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for tenant queue wrapper function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
|
||||
|
||||
@@ -536,11 +159,6 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
mock_tenant_isolated_queue.pull_tasks.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for normal_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestNormalDuplicateDocumentIndexingTask:
|
||||
"""Tests for normal_duplicate_document_indexing_task function."""
|
||||
|
||||
@@ -581,11 +199,6 @@ class TestNormalDuplicateDocumentIndexingTask:
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for priority_duplicate_document_indexing_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestPriorityDuplicateDocumentIndexingTask:
|
||||
"""Tests for priority_duplicate_document_indexing_task function."""
|
||||
|
||||
|
||||
12
api/uv.lock
generated
12
api/uv.lock
generated
@@ -441,14 +441,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.6"
|
||||
version = "1.6.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/49/dc/ed1681bf1339dd6ea1ce56136bad4baabc6f7ad466e375810702b0237047/authlib-1.6.7.tar.gz", hash = "sha256:dbf10100011d1e1b34048c9d120e83f13b35d69a826ae762b93d2fb5aafc337b", size = 164950, upload-time = "2026-02-06T14:04:14.171Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1989,11 +1989,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "fickling"
|
||||
version = "0.1.8"
|
||||
version = "0.1.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/be/cd91e3921f064230ac9462479e4647fb91a7b0d01677103fce89f52e3042/fickling-0.1.8.tar.gz", hash = "sha256:25a0bc7acda76176a9087b405b05f7f5021f76079aa26c6fe3270855ec57d9bf", size = 336756, upload-time = "2026-02-21T00:57:26.106Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/25/bd/ca7127df0201596b0b30f9ab3d36e565bb9d6f8f4da1560758b817e81b65/fickling-0.1.9.tar.gz", hash = "sha256:bb518c2fd833555183bc46b6903bb4022f3ae0436a69c3fb149cfc75eebaac33", size = 336940, upload-time = "2026-03-03T23:32:19.449Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/92/af72f783ac57fa2452f8f921c9441366c42ae1f03f5af41718445114c82f/fickling-0.1.8-py3-none-any.whl", hash = "sha256:97218785cfe00a93150808dcf9e3eb512371e0484e3ce0b05bc460b97240f292", size = 52613, upload-time = "2026-02-21T00:57:24.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/49/c597bad508c74917901432b41ae5a8f036839a7fb8d0d29a89765f5d3643/fickling-0.1.9-py3-none-any.whl", hash = "sha256:ccc3ce3b84733406ade2fe749717f6e428047335157c6431eefd3e7e970a06d1", size = 52786, upload-time = "2026-03-03T23:32:17.533Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 6. Close Handling ───────────────────────────────────────────────────
|
||||
describe('Close handling', () => {
|
||||
it('should call onCancel when pressing ESC key', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
// ahooks useKeyPress listens on document for keydown events
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
code: 'Escape',
|
||||
keyCode: 27,
|
||||
bubbles: true,
|
||||
}))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
|
||||
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
|
||||
describe('Pricing page URL', () => {
|
||||
it('should render pricing link with correct URL', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
@@ -8,7 +8,7 @@ import GotoAnything from '@/app/components/goto-anything'
|
||||
import Header from '@/app/components/header'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import ReadmePanel from '@/app/components/plugins/readme-panel'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { AppContextProvider } from '@/context/app-context-provider'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
import { ProviderContextProvider } from '@/context/provider-context'
|
||||
|
||||
@@ -4,7 +4,7 @@ import { AppInitializer } from '@/app/components/app-initializer'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import HeaderWrapper from '@/app/components/header/header-wrapper'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { AppContextProvider } from '@/context/app-context-provider'
|
||||
import { EventEmitterContextProvider } from '@/context/event-emitter'
|
||||
import { ModalContextProvider } from '@/context/modal-context'
|
||||
import { ProviderContextProvider } from '@/context/provider-context'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
||||
import Header from '@/app/signin/_header'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
import { AppContextProvider } from '@/context/app-context-provider'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useIsLogin } from '@/service/use-common'
|
||||
@@ -38,7 +38,7 @@ export default function SignInLayout({ children }: any) {
|
||||
</div>
|
||||
</div>
|
||||
{systemFeatures.branding.enabled === false && (
|
||||
<div className="system-xs-regular px-8 py-6 text-text-tertiary">
|
||||
<div className="px-8 py-6 text-text-tertiary system-xs-regular">
|
||||
©
|
||||
{' '}
|
||||
{new Date().getFullYear()}
|
||||
|
||||
@@ -26,11 +26,10 @@ export const AppInitializer = ({
|
||||
// Tokens are now stored in cookies, no need to check localStorage
|
||||
const pathname = usePathname()
|
||||
const [init, setInit] = useState(false)
|
||||
const [oauthNewUser, setOauthNewUser] = useQueryState(
|
||||
const [oauthNewUser] = useQueryState(
|
||||
'oauth_new_user',
|
||||
parseAsBoolean.withOptions({ history: 'replace' }),
|
||||
)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
try {
|
||||
const setUpStatus = await fetchSetupStatusWithCache()
|
||||
@@ -69,11 +68,12 @@ export const AppInitializer = ({
|
||||
...utmInfo,
|
||||
})
|
||||
|
||||
// Clean up: remove utm_info cookie and URL params
|
||||
Cookies.remove('utm_info')
|
||||
setOauthNewUser(null)
|
||||
}
|
||||
|
||||
if (oauthNewUser !== null)
|
||||
router.replace(pathname)
|
||||
|
||||
if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION)
|
||||
localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes')
|
||||
|
||||
@@ -96,7 +96,7 @@ export const AppInitializer = ({
|
||||
router.replace('/signin')
|
||||
}
|
||||
})()
|
||||
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser])
|
||||
}, [isSetupFinished, router, pathname, searchParams, oauthNewUser])
|
||||
|
||||
return init ? children : null
|
||||
}
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
|
||||
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "10",
|
||||
"height": "10",
|
||||
"viewBox": "0 0 10 10",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "CreditsCoin"
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './CreditsCoin.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'CreditsCoin'
|
||||
|
||||
export default Icon
|
||||
@@ -1,5 +1,6 @@
|
||||
export { default as Balance } from './Balance'
|
||||
export { default as CoinsStacked01 } from './CoinsStacked01'
|
||||
export { default as CreditsCoin } from './CreditsCoin'
|
||||
export { default as GoldCoin } from './GoldCoin'
|
||||
export { default as ReceiptList } from './ReceiptList'
|
||||
export { default as Tag01 } from './Tag01'
|
||||
|
||||
@@ -43,20 +43,24 @@ type DialogContentProps = {
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
overlayClassName?: string
|
||||
backdropProps?: React.ComponentPropsWithoutRef<typeof BaseDialog.Backdrop>
|
||||
}
|
||||
|
||||
export function DialogContent({
|
||||
children,
|
||||
className,
|
||||
overlayClassName,
|
||||
backdropProps,
|
||||
}: DialogContentProps) {
|
||||
return (
|
||||
<DialogPortal>
|
||||
<BaseDialog.Backdrop
|
||||
{...backdropProps}
|
||||
className={cn(
|
||||
'fixed inset-0 z-50 bg-background-overlay',
|
||||
'transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
|
||||
overlayClassName,
|
||||
backdropProps?.className,
|
||||
)}
|
||||
/>
|
||||
<BaseDialog.Popup
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { CategoryEnum } from '..'
|
||||
import Footer from '../footer'
|
||||
import { CategoryEnum } from '../types'
|
||||
|
||||
vi.mock('next/link', () => ({
|
||||
default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { Dialog } from '@/app/components/base/ui/dialog'
|
||||
import Header from '../header'
|
||||
|
||||
function renderHeader(onClose: () => void) {
|
||||
return render(
|
||||
<Dialog open>
|
||||
<Header onClose={onClose} />
|
||||
</Dialog>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Header', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -11,7 +20,7 @@ describe('Header', () => {
|
||||
it('should render title and description translations', () => {
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Header onClose={handleClose} />)
|
||||
renderHeader(handleClose)
|
||||
|
||||
expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument()
|
||||
@@ -22,7 +31,7 @@ describe('Header', () => {
|
||||
describe('Props', () => {
|
||||
it('should invoke onClose when close button is clicked', () => {
|
||||
const handleClose = vi.fn()
|
||||
render(<Header onClose={handleClose} />)
|
||||
renderHeader(handleClose)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
@@ -32,7 +41,7 @@ describe('Header', () => {
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should render structural elements with translation keys', () => {
|
||||
const { container } = render(<Header onClose={vi.fn()} />)
|
||||
const { container } = renderHeader(vi.fn())
|
||||
|
||||
expect(container.querySelector('span')).toBeInTheDocument()
|
||||
expect(container.querySelector('p')).toBeInTheDocument()
|
||||
|
||||
@@ -74,15 +74,11 @@ describe('Pricing', () => {
|
||||
})
|
||||
|
||||
describe('Props', () => {
|
||||
it('should allow switching categories and handle esc key', () => {
|
||||
const handleCancel = vi.fn()
|
||||
render(<Pricing onCancel={handleCancel} />)
|
||||
it('should allow switching categories', () => {
|
||||
render(<Pricing onCancel={vi.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText('billing.plansCommon.self'))
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.keyDown(window, { key: 'Escape', keyCode: 27 })
|
||||
expect(handleCancel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import type { Category } from '.'
|
||||
import { RiArrowRightUpLine } from '@remixicon/react'
|
||||
import type { Category } from './types'
|
||||
import Link from 'next/link'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { CategoryEnum } from '.'
|
||||
import { CategoryEnum } from './types'
|
||||
|
||||
type FooterProps = {
|
||||
pricingPageURL: string
|
||||
@@ -34,7 +33,7 @@ const Footer = ({
|
||||
>
|
||||
{t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })}
|
||||
</Link>
|
||||
<RiArrowRightUpLine className="size-4" />
|
||||
<span aria-hidden="true" className="i-ri-arrow-right-up-line size-4" />
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog'
|
||||
import Button from '../../base/button'
|
||||
import DifyLogo from '../../base/logo/dify-logo'
|
||||
|
||||
@@ -20,19 +20,19 @@ const Header = ({
|
||||
<div className="py-[5px]">
|
||||
<DifyLogo className="h-[27px] w-[60px]" />
|
||||
</div>
|
||||
<span className="bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
|
||||
<DialogTitle className="m-0 bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
|
||||
{t('plansCommon.title.plans', { ns: 'billing' })}
|
||||
</span>
|
||||
</DialogTitle>
|
||||
</div>
|
||||
<p className="system-sm-regular text-text-tertiary">
|
||||
<DialogDescription className="m-0 text-text-tertiary system-sm-regular">
|
||||
{t('plansCommon.title.description', { ns: 'billing' })}
|
||||
</p>
|
||||
</DialogDescription>
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="absolute bottom-[40.5px] right-[-18px] z-10 size-9 rounded-full p-2"
|
||||
onClick={onClose}
|
||||
>
|
||||
<RiCloseLine className="size-5" />
|
||||
<span aria-hidden="true" className="i-ri-close-line size-5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import type { Category } from './types'
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGetPricingPageLanguage } from '@/context/i18n'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
@@ -13,13 +13,7 @@ import Header from './header'
|
||||
import PlanSwitcher from './plan-switcher'
|
||||
import { PlanRange } from './plan-switcher/plan-range-switcher'
|
||||
import Plans from './plans'
|
||||
|
||||
export enum CategoryEnum {
|
||||
CLOUD = 'cloud',
|
||||
SELF = 'self',
|
||||
}
|
||||
|
||||
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
|
||||
import { CategoryEnum } from './types'
|
||||
|
||||
type PricingProps = {
|
||||
onCancel: () => void
|
||||
@@ -33,42 +27,47 @@ const Pricing: FC<PricingProps> = ({
|
||||
const [planRange, setPlanRange] = React.useState<PlanRange>(PlanRange.monthly)
|
||||
const [currentCategory, setCurrentCategory] = useState<Category>(CategoryEnum.CLOUD)
|
||||
const canPay = isCurrentWorkspaceManager
|
||||
useKeyPress(['esc'], onCancel)
|
||||
|
||||
const pricingPageLanguage = useGetPricingPageLanguage()
|
||||
const pricingPageURL = pricingPageLanguage
|
||||
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
|
||||
: 'https://dify.ai/pricing#plans-and-features'
|
||||
|
||||
return createPortal(
|
||||
<div
|
||||
className="fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] overflow-auto bg-saas-background"
|
||||
onClick={e => e.stopPropagation()}
|
||||
return (
|
||||
<Dialog
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}}
|
||||
>
|
||||
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
|
||||
<div className="absolute -top-12 left-0 right-0 -z-10">
|
||||
<NoiseTop />
|
||||
<DialogContent
|
||||
className="inset-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-auto rounded-none border-none bg-saas-background p-0 shadow-none"
|
||||
>
|
||||
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
|
||||
<div className="absolute -top-12 left-0 right-0 -z-10">
|
||||
<NoiseTop />
|
||||
</div>
|
||||
<Header onClose={onCancel} />
|
||||
<PlanSwitcher
|
||||
currentCategory={currentCategory}
|
||||
onChangeCategory={setCurrentCategory}
|
||||
currentPlanRange={planRange}
|
||||
onChangePlanRange={setPlanRange}
|
||||
/>
|
||||
<Plans
|
||||
plan={plan}
|
||||
currentPlan={currentCategory}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
|
||||
<div className="absolute -bottom-12 left-0 right-0 -z-10">
|
||||
<NoiseBottom />
|
||||
</div>
|
||||
</div>
|
||||
<Header onClose={onCancel} />
|
||||
<PlanSwitcher
|
||||
currentCategory={currentCategory}
|
||||
onChangeCategory={setCurrentCategory}
|
||||
currentPlanRange={planRange}
|
||||
onChangePlanRange={setPlanRange}
|
||||
/>
|
||||
<Plans
|
||||
plan={plan}
|
||||
currentPlan={currentCategory}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
|
||||
<div className="absolute -bottom-12 left-0 right-0 -z-10">
|
||||
<NoiseBottom />
|
||||
</div>
|
||||
</div>
|
||||
</div>,
|
||||
document.body,
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
export default React.memo(Pricing)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { CategoryEnum } from '../../index'
|
||||
import { CategoryEnum } from '../../types'
|
||||
import PlanSwitcher from '../index'
|
||||
import { PlanRange } from '../plan-range-switcher'
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Category } from '../index'
|
||||
import type { Category } from '../types'
|
||||
import type { PlanRange } from './plan-range-switcher'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
6
web/app/components/billing/pricing/types.ts
Normal file
6
web/app/components/billing/pricing/types.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export enum CategoryEnum {
|
||||
CLOUD = 'cloud',
|
||||
SELF = 'self',
|
||||
}
|
||||
|
||||
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
|
||||
@@ -100,10 +100,10 @@ vi.mock('@/app/components/datasets/create/step-two', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting', () => ({
|
||||
default: ({ activeTab, onCancel }: { activeTab?: string, onCancel?: () => void }) => (
|
||||
default: ({ activeTab, onCancelAction }: { activeTab?: string, onCancelAction?: () => void }) => (
|
||||
<div data-testid="account-setting">
|
||||
<span data-testid="active-tab">{activeTab}</span>
|
||||
<button onClick={onCancel} data-testid="close-setting">Close</button>
|
||||
<button onClick={onCancelAction} data-testid="close-setting">Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import type { DataSourceProvider, NotionPage } from '@/models/common'
|
||||
import type {
|
||||
CrawlOptions,
|
||||
@@ -19,6 +20,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import StepTwo from '@/app/components/datasets/create/step-two'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
@@ -33,8 +35,13 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
const [accountSettingTab, setAccountSettingTab] = React.useState<AccountSettingTab>(ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
|
||||
const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
const handleOpenAccountSetting = React.useCallback(() => {
|
||||
setAccountSettingTab(ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
showSetAPIKey()
|
||||
}, [showSetAPIKey])
|
||||
|
||||
const invalidDocumentList = useInvalidDocumentList(datasetId)
|
||||
const invalidDocumentDetail = useInvalidDocumentDetail()
|
||||
@@ -135,7 +142,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
{dataset && documentDetail && (
|
||||
<StepTwo
|
||||
isAPIKeySet={!!embeddingsDefaultModel}
|
||||
onSetting={showSetAPIKey}
|
||||
onSetting={handleOpenAccountSetting}
|
||||
datasetId={datasetId}
|
||||
dataSourceType={documentDetail.data_source_type as DataSourceType}
|
||||
notionPages={currentPage ? [currentPage as unknown as NotionPage] : []}
|
||||
@@ -155,8 +162,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
</div>
|
||||
{isShowSetAPIKey && (
|
||||
<AccountSetting
|
||||
activeTab="provider"
|
||||
onCancel={async () => {
|
||||
activeTab={accountSettingTab}
|
||||
onTabChangeAction={setAccountSettingTab}
|
||||
onCancelAction={async () => {
|
||||
hideSetAPIkey()
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import type { AccountSettingTab } from './constants'
|
||||
import type { AppContextValue } from '@/context/app-context'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { useState } from 'react'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { baseProviderContextValue, useProviderContext } from '@/context/provider-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { ACCOUNT_SETTING_TAB } from './constants'
|
||||
import AccountSetting from './index'
|
||||
|
||||
const mockResetModelProviderListExpanded = vi.fn()
|
||||
|
||||
vi.mock('@/context/provider-context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/context/provider-context')>()
|
||||
return {
|
||||
@@ -47,10 +51,15 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', ()
|
||||
useDefaultModel: vi.fn(() => ({ data: null, isLoading: false })),
|
||||
useUpdateDefaultModel: vi.fn(() => ({ trigger: vi.fn() })),
|
||||
useUpdateModelList: vi.fn(() => vi.fn()),
|
||||
useInvalidateDefaultModel: vi.fn(() => vi.fn()),
|
||||
useModelList: vi.fn(() => ({ data: [], isLoading: false })),
|
||||
useSystemDefaultModelAndModelList: vi.fn(() => [null, vi.fn()]),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/atoms', () => ({
|
||||
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-datasource', () => ({
|
||||
useGetDataSourceListAuth: vi.fn(() => ({ data: { result: [] } })),
|
||||
}))
|
||||
@@ -105,6 +114,38 @@ const baseAppContextValue: AppContextValue = {
|
||||
describe('AccountSetting', () => {
|
||||
const mockOnCancel = vi.fn()
|
||||
const mockOnTabChange = vi.fn()
|
||||
const renderAccountSetting = (props?: {
|
||||
initialTab?: AccountSettingTab
|
||||
onCancel?: () => void
|
||||
onTabChange?: (tab: AccountSettingTab) => void
|
||||
}) => {
|
||||
const {
|
||||
initialTab = ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
onCancel = mockOnCancel,
|
||||
onTabChange = mockOnTabChange,
|
||||
} = props ?? {}
|
||||
|
||||
const StatefulAccountSetting = () => {
|
||||
const [activeTab, setActiveTab] = useState<AccountSettingTab>(initialTab)
|
||||
|
||||
return (
|
||||
<AccountSetting
|
||||
onCancelAction={onCancel}
|
||||
activeTab={activeTab}
|
||||
onTabChangeAction={(tab) => {
|
||||
setActiveTab(tab)
|
||||
onTabChange(tab)
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<StatefulAccountSetting />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -120,11 +161,7 @@ describe('AccountSetting', () => {
|
||||
describe('Rendering', () => {
|
||||
it('should render the sidebar with correct menu items', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.userProfile.settings')).toBeInTheDocument()
|
||||
@@ -137,13 +174,9 @@ describe('AccountSetting', () => {
|
||||
expect(screen.getAllByText('common.settings.language').length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should respect the activeTab prop', () => {
|
||||
it('should respect the initial tab', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} activeTab={ACCOUNT_SETTING_TAB.DATA_SOURCE} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.DATA_SOURCE })
|
||||
|
||||
// Assert
|
||||
// Check that the active item title is Data Source
|
||||
@@ -157,11 +190,7 @@ describe('AccountSetting', () => {
|
||||
vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile)
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
// On mobile, the labels should not be rendered as per the implementation
|
||||
@@ -176,11 +205,7 @@ describe('AccountSetting', () => {
|
||||
})
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.settings.provider')).not.toBeInTheDocument()
|
||||
@@ -197,11 +222,7 @@ describe('AccountSetting', () => {
|
||||
})
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.settings.billing')).not.toBeInTheDocument()
|
||||
@@ -212,11 +233,7 @@ describe('AccountSetting', () => {
|
||||
describe('Tab Navigation', () => {
|
||||
it('should change active tab when clicking on menu item', () => {
|
||||
// Arrange
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} onTabChange={mockOnTabChange} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting({ onTabChange: mockOnTabChange })
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByText('common.settings.provider'))
|
||||
@@ -229,11 +246,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should navigate through various tabs and show correct details', () => {
|
||||
// Act & Assert
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Billing
|
||||
fireEvent.click(screen.getByText('common.settings.billing'))
|
||||
@@ -267,13 +280,11 @@ describe('AccountSetting', () => {
|
||||
describe('Interactions', () => {
|
||||
it('should call onCancel when clicking close button', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
const buttons = screen.getAllByRole('button')
|
||||
fireEvent.click(buttons[0])
|
||||
renderAccountSetting()
|
||||
const closeIcon = document.querySelector('.i-ri-close-line')
|
||||
const closeButton = closeIcon?.closest('button')
|
||||
expect(closeButton).not.toBeNull()
|
||||
fireEvent.click(closeButton!)
|
||||
|
||||
// Assert
|
||||
expect(mockOnCancel).toHaveBeenCalled()
|
||||
@@ -281,11 +292,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should call onCancel when pressing Escape key', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
fireEvent.keyDown(document, { key: 'Escape' })
|
||||
|
||||
// Assert
|
||||
@@ -294,12 +301,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should update search value in provider tab', () => {
|
||||
// Arrange
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
fireEvent.click(screen.getByText('common.settings.provider'))
|
||||
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.PROVIDER })
|
||||
|
||||
// Act
|
||||
const input = screen.getByRole('textbox')
|
||||
@@ -312,11 +314,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should handle scroll event in panel', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto')
|
||||
|
||||
// Assert
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SearchInput from '@/app/components/base/search-input'
|
||||
import BillingPage from '@/app/components/billing/billing-page'
|
||||
@@ -20,15 +20,16 @@ import DataSourcePage from './data-source-page-new'
|
||||
import LanguagePage from './language-page'
|
||||
import MembersPage from './members-page'
|
||||
import ModelProviderPage from './model-provider-page'
|
||||
import { useResetModelProviderListExpanded } from './model-provider-page/atoms'
|
||||
|
||||
const iconClassName = `
|
||||
w-5 h-5 mr-2
|
||||
`
|
||||
|
||||
type IAccountSettingProps = {
|
||||
onCancel: () => void
|
||||
activeTab?: AccountSettingTab
|
||||
onTabChange?: (tab: AccountSettingTab) => void
|
||||
onCancelAction: () => void
|
||||
activeTab: AccountSettingTab
|
||||
onTabChangeAction: (tab: AccountSettingTab) => void
|
||||
}
|
||||
|
||||
type GroupItem = {
|
||||
@@ -40,14 +41,12 @@ type GroupItem = {
|
||||
}
|
||||
|
||||
export default function AccountSetting({
|
||||
onCancel,
|
||||
activeTab = ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
onTabChange,
|
||||
onCancelAction,
|
||||
activeTab,
|
||||
onTabChangeAction,
|
||||
}: IAccountSettingProps) {
|
||||
const [activeMenu, setActiveMenu] = useState<AccountSettingTab>(activeTab)
|
||||
useEffect(() => {
|
||||
setActiveMenu(activeTab)
|
||||
}, [activeTab])
|
||||
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
|
||||
const activeMenu = activeTab
|
||||
const { t } = useTranslation()
|
||||
const { enableBilling, enableReplaceWebAppLogo } = useProviderContext()
|
||||
const { isCurrentWorkspaceDatasetOperator } = useAppContext()
|
||||
@@ -148,10 +147,22 @@ export default function AccountSetting({
|
||||
|
||||
const [searchValue, setSearchValue] = useState<string>('')
|
||||
|
||||
const handleTabChange = useCallback((tab: AccountSettingTab) => {
|
||||
if (tab === ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
resetModelProviderListExpanded()
|
||||
|
||||
onTabChangeAction(tab)
|
||||
}, [onTabChangeAction, resetModelProviderListExpanded])
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
resetModelProviderListExpanded()
|
||||
onCancelAction()
|
||||
}, [onCancelAction, resetModelProviderListExpanded])
|
||||
|
||||
return (
|
||||
<MenuDialog
|
||||
show
|
||||
onClose={onCancel}
|
||||
onClose={handleClose}
|
||||
>
|
||||
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
|
||||
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
|
||||
@@ -166,21 +177,22 @@ export default function AccountSetting({
|
||||
<div>
|
||||
{
|
||||
menuItem.items.map(item => (
|
||||
<div
|
||||
<button
|
||||
type="button"
|
||||
key={item.key}
|
||||
className={cn(
|
||||
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
|
||||
'mb-0.5 flex h-[37px] w-full items-center rounded-lg p-1 pl-3 text-left text-sm',
|
||||
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
|
||||
)}
|
||||
aria-label={item.name}
|
||||
title={item.name}
|
||||
onClick={() => {
|
||||
setActiveMenu(item.key)
|
||||
onTabChange?.(item.key)
|
||||
handleTabChange(item.key)
|
||||
}}
|
||||
>
|
||||
{activeMenu === item.key ? item.activeIcon : item.icon}
|
||||
{!isMobile && <div className="truncate">{item.name}</div>}
|
||||
</div>
|
||||
</button>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
@@ -195,7 +207,8 @@ export default function AccountSetting({
|
||||
variant="tertiary"
|
||||
size="large"
|
||||
className="px-2"
|
||||
onClick={onCancel}
|
||||
aria-label={t('operation.close', { ns: 'common' })}
|
||||
onClick={handleClose}
|
||||
>
|
||||
<span className="i-ri-close-line h-5 w-5" />
|
||||
</Button>
|
||||
|
||||
@@ -40,8 +40,7 @@ describe('MenuDialog', () => {
|
||||
)
|
||||
|
||||
// Assert
|
||||
const panel = screen.getByRole('dialog').querySelector('.custom-class')
|
||||
expect(panel).toBeInTheDocument()
|
||||
expect(screen.getByRole('dialog')).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { Dialog, DialogPanel, Transition, TransitionChild } from '@headlessui/react'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { Fragment, useCallback, useEffect } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type DialogProps = {
|
||||
@@ -19,42 +18,25 @@ const MenuDialog = ({
|
||||
}: DialogProps) => {
|
||||
const close = useCallback(() => onClose?.(), [onClose])
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault()
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown)
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown)
|
||||
}
|
||||
}, [close])
|
||||
|
||||
return (
|
||||
<Transition appear show={show} as={Fragment}>
|
||||
<Dialog as="div" className="relative z-[60]" onClose={noop}>
|
||||
<div className="fixed inset-0">
|
||||
<div className="flex min-h-full flex-col items-center justify-center">
|
||||
<TransitionChild>
|
||||
<DialogPanel className={cn(
|
||||
'relative h-full w-full grow overflow-hidden bg-background-sidenav-bg p-0 text-left align-middle backdrop-blur-md transition-all',
|
||||
'duration-300 ease-in data-[closed]:scale-95 data-[closed]:opacity-0',
|
||||
'data-[enter]:scale-100 data-[enter]:opacity-100',
|
||||
'data-[enter]:scale-95 data-[leave]:opacity-0',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
|
||||
{children}
|
||||
</DialogPanel>
|
||||
</TransitionChild>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
</Transition>
|
||||
<Dialog
|
||||
open={show}
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
close()
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
overlayClassName="bg-transparent"
|
||||
className={cn(
|
||||
'left-0 top-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-hidden rounded-none border-none bg-background-sidenav-bg p-0 shadow-none backdrop-blur-md',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
|
||||
{children}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { selectAtom } from 'jotai/utils'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
|
||||
const modelProviderListExpandedAtom = atom<Record<string, boolean>>({})
|
||||
|
||||
const setModelProviderListExpandedAtom = atom(
|
||||
null,
|
||||
(get, set, params: { providerName: string, expanded: boolean }) => {
|
||||
const { providerName, expanded } = params
|
||||
const current = get(modelProviderListExpandedAtom)
|
||||
|
||||
if (expanded) {
|
||||
if (current[providerName])
|
||||
return
|
||||
|
||||
set(modelProviderListExpandedAtom, {
|
||||
...current,
|
||||
[providerName]: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (!current[providerName])
|
||||
return
|
||||
|
||||
const next = { ...current }
|
||||
delete next[providerName]
|
||||
set(modelProviderListExpandedAtom, next)
|
||||
},
|
||||
)
|
||||
|
||||
const resetModelProviderListExpandedAtom = atom(
|
||||
null,
|
||||
(_get, set) => {
|
||||
set(modelProviderListExpandedAtom, {})
|
||||
},
|
||||
)
|
||||
|
||||
export function useModelProviderListExpanded(providerName: string) {
|
||||
const selectedAtom = useMemo(
|
||||
() => selectAtom(modelProviderListExpandedAtom, state => state[providerName] ?? false),
|
||||
[providerName],
|
||||
)
|
||||
return useAtomValue(selectedAtom)
|
||||
}
|
||||
|
||||
export function useSetModelProviderListExpanded(providerName: string) {
|
||||
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
|
||||
return useCallback((expanded: boolean) => {
|
||||
setExpanded({ providerName, expanded })
|
||||
}, [providerName, setExpanded])
|
||||
}
|
||||
|
||||
export function useExpandModelProviderList() {
|
||||
const setExpanded = useSetAtom(setModelProviderListExpandedAtom)
|
||||
return useCallback((providerName: string) => {
|
||||
setExpanded({ providerName, expanded: true })
|
||||
}, [setExpanded])
|
||||
}
|
||||
|
||||
export function useResetModelProviderListExpanded() {
|
||||
return useSetAtom(resetModelProviderListExpandedAtom)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import type {
|
||||
} from './declarations'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
@@ -23,6 +24,7 @@ import {
|
||||
useAnthropicBuyQuota,
|
||||
useCurrentProviderAndModel,
|
||||
useDefaultModel,
|
||||
useInvalidateDefaultModel,
|
||||
useLanguage,
|
||||
useMarketplaceAllPlugins,
|
||||
useModelList,
|
||||
@@ -36,7 +38,6 @@ import {
|
||||
useUpdateModelList,
|
||||
useUpdateModelProviders,
|
||||
} from './hooks'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
@@ -78,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: vi.fn(() => ({
|
||||
eventEmitter: {
|
||||
emit: vi.fn(),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
useMarketplacePlugins: vi.fn(() => ({
|
||||
plugins: [],
|
||||
@@ -99,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('./atoms', () => ({
|
||||
useExpandModelProviderList: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
|
||||
const { getPayUrl } = await import('@/service/common')
|
||||
const { useProviderContext } = await import('@/context/provider-context')
|
||||
const { useModalContextSelector } = await import('@/context/modal-context')
|
||||
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
|
||||
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
|
||||
const { useExpandModelProviderList } = await import('./atoms')
|
||||
|
||||
describe('hooks', () => {
|
||||
beforeEach(() => {
|
||||
@@ -864,6 +861,38 @@ describe('hooks', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('useInvalidateDefaultModel', () => {
|
||||
it('should invalidate default model queries', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const { result } = renderHook(() => useInvalidateDefaultModel())
|
||||
|
||||
act(() => {
|
||||
result.current(ModelTypeEnum.textGeneration)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: ['default-model', ModelTypeEnum.textGeneration],
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle multiple model types', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const { result } = renderHook(() => useInvalidateDefaultModel())
|
||||
|
||||
act(() => {
|
||||
result.current(ModelTypeEnum.textGeneration)
|
||||
result.current(ModelTypeEnum.textEmbedding)
|
||||
result.current(ModelTypeEnum.rerank)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useAnthropicBuyQuota', () => {
|
||||
beforeEach(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
@@ -1167,39 +1196,52 @@ describe('hooks', () => {
|
||||
|
||||
it('should refresh providers and model lists', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
|
||||
})
|
||||
|
||||
it('should emit event when refreshModelList is true and custom config is active', () => {
|
||||
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = createMockProvider()
|
||||
const customFields: CustomConfigurationModelFixedFields = {
|
||||
__model_name: 'gpt-4',
|
||||
__model_type: ModelTypeEnum.textGeneration,
|
||||
}
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1207,23 +1249,30 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, customFields, true)
|
||||
})
|
||||
|
||||
expect(emit).toHaveBeenCalledWith({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: 'openai',
|
||||
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
})
|
||||
|
||||
it('should not emit event when custom config is not active', () => {
|
||||
it('should not expand provider list when custom config is not active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1231,16 +1280,43 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(emit).not.toHaveBeenCalled()
|
||||
expect(expandModelProviderList).not.toHaveBeenCalled()
|
||||
expect(invalidateQueries).not.toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should refetch active model provider list when custom refresh callback is absent', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle provider with single model type', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit: vi.fn() },
|
||||
})
|
||||
|
||||
const provider = {
|
||||
...createMockProvider(),
|
||||
|
||||
@@ -21,10 +21,10 @@ import {
|
||||
useMarketplacePluginsByCollectionId,
|
||||
} from '@/app/components/plugins/marketplace/hooks'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import {
|
||||
fetchDefaultModal,
|
||||
fetchModelList,
|
||||
@@ -32,12 +32,12 @@ import {
|
||||
getPayUrl,
|
||||
} from '@/service/common'
|
||||
import { commonQueryKeys } from '@/service/use-common'
|
||||
import { useExpandModelProviderList } from './atoms'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
ModelStatusEnum,
|
||||
} from './declarations'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
type UseDefaultModelAndModelList = (
|
||||
defaultModel: DefaultModelResponse | undefined,
|
||||
@@ -222,6 +222,14 @@ export const useUpdateModelList = () => {
|
||||
return updateModelList
|
||||
}
|
||||
|
||||
export const useInvalidateDefaultModel = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useCallback((type: ModelTypeEnum) => {
|
||||
queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) })
|
||||
}, [queryClient])
|
||||
}
|
||||
|
||||
export const useAnthropicBuyQuota = () => {
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
@@ -314,7 +322,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
||||
}
|
||||
|
||||
export const useRefreshModel = () => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const expandModelProviderList = useExpandModelProviderList()
|
||||
const queryClient = useQueryClient()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const handleRefreshModel = useCallback((
|
||||
@@ -322,6 +331,19 @@ export const useRefreshModel = () => {
|
||||
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
refreshModelList?: boolean,
|
||||
) => {
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
|
||||
updateModelProviders()
|
||||
|
||||
provider.supported_model_types.forEach((type) => {
|
||||
@@ -329,15 +351,17 @@ export const useRefreshModel = () => {
|
||||
})
|
||||
|
||||
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
expandModelProviderList(provider.provider)
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
|
||||
if (CustomConfigurationModelFixedFields?.__model_type)
|
||||
updateModelList(CustomConfigurationModelFixedFields.__model_type)
|
||||
}
|
||||
}, [eventEmitter, updateModelList, updateModelProviders])
|
||||
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
|
||||
|
||||
return {
|
||||
handleRefreshModel,
|
||||
|
||||
@@ -7,16 +7,7 @@ import {
|
||||
} from './declarations'
|
||||
import ModelProviderPage from './index'
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isValidatingCurrentWorkspace: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockGlobalState = {
|
||||
systemFeatures: { enable_marketplace: true },
|
||||
}
|
||||
let mockEnableMarketplace = true
|
||||
|
||||
const mockQuotaConfig = {
|
||||
quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
@@ -28,7 +19,11 @@ const mockQuotaConfig = {
|
||||
}
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (s: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector(mockGlobalState),
|
||||
useSystemFeaturesQuery: () => ({
|
||||
data: {
|
||||
enable_marketplace: mockEnableMarketplace,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockProviders = [
|
||||
@@ -60,13 +55,16 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockDefaultModelState = {
|
||||
data: null,
|
||||
isLoading: false,
|
||||
const mockDefaultModels: Record<string, { data: unknown, isLoading: boolean }> = {
|
||||
'llm': { data: null, isLoading: false },
|
||||
'text-embedding': { data: null, isLoading: false },
|
||||
'rerank': { data: null, isLoading: false },
|
||||
'speech2text': { data: null, isLoading: false },
|
||||
'tts': { data: null, isLoading: false },
|
||||
}
|
||||
|
||||
vi.mock('./hooks', () => ({
|
||||
useDefaultModel: () => mockDefaultModelState,
|
||||
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
|
||||
}))
|
||||
|
||||
vi.mock('./install-from-marketplace', () => ({
|
||||
@@ -85,13 +83,18 @@ vi.mock('./system-model-selector', () => ({
|
||||
default: () => <div data-testid="system-model-selector" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useCheckInstalled: () => ({ data: undefined }),
|
||||
}))
|
||||
|
||||
describe('ModelProviderPage', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
mockGlobalState.systemFeatures.enable_marketplace = true
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = false
|
||||
mockEnableMarketplace = true
|
||||
Object.keys(mockDefaultModels).forEach((key) => {
|
||||
mockDefaultModels[key] = { data: null, isLoading: false }
|
||||
})
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'openai',
|
||||
label: { en_US: 'OpenAI' },
|
||||
@@ -149,13 +152,76 @@ describe('ModelProviderPage', () => {
|
||||
})
|
||||
|
||||
it('should hide marketplace section when marketplace feature is disabled', () => {
|
||||
mockGlobalState.systemFeatures.enable_marketplace = false
|
||||
mockEnableMarketplace = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByTestId('install-from-marketplace')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('system model config status', () => {
|
||||
it('should not show top warning when no configured providers exist (empty state card handles it)', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'anthropic',
|
||||
label: { en_US: 'Anthropic' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.emptyProviderTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show none-configured warning when providers exist but no default models set', () => {
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.getByText('common.modelProvider.noneConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show partially-configured warning when some default models are set', () => {
|
||||
mockDefaultModels.llm = {
|
||||
data: { model: 'gpt-4', model_type: 'llm', provider: { provider: 'openai', icon_small: { en_US: '' } } },
|
||||
isLoading: false,
|
||||
}
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show warning when all default models are configured', () => {
|
||||
const makeModel = (model: string, type: string) => ({
|
||||
data: { model, model_type: type, provider: { provider: 'openai', icon_small: { en_US: '' } } },
|
||||
isLoading: false,
|
||||
})
|
||||
mockDefaultModels.llm = makeModel('gpt-4', 'llm')
|
||||
mockDefaultModels['text-embedding'] = makeModel('text-embedding-3', 'text-embedding')
|
||||
mockDefaultModels.rerank = makeModel('rerank-v3', 'rerank')
|
||||
mockDefaultModels.speech2text = makeModel('whisper-1', 'speech2text')
|
||||
mockDefaultModels.tts = makeModel('tts-1', 'tts')
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show warning while loading', () => {
|
||||
Object.keys(mockDefaultModels).forEach((key) => {
|
||||
mockDefaultModels[key] = { data: null, isLoading: true }
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should prioritize fixed providers in visible order', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'zeta-provider',
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import type {
|
||||
ModelProvider,
|
||||
} from './declarations'
|
||||
import {
|
||||
RiAlertFill,
|
||||
RiBrainLine,
|
||||
} from '@remixicon/react'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useSystemFeaturesQuery } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useCheckInstalled } from '@/service/use-plugins'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import {
|
||||
CustomConfigurationStatusEnum,
|
||||
@@ -24,6 +21,9 @@ import InstallFromMarketplace from './install-from-marketplace'
|
||||
import ProviderAddedCard from './provider-added-card'
|
||||
import QuotaPanel from './provider-added-card/quota-panel'
|
||||
import SystemModelSelector from './system-model-selector'
|
||||
import { providerToPluginId } from './utils'
|
||||
|
||||
type SystemModelConfigStatus = 'no-provider' | 'none-configured' | 'partially-configured' | 'fully-configured'
|
||||
|
||||
type Props = {
|
||||
searchText: string
|
||||
@@ -34,20 +34,35 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
|
||||
const ModelProviderPage = ({ searchText }: Props) => {
|
||||
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
|
||||
const { t } = useTranslation()
|
||||
const { mutateCurrentWorkspace, isValidatingCurrentWorkspace } = useAppContext()
|
||||
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
|
||||
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
|
||||
const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text)
|
||||
const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts)
|
||||
const { modelProviders: providers } = useProviderContext()
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: systemFeatures } = useSystemFeaturesQuery()
|
||||
|
||||
const allPluginIds = useMemo(() => {
|
||||
return [...new Set(providers.map(p => providerToPluginId(p.provider)).filter(Boolean))]
|
||||
}, [providers])
|
||||
const { data: installedPlugins } = useCheckInstalled({
|
||||
pluginIds: allPluginIds,
|
||||
enabled: allPluginIds.length > 0,
|
||||
})
|
||||
const pluginDetailMap = useMemo(() => {
|
||||
const map = new Map<string, PluginDetail>()
|
||||
if (installedPlugins?.plugins) {
|
||||
for (const plugin of installedPlugins.plugins)
|
||||
map.set(plugin.plugin_id, plugin)
|
||||
}
|
||||
return map
|
||||
}, [installedPlugins])
|
||||
const enableMarketplace = systemFeatures?.enable_marketplace ?? false
|
||||
const isDefaultModelLoading = isTextGenerationDefaultModelLoading
|
||||
|| isEmbeddingsDefaultModelLoading
|
||||
|| isRerankDefaultModelLoading
|
||||
|| isSpeech2textDefaultModelLoading
|
||||
|| isTTSDefaultModelLoading
|
||||
const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel
|
||||
const [configuredProviders, notConfiguredProviders] = useMemo(() => {
|
||||
const configuredProviders: ModelProvider[] = []
|
||||
const notConfiguredProviders: ModelProvider[] = []
|
||||
@@ -79,6 +94,26 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
|
||||
return [configuredProviders, notConfiguredProviders]
|
||||
}, [providers])
|
||||
|
||||
const systemModelConfigStatus: SystemModelConfigStatus = useMemo(() => {
|
||||
const defaultModels = [textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel]
|
||||
const configuredCount = defaultModels.filter(Boolean).length
|
||||
if (configuredCount === 0 && configuredProviders.length === 0)
|
||||
return 'no-provider'
|
||||
if (configuredCount === 0)
|
||||
return 'none-configured'
|
||||
if (configuredCount < defaultModels.length)
|
||||
return 'partially-configured'
|
||||
return 'fully-configured'
|
||||
}, [configuredProviders, textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel])
|
||||
const warningTextKey
|
||||
= systemModelConfigStatus === 'none-configured'
|
||||
? 'modelProvider.noneConfigured'
|
||||
: systemModelConfigStatus === 'partially-configured'
|
||||
? 'modelProvider.notConfigured'
|
||||
: null
|
||||
const showWarning = !isDefaultModelLoading && !!warningTextKey
|
||||
|
||||
const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => {
|
||||
const filteredConfiguredProviders = configuredProviders.filter(
|
||||
provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
|
||||
@@ -92,28 +127,24 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
|
||||
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
|
||||
|
||||
useEffect(() => {
|
||||
mutateCurrentWorkspace()
|
||||
}, [mutateCurrentWorkspace])
|
||||
|
||||
return (
|
||||
<div className="relative -mt-2 pt-1">
|
||||
<div className={cn('mb-2 flex items-center')}>
|
||||
<div className="system-md-semibold grow text-text-primary">{t('modelProvider.models', { ns: 'common' })}</div>
|
||||
<div className="grow text-text-primary system-md-semibold">{t('modelProvider.models', { ns: 'common' })}</div>
|
||||
<div className={cn(
|
||||
'relative flex shrink-0 items-center justify-end gap-2 rounded-lg border border-transparent p-px',
|
||||
defaultModelNotConfigured && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
|
||||
showWarning && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
|
||||
)}
|
||||
>
|
||||
{defaultModelNotConfigured && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
|
||||
{defaultModelNotConfigured && (
|
||||
<div className="system-xs-medium flex items-center gap-1 text-text-primary">
|
||||
<RiAlertFill className="h-4 w-4 text-text-warning-secondary" />
|
||||
<span className="max-w-[460px] truncate" title={t('modelProvider.notConfigured', { ns: 'common' })}>{t('modelProvider.notConfigured', { ns: 'common' })}</span>
|
||||
{showWarning && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
|
||||
{showWarning && (
|
||||
<div className="flex items-center gap-1 text-text-primary system-xs-medium">
|
||||
<span className="i-ri-alert-fill h-4 w-4 text-text-warning-secondary" />
|
||||
<span className="max-w-[460px] truncate" title={t(warningTextKey, { ns: 'common' })}>{t(warningTextKey, { ns: 'common' })}</span>
|
||||
</div>
|
||||
)}
|
||||
<SystemModelSelector
|
||||
notConfigured={defaultModelNotConfigured}
|
||||
notConfigured={showWarning}
|
||||
textGenerationDefaultModel={textGenerationDefaultModel}
|
||||
embeddingsDefaultModel={embeddingsDefaultModel}
|
||||
rerankDefaultModel={rerankDefaultModel}
|
||||
@@ -123,14 +154,14 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} isLoading={isValidatingCurrentWorkspace} />}
|
||||
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} />}
|
||||
{!filteredConfiguredProviders?.length && (
|
||||
<div className="mb-2 rounded-[10px] bg-workflow-process-bg p-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur">
|
||||
<RiBrainLine className="h-5 w-5 text-text-primary" />
|
||||
<span className="i-ri-brain-line h-5 w-5 text-text-primary" />
|
||||
</div>
|
||||
<div className="system-sm-medium mt-2 text-text-secondary">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
|
||||
<div className="system-xs-regular mt-1 text-text-tertiary">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
|
||||
<div className="mt-2 text-text-secondary system-sm-medium">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
|
||||
<div className="mt-1 text-text-tertiary system-xs-regular">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
|
||||
</div>
|
||||
)}
|
||||
{!!filteredConfiguredProviders?.length && (
|
||||
@@ -139,26 +170,28 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
<ProviderAddedCard
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{!!filteredNotConfiguredProviders?.length && (
|
||||
<>
|
||||
<div className="system-md-semibold mb-2 flex items-center pt-2 text-text-primary">{t('modelProvider.toBeConfigured', { ns: 'common' })}</div>
|
||||
<div className="mb-2 flex items-center pt-2 text-text-primary system-md-semibold">{t('modelProvider.toBeConfigured', { ns: 'common' })}</div>
|
||||
<div className="relative">
|
||||
{filteredNotConfiguredProviders?.map(provider => (
|
||||
<ProviderAddedCard
|
||||
notConfigured
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{
|
||||
enable_marketplace && (
|
||||
enableMarketplace && (
|
||||
<InstallFromMarketplace
|
||||
providers={providers}
|
||||
searchText={searchText}
|
||||
|
||||
@@ -2,12 +2,6 @@ import type { Credential } from '../../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import CredentialItem from './credential-item'
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiCheckLine: () => <div data-testid="check-icon" />,
|
||||
RiDeleteBinLine: () => <div data-testid="delete-icon" />,
|
||||
RiEqualizer2Line: () => <div data-testid="edit-icon" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/indicator', () => ({
|
||||
default: () => <div data-testid="indicator" />,
|
||||
}))
|
||||
@@ -61,8 +55,12 @@ describe('CredentialItem', () => {
|
||||
|
||||
render(<CredentialItem credential={credential} onEdit={onEdit} onDelete={onDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('edit-icon').closest('button') as HTMLButtonElement)
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const editButton = buttons.find(b => b.querySelector('.i-ri-equalizer-2-line'))!
|
||||
const deleteButton = buttons.find(b => b.querySelector('.i-ri-delete-bin-line'))!
|
||||
|
||||
fireEvent.click(editButton)
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onEdit).toHaveBeenCalledWith(credential)
|
||||
expect(onDelete).toHaveBeenCalledWith(credential)
|
||||
@@ -81,7 +79,10 @@ describe('CredentialItem', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
const deleteButton = screen.getAllByRole('button')
|
||||
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
|
||||
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onDelete).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
import type { Credential } from '../../declarations'
|
||||
import {
|
||||
RiCheckLine,
|
||||
RiDeleteBinLine,
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
@@ -11,7 +6,7 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@@ -56,7 +51,7 @@ const CredentialItem = ({
|
||||
key={credential.credential_id}
|
||||
className={cn(
|
||||
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
|
||||
(disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50',
|
||||
(disabled || credential.not_allowed_to_use) ? 'cursor-not-allowed opacity-50' : onItemClick && 'cursor-pointer',
|
||||
)}
|
||||
onClick={() => {
|
||||
if (disabled || credential.not_allowed_to_use)
|
||||
@@ -70,7 +65,7 @@ const CredentialItem = ({
|
||||
<div className="h-4 w-4">
|
||||
{
|
||||
selectedCredentialId === credential.credential_id && (
|
||||
<RiCheckLine className="h-4 w-4 text-text-accent" />
|
||||
<span className="i-ri-check-line h-4 w-4 text-text-accent" />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
@@ -78,7 +73,7 @@ const CredentialItem = ({
|
||||
}
|
||||
<Indicator className="ml-2 mr-1.5 shrink-0" />
|
||||
<div
|
||||
className="system-md-regular truncate text-text-secondary"
|
||||
className="truncate text-text-secondary system-md-regular"
|
||||
title={credential.credential_name}
|
||||
>
|
||||
{credential.credential_name}
|
||||
@@ -96,38 +91,50 @@ const CredentialItem = ({
|
||||
<div className="ml-2 hidden shrink-0 items-center group-hover:flex">
|
||||
{
|
||||
!disableEdit && !credential.not_allowed_to_use && (
|
||||
<Tooltip popupContent={t('operation.edit', { ns: 'common' })}>
|
||||
<ActionButton
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onEdit?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
|
||||
</ActionButton>
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
render={(
|
||||
<ActionButton
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onEdit?.(credential)
|
||||
}}
|
||||
>
|
||||
<span className="i-ri-equalizer-2-line h-4 w-4 text-text-tertiary" />
|
||||
</ActionButton>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>{t('operation.edit', { ns: 'common' })}</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
{
|
||||
!disableDelete && (
|
||||
<Tooltip popupContent={disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}>
|
||||
<ActionButton
|
||||
className="hover:bg-transparent"
|
||||
onClick={(e) => {
|
||||
if (disabled || disableDeleteWhenSelected)
|
||||
return
|
||||
e.stopPropagation()
|
||||
onDelete?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiDeleteBinLine className={cn(
|
||||
'h-4 w-4 text-text-tertiary',
|
||||
!disableDeleteWhenSelected && 'hover:text-text-destructive',
|
||||
disableDeleteWhenSelected && 'opacity-50',
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
render={(
|
||||
<ActionButton
|
||||
className="hover:bg-transparent"
|
||||
onClick={(e) => {
|
||||
if (disabled || disableDeleteWhenSelected)
|
||||
return
|
||||
e.stopPropagation()
|
||||
onDelete?.(credential)
|
||||
}}
|
||||
>
|
||||
<span className={cn(
|
||||
'i-ri-delete-bin-line h-4 w-4 text-text-tertiary',
|
||||
!disableDeleteWhenSelected && 'hover:text-text-destructive',
|
||||
disableDeleteWhenSelected && 'opacity-50',
|
||||
)}
|
||||
/>
|
||||
</ActionButton>
|
||||
)}
|
||||
/>
|
||||
</ActionButton>
|
||||
/>
|
||||
<TooltipContent>
|
||||
{disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
@@ -139,8 +146,9 @@ const CredentialItem = ({
|
||||
|
||||
if (credential.not_allowed_to_use) {
|
||||
return (
|
||||
<Tooltip popupContent={t('auth.customCredentialUnavailable', { ns: 'plugin' })}>
|
||||
{Item}
|
||||
<Tooltip>
|
||||
<TooltipTrigger render={Item} />
|
||||
<TooltipContent>{t('auth.customCredentialUnavailable', { ns: 'plugin' })}</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ const ModelBadge: FC<ModelBadgeProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
return (
|
||||
<div className={cn('system-2xs-medium-uppercase flex h-[18px] cursor-default items-center rounded-[5px] border border-divider-deep px-1 text-text-tertiary', className)}>
|
||||
<div className={cn('inline-flex h-[18px] shrink-0 items-center justify-center whitespace-nowrap rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-[5px] text-text-tertiary system-2xs-medium-uppercase', className)}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { Credential, CredentialFormSchema, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CurrentSystemQuotaTypeEnum,
|
||||
@@ -243,9 +243,10 @@ describe('ModelModal', () => {
|
||||
const credential: Credential = { credential_id: 'cred-1' }
|
||||
const { onCancel } = renderModal({ credential })
|
||||
|
||||
expect(screen.getByText('common.modelProvider.confirmDelete')).toBeInTheDocument()
|
||||
const alertDialog = screen.getByRole('alertdialog', { hidden: true })
|
||||
expect(alertDialog).toHaveTextContent('common.modelProvider.confirmDelete')
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
fireEvent.click(within(alertDialog).getByRole('button', { hidden: true, name: 'common.operation.confirm' }))
|
||||
|
||||
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalledTimes(1)
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
|
||||
@@ -9,11 +9,9 @@ import type {
|
||||
FormRefObject,
|
||||
FormSchema,
|
||||
} from '@/app/components/base/form/types'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
@@ -21,15 +19,23 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
|
||||
import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import {
|
||||
useAuth,
|
||||
useCredentialData,
|
||||
@@ -197,7 +203,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="title-2xl-semi-bold text-text-primary">
|
||||
<div className="text-text-primary title-2xl-semi-bold">
|
||||
{label}
|
||||
</div>
|
||||
)
|
||||
@@ -206,7 +212,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
const modalDesc = useMemo(() => {
|
||||
if (providerFormSchemaPredefined) {
|
||||
return (
|
||||
<div className="system-xs-regular mt-1 text-text-tertiary">
|
||||
<div className="mt-1 text-text-tertiary system-xs-regular">
|
||||
{t('modelProvider.auth.apiKeyModal.desc', { ns: 'common' })}
|
||||
</div>
|
||||
)
|
||||
@@ -223,7 +229,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
className="mr-2 h-4 w-4 shrink-0"
|
||||
provider={provider}
|
||||
/>
|
||||
<div className="system-md-regular mr-1 text-text-secondary">{renderI18nObject(provider.label)}</div>
|
||||
<div className="mr-1 text-text-secondary system-md-regular">{renderI18nObject(provider.label)}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -235,7 +241,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
provider={provider}
|
||||
modelName={model.model}
|
||||
/>
|
||||
<div className="system-md-regular mr-1 text-text-secondary">{model.model}</div>
|
||||
<div className="mr-1 text-text-secondary system-md-regular">{model.model}</div>
|
||||
<Badge>{model.model_type}</Badge>
|
||||
</div>
|
||||
)
|
||||
@@ -275,174 +281,171 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
}, [])
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.stopPropagation()
|
||||
onCancel()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown, true)
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown, true)
|
||||
}
|
||||
const handleOpenChange = useCallback((open: boolean) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}, [onCancel])
|
||||
|
||||
const handleConfirmOpenChange = useCallback((open: boolean) => {
|
||||
if (!open)
|
||||
closeConfirmDelete()
|
||||
}, [closeConfirmDelete])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem open>
|
||||
<PortalToFollowElemContent className="z-[60] h-full w-full">
|
||||
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
|
||||
<div className="relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl">
|
||||
<div
|
||||
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
<div className="p-6 pb-3">
|
||||
{modalTitle}
|
||||
{modalDesc}
|
||||
{modalModel}
|
||||
</div>
|
||||
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
|
||||
{
|
||||
mode === ModelModalModeEnum.configCustomModel && (
|
||||
<AuthForm
|
||||
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={modelNameAndTypeFormValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef1}
|
||||
onChange={handleModelNameAndTypeChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
mode === ModelModalModeEnum.addCustomModelToModelList && (
|
||||
<CredentialSelector
|
||||
credentials={available_credentials || []}
|
||||
onSelect={setSelectedCredential}
|
||||
selectedCredential={selectedCredential}
|
||||
disabled={isLoading}
|
||||
notAllowAddNewCredential={notAllowCustomCredential}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
showCredentialLabel && (
|
||||
<div className="system-xs-medium-uppercase mb-3 mt-6 flex items-center text-text-tertiary">
|
||||
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
|
||||
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-3 flex items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading
|
||||
&& showCredentialForm
|
||||
&& (
|
||||
<AuthForm
|
||||
formSchemas={formSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
showRadioUI: formSchema.type === FormTypeEnum.radio,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={formValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef2}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className="flex justify-between p-6 pt-5">
|
||||
{
|
||||
(provider.help && (provider.help.title || provider.help.url))
|
||||
? (
|
||||
<a
|
||||
href={provider.help?.url[language] || provider.help?.url.en_US}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="system-xs-regular mt-2 inline-block align-middle text-text-accent"
|
||||
onClick={e => !provider.help.url && e.preventDefault()}
|
||||
>
|
||||
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
|
||||
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
|
||||
</a>
|
||||
)
|
||||
: <div />
|
||||
}
|
||||
<div className="ml-2 flex items-center justify-end space-x-2">
|
||||
{
|
||||
isEditMode && (
|
||||
<Button
|
||||
variant="warning"
|
||||
onClick={() => openConfirmDelete(credential, model)}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={isLoading || doingAction}
|
||||
>
|
||||
{saveButtonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
|
||||
<div className="border-t-[0.5px] border-t-divider-regular">
|
||||
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
|
||||
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
|
||||
{t('modelProvider.encrypted.front', { ns: 'common' })}
|
||||
<a
|
||||
className="mx-1 text-text-accent"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
|
||||
>
|
||||
PKCS1_OAEP
|
||||
</a>
|
||||
{t('modelProvider.encrypted.back', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<Dialog open onOpenChange={handleOpenChange}>
|
||||
<DialogContent
|
||||
backdropProps={{ forceRender: true }}
|
||||
className="w-[640px] max-w-[640px] overflow-hidden p-0"
|
||||
>
|
||||
<DialogCloseButton className="right-5 top-5 h-8 w-8" />
|
||||
<div className="p-6 pb-3">
|
||||
{modalTitle}
|
||||
{modalDesc}
|
||||
{modalModel}
|
||||
</div>
|
||||
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
|
||||
{
|
||||
deleteCredentialId && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t('modelProvider.confirmDelete', { ns: 'common' })}
|
||||
isDisabled={doingAction}
|
||||
onCancel={closeConfirmDelete}
|
||||
onConfirm={handleDeleteCredential}
|
||||
mode === ModelModalModeEnum.configCustomModel && (
|
||||
<AuthForm
|
||||
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={modelNameAndTypeFormValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef1}
|
||||
onChange={handleModelNameAndTypeChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
mode === ModelModalModeEnum.addCustomModelToModelList && (
|
||||
<CredentialSelector
|
||||
credentials={available_credentials || []}
|
||||
onSelect={setSelectedCredential}
|
||||
selectedCredential={selectedCredential}
|
||||
disabled={isLoading}
|
||||
notAllowAddNewCredential={notAllowCustomCredential}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
showCredentialLabel && (
|
||||
<div className="mb-3 mt-6 flex items-center text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
|
||||
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-3 flex items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading
|
||||
&& showCredentialForm
|
||||
&& (
|
||||
<AuthForm
|
||||
formSchemas={formSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
showRadioUI: formSchema.type === FormTypeEnum.radio,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={formValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef2}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
<div className="flex justify-between p-6 pt-5">
|
||||
{
|
||||
(provider.help && (provider.help.title || provider.help.url))
|
||||
? (
|
||||
<a
|
||||
href={provider.help?.url[language] || provider.help?.url.en_US}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mt-2 inline-block align-middle text-text-accent system-xs-regular"
|
||||
onClick={e => !provider.help.url && e.preventDefault()}
|
||||
>
|
||||
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
|
||||
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
|
||||
</a>
|
||||
)
|
||||
: <div />
|
||||
}
|
||||
<div className="ml-2 flex items-center justify-end space-x-2">
|
||||
{
|
||||
isEditMode && (
|
||||
<Button
|
||||
variant="warning"
|
||||
onClick={() => openConfirmDelete(credential, model)}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={isLoading || doingAction}
|
||||
>
|
||||
{saveButtonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
|
||||
<div className="border-t-[0.5px] border-t-divider-regular">
|
||||
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
|
||||
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
|
||||
{t('modelProvider.encrypted.front', { ns: 'common' })}
|
||||
<a
|
||||
className="mx-1 text-text-accent"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
|
||||
>
|
||||
PKCS1_OAEP
|
||||
</a>
|
||||
{t('modelProvider.encrypted.back', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</DialogContent>
|
||||
<AlertDialog open={!!deleteCredentialId} onOpenChange={handleConfirmOpenChange}>
|
||||
<AlertDialogContent backdropProps={{ forceRender: true }}>
|
||||
<div className="flex flex-col gap-2 p-6 pb-4">
|
||||
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
|
||||
{t('modelProvider.confirmDelete', { ns: 'common' })}
|
||||
</AlertDialogTitle>
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton>{t('operation.cancel', { ns: 'common' })}</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton
|
||||
disabled={doingAction}
|
||||
onClick={handleDeleteCredential}
|
||||
>
|
||||
{t('operation.confirm', { ns: 'common' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import type { FC } from 'react'
|
||||
import { RiEqualizer2Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type ModelTriggerProps = {
|
||||
@@ -16,24 +14,26 @@ const ModelTrigger: FC<ModelTriggerProps> = ({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
|
||||
'group flex h-8 cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
|
||||
open && 'bg-components-input-bg-hover',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex grow items-center">
|
||||
<div className="mr-1.5 flex h-4 w-4 items-center justify-center rounded-[5px] border border-dashed border-divider-regular">
|
||||
<CubeOutline className="h-3 w-3 text-text-quaternary" />
|
||||
<div className="flex h-6 w-6 items-center justify-center">
|
||||
<div className="flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle">
|
||||
<span className="i-ri-brain-2-line h-3.5 w-3.5 text-text-quaternary" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex grow items-center gap-1 truncate px-1 py-[3px]">
|
||||
<div
|
||||
className="truncate text-[13px] text-text-tertiary"
|
||||
className="grow truncate text-[13px] text-text-quaternary"
|
||||
title="Configure model"
|
||||
>
|
||||
{t('detailPanel.configureModel', { ns: 'plugin' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<RiEqualizer2Line className="h-3.5 w-3.5 text-text-tertiary" />
|
||||
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<span className="i-ri-arrow-down-s-line h-3.5 w-3.5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import AddModelButton from './add-model-button'
|
||||
|
||||
describe('AddModelButton', () => {
|
||||
it('should render button with text', () => {
|
||||
render(<AddModelButton onClick={vi.fn()} />)
|
||||
expect(screen.getByText('common.modelProvider.addModel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onClick when clicked', () => {
|
||||
const handleClick = vi.fn()
|
||||
render(<AddModelButton onClick={handleClick} />)
|
||||
const button = screen.getByText('common.modelProvider.addModel')
|
||||
fireEvent.click(button)
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { PlusCircle } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type AddModelButtonProps = {
|
||||
className?: string
|
||||
onClick: () => void
|
||||
}
|
||||
const AddModelButton: FC<AddModelButtonProps> = ({
|
||||
className,
|
||||
onClick,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<span
|
||||
className={cn('system-xs-medium flex h-6 shrink-0 cursor-pointer items-center rounded-md px-1.5 text-text-tertiary hover:bg-components-button-ghost-bg-hover hover:text-components-button-ghost-text', className)}
|
||||
onClick={onClick}
|
||||
>
|
||||
<PlusCircle className="mr-1 h-3 w-3" />
|
||||
{t('modelProvider.addModel', { ns: 'common' })}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
export default AddModelButton
|
||||
@@ -1,51 +1,54 @@
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { changeModelProviderPriority } from '@/service/common'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CurrentSystemQuotaTypeEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
PreferredProviderTypeEnum,
|
||||
} from '../declarations'
|
||||
import CredentialPanel from './credential-panel'
|
||||
|
||||
const mockEventEmitter = { emit: vi.fn() }
|
||||
const mockNotify = vi.fn()
|
||||
const mockUpdateModelList = vi.fn()
|
||||
const mockUpdateModelProviders = vi.fn()
|
||||
const mockCredentialStatus = {
|
||||
hasCredential: true,
|
||||
authorized: true,
|
||||
authRemoved: false,
|
||||
current_credential_name: 'test-credential',
|
||||
notAllowedToUse: false,
|
||||
}
|
||||
const {
|
||||
mockToastNotify,
|
||||
mockUpdateModelList,
|
||||
mockUpdateModelProviders,
|
||||
mockTrialCredits,
|
||||
mockChangePriorityFn,
|
||||
} = vi.hoisted(() => ({
|
||||
mockToastNotify: vi.fn(),
|
||||
mockUpdateModelList: vi.fn(),
|
||||
mockUpdateModelProviders: vi.fn(),
|
||||
mockTrialCredits: { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false, nextCreditResetDate: undefined },
|
||||
mockChangePriorityFn: vi.fn().mockResolvedValue({ result: 'success' }),
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/config')>()
|
||||
return {
|
||||
...actual,
|
||||
IS_CLOUD_EDITION: true,
|
||||
}
|
||||
return { ...actual, IS_CLOUD_EDITION: true }
|
||||
})
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
default: { notify: mockToastNotify },
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
eventEmitter: mockEventEmitter,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
changeModelProviderPriority: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth', () => ({
|
||||
ConfigProvider: () => <div data-testid="config-provider" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth/hooks', () => ({
|
||||
useCredentialStatus: () => mockCredentialStatus,
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
modelProviders: {
|
||||
models: {
|
||||
queryKey: ({ input }: { input: { params: { provider: string } } }) => ['console', 'modelProviders', 'models', input.params.provider],
|
||||
},
|
||||
changePreferredProviderType: {
|
||||
mutationOptions: (opts: Record<string, unknown>) => ({
|
||||
mutationFn: (...args: unknown[]) => {
|
||||
mockChangePriorityFn(...args)
|
||||
return Promise.resolve({ result: 'success' })
|
||||
},
|
||||
...opts,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
@@ -53,93 +56,375 @@ vi.mock('../hooks', () => ({
|
||||
useUpdateModelProviders: () => mockUpdateModelProviders,
|
||||
}))
|
||||
|
||||
vi.mock('./priority-selector', () => ({
|
||||
default: ({ value, onSelect }: { value: string, onSelect: (key: string) => void }) => (
|
||||
<button data-testid="priority-selector" onClick={() => onSelect('custom')}>
|
||||
Priority Selector
|
||||
{' '}
|
||||
{value}
|
||||
</button>
|
||||
vi.mock('./use-trial-credits', () => ({
|
||||
useTrialCredits: () => mockTrialCredits,
|
||||
}))
|
||||
|
||||
vi.mock('./model-auth-dropdown', () => ({
|
||||
default: ({ state, onChangePriority }: { state: { variant: string, hasCredentials: boolean }, onChangePriority: (key: string) => void }) => (
|
||||
<div data-testid="model-auth-dropdown" data-variant={state.variant}>
|
||||
<button data-testid="change-priority-btn" onClick={() => onChangePriority('custom')}>
|
||||
Change Priority
|
||||
</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('./priority-use-tip', () => ({
|
||||
default: () => <div data-testid="priority-use-tip">Priority Tip</div>,
|
||||
vi.mock('@/app/components/header/indicator', () => ({
|
||||
default: ({ color }: { color: string }) => <div data-testid="indicator" data-color={color} />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/indicator', () => ({
|
||||
default: ({ color }: { color: string }) => <div data-testid="indicator">{color}</div>,
|
||||
vi.mock('@/app/components/base/icons/src/vender/line/alertsAndFeedback/Warning', () => ({
|
||||
default: (props: Record<string, unknown>) => <div data-testid="warning-icon" className={props.className as string} />,
|
||||
}))
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false, gcTime: 0 },
|
||||
mutations: { retry: false },
|
||||
},
|
||||
})
|
||||
|
||||
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: { credential_form_schemas: [] },
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'test-credential',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'test-credential' }],
|
||||
},
|
||||
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
|
||||
supported_model_types: ['llm'],
|
||||
...overrides,
|
||||
} as unknown as ModelProvider)
|
||||
|
||||
const renderWithQueryClient = (provider: ModelProvider) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<CredentialPanel provider={provider} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('CredentialPanel', () => {
|
||||
const mockProvider: ModelProvider = {
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: true,
|
||||
custom_configuration: { status: 'active' },
|
||||
system_configuration: { enabled: true },
|
||||
preferred_provider_type: 'system',
|
||||
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
|
||||
supported_model_types: ['gpt-4'],
|
||||
} as unknown as ModelProvider
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
Object.assign(mockCredentialStatus, {
|
||||
hasCredential: true,
|
||||
authorized: true,
|
||||
authRemoved: false,
|
||||
current_credential_name: 'test-credential',
|
||||
notAllowedToUse: false,
|
||||
Object.assign(mockTrialCredits, { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false })
|
||||
})
|
||||
|
||||
describe('Text label variants', () => {
|
||||
it('should show "AI credits in use" for credits-active variant', () => {
|
||||
renderWithQueryClient(createProvider())
|
||||
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Credits exhausted" for credits-exhausted variant (no credentials)', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
mockTrialCredits.credits = 0
|
||||
renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByText(/quotaExhausted/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "No available usage" for no-usage variant (exhausted + credential unauthorized)', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1' }],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByText(/noAvailableUsage/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "AI credits in use" with warning for credits-fallback (custom priority, no credentials, credits available)', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "AI credits in use" with warning for credits-fallback (custom priority, credential unauthorized, credits available)', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1' }],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByText(/aiCreditsInUse/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show warning icon for credits-fallback variant', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('warning-icon')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show credential name and configuration actions', () => {
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
describe('Status label variants', () => {
|
||||
it('should show green indicator and credential name for api-fallback (exhausted + authorized key)', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider())
|
||||
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'green')
|
||||
expect(screen.getByText('test-credential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('test-credential')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('config-provider')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('priority-selector')).toBeInTheDocument()
|
||||
})
|
||||
it('should show warning icon for api-fallback variant', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider())
|
||||
expect(screen.getByTestId('warning-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show unauthorized status label when credential is missing', () => {
|
||||
mockCredentialStatus.hasCredential = false
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
it('should show green indicator for api-active (custom priority + authorized)', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
}))
|
||||
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'green')
|
||||
expect(screen.getByText('test-credential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.unAuthorized/)).toBeInTheDocument()
|
||||
})
|
||||
it('should NOT show warning icon for api-active variant', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
}))
|
||||
expect(screen.queryByTestId('warning-icon')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show removed credential label and priority tip for custom preference', () => {
|
||||
mockCredentialStatus.authorized = false
|
||||
mockCredentialStatus.authRemoved = true
|
||||
render(<CredentialPanel provider={{ ...mockProvider, preferred_provider_type: 'custom' } as ModelProvider} />)
|
||||
|
||||
expect(screen.getByText(/modelProvider\.auth\.authRemoved/)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('priority-use-tip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should change priority and refresh related data after success', async () => {
|
||||
const mockChangePriority = changeModelProviderPriority as ReturnType<typeof vi.fn>
|
||||
mockChangePriority.mockResolvedValue({ result: 'success' })
|
||||
render(<CredentialPanel provider={mockProvider} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('priority-selector'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockChangePriority).toHaveBeenCalled()
|
||||
expect(mockNotify).toHaveBeenCalled()
|
||||
expect(mockUpdateModelProviders).toHaveBeenCalled()
|
||||
expect(mockUpdateModelList).toHaveBeenCalledWith('gpt-4')
|
||||
expect(mockEventEmitter.emit).toHaveBeenCalled()
|
||||
it('should show red indicator and "Unavailable" for api-unavailable (exhausted + named unauthorized key)', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: 'Bad Key',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('indicator')).toHaveAttribute('data-color', 'red')
|
||||
expect(screen.getByText(/unavailable/i)).toBeInTheDocument()
|
||||
expect(screen.getByText('Bad Key')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should render standalone priority selector without provider schema', () => {
|
||||
const providerNoSchema = {
|
||||
...mockProvider,
|
||||
provider_credential_schema: null,
|
||||
} as unknown as ModelProvider
|
||||
render(<CredentialPanel provider={providerNoSchema} />)
|
||||
expect(screen.getByTestId('priority-selector')).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('config-provider')).not.toBeInTheDocument()
|
||||
describe('Destructive styling', () => {
|
||||
it('should apply destructive container for credits-exhausted', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const { container } = renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should apply destructive container for no-usage variant', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const { container } = renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1' }],
|
||||
},
|
||||
}))
|
||||
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should apply destructive container for api-unavailable variant', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const { container } = renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: 'Bad Key',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
|
||||
},
|
||||
}))
|
||||
expect(container.querySelector('[class*="border-state-destructive"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should apply default container for credits-active', () => {
|
||||
const { container } = renderWithQueryClient(createProvider())
|
||||
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should apply default container for api-active', () => {
|
||||
const { container } = renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
}))
|
||||
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should apply default container for api-fallback', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const { container } = renderWithQueryClient(createProvider())
|
||||
expect(container.querySelector('[class*="bg-white"]')).toBeTruthy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Text color', () => {
|
||||
it('should use destructive text color for credits-exhausted label', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const { container } = renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(container.querySelector('.text-text-destructive')).toBeTruthy()
|
||||
})
|
||||
|
||||
it('should use secondary text color for credits-active label', () => {
|
||||
const { container } = renderWithQueryClient(createProvider())
|
||||
expect(container.querySelector('.text-text-secondary')).toBeTruthy()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Priority change', () => {
|
||||
it('should call mutation with correct params on priority change', async () => {
|
||||
renderWithQueryClient(createProvider())
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByTestId('change-priority-btn'))
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockChangePriorityFn.mock.calls[0]?.[0]).toEqual({
|
||||
params: { provider: 'test-provider' },
|
||||
body: { preferred_provider_type: 'custom' },
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should show success toast and refresh data after successful mutation', async () => {
|
||||
renderWithQueryClient(createProvider())
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByTestId('change-priority-btn'))
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'success' }),
|
||||
)
|
||||
expect(mockUpdateModelProviders).toHaveBeenCalled()
|
||||
expect(mockUpdateModelList).toHaveBeenCalledWith('llm')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('ModelAuthDropdown integration', () => {
|
||||
it('should pass credits-active variant to dropdown when credits available', () => {
|
||||
renderWithQueryClient(createProvider())
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-active')
|
||||
})
|
||||
|
||||
it('should pass api-fallback variant to dropdown when exhausted with valid key', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider())
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-fallback')
|
||||
})
|
||||
|
||||
it('should pass credits-exhausted variant when exhausted with no credentials', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-exhausted')
|
||||
})
|
||||
|
||||
it('should pass api-active variant for custom priority with authorized key', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-active')
|
||||
})
|
||||
|
||||
it('should pass credits-fallback variant for custom priority with no credentials and credits available', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-fallback')
|
||||
})
|
||||
|
||||
it('should pass credits-fallback variant for custom priority with named unauthorized key and credits available', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: 'Bad Key',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'credits-fallback')
|
||||
})
|
||||
|
||||
it('should pass no-usage variant when exhausted + credential but unauthorized', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
renderWithQueryClient(createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1' }],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'no-usage')
|
||||
})
|
||||
})
|
||||
|
||||
describe('apiKeyOnly priority (system disabled)', () => {
|
||||
it('should derive api-required-add when system config disabled and no credentials', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-required-add')
|
||||
expect(screen.getByText(/apiKeyRequired/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should derive api-active when system config disabled but has authorized key', () => {
|
||||
renderWithQueryClient(createProvider({
|
||||
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
|
||||
}))
|
||||
expect(screen.getByTestId('model-auth-dropdown')).toHaveAttribute('data-variant', 'api-active')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,149 +1,157 @@
|
||||
import type {
|
||||
ModelProvider,
|
||||
PreferredProviderTypeEnum,
|
||||
} from '../declarations'
|
||||
import { useMemo } from 'react'
|
||||
import type { CardVariant } from './use-credential-panel-state'
|
||||
import { useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { ConfigProvider } from '@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks'
|
||||
import Warning from '@/app/components/base/icons/src/vender/line/alertsAndFeedback/Warning'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { changeModelProviderPriority } from '@/service/common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
PreferredProviderTypeEnum,
|
||||
} from '../declarations'
|
||||
import {
|
||||
useUpdateModelList,
|
||||
useUpdateModelProviders,
|
||||
} from '../hooks'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './index'
|
||||
import PrioritySelector from './priority-selector'
|
||||
import PriorityUseTip from './priority-use-tip'
|
||||
import ModelAuthDropdown from './model-auth-dropdown'
|
||||
import SystemQuotaCard from './system-quota-card'
|
||||
import { isDestructiveVariant, useCredentialPanelState } from './use-credential-panel-state'
|
||||
|
||||
type CredentialPanelProps = {
|
||||
provider: ModelProvider
|
||||
}
|
||||
|
||||
const TEXT_LABEL_VARIANTS = new Set<CardVariant>([
|
||||
'credits-active',
|
||||
'credits-fallback',
|
||||
'credits-exhausted',
|
||||
'no-usage',
|
||||
'api-required-add',
|
||||
'api-required-configure',
|
||||
])
|
||||
|
||||
const CredentialPanel = ({
|
||||
provider,
|
||||
}: CredentialPanelProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const queryClient = useQueryClient()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const customConfig = provider.custom_configuration
|
||||
const systemConfig = provider.system_configuration
|
||||
const priorityUseType = provider.preferred_provider_type
|
||||
const isCustomConfigured = customConfig.status === CustomConfigurationStatusEnum.active
|
||||
const configurateMethods = provider.configurate_methods
|
||||
const {
|
||||
hasCredential,
|
||||
authorized,
|
||||
authRemoved,
|
||||
current_credential_name,
|
||||
notAllowedToUse,
|
||||
} = useCredentialStatus(provider)
|
||||
|
||||
const showPrioritySelector = systemConfig.enabled && isCustomConfigured && IS_CLOUD_EDITION
|
||||
|
||||
const handleChangePriority = async (key: PreferredProviderTypeEnum) => {
|
||||
const res = await changeModelProviderPriority({
|
||||
url: `/workspaces/current/model-providers/${provider.provider}/preferred-provider-type`,
|
||||
body: {
|
||||
preferred_provider_type: key,
|
||||
const state = useCredentialPanelState(provider)
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { mutate: changePriority, isPending: isChangingPriority } = useMutation(
|
||||
consoleQuery.modelProviders.changePreferredProviderType.mutationOptions({
|
||||
onSuccess: () => {
|
||||
Toast.notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
updateModelProviders()
|
||||
provider.configurate_methods.forEach((method) => {
|
||||
if (method === ConfigurationMethodEnum.predefinedModel)
|
||||
provider.supported_model_types.forEach(modelType => updateModelList(modelType))
|
||||
})
|
||||
},
|
||||
onError: () => {
|
||||
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
|
||||
},
|
||||
}),
|
||||
)
|
||||
|
||||
const handleChangePriority = (key: PreferredProviderTypeEnum) => {
|
||||
changePriority({
|
||||
params: { provider: provider.provider },
|
||||
body: { preferred_provider_type: key },
|
||||
})
|
||||
if (res.result === 'success') {
|
||||
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
updateModelProviders()
|
||||
|
||||
configurateMethods.forEach((method) => {
|
||||
if (method === ConfigurationMethodEnum.predefinedModel)
|
||||
provider.supported_model_types.forEach(modelType => updateModelList(modelType))
|
||||
})
|
||||
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
}
|
||||
}
|
||||
const credentialLabel = useMemo(() => {
|
||||
if (!hasCredential)
|
||||
return t('modelProvider.auth.unAuthorized', { ns: 'common' })
|
||||
if (authorized)
|
||||
return current_credential_name
|
||||
if (authRemoved)
|
||||
return t('modelProvider.auth.authRemoved', { ns: 'common' })
|
||||
|
||||
return ''
|
||||
}, [authorized, authRemoved, current_credential_name, hasCredential])
|
||||
const { variant, credentialName } = state
|
||||
const isDestructive = isDestructiveVariant(variant)
|
||||
const isTextLabel = TEXT_LABEL_VARIANTS.has(variant)
|
||||
const needsGap = !isTextLabel || variant === 'credits-fallback'
|
||||
|
||||
const color = useMemo(() => {
|
||||
if (authRemoved || !hasCredential)
|
||||
return 'red'
|
||||
if (notAllowedToUse)
|
||||
return 'gray'
|
||||
return 'green'
|
||||
}, [authRemoved, notAllowedToUse, hasCredential])
|
||||
return (
|
||||
<SystemQuotaCard variant={isDestructive ? 'destructive' : 'default'}>
|
||||
<SystemQuotaCard.Label className={needsGap ? 'gap-1' : undefined}>
|
||||
{isTextLabel
|
||||
? <TextLabel variant={variant} />
|
||||
: <StatusLabel variant={variant} credentialName={credentialName} />}
|
||||
</SystemQuotaCard.Label>
|
||||
<SystemQuotaCard.Actions>
|
||||
<ModelAuthDropdown
|
||||
provider={provider}
|
||||
state={state}
|
||||
isChangingPriority={isChangingPriority}
|
||||
onChangePriority={handleChangePriority}
|
||||
/>
|
||||
</SystemQuotaCard.Actions>
|
||||
</SystemQuotaCard>
|
||||
)
|
||||
}
|
||||
|
||||
const TEXT_LABEL_KEYS = {
|
||||
'credits-active': 'modelProvider.card.aiCreditsInUse',
|
||||
'credits-fallback': 'modelProvider.card.aiCreditsInUse',
|
||||
'credits-exhausted': 'modelProvider.card.quotaExhausted',
|
||||
'no-usage': 'modelProvider.card.noAvailableUsage',
|
||||
'api-required-add': 'modelProvider.card.apiKeyRequired',
|
||||
'api-required-configure': 'modelProvider.card.apiKeyRequired',
|
||||
} as const satisfies Partial<Record<CardVariant, string>>
|
||||
|
||||
function TextLabel({ variant }: { variant: CardVariant }) {
|
||||
const { t } = useTranslation()
|
||||
const isDestructive = isDestructiveVariant(variant)
|
||||
const labelKey = TEXT_LABEL_KEYS[variant as keyof typeof TEXT_LABEL_KEYS]
|
||||
|
||||
return (
|
||||
<>
|
||||
{
|
||||
provider.provider_credential_schema && (
|
||||
<div className={cn(
|
||||
'relative ml-1 w-[120px] shrink-0 rounded-lg border-[0.5px] border-components-panel-border bg-white/[0.18] p-1',
|
||||
authRemoved && 'border-state-destructive-border bg-state-destructive-hover',
|
||||
)}
|
||||
>
|
||||
<div className="system-xs-medium mb-1 flex h-5 items-center justify-between pl-2 pr-[7px] pt-1 text-text-tertiary">
|
||||
<div
|
||||
className={cn(
|
||||
'grow truncate',
|
||||
authRemoved && 'text-text-destructive',
|
||||
)}
|
||||
title={credentialLabel}
|
||||
>
|
||||
{credentialLabel}
|
||||
</div>
|
||||
<Indicator className="shrink-0" color={color} />
|
||||
</div>
|
||||
<div className="flex items-center gap-0.5">
|
||||
<ConfigProvider
|
||||
provider={provider}
|
||||
/>
|
||||
{
|
||||
showPrioritySelector && (
|
||||
<PrioritySelector
|
||||
value={priorityUseType}
|
||||
onSelect={handleChangePriority}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{
|
||||
priorityUseType === PreferredProviderTypeEnum.custom && systemConfig.enabled && (
|
||||
<PriorityUseTip />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
showPrioritySelector && !provider.provider_credential_schema && (
|
||||
<div className="ml-1">
|
||||
<PrioritySelector
|
||||
value={priorityUseType}
|
||||
onSelect={handleChangePriority}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<span className={isDestructive ? 'text-text-destructive' : 'text-text-secondary'}>
|
||||
{t(labelKey, { ns: 'common' })}
|
||||
</span>
|
||||
{variant === 'credits-fallback' && (
|
||||
<Warning className="h-3 w-3 shrink-0 text-text-warning" />
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
function StatusLabel({ variant, credentialName }: {
|
||||
variant: CardVariant
|
||||
credentialName: string | undefined
|
||||
}) {
|
||||
const { t } = useTranslation()
|
||||
const dotColor = variant === 'api-unavailable' ? 'red' : 'green'
|
||||
const showWarning = variant === 'api-fallback'
|
||||
|
||||
return (
|
||||
<>
|
||||
<Indicator className="shrink-0" color={dotColor} />
|
||||
<span
|
||||
className="truncate text-text-secondary"
|
||||
title={credentialName}
|
||||
>
|
||||
{credentialName}
|
||||
</span>
|
||||
{showWarning && (
|
||||
<Warning className="h-3 w-3 shrink-0 text-text-warning" />
|
||||
)}
|
||||
{variant === 'api-unavailable' && (
|
||||
<span className="shrink-0 text-text-destructive system-2xs-medium">
|
||||
{t('modelProvider.card.unavailable', { ns: 'common' })}
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import type { ReactNode } from 'react'
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { fetchModelProviderModelList } from '@/service/common'
|
||||
import { createStore, Provider as JotaiProvider } from 'jotai'
|
||||
import { useExpandModelProviderList } from '../atoms'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ProviderAddedCard from './index'
|
||||
|
||||
let mockIsCurrentWorkspaceManager = true
|
||||
const mockEventEmitter = {
|
||||
useSubscription: vi.fn(),
|
||||
emit: vi.fn(),
|
||||
}
|
||||
const mockFetchModelProviderModels = vi.fn()
|
||||
const mockQueryOptions = vi.fn(({ input, ...options }: { input: { params: { provider: string } }, enabled?: boolean }) => ({
|
||||
queryKey: ['console', 'modelProviders', 'models', input.params.provider],
|
||||
queryFn: () => mockFetchModelProviderModels(input.params.provider),
|
||||
...options,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
fetchModelProviderModelList: vi.fn(),
|
||||
vi.mock('@/service/client', () => ({
|
||||
consoleQuery: {
|
||||
modelProviders: {
|
||||
models: {
|
||||
queryOptions: (options: { input: { params: { provider: string } }, enabled?: boolean }) => mockQueryOptions(options),
|
||||
},
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@@ -20,12 +31,6 @@ vi.mock('@/context/app-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: () => ({
|
||||
eventEmitter: mockEventEmitter,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock internal components to simplify testing of the index file
|
||||
vi.mock('./credential-panel', () => ({
|
||||
default: () => <div data-testid="credential-panel" />,
|
||||
@@ -53,6 +58,38 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
ManageCustomModelCredentials: () => <div data-testid="manage-custom-model" />,
|
||||
}))
|
||||
|
||||
const createTestQueryClient = () => new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false, gcTime: 0 },
|
||||
},
|
||||
})
|
||||
|
||||
const renderWithQueryClient = (node: ReactNode) => {
|
||||
const queryClient = createTestQueryClient()
|
||||
const store = createStore()
|
||||
return render(
|
||||
<JotaiProvider store={store}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{node}
|
||||
</QueryClientProvider>
|
||||
</JotaiProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
const ExternalExpandControls = () => {
|
||||
const expandModelProviderList = useExpandModelProviderList()
|
||||
return (
|
||||
<>
|
||||
<button type="button" data-testid="expand-other-provider" onClick={() => expandModelProviderList('langgenius/anthropic/anthropic')}>
|
||||
expand other
|
||||
</button>
|
||||
<button type="button" data-testid="expand-current-provider" onClick={() => expandModelProviderList('langgenius/openai/openai')}>
|
||||
expand current
|
||||
</button>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
describe('ProviderAddedCard', () => {
|
||||
const mockProvider = {
|
||||
provider: 'langgenius/openai/openai',
|
||||
@@ -67,19 +104,21 @@ describe('ProviderAddedCard', () => {
|
||||
})
|
||||
|
||||
it('should render provider added card component', () => {
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
expect(screen.getByTestId('provider-added-card')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('provider-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open, refresh and collapse model list', async () => {
|
||||
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [{ model: 'gpt-4' }] } as unknown as { data: ModelItem[] })
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
const showModelsBtn = screen.getByTestId('show-models-button')
|
||||
fireEvent.click(showModelsBtn)
|
||||
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledWith(`/workspaces/current/model-providers/${mockProvider.provider}/models`)
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
|
||||
})
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
|
||||
// Test line 71-72: Opening when already fetched
|
||||
@@ -90,13 +129,13 @@ describe('ProviderAddedCard', () => {
|
||||
// Explicitly re-find and click to re-open
|
||||
fireEvent.click(screen.getByTestId('show-models-button'))
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1) // Should not fetch again
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(2) // Re-open fetches again with default stale/gc behavior
|
||||
|
||||
// Refresh list from ModelList
|
||||
const refreshBtn = screen.getByRole('button', { name: 'refresh list' })
|
||||
fireEvent.click(refreshBtn)
|
||||
await waitFor(() => {
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(2)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -105,18 +144,20 @@ describe('ProviderAddedCard', () => {
|
||||
const promise = new Promise((resolve) => {
|
||||
resolveOuter = resolve
|
||||
})
|
||||
vi.mocked(fetchModelProviderModelList).mockReturnValue(promise as unknown as ReturnType<typeof fetchModelProviderModelList>)
|
||||
mockFetchModelProviderModels.mockReturnValue(promise)
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={mockProvider} />)
|
||||
const showModelsBtn = screen.getByTestId('show-models-button')
|
||||
|
||||
// First call sets loading to true
|
||||
fireEvent.click(showModelsBtn)
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Second call should return early because loading is true
|
||||
fireEvent.click(showModelsBtn)
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
|
||||
await act(async () => {
|
||||
resolveOuter({ data: [] })
|
||||
@@ -125,56 +166,49 @@ describe('ProviderAddedCard', () => {
|
||||
expect(await screen.findByTestId('model-list')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should only react to external expansion for the matching provider', async () => {
|
||||
mockFetchModelProviderModels.mockResolvedValue({ data: [{ model: 'gpt-4' }] })
|
||||
renderWithQueryClient(
|
||||
<>
|
||||
<ProviderAddedCard provider={mockProvider} />
|
||||
<ExternalExpandControls />
|
||||
</>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('expand-other-provider'))
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(0)
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('expand-current-provider'))
|
||||
await waitFor(() => {
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledWith(mockProvider.provider)
|
||||
})
|
||||
expect(mockFetchModelProviderModels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render configure tip when provider is not in quota list and not configured', () => {
|
||||
const providerWithoutQuota = {
|
||||
...mockProvider,
|
||||
provider: 'custom/provider',
|
||||
} as unknown as ModelProvider
|
||||
render(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={providerWithoutQuota} notConfigured />)
|
||||
expect(screen.getByText('common.modelProvider.configureTip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should refresh model list on event subscription', async () => {
|
||||
let capturedHandler: (v: { type: string, payload: string } | null) => void = () => { }
|
||||
mockEventEmitter.useSubscription.mockImplementation((handler: (v: unknown) => void) => {
|
||||
capturedHandler = handler as (v: { type: string, payload: string } | null) => void
|
||||
})
|
||||
vi.mocked(fetchModelProviderModelList).mockResolvedValue({ data: [] } as unknown as { data: ModelItem[] })
|
||||
|
||||
render(<ProviderAddedCard provider={mockProvider} />)
|
||||
|
||||
expect(capturedHandler).toBeDefined()
|
||||
act(() => {
|
||||
capturedHandler({
|
||||
type: 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST',
|
||||
payload: mockProvider.provider,
|
||||
})
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Should ignore non-matching events
|
||||
act(() => {
|
||||
capturedHandler({ type: 'OTHER', payload: '' })
|
||||
capturedHandler(null)
|
||||
})
|
||||
expect(fetchModelProviderModelList).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render custom model actions for workspace managers', () => {
|
||||
const customConfigProvider = {
|
||||
...mockProvider,
|
||||
configurate_methods: [ConfigurationMethodEnum.customizableModel],
|
||||
} as unknown as ModelProvider
|
||||
const { rerender } = render(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
const { unmount } = renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
|
||||
expect(screen.getByTestId('manage-custom-model')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('add-custom-model')).toBeInTheDocument()
|
||||
|
||||
unmount()
|
||||
mockIsCurrentWorkspaceManager = false
|
||||
rerender(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
renderWithQueryClient(<ProviderAddedCard provider={customConfigProvider} />)
|
||||
expect(screen.queryByTestId('manage-custom-model')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
ModelItem,
|
||||
ModelProvider,
|
||||
} from '../declarations'
|
||||
import type { ModelProviderQuotaGetPaid } from '../utils'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
AddCustomModel,
|
||||
@@ -13,9 +14,10 @@ import {
|
||||
} from '@/app/components/header/account-setting/model-provider-page/model-auth'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { fetchModelProviderModelList } from '@/service/common'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useModelProviderListExpanded, useSetModelProviderListExpanded } from '../atoms'
|
||||
import { ConfigurationMethodEnum } from '../declarations'
|
||||
import ModelBadge from '../model-badge'
|
||||
import ProviderIcon from '../provider-icon'
|
||||
@@ -25,121 +27,123 @@ import {
|
||||
} from '../utils'
|
||||
import CredentialPanel from './credential-panel'
|
||||
import ModelList from './model-list'
|
||||
import ProviderCardActions from './provider-card-actions'
|
||||
|
||||
export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
|
||||
type ProviderAddedCardProps = {
|
||||
notConfigured?: boolean
|
||||
provider: ModelProvider
|
||||
pluginDetail?: PluginDetail
|
||||
}
|
||||
const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
notConfigured,
|
||||
provider,
|
||||
pluginDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const [fetched, setFetched] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [collapsed, setCollapsed] = useState(true)
|
||||
const [modelList, setModelList] = useState<ModelItem[]>([])
|
||||
const configurationMethods = provider.configurate_methods.filter(method => method !== ConfigurationMethodEnum.fetchFromRemote)
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const currentProviderName = provider.provider
|
||||
const expanded = useModelProviderListExpanded(currentProviderName)
|
||||
const setExpanded = useSetModelProviderListExpanded(currentProviderName)
|
||||
const supportsPredefinedModel = provider.configurate_methods.includes(ConfigurationMethodEnum.predefinedModel)
|
||||
const supportsCustomizableModel = provider.configurate_methods.includes(ConfigurationMethodEnum.customizableModel)
|
||||
const systemConfig = provider.system_configuration
|
||||
const hasModelList = fetched && !!modelList.length
|
||||
const {
|
||||
data: modelList = [],
|
||||
isFetching: loading,
|
||||
isSuccess: hasFetchedModelList,
|
||||
refetch: refetchModelList,
|
||||
} = useQuery(consoleQuery.modelProviders.models.queryOptions({
|
||||
input: { params: { provider: currentProviderName } },
|
||||
enabled: expanded,
|
||||
refetchOnWindowFocus: false,
|
||||
select: response => response.data,
|
||||
}))
|
||||
const hasModelList = hasFetchedModelList && !!modelList.length
|
||||
const showCollapsedSection = !expanded || !hasFetchedModelList
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(provider.provider as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
|
||||
const showCredential = configurationMethods.includes(ConfigurationMethodEnum.predefinedModel) && isCurrentWorkspaceManager
|
||||
const showModelProvider = systemConfig.enabled && MODEL_PROVIDER_QUOTA_GET_PAID.includes(currentProviderName as ModelProviderQuotaGetPaid) && !IS_CE_EDITION
|
||||
const showCredential = supportsPredefinedModel && isCurrentWorkspaceManager
|
||||
const showCustomModelActions = supportsCustomizableModel && isCurrentWorkspaceManager
|
||||
|
||||
const getModelList = async (providerName: string) => {
|
||||
const refreshModelList = useCallback((targetProviderName: string) => {
|
||||
if (targetProviderName !== currentProviderName)
|
||||
return
|
||||
|
||||
if (!expanded)
|
||||
setExpanded(true)
|
||||
|
||||
refetchModelList().catch(() => {})
|
||||
}, [currentProviderName, expanded, refetchModelList, setExpanded])
|
||||
|
||||
const handleOpenModelList = useCallback(() => {
|
||||
if (loading)
|
||||
return
|
||||
try {
|
||||
setLoading(true)
|
||||
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
|
||||
setModelList(modelsData.data)
|
||||
setCollapsed(false)
|
||||
setFetched(true)
|
||||
}
|
||||
finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}
|
||||
const handleOpenModelList = () => {
|
||||
if (fetched) {
|
||||
setCollapsed(false)
|
||||
|
||||
if (!expanded) {
|
||||
setExpanded(true)
|
||||
return
|
||||
}
|
||||
|
||||
getModelList(provider.provider)
|
||||
}
|
||||
|
||||
eventEmitter?.useSubscription((v: any) => {
|
||||
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST && v.payload === provider.provider)
|
||||
getModelList(v.payload)
|
||||
})
|
||||
refetchModelList().catch(() => {})
|
||||
}, [expanded, loading, refetchModelList, setExpanded])
|
||||
|
||||
return (
|
||||
<div
|
||||
data-testid="provider-added-card"
|
||||
className={cn(
|
||||
'mb-2 rounded-xl border-[0.5px] border-divider-regular bg-third-party-model-bg-default shadow-xs',
|
||||
provider.provider === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
|
||||
provider.provider === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
|
||||
currentProviderName === 'langgenius/openai/openai' && 'bg-third-party-model-bg-openai',
|
||||
currentProviderName === 'langgenius/anthropic/anthropic' && 'bg-third-party-model-bg-anthropic',
|
||||
)}
|
||||
>
|
||||
<div className="flex rounded-t-xl py-2 pl-3 pr-2">
|
||||
<div className="grow px-1 pb-0.5 pt-1">
|
||||
<ProviderIcon
|
||||
className="mb-2"
|
||||
provider={provider}
|
||||
/>
|
||||
<div className="mb-2 flex items-center gap-1">
|
||||
<ProviderIcon provider={provider} />
|
||||
{pluginDetail && (
|
||||
<ProviderCardActions
|
||||
detail={pluginDetail}
|
||||
onUpdate={refreshModelProviders}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-0.5">
|
||||
{
|
||||
provider.supported_model_types.map(modelType => (
|
||||
<ModelBadge key={modelType}>
|
||||
{modelTypeFormat(modelType)}
|
||||
</ModelBadge>
|
||||
))
|
||||
}
|
||||
{provider.supported_model_types.map(modelType => (
|
||||
<ModelBadge key={modelType}>
|
||||
{modelTypeFormat(modelType)}
|
||||
</ModelBadge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
showCredential && (
|
||||
<CredentialPanel
|
||||
provider={provider}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{showCredential && (
|
||||
<CredentialPanel
|
||||
provider={provider}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{
|
||||
collapsed && (
|
||||
showCollapsedSection && (
|
||||
<div className="group flex items-center justify-between border-t border-t-divider-subtle py-1.5 pl-2 pr-[11px] text-text-tertiary system-xs-medium">
|
||||
{(showModelProvider || !notConfigured) && (
|
||||
<>
|
||||
<div className="flex h-6 items-center pl-1 pr-1.5 leading-6 group-hover:hidden">
|
||||
{
|
||||
hasModelList
|
||||
? t('modelProvider.modelsNum', { ns: 'common', num: modelList.length })
|
||||
: t('modelProvider.showModels', { ns: 'common' })
|
||||
}
|
||||
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
|
||||
</div>
|
||||
<div
|
||||
data-testid="show-models-button"
|
||||
className="hidden h-6 cursor-pointer items-center rounded-lg pl-1 pr-1.5 hover:bg-components-button-ghost-bg-hover group-hover:flex"
|
||||
onClick={handleOpenModelList}
|
||||
>
|
||||
{
|
||||
hasModelList
|
||||
? t('modelProvider.showModelsNum', { ns: 'common', num: modelList.length })
|
||||
: t('modelProvider.showModels', { ns: 'common' })
|
||||
}
|
||||
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
|
||||
{
|
||||
loading && (
|
||||
<div className="i-ri-loader-2-line ml-0.5 h-3 w-3 animate-spin" />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
<button
|
||||
type="button"
|
||||
data-testid="show-models-button"
|
||||
className="flex h-6 items-center rounded-lg pl-1 pr-1.5 hover:bg-components-button-ghost-bg-hover"
|
||||
aria-label={t('modelProvider.showModels', { ns: 'common' })}
|
||||
onClick={handleOpenModelList}
|
||||
>
|
||||
{
|
||||
hasModelList
|
||||
? t('modelProvider.modelsNum', { ns: 'common', num: modelList.length })
|
||||
: t('modelProvider.showModels', { ns: 'common' })
|
||||
}
|
||||
{!loading && <div className="i-ri-arrow-right-s-line h-4 w-4" />}
|
||||
{
|
||||
loading && (
|
||||
<div className="i-ri-loader-2-line ml-0.5 h-3 w-3 animate-spin" />
|
||||
)
|
||||
}
|
||||
</button>
|
||||
)}
|
||||
{!showModelProvider && notConfigured && (
|
||||
<div className="flex h-6 items-center pl-1 pr-1.5">
|
||||
@@ -148,7 +152,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
configurationMethods.includes(ConfigurationMethodEnum.customizableModel) && isCurrentWorkspaceManager && (
|
||||
showCustomModelActions && (
|
||||
<div className="flex grow justify-end">
|
||||
<ManageCustomModelCredentials
|
||||
provider={provider}
|
||||
@@ -166,12 +170,12 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
!collapsed && (
|
||||
!showCollapsedSection && (
|
||||
<ModelList
|
||||
provider={provider}
|
||||
models={modelList}
|
||||
onCollapse={() => setCollapsed(true)}
|
||||
onChange={(provider: string) => getModelList(provider)}
|
||||
onCollapse={() => setExpanded(false)}
|
||||
onChange={refreshModelList}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
import type { Credential, ModelProvider } from '../../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
|
||||
import ApiKeySection from './api-key-section'
|
||||
|
||||
const createCredential = (overrides: Partial<Credential> = {}): Credential => ({
|
||||
credential_id: 'cred-1',
|
||||
credential_name: 'Test API Key',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
|
||||
provider: 'test-provider',
|
||||
allow_custom_token: true,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
available_credentials: [],
|
||||
},
|
||||
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
...overrides,
|
||||
} as unknown as ModelProvider)
|
||||
|
||||
describe('ApiKeySection', () => {
|
||||
const handlers = {
|
||||
onItemClick: vi.fn(),
|
||||
onEdit: vi.fn(),
|
||||
onDelete: vi.fn(),
|
||||
onAdd: vi.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Empty state
|
||||
describe('Empty state (no credentials)', () => {
|
||||
it('should show empty state message', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider()}
|
||||
credentials={[]}
|
||||
selectedCredentialId={undefined}
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/noApiKeysTitle/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/noApiKeysDescription/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show Add API Key button', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider()}
|
||||
credentials={[]}
|
||||
selectedCredentialId={undefined}
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onAdd when Add API Key is clicked', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider()}
|
||||
credentials={[]}
|
||||
selectedCredentialId={undefined}
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /addApiKey/ }))
|
||||
|
||||
expect(handlers.onAdd).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should hide Add API Key button when allow_custom_token is false', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider({ allow_custom_token: false })}
|
||||
credentials={[]}
|
||||
selectedCredentialId={undefined}
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByRole('button', { name: /addApiKey/ })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// With credentials
|
||||
describe('With credentials', () => {
|
||||
const credentials = [
|
||||
createCredential({ credential_id: 'cred-1', credential_name: 'Key Alpha' }),
|
||||
createCredential({ credential_id: 'cred-2', credential_name: 'Key Beta' }),
|
||||
]
|
||||
|
||||
it('should render credential list with header', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider()}
|
||||
credentials={credentials}
|
||||
selectedCredentialId="cred-1"
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(/apiKeys/)).toBeInTheDocument()
|
||||
expect(screen.getByText('Key Alpha')).toBeInTheDocument()
|
||||
expect(screen.getByText('Key Beta')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show Add API Key button in footer', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider()}
|
||||
credentials={credentials}
|
||||
selectedCredentialId="cred-1"
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide Add API Key when allow_custom_token is false', () => {
|
||||
render(
|
||||
<ApiKeySection
|
||||
provider={createProvider({ allow_custom_token: false })}
|
||||
credentials={credentials}
|
||||
selectedCredentialId="cred-1"
|
||||
{...handlers}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByRole('button', { name: /addApiKey/ })).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,91 @@
|
||||
import type { Credential, CustomModel, ModelProvider } from '../../declarations'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import CredentialItem from '../../model-auth/authorized/credential-item'
|
||||
|
||||
type ApiKeySectionProps = {
|
||||
provider: ModelProvider
|
||||
credentials: Credential[]
|
||||
selectedCredentialId: string | undefined
|
||||
isActivating?: boolean
|
||||
onItemClick: (credential: Credential, model?: CustomModel) => void
|
||||
onEdit: (credential?: Credential) => void
|
||||
onDelete: (credential?: Credential) => void
|
||||
onAdd: () => void
|
||||
}
|
||||
|
||||
function ApiKeySection({
|
||||
provider,
|
||||
credentials,
|
||||
selectedCredentialId,
|
||||
isActivating,
|
||||
onItemClick,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onAdd,
|
||||
}: ApiKeySectionProps) {
|
||||
const { t } = useTranslation()
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
if (!credentials.length) {
|
||||
return (
|
||||
<div className="flex flex-col gap-2 p-2">
|
||||
<div className="rounded-[10px] bg-gradient-to-r from-state-base-hover to-transparent p-4">
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="text-text-secondary system-sm-medium">
|
||||
{t('modelProvider.card.noApiKeysTitle', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="text-text-tertiary system-xs-regular">
|
||||
{t('modelProvider.card.noApiKeysDescription', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!notAllowCustomCredential && (
|
||||
<Button
|
||||
onClick={onAdd}
|
||||
className="w-full"
|
||||
>
|
||||
{t('modelProvider.auth.addApiKey', { ns: 'common' })}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="border-t border-t-divider-subtle">
|
||||
<div className="px-1">
|
||||
<div className="pb-1 pl-7 pr-2 pt-3 text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('modelProvider.auth.apiKeys', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="max-h-[200px] overflow-y-auto">
|
||||
{credentials.map(credential => (
|
||||
<CredentialItem
|
||||
key={credential.credential_id}
|
||||
credential={credential}
|
||||
disabled={isActivating}
|
||||
showSelectedIcon
|
||||
selectedCredentialId={selectedCredentialId}
|
||||
onItemClick={onItemClick}
|
||||
onEdit={onEdit}
|
||||
onDelete={onDelete}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
{!notAllowCustomCredential && (
|
||||
<div className="p-2">
|
||||
<Button
|
||||
onClick={onAdd}
|
||||
className="w-full"
|
||||
>
|
||||
{t('modelProvider.auth.addApiKey', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ApiKeySection)
|
||||
@@ -0,0 +1,63 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import CreditsExhaustedAlert from './credits-exhausted-alert'
|
||||
|
||||
const mockTrialCredits = { credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false, nextCreditResetDate: undefined }
|
||||
|
||||
vi.mock('../use-trial-credits', () => ({
|
||||
useTrialCredits: () => mockTrialCredits,
|
||||
}))
|
||||
|
||||
describe('CreditsExhaustedAlert', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
Object.assign(mockTrialCredits, { credits: 0 })
|
||||
})
|
||||
|
||||
// Without API key fallback
|
||||
describe('Without API key fallback', () => {
|
||||
it('should show exhausted message', () => {
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
|
||||
|
||||
expect(screen.getByText(/creditsExhaustedMessage/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show description with upgrade link', () => {
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
|
||||
|
||||
expect(screen.getByText(/creditsExhaustedDescription/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// With API key fallback
|
||||
describe('With API key fallback', () => {
|
||||
it('should show fallback message', () => {
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback />)
|
||||
|
||||
expect(screen.getByText(/creditsExhaustedFallback(?!Description)/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show fallback description', () => {
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback />)
|
||||
|
||||
expect(screen.getByText(/creditsExhaustedFallbackDescription/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Usage display
|
||||
describe('Usage display', () => {
|
||||
it('should show usage label', () => {
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
|
||||
|
||||
expect(screen.getByText(/usageLabel/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show usage amounts', () => {
|
||||
mockTrialCredits.credits = 200
|
||||
|
||||
render(<CreditsExhaustedAlert hasApiKeyFallback={false} />)
|
||||
|
||||
expect(screen.getByText(/9,800/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/10,000/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,65 @@
|
||||
import { Trans, useTranslation } from 'react-i18next'
|
||||
import { CreditsCoin } from '@/app/components/base/icons/src/vender/line/financeAndECommerce'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { formatNumber } from '@/utils/format'
|
||||
import { useTrialCredits } from '../use-trial-credits'
|
||||
|
||||
type CreditsExhaustedAlertProps = {
|
||||
hasApiKeyFallback: boolean
|
||||
}
|
||||
|
||||
export default function CreditsExhaustedAlert({ hasApiKeyFallback }: CreditsExhaustedAlertProps) {
|
||||
const { t } = useTranslation()
|
||||
const setShowPricingModal = useModalContextSelector(s => s.setShowPricingModal)
|
||||
const { credits, totalCredits } = useTrialCredits()
|
||||
|
||||
const titleKey = hasApiKeyFallback
|
||||
? 'modelProvider.card.creditsExhaustedFallback'
|
||||
: 'modelProvider.card.creditsExhaustedMessage'
|
||||
const descriptionKey = hasApiKeyFallback
|
||||
? 'modelProvider.card.creditsExhaustedFallbackDescription'
|
||||
: 'modelProvider.card.creditsExhaustedDescription'
|
||||
|
||||
const usedCredits = totalCredits - credits
|
||||
const usagePercent = totalCredits > 0 ? Math.min((usedCredits / totalCredits) * 100, 100) : 100
|
||||
|
||||
return (
|
||||
<div className="mx-2 mb-1 mt-0.5 rounded-lg bg-background-section-burn p-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="text-text-primary system-sm-medium">
|
||||
{t(titleKey, { ns: 'common' })}
|
||||
</div>
|
||||
<div className="text-text-tertiary system-xs-regular">
|
||||
<Trans
|
||||
i18nKey={descriptionKey}
|
||||
ns="common"
|
||||
components={{
|
||||
upgradeLink: <span className="cursor-pointer text-text-accent system-xs-medium" onClick={setShowPricingModal} />,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3 flex flex-col gap-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-text-tertiary system-xs-medium">
|
||||
{t('modelProvider.card.usageLabel', { ns: 'common' })}
|
||||
</span>
|
||||
<div className="flex items-center gap-0.5 text-text-tertiary system-xs-regular">
|
||||
<CreditsCoin className="h-3 w-3" />
|
||||
<span>
|
||||
{formatNumber(usedCredits)}
|
||||
/
|
||||
{formatNumber(totalCredits)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="h-1 overflow-hidden rounded-[6px] bg-components-progress-error-bg">
|
||||
<div
|
||||
className="h-full rounded-l-[6px] bg-components-progress-error-progress"
|
||||
style={{ width: `${usagePercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type CreditsFallbackAlertProps = {
|
||||
hasCredentials: boolean
|
||||
}
|
||||
|
||||
export default function CreditsFallbackAlert({ hasCredentials }: CreditsFallbackAlertProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const titleKey = hasCredentials
|
||||
? 'modelProvider.card.apiKeyUnavailableFallback'
|
||||
: 'modelProvider.card.noApiKeysFallback'
|
||||
|
||||
return (
|
||||
<div className="mx-2 mb-1 mt-0.5 rounded-lg bg-background-section-burn p-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="text-text-primary system-sm-medium">
|
||||
{t(titleKey, { ns: 'common' })}
|
||||
</div>
|
||||
{hasCredentials && (
|
||||
<div className="text-text-tertiary system-xs-regular">
|
||||
{t('modelProvider.card.apiKeyUnavailableFallbackDescription', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,435 @@
|
||||
import type { ModelProvider } from '../../declarations'
|
||||
import type { CredentialPanelState } from '../use-credential-panel-state'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
|
||||
import DropdownContent from './dropdown-content'
|
||||
|
||||
const mockHandleOpenModal = vi.fn()
|
||||
const mockActivate = vi.fn()
|
||||
const mockOpenConfirmDelete = vi.fn()
|
||||
const mockCloseConfirmDelete = vi.fn()
|
||||
const mockHandleConfirmDelete = vi.fn()
|
||||
let mockDeleteCredentialId: string | null = null
|
||||
|
||||
vi.mock('../use-trial-credits', () => ({
|
||||
useTrialCredits: () => ({ credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false }),
|
||||
}))
|
||||
|
||||
vi.mock('./use-activate-credential', () => ({
|
||||
useActivateCredential: () => ({
|
||||
selectedCredentialId: 'cred-1',
|
||||
isActivating: false,
|
||||
activate: mockActivate,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../model-auth/hooks', () => ({
|
||||
useAuth: () => ({
|
||||
openConfirmDelete: mockOpenConfirmDelete,
|
||||
closeConfirmDelete: mockCloseConfirmDelete,
|
||||
doingAction: false,
|
||||
handleConfirmDelete: mockHandleConfirmDelete,
|
||||
deleteCredentialId: mockDeleteCredentialId,
|
||||
handleOpenModal: mockHandleOpenModal,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../model-auth/authorized/credential-item', () => ({
|
||||
default: ({ credential, onItemClick, onEdit, onDelete }: {
|
||||
credential: { credential_id: string, credential_name: string }
|
||||
onItemClick?: (c: unknown) => void
|
||||
onEdit?: (c: unknown) => void
|
||||
onDelete?: (c: unknown) => void
|
||||
}) => (
|
||||
<div data-testid={`credential-${credential.credential_id}`}>
|
||||
<span>{credential.credential_name}</span>
|
||||
<button data-testid={`click-${credential.credential_id}`} onClick={() => onItemClick?.(credential)}>select</button>
|
||||
<button data-testid={`edit-${credential.credential_id}`} onClick={() => onEdit?.(credential)}>edit</button>
|
||||
<button data-testid={`delete-${credential.credential_id}`} onClick={() => onDelete?.(credential)}>delete</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
|
||||
provider: 'test',
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'My Key',
|
||||
available_credentials: [
|
||||
{ credential_id: 'cred-1', credential_name: 'My Key' },
|
||||
{ credential_id: 'cred-2', credential_name: 'Other Key' },
|
||||
],
|
||||
},
|
||||
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
configurate_methods: ['predefined-model'],
|
||||
supported_model_types: ['llm'],
|
||||
...overrides,
|
||||
} as unknown as ModelProvider)
|
||||
|
||||
const createState = (overrides: Partial<CredentialPanelState> = {}): CredentialPanelState => ({
|
||||
variant: 'api-active',
|
||||
priority: 'apiKey',
|
||||
supportsCredits: true,
|
||||
showPrioritySwitcher: true,
|
||||
hasCredentials: true,
|
||||
isCreditsExhausted: false,
|
||||
credentialName: 'My Key',
|
||||
credits: 100,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('DropdownContent', () => {
|
||||
const onChangePriority = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockDeleteCredentialId = null
|
||||
})
|
||||
|
||||
describe('UsagePrioritySection visibility', () => {
|
||||
it('should show when showPrioritySwitcher is true', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({ showPrioritySwitcher: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText(/usagePriority/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide when showPrioritySwitcher is false', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({ showPrioritySwitcher: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/usagePriority/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('CreditsExhaustedAlert', () => {
|
||||
it('should show when credits exhausted and supports credits', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({ isCreditsExhausted: true, supportsCredits: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getAllByText(/creditsExhausted/).length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should hide when credits not exhausted', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({ isCreditsExhausted: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/creditsExhausted/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide when credits exhausted but supportsCredits is false', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({ isCreditsExhausted: true, supportsCredits: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/creditsExhausted/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show fallback message when api-fallback variant with exhausted credits', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'api-fallback',
|
||||
isCreditsExhausted: true,
|
||||
supportsCredits: true,
|
||||
priority: 'credits',
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getAllByText(/creditsExhaustedFallback/).length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should show non-fallback message when credits-exhausted variant', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'credits-exhausted',
|
||||
isCreditsExhausted: true,
|
||||
supportsCredits: true,
|
||||
hasCredentials: false,
|
||||
priority: 'credits',
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText(/creditsExhaustedMessage/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('CreditsFallbackAlert', () => {
|
||||
it('should show when priority is apiKey, supports credits, not exhausted, and variant is not api-active', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'api-required-add',
|
||||
priority: 'apiKey',
|
||||
supportsCredits: true,
|
||||
isCreditsExhausted: false,
|
||||
hasCredentials: false,
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText(/noApiKeysFallback/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show unavailable message when priority is apiKey with credentials but not api-active', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'api-unavailable',
|
||||
priority: 'apiKey',
|
||||
supportsCredits: true,
|
||||
isCreditsExhausted: false,
|
||||
hasCredentials: true,
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getAllByText(/apiKeyUnavailableFallback/).length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should NOT show when variant is api-active', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'api-active',
|
||||
priority: 'apiKey',
|
||||
supportsCredits: true,
|
||||
isCreditsExhausted: false,
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/noApiKeysFallback/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/apiKeyUnavailableFallback/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should NOT show when priority is credits', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState({
|
||||
variant: 'credits-active',
|
||||
priority: 'credits',
|
||||
supportsCredits: true,
|
||||
isCreditsExhausted: false,
|
||||
})}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/noApiKeysFallback/)).not.toBeInTheDocument()
|
||||
expect(screen.queryByText(/apiKeyUnavailableFallback/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('API key section', () => {
|
||||
it('should render all credential items', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText('My Key')).toBeInTheDocument()
|
||||
expect(screen.getByText('Other Key')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show empty state when no credentials', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
})}
|
||||
state={createState({ hasCredentials: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText(/noApiKeysTitle/)).toBeInTheDocument()
|
||||
expect(screen.getByText(/noApiKeysDescription/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call activate without closing on credential item click', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('click-cred-2'))
|
||||
|
||||
expect(mockActivate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ credential_id: 'cred-2' }),
|
||||
)
|
||||
expect(onClose).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call handleOpenModal and close on edit credential', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('edit-cred-2'))
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ credential_id: 'cred-2' }),
|
||||
)
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call openConfirmDelete on delete credential', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('delete-cred-2'))
|
||||
|
||||
expect(mockOpenConfirmDelete).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ credential_id: 'cred-2' }),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Add API Key', () => {
|
||||
it('should call handleOpenModal with no args and close on add', () => {
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
})}
|
||||
state={createState({ hasCredentials: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /addApiKey/ }))
|
||||
|
||||
expect(mockHandleOpenModal).toHaveBeenCalledWith()
|
||||
expect(onClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('AlertDialog for delete confirmation', () => {
|
||||
it('should show confirm dialog when deleteCredentialId is set', () => {
|
||||
mockDeleteCredentialId = 'cred-1'
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByText(/confirmDelete/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show confirm dialog when deleteCredentialId is null', () => {
|
||||
mockDeleteCredentialId = null
|
||||
render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(screen.queryByText(/confirmDelete/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Layout', () => {
|
||||
it('should have 320px width container', () => {
|
||||
const { container } = render(
|
||||
<DropdownContent
|
||||
provider={createProvider()}
|
||||
state={createState()}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
expect(container.querySelector('.w-\\[320px\\]')).toBeTruthy()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,131 @@
|
||||
import type { Credential, ModelProvider, PreferredProviderTypeEnum } from '../../declarations'
|
||||
import type { CredentialPanelState } from '../use-credential-panel-state'
|
||||
import { memo, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import { ConfigurationMethodEnum } from '../../declarations'
|
||||
import { useAuth } from '../../model-auth/hooks'
|
||||
import ApiKeySection from './api-key-section'
|
||||
import CreditsExhaustedAlert from './credits-exhausted-alert'
|
||||
import CreditsFallbackAlert from './credits-fallback-alert'
|
||||
import UsagePrioritySection from './usage-priority-section'
|
||||
import { useActivateCredential } from './use-activate-credential'
|
||||
|
||||
const EMPTY_CREDENTIALS: Credential[] = []
|
||||
|
||||
type DropdownContentProps = {
|
||||
provider: ModelProvider
|
||||
state: CredentialPanelState
|
||||
isChangingPriority: boolean
|
||||
onChangePriority: (key: PreferredProviderTypeEnum) => void
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
function DropdownContent({
|
||||
provider,
|
||||
state,
|
||||
isChangingPriority,
|
||||
onChangePriority,
|
||||
onClose,
|
||||
}: DropdownContentProps) {
|
||||
const { t } = useTranslation()
|
||||
const { available_credentials } = provider.custom_configuration
|
||||
|
||||
const {
|
||||
openConfirmDelete,
|
||||
closeConfirmDelete,
|
||||
doingAction,
|
||||
handleConfirmDelete,
|
||||
deleteCredentialId,
|
||||
handleOpenModal,
|
||||
} = useAuth(provider, ConfigurationMethodEnum.predefinedModel)
|
||||
|
||||
const { selectedCredentialId, isActivating, activate } = useActivateCredential(provider)
|
||||
|
||||
const handleEdit = useCallback((credential?: Credential) => {
|
||||
handleOpenModal(credential)
|
||||
onClose()
|
||||
}, [handleOpenModal, onClose])
|
||||
|
||||
const handleDelete = useCallback((credential?: Credential) => {
|
||||
if (credential)
|
||||
openConfirmDelete(credential)
|
||||
}, [openConfirmDelete])
|
||||
|
||||
const handleAdd = useCallback(() => {
|
||||
handleOpenModal()
|
||||
onClose()
|
||||
}, [handleOpenModal, onClose])
|
||||
|
||||
const showCreditsExhaustedAlert = state.isCreditsExhausted && state.supportsCredits
|
||||
const hasApiKeyFallback = state.variant === 'api-fallback'
|
||||
|| (state.variant === 'api-active' && state.priority === 'apiKey')
|
||||
const showCreditsFallbackAlert = state.priority === 'apiKey'
|
||||
&& state.supportsCredits
|
||||
&& !state.isCreditsExhausted
|
||||
&& state.variant !== 'api-active'
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="w-[320px]">
|
||||
{state.showPrioritySwitcher && (
|
||||
<UsagePrioritySection
|
||||
value={state.priority}
|
||||
disabled={isChangingPriority}
|
||||
onSelect={onChangePriority}
|
||||
/>
|
||||
)}
|
||||
{showCreditsFallbackAlert && (
|
||||
<CreditsFallbackAlert hasCredentials={state.hasCredentials} />
|
||||
)}
|
||||
{showCreditsExhaustedAlert && (
|
||||
<CreditsExhaustedAlert hasApiKeyFallback={hasApiKeyFallback} />
|
||||
)}
|
||||
<ApiKeySection
|
||||
provider={provider}
|
||||
credentials={available_credentials ?? EMPTY_CREDENTIALS}
|
||||
selectedCredentialId={selectedCredentialId}
|
||||
isActivating={isActivating}
|
||||
onItemClick={activate}
|
||||
onEdit={handleEdit}
|
||||
onDelete={handleDelete}
|
||||
onAdd={handleAdd}
|
||||
/>
|
||||
</div>
|
||||
<AlertDialog
|
||||
open={!!deleteCredentialId}
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
closeConfirmDelete()
|
||||
}}
|
||||
>
|
||||
<AlertDialogContent>
|
||||
<div className="p-6 pb-0">
|
||||
<AlertDialogTitle className="text-text-primary system-xl-semibold">
|
||||
{t('modelProvider.confirmDelete', { ns: 'common' })}
|
||||
</AlertDialogTitle>
|
||||
<AlertDialogDescription className="mt-1 text-text-secondary system-sm-regular" />
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton disabled={doingAction}>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton disabled={doingAction} onClick={handleConfirmDelete}>
|
||||
{t('operation.delete', { ns: 'common' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(DropdownContent)
|
||||
@@ -0,0 +1,211 @@
|
||||
import type { ModelProvider } from '../../declarations'
|
||||
import type { CredentialPanelState } from '../use-credential-panel-state'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { CustomConfigurationStatusEnum, PreferredProviderTypeEnum } from '../../declarations'
|
||||
import ModelAuthDropdown from './index'
|
||||
|
||||
vi.mock('../../model-auth/hooks', () => ({
|
||||
useAuth: () => ({
|
||||
openConfirmDelete: vi.fn(),
|
||||
closeConfirmDelete: vi.fn(),
|
||||
doingAction: false,
|
||||
handleConfirmDelete: vi.fn(),
|
||||
deleteCredentialId: null,
|
||||
handleOpenModal: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('./use-activate-credential', () => ({
|
||||
useActivateCredential: () => ({
|
||||
selectedCredentialId: undefined,
|
||||
isActivating: false,
|
||||
activate: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../use-trial-credits', () => ({
|
||||
useTrialCredits: () => ({ credits: 0, totalCredits: 10_000, isExhausted: true, isLoading: false }),
|
||||
}))
|
||||
|
||||
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
|
||||
provider: 'test',
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
available_credentials: [],
|
||||
},
|
||||
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
...overrides,
|
||||
} as unknown as ModelProvider)
|
||||
|
||||
const createState = (overrides: Partial<CredentialPanelState> = {}): CredentialPanelState => ({
|
||||
variant: 'credits-active',
|
||||
priority: 'credits',
|
||||
supportsCredits: true,
|
||||
showPrioritySwitcher: false,
|
||||
hasCredentials: false,
|
||||
isCreditsExhausted: false,
|
||||
credentialName: undefined,
|
||||
credits: 100,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('ModelAuthDropdown', () => {
|
||||
const onChangePriority = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Button text', () => {
|
||||
it('should show "Add API Key" when no credentials for credits-active', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ hasCredentials: false, variant: 'credits-active' })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Configure" when has credentials for api-active', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ hasCredentials: true, variant: 'api-active' })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Add API Key" for api-required-add variant', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-required-add', hasCredentials: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Configure" for api-required-configure variant', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-required-configure', hasCredentials: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Configure" for credits-active when has credentials', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ hasCredentials: true, variant: 'credits-active' })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Add API Key" for credits-exhausted (no credentials)', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'credits-exhausted', hasCredentials: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /addApiKey/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Configure" for api-unavailable (has credentials)', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-unavailable', hasCredentials: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show "Configure" for api-fallback (has credentials)', () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-fallback', hasCredentials: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
expect(screen.getByRole('button', { name: /config/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Button variant styling', () => {
|
||||
it('should use secondary-accent for api-required-add', () => {
|
||||
const { container } = render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-required-add', hasCredentials: false })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
const button = container.querySelector('button')
|
||||
expect(button?.getAttribute('data-variant') ?? button?.className).toMatch(/accent/)
|
||||
})
|
||||
|
||||
it('should use secondary-accent for api-required-configure', () => {
|
||||
const { container } = render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider()}
|
||||
state={createState({ variant: 'api-required-configure', hasCredentials: true })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
const button = container.querySelector('button')
|
||||
expect(button?.getAttribute('data-variant') ?? button?.className).toMatch(/accent/)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Popover behavior', () => {
|
||||
it('should open popover on button click and show dropdown content', async () => {
|
||||
render(
|
||||
<ModelAuthDropdown
|
||||
provider={createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
available_credentials: [{ credential_id: 'c1', credential_name: 'Key 1' }],
|
||||
current_credential_id: 'c1',
|
||||
current_credential_name: 'Key 1',
|
||||
},
|
||||
})}
|
||||
state={createState({ hasCredentials: true, variant: 'api-active' })}
|
||||
isChangingPriority={false}
|
||||
onChangePriority={onChangePriority}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /config/i }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('Key 1')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,85 @@
|
||||
import type { ModelProvider, PreferredProviderTypeEnum } from '../../declarations'
|
||||
import type { CardVariant, CredentialPanelState } from '../use-credential-panel-state'
|
||||
import { memo, useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/app/components/base/ui/popover'
|
||||
import DropdownContent from './dropdown-content'
|
||||
|
||||
type ModelAuthDropdownProps = {
|
||||
provider: ModelProvider
|
||||
state: CredentialPanelState
|
||||
isChangingPriority: boolean
|
||||
onChangePriority: (key: PreferredProviderTypeEnum) => void
|
||||
}
|
||||
|
||||
const ACCENT_VARIANTS = new Set<CardVariant>([
|
||||
'api-required-add',
|
||||
'api-required-configure',
|
||||
])
|
||||
|
||||
function getButtonConfig(variant: CardVariant, hasCredentials: boolean, t: (key: string, opts?: Record<string, string>) => string) {
|
||||
if (ACCENT_VARIANTS.has(variant)) {
|
||||
return {
|
||||
text: variant === 'api-required-add'
|
||||
? t('modelProvider.auth.addApiKey', { ns: 'common' })
|
||||
: t('operation.config', { ns: 'common' }),
|
||||
variant: 'secondary-accent' as const,
|
||||
}
|
||||
}
|
||||
|
||||
const text = hasCredentials
|
||||
? t('operation.config', { ns: 'common' })
|
||||
: t('modelProvider.auth.addApiKey', { ns: 'common' })
|
||||
|
||||
return { text, variant: 'secondary' as const }
|
||||
}
|
||||
|
||||
function ModelAuthDropdown({
|
||||
provider,
|
||||
state,
|
||||
isChangingPriority,
|
||||
onChangePriority,
|
||||
}: ModelAuthDropdownProps) {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
const handleClose = useCallback(() => setOpen(false), [])
|
||||
|
||||
const buttonConfig = getButtonConfig(state.variant, state.hasCredentials, t)
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger
|
||||
render={(
|
||||
<Button
|
||||
className="flex grow"
|
||||
size="small"
|
||||
variant={buttonConfig.variant}
|
||||
title={buttonConfig.text}
|
||||
>
|
||||
<span className="i-ri-equalizer-2-line mr-1 h-3.5 w-3.5 shrink-0" />
|
||||
<span className="w-0 grow truncate text-left">
|
||||
{buttonConfig.text}
|
||||
</span>
|
||||
</Button>
|
||||
)}
|
||||
/>
|
||||
<PopoverContent placement="bottom-end">
|
||||
<DropdownContent
|
||||
provider={provider}
|
||||
state={state}
|
||||
isChangingPriority={isChangingPriority}
|
||||
onChangePriority={onChangePriority}
|
||||
onClose={handleClose}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ModelAuthDropdown)
|
||||
@@ -0,0 +1,66 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { PreferredProviderTypeEnum } from '../../declarations'
|
||||
import UsagePrioritySection from './usage-priority-section'
|
||||
|
||||
describe('UsagePrioritySection', () => {
|
||||
const onSelect = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Rendering
|
||||
describe('Rendering', () => {
|
||||
it('should render title and both option buttons', () => {
|
||||
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
|
||||
|
||||
expect(screen.getByText(/usagePriority/)).toBeInTheDocument()
|
||||
expect(screen.getAllByRole('button')).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
// Selection state
|
||||
describe('Selection state', () => {
|
||||
it('should highlight AI credits option when value is credits', () => {
|
||||
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
expect(buttons[0].className).toContain('border-components-option-card-option-selected-border')
|
||||
expect(buttons[1].className).not.toContain('border-components-option-card-option-selected-border')
|
||||
})
|
||||
|
||||
it('should highlight API key option when value is apiKey', () => {
|
||||
render(<UsagePrioritySection value="apiKey" onSelect={onSelect} />)
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
expect(buttons[0].className).not.toContain('border-components-option-card-option-selected-border')
|
||||
expect(buttons[1].className).toContain('border-components-option-card-option-selected-border')
|
||||
})
|
||||
|
||||
it('should highlight API key option when value is apiKeyOnly', () => {
|
||||
render(<UsagePrioritySection value="apiKeyOnly" onSelect={onSelect} />)
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
expect(buttons[1].className).toContain('border-components-option-card-option-selected-border')
|
||||
})
|
||||
})
|
||||
|
||||
// User interactions
|
||||
describe('User interactions', () => {
|
||||
it('should call onSelect with system when clicking AI credits option', () => {
|
||||
render(<UsagePrioritySection value="apiKey" onSelect={onSelect} />)
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button')[0])
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(PreferredProviderTypeEnum.system)
|
||||
})
|
||||
|
||||
it('should call onSelect with custom when clicking API key option', () => {
|
||||
render(<UsagePrioritySection value="credits" onSelect={onSelect} />)
|
||||
|
||||
fireEvent.click(screen.getAllByRole('button')[1])
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith(PreferredProviderTypeEnum.custom)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,56 @@
|
||||
import type { UsagePriority } from '../use-credential-panel-state'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { PreferredProviderTypeEnum } from '../../declarations'
|
||||
|
||||
type UsagePrioritySectionProps = {
|
||||
value: UsagePriority
|
||||
disabled?: boolean
|
||||
onSelect: (key: PreferredProviderTypeEnum) => void
|
||||
}
|
||||
|
||||
const options = [
|
||||
{ key: PreferredProviderTypeEnum.system, labelKey: 'modelProvider.card.aiCreditsOption' },
|
||||
{ key: PreferredProviderTypeEnum.custom, labelKey: 'modelProvider.card.apiKeyOption' },
|
||||
] as const
|
||||
|
||||
export default function UsagePrioritySection({ value, disabled, onSelect }: UsagePrioritySectionProps) {
|
||||
const { t } = useTranslation()
|
||||
const selectedKey = value === 'credits'
|
||||
? PreferredProviderTypeEnum.system
|
||||
: PreferredProviderTypeEnum.custom
|
||||
|
||||
return (
|
||||
<div className="p-1">
|
||||
<div className="flex items-center gap-1 rounded-lg p-1">
|
||||
<div className="shrink-0 px-0.5 py-1">
|
||||
<span className="i-ri-arrow-up-double-line block h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-0.5 py-0.5">
|
||||
<span className="truncate text-text-secondary system-sm-medium">
|
||||
{t('modelProvider.card.usagePriority', { ns: 'common' })}
|
||||
</span>
|
||||
<span className="i-ri-question-line h-3.5 w-3.5 shrink-0 text-text-quaternary" />
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-1">
|
||||
{options.map(option => (
|
||||
<button
|
||||
key={option.key}
|
||||
type="button"
|
||||
className={cn(
|
||||
'shrink-0 whitespace-nowrap rounded-md px-2 py-1 text-center transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-components-button-primary-border disabled:opacity-50',
|
||||
selectedKey === option.key
|
||||
? 'border-[1.5px] border-components-option-card-option-selected-border bg-components-panel-bg text-text-primary shadow-xs system-xs-medium'
|
||||
: 'border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary system-xs-regular hover:bg-components-option-card-option-bg-hover',
|
||||
)}
|
||||
disabled={disabled}
|
||||
onClick={() => onSelect(option.key)}
|
||||
>
|
||||
{t(option.labelKey, { ns: 'common' })}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
import type { Credential, ModelProvider } from '../../declarations'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useActiveProviderCredential } from '@/service/use-models'
|
||||
import {
|
||||
useUpdateModelList,
|
||||
useUpdateModelProviders,
|
||||
} from '../../hooks'
|
||||
|
||||
export function useActivateCredential(provider: ModelProvider) {
|
||||
const { t } = useTranslation()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const { mutate, isPending } = useActiveProviderCredential(provider.provider)
|
||||
const [optimisticId, setOptimisticId] = useState<string>()
|
||||
|
||||
const currentId = provider.custom_configuration.current_credential_id
|
||||
const selectedCredentialId = optimisticId ?? currentId
|
||||
|
||||
const selectedIdRef = useRef(selectedCredentialId)
|
||||
selectedIdRef.current = selectedCredentialId
|
||||
|
||||
const supportedModelTypes = provider.supported_model_types
|
||||
|
||||
const activate = useCallback((credential: Credential) => {
|
||||
if (credential.credential_id === selectedIdRef.current)
|
||||
return
|
||||
setOptimisticId(credential.credential_id)
|
||||
mutate(
|
||||
{ credential_id: credential.credential_id },
|
||||
{
|
||||
onSuccess: () => {
|
||||
Toast.notify({ type: 'success', message: t('api.actionSuccess', { ns: 'common' }) })
|
||||
updateModelProviders()
|
||||
supportedModelTypes.forEach(type => updateModelList(type))
|
||||
},
|
||||
onError: () => {
|
||||
setOptimisticId(undefined)
|
||||
Toast.notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [mutate, t, updateModelProviders, updateModelList, supportedModelTypes])
|
||||
|
||||
return {
|
||||
selectedCredentialId,
|
||||
isActivating: isPending,
|
||||
activate,
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,17 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { disableModel, enableModel } from '@/service/common'
|
||||
import { ModelStatusEnum } from '../declarations'
|
||||
import ModelListItem from './model-list-item'
|
||||
|
||||
function createWrapper() {
|
||||
const queryClient = new QueryClient({ defaultOptions: { queries: { retry: false } } })
|
||||
return ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
)
|
||||
}
|
||||
|
||||
let mockModelLoadBalancingEnabled = false
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@@ -69,6 +77,7 @@ describe('ModelListItem', () => {
|
||||
provider={mockProvider}
|
||||
isConfigurable={false}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
expect(screen.getByTestId('model-icon')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('model-name')).toBeInTheDocument()
|
||||
@@ -83,6 +92,7 @@ describe('ModelListItem', () => {
|
||||
isConfigurable={false}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
|
||||
@@ -102,6 +112,7 @@ describe('ModelListItem', () => {
|
||||
isConfigurable={false}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
fireEvent.click(screen.getByRole('switch'))
|
||||
|
||||
@@ -122,6 +133,7 @@ describe('ModelListItem', () => {
|
||||
isConfigurable={false}
|
||||
onModifyLoadBalancing={onModifyLoadBalancing}
|
||||
/>,
|
||||
{ wrapper: createWrapper() },
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'modify load balancing' }))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { ModelItem, ModelProvider } from '../declarations'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { memo, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -9,6 +10,7 @@ import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext, useProviderContextSelector } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { disableModel, enableModel } from '@/service/common'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { ModelStatusEnum } from '../declarations'
|
||||
@@ -30,16 +32,30 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad
|
||||
const { plan } = useProviderContext()
|
||||
const modelLoadBalancingEnabled = useProviderContextSelector(state => state.modelLoadBalancingEnabled)
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const queryClient = useQueryClient()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const toggleModelEnablingStatus = useCallback(async (enabled: boolean) => {
|
||||
if (enabled)
|
||||
await enableModel(`/workspaces/current/model-providers/${provider.provider}/models/enable`, { model: model.model, model_type: model.model_type })
|
||||
else
|
||||
await disableModel(`/workspaces/current/model-providers/${provider.provider}/models/disable`, { model: model.model, model_type: model.model_type })
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
updateModelList(model.model_type)
|
||||
onChange?.(provider.provider)
|
||||
}, [model.model, model.model_type, onChange, provider.provider, updateModelList])
|
||||
}, [model.model, model.model_type, modelProviderModelListQueryKey, onChange, provider.provider, queryClient, updateModelList])
|
||||
|
||||
const { run: debouncedToggleModelEnablingStatus } = useDebounceFn(toggleModelEnablingStatus, { wait: 500 })
|
||||
|
||||
@@ -58,7 +74,7 @@ const ModelListItem = ({ model, provider, isConfigurable, onChange, onModifyLoad
|
||||
modelName={model.model}
|
||||
/>
|
||||
<ModelName
|
||||
className="system-md-regular grow text-text-secondary"
|
||||
className="grow text-text-secondary system-md-regular"
|
||||
modelItem={model}
|
||||
showModelType
|
||||
showMode
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
import type { FC } from 'react'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { HeaderModals } from '@/app/components/plugins/plugin-detail-panel/detail-header/components'
|
||||
import { useDetailHeaderState, usePluginOperations } from '@/app/components/plugins/plugin-detail-panel/detail-header/hooks'
|
||||
import OperationDropdown from '@/app/components/plugins/plugin-detail-panel/operation-dropdown'
|
||||
import { PluginSource } from '@/app/components/plugins/types'
|
||||
import PluginVersionPicker from '@/app/components/plugins/update-plugin/plugin-version-picker'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { getMarketplaceUrl } from '@/utils/var'
|
||||
|
||||
type Props = {
|
||||
detail: PluginDetail
|
||||
onUpdate?: () => void
|
||||
}
|
||||
|
||||
const ProviderCardActions: FC<Props> = ({ detail, onUpdate }) => {
|
||||
const { t } = useTranslation()
|
||||
const { theme } = useTheme()
|
||||
const locale = useLocale()
|
||||
|
||||
const { source, version, latest_version, latest_unique_identifier, meta } = detail
|
||||
const author = detail.declaration?.author ?? ''
|
||||
const name = detail.declaration?.name ?? detail.name
|
||||
|
||||
const {
|
||||
modalStates,
|
||||
versionPicker,
|
||||
hasNewVersion,
|
||||
isAutoUpgradeEnabled,
|
||||
isFromMarketplace,
|
||||
isFromGitHub,
|
||||
} = useDetailHeaderState(detail)
|
||||
|
||||
const {
|
||||
handleUpdate,
|
||||
handleUpdatedFromMarketplace,
|
||||
handleDelete,
|
||||
} = usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
})
|
||||
|
||||
const handleVersionSelect = (state: { version: string, unique_identifier: string, isDowngrade?: boolean }) => {
|
||||
versionPicker.setTargetVersion(state)
|
||||
handleUpdate(state.isDowngrade)
|
||||
}
|
||||
|
||||
const handleTriggerLatestUpdate = () => {
|
||||
if (isFromMarketplace) {
|
||||
versionPicker.setTargetVersion({
|
||||
version: latest_version,
|
||||
unique_identifier: latest_unique_identifier,
|
||||
})
|
||||
}
|
||||
handleUpdate()
|
||||
}
|
||||
|
||||
const detailUrl = useMemo(() => {
|
||||
if (source === PluginSource.github)
|
||||
return meta?.repo ? `https://github.com/${meta.repo}` : ''
|
||||
if (source === PluginSource.marketplace)
|
||||
return getMarketplaceUrl(`/plugins/${author}/${name}`, { language: locale, theme })
|
||||
return ''
|
||||
}, [source, meta?.repo, author, name, locale, theme])
|
||||
|
||||
return (
|
||||
<>
|
||||
{!!version && (
|
||||
<PluginVersionPicker
|
||||
disabled={!isFromMarketplace}
|
||||
isShow={versionPicker.isShow}
|
||||
onShowChange={versionPicker.setIsShow}
|
||||
pluginID={detail.plugin_id}
|
||||
currentVersion={version}
|
||||
onSelect={handleVersionSelect}
|
||||
sideOffset={4}
|
||||
alignOffset={0}
|
||||
trigger={(
|
||||
<span
|
||||
className={cn(
|
||||
'relative inline-flex min-w-5 items-center justify-center gap-[3px] rounded-md border border-divider-deep bg-state-base-hover px-[5px] py-[2px] text-text-tertiary system-xs-medium-uppercase',
|
||||
isFromMarketplace && 'cursor-pointer hover:bg-state-base-hover-alt',
|
||||
)}
|
||||
>
|
||||
<span>{version}</span>
|
||||
{isFromMarketplace && <span aria-hidden className="i-ri-arrow-left-right-line h-3 w-3" />}
|
||||
{hasNewVersion && (
|
||||
<span className="absolute -right-0.5 -top-0.5 h-1.5 w-1.5 rounded-full bg-state-destructive-solid" />
|
||||
)}
|
||||
</span>
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{(hasNewVersion || isFromGitHub) && (
|
||||
<Button
|
||||
variant="secondary-accent"
|
||||
size="small"
|
||||
className="!h-5"
|
||||
onClick={handleTriggerLatestUpdate}
|
||||
>
|
||||
{t('detailPanel.operation.update', { ns: 'plugin' })}
|
||||
</Button>
|
||||
)}
|
||||
|
||||
<OperationDropdown
|
||||
source={source}
|
||||
onInfo={modalStates.showPluginInfo}
|
||||
onCheckVersion={() => handleUpdate()}
|
||||
onRemove={modalStates.showDeleteConfirm}
|
||||
detailUrl={detailUrl}
|
||||
placement="bottom-start"
|
||||
popupClassName="w-[192px]"
|
||||
/>
|
||||
|
||||
<HeaderModals
|
||||
detail={detail}
|
||||
modalStates={modalStates}
|
||||
targetVersion={versionPicker.targetVersion}
|
||||
isDowngrade={versionPicker.isDowngrade}
|
||||
isAutoUpgradeEnabled={isAutoUpgradeEnabled}
|
||||
onUpdatedFromMarketplace={handleUpdatedFromMarketplace}
|
||||
onDelete={handleDelete}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default ProviderCardActions
|
||||
@@ -0,0 +1,8 @@
|
||||
.gridBg {
|
||||
background-size: 4px 4px;
|
||||
background-image:
|
||||
linear-gradient(to right, var(--color-divider-subtle) 0.5px, transparent 0.5px),
|
||||
linear-gradient(to bottom, var(--color-divider-subtle) 0.5px, transparent 0.5px);
|
||||
-webkit-mask-image: radial-gradient(ellipse at center, rgba(0, 0, 0, 0.6), transparent 70%);
|
||||
mask-image: radial-gradient(ellipse at center, rgba(0, 0, 0, 0.6), transparent 70%);
|
||||
}
|
||||
@@ -2,11 +2,16 @@ import type { ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import QuotaPanel from './quota-panel'
|
||||
|
||||
let mockWorkspace = {
|
||||
let mockWorkspaceData: {
|
||||
trial_credits: number
|
||||
trial_credits_used: number
|
||||
next_credit_reset_date: string
|
||||
} | undefined = {
|
||||
trial_credits: 100,
|
||||
trial_credits_used: 30,
|
||||
next_credit_reset_date: '2024-12-31',
|
||||
}
|
||||
let mockWorkspaceIsPending = false
|
||||
let mockTrialModels: string[] = ['langgenius/openai/openai']
|
||||
let mockPlugins = [{
|
||||
plugin_id: 'langgenius/openai',
|
||||
@@ -25,15 +30,16 @@ vi.mock('@/app/components/base/icons/src/public/llm', () => {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
currentWorkspace: mockWorkspace,
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useCurrentWorkspace: () => ({
|
||||
data: mockWorkspaceData,
|
||||
isPending: mockWorkspaceIsPending,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (state: { systemFeatures: { trial_models: string[] } }) => unknown) => selector({
|
||||
systemFeatures: {
|
||||
useSystemFeaturesQuery: () => ({
|
||||
data: {
|
||||
trial_models: mockTrialModels,
|
||||
},
|
||||
}),
|
||||
@@ -71,22 +77,21 @@ describe('QuotaPanel', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkspace = {
|
||||
mockWorkspaceData = {
|
||||
trial_credits: 100,
|
||||
trial_credits_used: 30,
|
||||
next_credit_reset_date: '2024-12-31',
|
||||
}
|
||||
mockWorkspaceIsPending = false
|
||||
mockTrialModels = ['langgenius/openai/openai']
|
||||
mockPlugins = [{ plugin_id: 'langgenius/openai', latest_package_identifier: 'openai@1.0.0' }]
|
||||
})
|
||||
|
||||
it('should render loading state', () => {
|
||||
render(
|
||||
<QuotaPanel
|
||||
providers={mockProviders}
|
||||
isLoading
|
||||
/>,
|
||||
)
|
||||
mockWorkspaceData = undefined
|
||||
mockWorkspaceIsPending = true
|
||||
|
||||
render(<QuotaPanel providers={mockProviders} />)
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -102,8 +107,17 @@ describe('QuotaPanel', () => {
|
||||
expect(screen.getByText(/modelProvider\.resetDate/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should keep quota content during background refetch when cached workspace exists', () => {
|
||||
mockWorkspaceIsPending = true
|
||||
|
||||
render(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
expect(screen.queryByRole('status')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('70')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should floor credits at zero when usage is higher than quota', () => {
|
||||
mockWorkspace = {
|
||||
mockWorkspaceData = {
|
||||
trial_credits: 10,
|
||||
trial_credits_used: 999,
|
||||
next_credit_reset_date: '',
|
||||
@@ -111,7 +125,7 @@ describe('QuotaPanel', () => {
|
||||
|
||||
render(<QuotaPanel providers={mockProviders} />)
|
||||
|
||||
expect(screen.getByText('0')).toBeInTheDocument()
|
||||
expect(screen.getByText(/modelProvider\.card\.quotaExhausted/)).toBeInTheDocument()
|
||||
expect(screen.queryByText(/modelProvider\.resetDate/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { AnthropicShortLight, Deepseek, Gemini, Grok, OpenaiSmall, Tongyi } from '@/app/components/base/icons/src/public/llm'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import InstallFromMarketplace from '@/app/components/plugins/install-plugin/install-from-marketplace'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useSystemFeaturesQuery } from '@/context/global-public-context'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { ModelProviderQuotaGetPaid } from '@/types/model-provider'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@@ -18,8 +17,9 @@ import { formatNumber } from '@/utils/format'
|
||||
import { PreferredProviderTypeEnum } from '../declarations'
|
||||
import { useMarketplaceAllPlugins } from '../hooks'
|
||||
import { MODEL_PROVIDER_QUOTA_GET_PAID, modelNameMap } from '../utils'
|
||||
import styles from './quota-panel.module.css'
|
||||
import { useTrialCredits } from './use-trial-credits'
|
||||
|
||||
// Icon map for each provider - single source of truth for provider icons
|
||||
const providerIconMap: Record<ModelProviderQuotaGetPaid, ComponentType<{ className?: string }>> = {
|
||||
[ModelProviderQuotaGetPaid.OPENAI]: OpenaiSmall,
|
||||
[ModelProviderQuotaGetPaid.ANTHROPIC]: AnthropicShortLight,
|
||||
@@ -29,14 +29,11 @@ const providerIconMap: Record<ModelProviderQuotaGetPaid, ComponentType<{ classNa
|
||||
[ModelProviderQuotaGetPaid.TONGYI]: Tongyi,
|
||||
}
|
||||
|
||||
// Derive allProviders from the shared constant
|
||||
const allProviders = MODEL_PROVIDER_QUOTA_GET_PAID.map(key => ({
|
||||
key,
|
||||
Icon: providerIconMap[key],
|
||||
}))
|
||||
|
||||
// Map provider key to plugin ID
|
||||
// provider key format: langgenius/provider/model, plugin ID format: langgenius/provider
|
||||
const providerKeyToPluginId: Record<ModelProviderQuotaGetPaid, string> = {
|
||||
[ModelProviderQuotaGetPaid.OPENAI]: 'langgenius/openai',
|
||||
[ModelProviderQuotaGetPaid.ANTHROPIC]: 'langgenius/anthropic',
|
||||
@@ -48,16 +45,14 @@ const providerKeyToPluginId: Record<ModelProviderQuotaGetPaid, string> = {
|
||||
|
||||
type QuotaPanelProps = {
|
||||
providers: ModelProvider[]
|
||||
isLoading?: boolean
|
||||
}
|
||||
const QuotaPanel: FC<QuotaPanelProps> = ({
|
||||
providers,
|
||||
isLoading = false,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { currentWorkspace } = useAppContext()
|
||||
const { trial_models } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const credits = Math.max((currentWorkspace.trial_credits - currentWorkspace.trial_credits_used) || 0, 0)
|
||||
const { credits, isExhausted, isLoading, nextCreditResetDate } = useTrialCredits()
|
||||
const { data: systemFeatures } = useSystemFeaturesQuery()
|
||||
const trialModels = systemFeatures?.trial_models ?? []
|
||||
const providerMap = useMemo(() => new Map(
|
||||
providers.map(p => [p.provider, p.preferred_provider_type]),
|
||||
), [providers])
|
||||
@@ -98,6 +93,11 @@ const QuotaPanel: FC<QuotaPanelProps> = ({
|
||||
}
|
||||
}, [providers, isShowInstallModal, hideInstallFromMarketplace])
|
||||
|
||||
const tipText = t('modelProvider.card.tip', {
|
||||
ns: 'common',
|
||||
modelNames: trialModels.map(key => modelNameMap[key as keyof typeof modelNameMap]).filter(Boolean).join(', '),
|
||||
})
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="my-2 flex min-h-[72px] items-center justify-center rounded-xl border-[0.5px] border-components-panel-border bg-third-party-model-bg-default shadow-xs">
|
||||
@@ -107,59 +107,88 @@ const QuotaPanel: FC<QuotaPanelProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn('my-2 min-w-[72px] shrink-0 rounded-xl border-[0.5px] pb-2.5 pl-4 pr-2.5 pt-3 shadow-xs', credits <= 0 ? 'border-state-destructive-border hover:bg-state-destructive-hover' : 'border-components-panel-border bg-third-party-model-bg-default')}>
|
||||
<div className="system-xs-medium-uppercase mb-2 flex h-4 items-center text-text-tertiary">
|
||||
{t('modelProvider.quota', { ns: 'common' })}
|
||||
<Tooltip popupContent={t('modelProvider.card.tip', { ns: 'common', modelNames: trial_models.map(key => modelNameMap[key as keyof typeof modelNameMap]).filter(Boolean).join(', ') })} />
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-1 text-xs text-text-tertiary">
|
||||
<span className="system-md-semibold-uppercase mr-0.5 text-text-secondary">{formatNumber(credits)}</span>
|
||||
<span>{t('modelProvider.credits', { ns: 'common' })}</span>
|
||||
{currentWorkspace.next_credit_reset_date
|
||||
? (
|
||||
<>
|
||||
<span>·</span>
|
||||
<span>
|
||||
{t('modelProvider.resetDate', {
|
||||
ns: 'common',
|
||||
date: formatTime(currentWorkspace.next_credit_reset_date, t('dateFormat', { ns: 'appLog' })),
|
||||
interpolation: { escapeValue: false },
|
||||
})}
|
||||
</span>
|
||||
</>
|
||||
)
|
||||
: null}
|
||||
<div className={cn(
|
||||
'relative my-2 min-w-[72px] shrink-0 overflow-hidden rounded-xl border-[0.5px] pb-2.5 pl-4 pr-2.5 pt-3 shadow-xs',
|
||||
isExhausted
|
||||
? 'border-state-destructive-border hover:bg-state-destructive-hover'
|
||||
: 'border-components-panel-border bg-third-party-model-bg-default',
|
||||
)}
|
||||
>
|
||||
<div className={cn('pointer-events-none absolute inset-0', styles.gridBg)} />
|
||||
<div className="relative">
|
||||
<div className="mb-2 flex h-4 items-center text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('modelProvider.quota', { ns: 'common' })}
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
aria-label={tipText}
|
||||
delay={0}
|
||||
render={(
|
||||
<span className="ml-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<span aria-hidden className="i-ri-question-line h-3.5 w-3.5 text-text-quaternary hover:text-text-tertiary" />
|
||||
</span>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>
|
||||
{tipText}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{allProviders.filter(({ key }) => trial_models.includes(key)).map(({ key, Icon }) => {
|
||||
const providerType = providerMap.get(key)
|
||||
const isConfigured = (installedProvidersMap.get(key)?.length ?? 0) > 0 // means the provider is configured API key
|
||||
const getTooltipKey = () => {
|
||||
// if provider type is not set, it means the provider is not installed
|
||||
if (!providerType)
|
||||
return 'modelProvider.card.modelNotSupported'
|
||||
if (isConfigured && providerType === PreferredProviderTypeEnum.custom)
|
||||
return 'modelProvider.card.modelAPI'
|
||||
return 'modelProvider.card.modelSupported'
|
||||
}
|
||||
return (
|
||||
<Tooltip
|
||||
key={key}
|
||||
popupContent={t(getTooltipKey(), { modelName: modelNameMap[key], ns: 'common' })}
|
||||
>
|
||||
<div
|
||||
className={cn('relative h-6 w-6', !providerType && 'cursor-pointer hover:opacity-80')}
|
||||
onClick={() => handleIconClick(key)}
|
||||
>
|
||||
<Icon className="h-6 w-6 rounded-lg" />
|
||||
{!providerType && (
|
||||
<div className="absolute inset-0 rounded-lg border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge opacity-30" />
|
||||
)}
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
})}
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-1 text-xs text-text-tertiary">
|
||||
{credits > 0
|
||||
? <span className="mr-0.5 text-text-secondary system-xl-semibold">{formatNumber(credits)}</span>
|
||||
: <span className="mr-0.5 text-text-destructive system-xl-semibold">{t('modelProvider.card.quotaExhausted', { ns: 'common' })}</span>}
|
||||
{nextCreditResetDate
|
||||
? (
|
||||
<>
|
||||
<span>·</span>
|
||||
<span>
|
||||
{t('modelProvider.resetDate', {
|
||||
ns: 'common',
|
||||
date: formatTime(nextCreditResetDate!, t('dateFormat', { ns: 'appLog' })),
|
||||
interpolation: { escapeValue: false },
|
||||
})}
|
||||
</span>
|
||||
</>
|
||||
)
|
||||
: null}
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{allProviders.filter(({ key }) => trialModels.includes(key)).map(({ key, Icon }) => {
|
||||
const providerType = providerMap.get(key)
|
||||
const isConfigured = (installedProvidersMap.get(key)?.length ?? 0) > 0
|
||||
const getTooltipKey = () => {
|
||||
if (!providerType)
|
||||
return 'modelProvider.card.modelNotSupported'
|
||||
if (isConfigured && providerType === PreferredProviderTypeEnum.custom)
|
||||
return 'modelProvider.card.modelAPI'
|
||||
return 'modelProvider.card.modelSupported'
|
||||
}
|
||||
const tooltipText = t(getTooltipKey(), { modelName: modelNameMap[key], ns: 'common' })
|
||||
return (
|
||||
<Tooltip key={key}>
|
||||
<TooltipTrigger
|
||||
aria-label={tooltipText}
|
||||
delay={0}
|
||||
render={(
|
||||
<div
|
||||
className={cn('relative h-6 w-6', !providerType && 'cursor-pointer hover:opacity-80')}
|
||||
onClick={() => handleIconClick(key)}
|
||||
>
|
||||
<Icon className="h-6 w-6 rounded-lg" />
|
||||
{!providerType && (
|
||||
<div className="absolute inset-0 rounded-lg border-[0.5px] border-components-panel-border-subtle bg-background-default-dodge opacity-30" />
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>
|
||||
{tooltipText}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{isShowInstallModal && selectedPlugin && (
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import SystemQuotaCard from './system-quota-card'
|
||||
|
||||
describe('SystemQuotaCard', () => {
|
||||
// Renders container with children
|
||||
describe('Rendering', () => {
|
||||
it('should render children', () => {
|
||||
render(
|
||||
<SystemQuotaCard>
|
||||
<span>content</span>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply default variant styles', () => {
|
||||
const { container } = render(
|
||||
<SystemQuotaCard>
|
||||
<span>test</span>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
const card = container.firstElementChild!
|
||||
expect(card.className).toContain('bg-white')
|
||||
})
|
||||
|
||||
it('should apply destructive variant styles', () => {
|
||||
const { container } = render(
|
||||
<SystemQuotaCard variant="destructive">
|
||||
<span>test</span>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
const card = container.firstElementChild!
|
||||
expect(card.className).toContain('border-state-destructive-border')
|
||||
})
|
||||
})
|
||||
|
||||
// Label sub-component
|
||||
describe('Label', () => {
|
||||
it('should apply default variant text color when no className provided', () => {
|
||||
render(
|
||||
<SystemQuotaCard>
|
||||
<SystemQuotaCard.Label>Default label</SystemQuotaCard.Label>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Default label').className).toContain('text-text-secondary')
|
||||
})
|
||||
|
||||
it('should apply destructive variant text color when no className provided', () => {
|
||||
render(
|
||||
<SystemQuotaCard variant="destructive">
|
||||
<SystemQuotaCard.Label>Error label</SystemQuotaCard.Label>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Error label').className).toContain('text-text-destructive')
|
||||
})
|
||||
|
||||
it('should override variant color with custom className', () => {
|
||||
render(
|
||||
<SystemQuotaCard variant="destructive">
|
||||
<SystemQuotaCard.Label className="gap-1">Custom label</SystemQuotaCard.Label>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
const label = screen.getByText('Custom label')
|
||||
expect(label.className).toContain('gap-1')
|
||||
expect(label.className).not.toContain('text-text-destructive')
|
||||
})
|
||||
})
|
||||
|
||||
// Actions sub-component
|
||||
describe('Actions', () => {
|
||||
it('should render action children', () => {
|
||||
render(
|
||||
<SystemQuotaCard>
|
||||
<SystemQuotaCard.Actions>
|
||||
<button>Click me</button>
|
||||
</SystemQuotaCard.Actions>
|
||||
</SystemQuotaCard>,
|
||||
)
|
||||
|
||||
expect(screen.getByRole('button', { name: /click me/i })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,67 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { createContext, useContext } from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import styles from './quota-panel.module.css'
|
||||
|
||||
type Variant = 'default' | 'destructive'
|
||||
|
||||
const VariantContext = createContext<Variant>('default')
|
||||
|
||||
const containerVariants: Record<Variant, string> = {
|
||||
default: 'border-components-panel-border bg-white/[0.18]',
|
||||
destructive: 'border-state-destructive-border bg-state-destructive-hover',
|
||||
}
|
||||
|
||||
const labelVariants: Record<Variant, string> = {
|
||||
default: 'text-text-secondary',
|
||||
destructive: 'text-text-destructive',
|
||||
}
|
||||
|
||||
type SystemQuotaCardProps = {
|
||||
variant?: Variant
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
const SystemQuotaCard = ({
|
||||
variant = 'default',
|
||||
children,
|
||||
}: SystemQuotaCardProps) => {
|
||||
return (
|
||||
<VariantContext.Provider value={variant}>
|
||||
<div className={cn(
|
||||
'relative isolate ml-1 flex w-[128px] shrink-0 flex-col justify-between rounded-lg border-[0.5px] p-1 shadow-xs',
|
||||
containerVariants[variant],
|
||||
)}
|
||||
>
|
||||
<div className={cn('pointer-events-none absolute inset-0 rounded-[7px]', styles.gridBg)} />
|
||||
{children}
|
||||
</div>
|
||||
</VariantContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
const Label = ({ children, className }: { children: ReactNode, className?: string }) => {
|
||||
const variant = useContext(VariantContext)
|
||||
return (
|
||||
<div className={cn(
|
||||
'relative z-[1] flex items-center gap-1 truncate px-1.5 pt-1 system-xs-medium',
|
||||
className ?? labelVariants[variant],
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const Actions = ({ children }: { children: ReactNode }) => {
|
||||
return (
|
||||
<div className="relative z-[1] flex items-center gap-0.5">
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
SystemQuotaCard.Label = Label
|
||||
SystemQuotaCard.Actions = Actions
|
||||
|
||||
export default SystemQuotaCard
|
||||
@@ -0,0 +1,235 @@
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { renderHook } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CurrentSystemQuotaTypeEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
PreferredProviderTypeEnum,
|
||||
} from '../declarations'
|
||||
import { isDestructiveVariant, useCredentialPanelState } from './use-credential-panel-state'
|
||||
|
||||
const mockTrialCredits = { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false, nextCreditResetDate: undefined }
|
||||
|
||||
vi.mock('./use-trial-credits', () => ({
|
||||
useTrialCredits: () => mockTrialCredits,
|
||||
}))
|
||||
|
||||
vi.mock('@/config', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/config')>()
|
||||
return { ...actual, IS_CLOUD_EDITION: true }
|
||||
})
|
||||
|
||||
const createProvider = (overrides: Partial<ModelProvider> = {}): ModelProvider => ({
|
||||
provider: 'test-provider',
|
||||
provider_credential_schema: { credential_form_schemas: [] },
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: 'cred-1',
|
||||
current_credential_name: 'My Key',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
|
||||
},
|
||||
system_configuration: { enabled: true, current_quota_type: 'trial', quota_configurations: [] },
|
||||
preferred_provider_type: PreferredProviderTypeEnum.system,
|
||||
configurate_methods: [ConfigurationMethodEnum.predefinedModel],
|
||||
supported_model_types: ['llm'],
|
||||
...overrides,
|
||||
} as unknown as ModelProvider)
|
||||
|
||||
describe('useCredentialPanelState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
Object.assign(mockTrialCredits, { credits: 100, totalCredits: 10_000, isExhausted: false, isLoading: false })
|
||||
})
|
||||
|
||||
// Credits priority variants
|
||||
describe('Credits priority variants', () => {
|
||||
it('should return credits-active when credits available', () => {
|
||||
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
|
||||
|
||||
expect(result.current.variant).toBe('credits-active')
|
||||
expect(result.current.priority).toBe('credits')
|
||||
expect(result.current.supportsCredits).toBe(true)
|
||||
})
|
||||
|
||||
it('should return api-fallback when credits exhausted but API key authorized', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
mockTrialCredits.credits = 0
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
|
||||
|
||||
expect(result.current.variant).toBe('api-fallback')
|
||||
})
|
||||
|
||||
it('should return no-usage when credits exhausted and API key unauthorized', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const provider = createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('no-usage')
|
||||
})
|
||||
|
||||
it('should return credits-exhausted when credits exhausted and no credentials', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
const provider = createProvider({
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('credits-exhausted')
|
||||
})
|
||||
})
|
||||
|
||||
// API key priority variants
|
||||
describe('API key priority variants', () => {
|
||||
it('should return api-active when API key authorized', () => {
|
||||
const provider = createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('api-active')
|
||||
expect(result.current.priority).toBe('apiKey')
|
||||
})
|
||||
|
||||
it('should return credits-fallback when API key unauthorized and credits available', () => {
|
||||
const provider = createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: undefined,
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'My Key' }],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('credits-fallback')
|
||||
})
|
||||
|
||||
it('should return credits-fallback when no credentials and credits available', () => {
|
||||
const provider = createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('credits-fallback')
|
||||
})
|
||||
|
||||
it('should return no-usage when no credentials and credits exhausted', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
mockTrialCredits.credits = 0
|
||||
const provider = createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.noConfigure,
|
||||
available_credentials: [],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('no-usage')
|
||||
})
|
||||
|
||||
it('should return api-unavailable when credential with name unauthorized and credits exhausted', () => {
|
||||
mockTrialCredits.isExhausted = true
|
||||
mockTrialCredits.credits = 0
|
||||
const provider = createProvider({
|
||||
preferred_provider_type: PreferredProviderTypeEnum.custom,
|
||||
custom_configuration: {
|
||||
status: CustomConfigurationStatusEnum.active,
|
||||
current_credential_id: undefined,
|
||||
current_credential_name: 'Bad Key',
|
||||
available_credentials: [{ credential_id: 'cred-1', credential_name: 'Bad Key' }],
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.variant).toBe('api-unavailable')
|
||||
})
|
||||
})
|
||||
|
||||
// apiKeyOnly priority
|
||||
describe('apiKeyOnly priority (non-cloud / system disabled)', () => {
|
||||
it('should return apiKeyOnly when system config disabled', () => {
|
||||
const provider = createProvider({
|
||||
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.priority).toBe('apiKeyOnly')
|
||||
expect(result.current.supportsCredits).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// Derived metadata
|
||||
describe('Derived metadata', () => {
|
||||
it('should show priority switcher when credits supported and custom config active', () => {
|
||||
const provider = createProvider()
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.showPrioritySwitcher).toBe(true)
|
||||
})
|
||||
|
||||
it('should hide priority switcher when system config disabled', () => {
|
||||
const provider = createProvider({
|
||||
system_configuration: { enabled: false, current_quota_type: CurrentSystemQuotaTypeEnum.trial, quota_configurations: [] },
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(provider))
|
||||
|
||||
expect(result.current.showPrioritySwitcher).toBe(false)
|
||||
})
|
||||
|
||||
it('should expose credential name from provider', () => {
|
||||
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
|
||||
|
||||
expect(result.current.credentialName).toBe('My Key')
|
||||
})
|
||||
|
||||
it('should expose credits amount', () => {
|
||||
mockTrialCredits.credits = 500
|
||||
|
||||
const { result } = renderHook(() => useCredentialPanelState(createProvider()))
|
||||
|
||||
expect(result.current.credits).toBe(500)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isDestructiveVariant', () => {
|
||||
it.each([
|
||||
['credits-exhausted', true],
|
||||
['no-usage', true],
|
||||
['api-unavailable', true],
|
||||
['credits-active', false],
|
||||
['api-fallback', false],
|
||||
['api-active', false],
|
||||
['api-required-add', false],
|
||||
['api-required-configure', false],
|
||||
] as const)('should return %s for variant %s', (variant, expected) => {
|
||||
expect(isDestructiveVariant(variant)).toBe(expected)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,106 @@
|
||||
import type { ModelProvider } from '../declarations'
|
||||
import { useCredentialStatus } from '@/app/components/header/account-setting/model-provider-page/model-auth/hooks'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import {
|
||||
PreferredProviderTypeEnum,
|
||||
} from '../declarations'
|
||||
import { useTrialCredits } from './use-trial-credits'
|
||||
|
||||
export type UsagePriority = 'credits' | 'apiKey' | 'apiKeyOnly'
|
||||
|
||||
export type CardVariant
|
||||
= | 'credits-active'
|
||||
| 'credits-fallback'
|
||||
| 'credits-exhausted'
|
||||
| 'no-usage'
|
||||
| 'api-fallback'
|
||||
| 'api-active'
|
||||
| 'api-required-add'
|
||||
| 'api-required-configure'
|
||||
| 'api-unavailable'
|
||||
|
||||
export type CredentialPanelState = {
|
||||
variant: CardVariant
|
||||
priority: UsagePriority
|
||||
supportsCredits: boolean
|
||||
showPrioritySwitcher: boolean
|
||||
hasCredentials: boolean
|
||||
isCreditsExhausted: boolean
|
||||
credentialName: string | undefined
|
||||
credits: number
|
||||
}
|
||||
|
||||
const DESTRUCTIVE_VARIANTS = new Set<CardVariant>([
|
||||
'credits-exhausted',
|
||||
'no-usage',
|
||||
'api-unavailable',
|
||||
])
|
||||
|
||||
export const isDestructiveVariant = (variant: CardVariant) =>
|
||||
DESTRUCTIVE_VARIANTS.has(variant)
|
||||
|
||||
function deriveVariant(
|
||||
priority: UsagePriority,
|
||||
isExhausted: boolean,
|
||||
hasCredential: boolean,
|
||||
authorized: boolean | undefined,
|
||||
credentialName: string | undefined,
|
||||
): CardVariant {
|
||||
if (priority === 'credits') {
|
||||
if (!isExhausted)
|
||||
return 'credits-active'
|
||||
if (hasCredential && authorized)
|
||||
return 'api-fallback'
|
||||
if (hasCredential && !authorized)
|
||||
return 'no-usage'
|
||||
return 'credits-exhausted'
|
||||
}
|
||||
|
||||
if (hasCredential && authorized)
|
||||
return 'api-active'
|
||||
|
||||
if (priority === 'apiKey' && !isExhausted)
|
||||
return 'credits-fallback'
|
||||
|
||||
if (priority === 'apiKey' && !hasCredential)
|
||||
return 'no-usage'
|
||||
|
||||
if (hasCredential && !authorized)
|
||||
return credentialName ? 'api-unavailable' : 'api-required-configure'
|
||||
return 'api-required-add'
|
||||
}
|
||||
|
||||
export function useCredentialPanelState(provider: ModelProvider): CredentialPanelState {
|
||||
const { isExhausted, credits } = useTrialCredits()
|
||||
const {
|
||||
hasCredential,
|
||||
authorized,
|
||||
current_credential_name,
|
||||
} = useCredentialStatus(provider)
|
||||
|
||||
const systemConfig = provider.system_configuration
|
||||
const preferredType = provider.preferred_provider_type
|
||||
|
||||
const supportsCredits = systemConfig.enabled && IS_CLOUD_EDITION
|
||||
|
||||
const priority: UsagePriority = !supportsCredits
|
||||
? 'apiKeyOnly'
|
||||
: preferredType === PreferredProviderTypeEnum.system
|
||||
? 'credits'
|
||||
: 'apiKey'
|
||||
|
||||
const showPrioritySwitcher = supportsCredits
|
||||
|
||||
const variant = deriveVariant(priority, isExhausted, hasCredential, !!authorized, current_credential_name)
|
||||
|
||||
return {
|
||||
variant,
|
||||
priority,
|
||||
supportsCredits,
|
||||
showPrioritySwitcher,
|
||||
hasCredentials: hasCredential,
|
||||
isCreditsExhausted: isExhausted,
|
||||
credentialName: current_credential_name,
|
||||
credits,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
import { useCurrentWorkspace } from '@/service/use-common'
|
||||
|
||||
export const useTrialCredits = () => {
|
||||
const { data: currentWorkspace, isPending } = useCurrentWorkspace()
|
||||
const totalCredits = currentWorkspace?.trial_credits ?? 0
|
||||
const credits = Math.max(totalCredits - (currentWorkspace?.trial_credits_used ?? 0), 0)
|
||||
|
||||
return {
|
||||
credits,
|
||||
totalCredits,
|
||||
isExhausted: credits <= 0,
|
||||
isLoading: isPending && !currentWorkspace,
|
||||
nextCreditResetDate: currentWorkspace?.next_credit_reset_date,
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ const ProviderIcon: FC<ProviderIconProps> = ({
|
||||
|
||||
if (provider.provider === 'langgenius/anthropic/anthropic') {
|
||||
return (
|
||||
<div className="mb-2 py-[7px]">
|
||||
<div className={cn('py-[7px]', className)}>
|
||||
{theme === Theme.dark && <AnthropicLight className="h-2.5 w-[90px]" />}
|
||||
{theme === Theme.light && <AnthropicDark className="h-2.5 w-[90px]" />}
|
||||
</div>
|
||||
@@ -30,7 +30,7 @@ const ProviderIcon: FC<ProviderIconProps> = ({
|
||||
|
||||
if (provider.provider === 'langgenius/openai/openai') {
|
||||
return (
|
||||
<div className="mb-2">
|
||||
<div className={className}>
|
||||
<Openai className="h-6 w-auto text-text-inverted-dimmed" />
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ vi.mock('react-i18next', async () => {
|
||||
|
||||
const mockNotify = vi.hoisted(() => vi.fn())
|
||||
const mockUpdateModelList = vi.hoisted(() => vi.fn())
|
||||
const mockInvalidateDefaultModel = vi.hoisted(() => vi.fn())
|
||||
const mockUpdateDefaultModel = vi.hoisted(() => vi.fn(() => Promise.resolve({ result: 'success' })))
|
||||
|
||||
let mockIsCurrentWorkspaceManager = true
|
||||
@@ -57,6 +58,7 @@ vi.mock('../hooks', () => ({
|
||||
vi.fn(),
|
||||
],
|
||||
useUpdateModelList: () => mockUpdateModelList,
|
||||
useInvalidateDefaultModel: () => mockInvalidateDefaultModel,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
@@ -99,7 +101,7 @@ describe('SystemModel', () => {
|
||||
expect(screen.getByRole('button', { name: /system model settings/i })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should open modal when button is clicked', async () => {
|
||||
it('should open dialog when button is clicked', async () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
const button = screen.getByRole('button', { name: /system model settings/i })
|
||||
fireEvent.click(button)
|
||||
@@ -113,7 +115,7 @@ describe('SystemModel', () => {
|
||||
expect(screen.getByRole('button', { name: /system model settings/i })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should close modal when cancel is clicked', async () => {
|
||||
it('should close dialog when cancel is clicked', async () => {
|
||||
render(<SystemModel {...defaultProps} />)
|
||||
fireEvent.click(screen.getByRole('button', { name: /system model settings/i }))
|
||||
await waitFor(() => {
|
||||
@@ -144,6 +146,7 @@ describe('SystemModel', () => {
|
||||
type: 'success',
|
||||
message: 'Modified successfully',
|
||||
})
|
||||
expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5)
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(5)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,22 +3,27 @@ import type {
|
||||
DefaultModel,
|
||||
DefaultModelResponse,
|
||||
} from '../declarations'
|
||||
import { RiEqualizer2Line, RiLoader2Line } from '@remixicon/react'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
DialogTitle,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/app/components/base/ui/tooltip'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { updateDefaultModel } from '@/service/common'
|
||||
import { ModelTypeEnum } from '../declarations'
|
||||
import {
|
||||
useInvalidateDefaultModel,
|
||||
useModelList,
|
||||
useSystemDefaultModelAndModelList,
|
||||
useUpdateModelList,
|
||||
@@ -34,6 +39,21 @@ type SystemModelSelectorProps = {
|
||||
notConfigured: boolean
|
||||
isLoading?: boolean
|
||||
}
|
||||
|
||||
type SystemModelLabelKey
|
||||
= | 'modelProvider.systemReasoningModel.key'
|
||||
| 'modelProvider.embeddingModel.key'
|
||||
| 'modelProvider.rerankModel.key'
|
||||
| 'modelProvider.speechToTextModel.key'
|
||||
| 'modelProvider.ttsModel.key'
|
||||
|
||||
type SystemModelTipKey
|
||||
= | 'modelProvider.systemReasoningModel.tip'
|
||||
| 'modelProvider.embeddingModel.tip'
|
||||
| 'modelProvider.rerankModel.tip'
|
||||
| 'modelProvider.speechToTextModel.tip'
|
||||
| 'modelProvider.ttsModel.tip'
|
||||
|
||||
const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
textGenerationDefaultModel,
|
||||
embeddingsDefaultModel,
|
||||
@@ -48,6 +68,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const { textGenerationModelList } = useProviderContext()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const invalidateDefaultModel = useInvalidateDefaultModel()
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
|
||||
const { data: speech2textModelList } = useModelList(ModelTypeEnum.speech2text)
|
||||
@@ -106,154 +127,124 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
setOpen(false)
|
||||
|
||||
changedModelTypes.forEach((modelType) => {
|
||||
if (modelType === ModelTypeEnum.textGeneration)
|
||||
updateModelList(modelType)
|
||||
else if (modelType === ModelTypeEnum.textEmbedding)
|
||||
updateModelList(modelType)
|
||||
else if (modelType === ModelTypeEnum.rerank)
|
||||
updateModelList(modelType)
|
||||
else if (modelType === ModelTypeEnum.speech2text)
|
||||
updateModelList(modelType)
|
||||
else if (modelType === ModelTypeEnum.tts)
|
||||
updateModelList(modelType)
|
||||
})
|
||||
const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts]
|
||||
allModelTypes.forEach(type => invalidateDefaultModel(type))
|
||||
changedModelTypes.forEach(type => updateModelList(type))
|
||||
}
|
||||
}
|
||||
|
||||
const renderModelLabel = (labelKey: SystemModelLabelKey, tipKey: SystemModelTipKey) => {
|
||||
const tipText = t(tipKey, { ns: 'common' })
|
||||
|
||||
return (
|
||||
<div className="flex min-h-6 items-center text-[13px] font-medium text-text-secondary">
|
||||
{t(labelKey, { ns: 'common' })}
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
aria-label={tipText}
|
||||
delay={0}
|
||||
render={(
|
||||
<span className="ml-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<span aria-hidden className="i-ri-question-line h-3.5 w-3.5 text-text-quaternary hover:text-text-tertiary" />
|
||||
</span>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{tipText}
|
||||
</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement="bottom-end"
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
crossAxis: 8,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger asChild onClick={() => setOpen(v => !v)}>
|
||||
<Button
|
||||
className="relative"
|
||||
variant={notConfigured ? 'primary' : 'secondary'}
|
||||
size="small"
|
||||
disabled={isLoading}
|
||||
<>
|
||||
<Button
|
||||
className="relative"
|
||||
variant={notConfigured ? 'primary' : 'secondary'}
|
||||
size="small"
|
||||
disabled={isLoading}
|
||||
onClick={() => setOpen(true)}
|
||||
>
|
||||
{isLoading
|
||||
? <span className="i-ri-loader-2-line mr-1 h-3.5 w-3.5 animate-spin" />
|
||||
: <span className="i-ri-equalizer-2-line mr-1 h-3.5 w-3.5" />}
|
||||
{t('modelProvider.systemModelSettings', { ns: 'common' })}
|
||||
</Button>
|
||||
<Dialog open={open} onOpenChange={setOpen}>
|
||||
<DialogContent
|
||||
backdropProps={{ forceRender: true }}
|
||||
className="w-[480px] max-w-[480px] overflow-hidden p-0"
|
||||
>
|
||||
{isLoading
|
||||
? <RiLoader2Line className="mr-1 h-3.5 w-3.5 animate-spin" />
|
||||
: <RiEqualizer2Line className="mr-1 h-3.5 w-3.5" />}
|
||||
{t('modelProvider.systemModelSettings', { ns: 'common' })}
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className="z-[60]">
|
||||
<div className="w-[360px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg pt-4 shadow-xl">
|
||||
<div className="px-6 py-1">
|
||||
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
|
||||
{t('modelProvider.systemReasoningModel.key', { ns: 'common' })}
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{t('modelProvider.systemReasoningModel.tip', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
|
||||
/>
|
||||
<DialogCloseButton className="right-5 top-5" />
|
||||
<div className="px-6 pb-3 pr-14 pt-6">
|
||||
<DialogTitle className="text-text-primary title-2xl-semi-bold">
|
||||
{t('modelProvider.systemModelSettings', { ns: 'common' })}
|
||||
</DialogTitle>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4 px-6 py-3">
|
||||
<div className="flex flex-col gap-1">
|
||||
{renderModelLabel('modelProvider.systemReasoningModel.key', 'modelProvider.systemReasoningModel.tip')}
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentTextGenerationDefaultModel}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textGeneration, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentTextGenerationDefaultModel}
|
||||
modelList={textGenerationModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textGeneration, model)}
|
||||
/>
|
||||
<div className="flex flex-col gap-1">
|
||||
{renderModelLabel('modelProvider.embeddingModel.key', 'modelProvider.embeddingModel.tip')}
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentEmbeddingsDefaultModel}
|
||||
modelList={embeddingModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textEmbedding, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
{renderModelLabel('modelProvider.rerankModel.key', 'modelProvider.rerankModel.tip')}
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentRerankDefaultModel}
|
||||
modelList={rerankModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.rerank, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
{renderModelLabel('modelProvider.speechToTextModel.key', 'modelProvider.speechToTextModel.tip')}
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentSpeech2textDefaultModel}
|
||||
modelList={speech2textModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.speech2text, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
{renderModelLabel('modelProvider.ttsModel.key', 'modelProvider.ttsModel.tip')}
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentTTSDefaultModel}
|
||||
modelList={ttsModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.tts, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-6 py-1">
|
||||
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
|
||||
{t('modelProvider.embeddingModel.key', { ns: 'common' })}
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{t('modelProvider.embeddingModel.tip', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentEmbeddingsDefaultModel}
|
||||
modelList={embeddingModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.textEmbedding, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-6 py-1">
|
||||
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
|
||||
{t('modelProvider.rerankModel.key', { ns: 'common' })}
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{t('modelProvider.rerankModel.tip', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentRerankDefaultModel}
|
||||
modelList={rerankModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.rerank, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-6 py-1">
|
||||
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
|
||||
{t('modelProvider.speechToTextModel.key', { ns: 'common' })}
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{t('modelProvider.speechToTextModel.tip', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentSpeech2textDefaultModel}
|
||||
modelList={speech2textModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.speech2text, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-6 py-1">
|
||||
<div className="flex h-8 items-center text-[13px] font-medium text-text-primary">
|
||||
{t('modelProvider.ttsModel.key', { ns: 'common' })}
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[261px] text-text-tertiary">
|
||||
{t('modelProvider.ttsModel.tip', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
triggerClassName="ml-0.5 w-4 h-4 shrink-0"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={currentTTSDefaultModel}
|
||||
modelList={ttsModelList}
|
||||
onSelect={model => handleChangeDefaultModel(ModelTypeEnum.tts, model)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end px-6 py-4">
|
||||
<div className="flex items-center justify-end gap-2 px-6 pb-6 pt-5">
|
||||
<Button
|
||||
className="min-w-[72px]"
|
||||
onClick={() => setOpen(false)}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
className="ml-2"
|
||||
className="min-w-[72px]"
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={!isCurrentWorkspaceManager}
|
||||
@@ -261,9 +252,9 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
{t('operation.save', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,11 @@ import {
|
||||
|
||||
export { ModelProviderQuotaGetPaid } from '@/types/model-provider'
|
||||
|
||||
export const providerToPluginId = (providerKey: string): string => {
|
||||
const lastSlash = providerKey.lastIndexOf('/')
|
||||
return lastSlash > 0 ? providerKey.slice(0, lastSlash) : ''
|
||||
}
|
||||
|
||||
export const MODEL_PROVIDER_QUOTA_GET_PAID = [ModelProviderQuotaGetPaid.OPENAI, ModelProviderQuotaGetPaid.ANTHROPIC, ModelProviderQuotaGetPaid.GEMINI, ModelProviderQuotaGetPaid.X, ModelProviderQuotaGetPaid.DEEPSEEK, ModelProviderQuotaGetPaid.TONGYI]
|
||||
|
||||
export const modelNameMap = {
|
||||
|
||||
@@ -79,6 +79,10 @@ vi.mock('@/service/plugins', () => ({
|
||||
uninstallPlugin: mockUninstallPlugin,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useInvalidateCheckInstalled: () => vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useAllToolProviders: () => ({ data: [] }),
|
||||
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
|
||||
@@ -218,23 +222,6 @@ vi.mock('../../plugin-auth', () => ({
|
||||
PluginAuth: () => <div data-testid="plugin-auth" />,
|
||||
}))
|
||||
|
||||
// Mock Confirm component
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, onCancel, onConfirm, isLoading }: {
|
||||
isShow: boolean
|
||||
onCancel: () => void
|
||||
onConfirm: () => void
|
||||
isLoading: boolean
|
||||
}) => isShow
|
||||
? (
|
||||
<div data-testid="delete-confirm">
|
||||
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
|
||||
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
const createPluginDetail = (overrides: Partial<PluginDetail> = {}): PluginDetail => ({
|
||||
id: 'test-id',
|
||||
created_at: '2024-01-01',
|
||||
@@ -801,7 +788,7 @@ describe('DetailHeader', () => {
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -810,13 +797,13 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-cancel'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -825,10 +812,10 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockUninstallPlugin).toHaveBeenCalledWith('test-id')
|
||||
@@ -840,10 +827,10 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnUpdate).toHaveBeenCalledWith(true)
|
||||
@@ -861,10 +848,10 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRefreshModelProviders).toHaveBeenCalled()
|
||||
@@ -876,10 +863,10 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockInvalidateAllToolProviders).toHaveBeenCalled()
|
||||
@@ -891,10 +878,10 @@ describe('DetailHeader', () => {
|
||||
|
||||
fireEvent.click(screen.getByTestId('remove-btn'))
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(amplitude.trackEvent).toHaveBeenCalledWith('plugin_uninstalled', expect.any(Object))
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import type { ReactElement, ReactNode } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { cloneElement } from 'react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginSource } from '../../types'
|
||||
import OperationDropdown from '../operation-dropdown'
|
||||
@@ -12,24 +14,22 @@ vi.mock('@/utils/classnames', () => ({
|
||||
cn: (...args: (string | undefined | false | null)[]) => args.filter(Boolean).join(' '),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/action-button', () => ({
|
||||
default: ({ children, className, onClick }: { children: React.ReactNode, className?: string, onClick?: () => void }) => (
|
||||
<button data-testid="action-button" className={className} onClick={onClick}>
|
||||
{children}
|
||||
</button>
|
||||
vi.mock('@/app/components/base/ui/dropdown-menu', () => ({
|
||||
DropdownMenu: ({ children, open }: { children: ReactNode, open: boolean }) => (
|
||||
<div data-testid="dropdown-menu" data-open={open}>{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
|
||||
<div data-testid="portal-elem" data-open={open}>{children}</div>
|
||||
DropdownMenuTrigger: ({ children, className }: { children: ReactNode, className?: string }) => (
|
||||
<button data-testid="dropdown-trigger" className={className}>{children}</button>
|
||||
),
|
||||
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick}>{children}</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className?: string }) => (
|
||||
<div data-testid="portal-content" className={className}>{children}</div>
|
||||
DropdownMenuContent: ({ children }: { children: ReactNode }) => (
|
||||
<div data-testid="dropdown-content">{children}</div>
|
||||
),
|
||||
DropdownMenuItem: ({ children, onClick, render, destructive }: { children: ReactNode, onClick?: () => void, render?: ReactElement, destructive?: boolean }) => {
|
||||
if (render)
|
||||
return cloneElement(render, { onClick, 'data-destructive': destructive } as Record<string, unknown>, children)
|
||||
return <div data-testid="dropdown-item" data-destructive={destructive} onClick={onClick}>{children}</div>
|
||||
},
|
||||
DropdownMenuSeparator: () => <hr data-testid="dropdown-separator" />,
|
||||
}))
|
||||
|
||||
describe('OperationDropdown', () => {
|
||||
@@ -52,14 +52,13 @@ describe('OperationDropdown', () => {
|
||||
it('should render trigger button', () => {
|
||||
render(<OperationDropdown {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('portal-trigger')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('action-button')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dropdown-trigger')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render dropdown content', () => {
|
||||
render(<OperationDropdown {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dropdown-content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render info option for github source', () => {
|
||||
@@ -118,14 +117,10 @@ describe('OperationDropdown', () => {
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should toggle dropdown when trigger is clicked', () => {
|
||||
it('should render dropdown menu root', () => {
|
||||
render(<OperationDropdown {...defaultProps} />)
|
||||
|
||||
const trigger = screen.getByTestId('portal-trigger')
|
||||
fireEvent.click(trigger)
|
||||
|
||||
// The portal-elem should reflect the open state
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dropdown-menu')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onInfo when info option is clicked', () => {
|
||||
@@ -174,7 +169,7 @@ describe('OperationDropdown', () => {
|
||||
const { unmount } = render(
|
||||
<OperationDropdown {...defaultProps} source={source} />,
|
||||
)
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('dropdown-menu')).toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.detailPanel.operation.remove')).toBeInTheDocument()
|
||||
unmount()
|
||||
})
|
||||
@@ -199,9 +194,7 @@ describe('OperationDropdown', () => {
|
||||
|
||||
describe('Memoization', () => {
|
||||
it('should be wrapped with React.memo', () => {
|
||||
// Verify the component is exported as a memo component
|
||||
expect(OperationDropdown).toBeDefined()
|
||||
// React.memo wraps the component, so it should have $$typeof
|
||||
expect((OperationDropdown as { $$typeof?: symbol }).$$typeof).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,24 +9,6 @@ vi.mock('@/context/i18n', () => ({
|
||||
useGetLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/confirm', () => ({
|
||||
default: ({ isShow, title, onCancel, onConfirm, isLoading }: {
|
||||
isShow: boolean
|
||||
title: string
|
||||
onCancel: () => void
|
||||
onConfirm: () => void
|
||||
isLoading: boolean
|
||||
}) => isShow
|
||||
? (
|
||||
<div data-testid="delete-confirm">
|
||||
<div data-testid="delete-title">{title}</div>
|
||||
<button data-testid="confirm-cancel" onClick={onCancel}>Cancel</button>
|
||||
<button data-testid="confirm-ok" onClick={onConfirm} disabled={isLoading}>Confirm</button>
|
||||
</div>
|
||||
)
|
||||
: null,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/plugin-page/plugin-info', () => ({
|
||||
default: ({ repository, release, packageName, onHide }: {
|
||||
repository: string
|
||||
@@ -230,7 +212,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('delete-confirm')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('alertdialog')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render delete confirm when isShowDeleteConfirm is true', () => {
|
||||
@@ -247,7 +229,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show correct delete title', () => {
|
||||
@@ -264,7 +246,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('delete-title')).toHaveTextContent('plugin.action.delete')
|
||||
expect(screen.getByRole('alertdialog')).toHaveTextContent('plugin.action.delete')
|
||||
})
|
||||
|
||||
it('should call hideDeleteConfirm when cancel is clicked', () => {
|
||||
@@ -281,7 +263,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-cancel'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(modalStates.hideDeleteConfirm).toHaveBeenCalled()
|
||||
})
|
||||
@@ -300,7 +282,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('confirm-ok'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
expect(mockOnDelete).toHaveBeenCalled()
|
||||
})
|
||||
@@ -319,7 +301,7 @@ describe('HeaderModals', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('confirm-ok')).toBeDisabled()
|
||||
expect(screen.getByRole('button', { name: /common\.operation\.confirm/ })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -485,7 +467,7 @@ describe('HeaderModals', () => {
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('plugin-info')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('delete-confirm')).toBeInTheDocument()
|
||||
expect(screen.getByRole('alertdialog')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('update-modal')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,7 +4,15 @@ import type { FC } from 'react'
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from '../hooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import PluginInfo from '@/app/components/plugins/plugin-page/plugin-info'
|
||||
import UpdateFromMarketplace from '@/app/components/plugins/update-plugin/from-market-place'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
@@ -50,7 +58,6 @@ const HeaderModals: FC<HeaderModalsProps> = ({
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Plugin Info Modal */}
|
||||
{isShowPluginInfo && (
|
||||
<PluginInfo
|
||||
repository={isFromGitHub ? meta?.repo : ''}
|
||||
@@ -60,27 +67,35 @@ const HeaderModals: FC<HeaderModalsProps> = ({
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete Confirm Modal */}
|
||||
{isShowDeleteConfirm && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
content={(
|
||||
<div>
|
||||
<AlertDialog
|
||||
open={isShowDeleteConfirm}
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
hideDeleteConfirm()
|
||||
}}
|
||||
>
|
||||
<AlertDialogContent backdropProps={{ forceRender: true }}>
|
||||
<div className="flex flex-col gap-2 px-6 pb-4 pt-6">
|
||||
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
|
||||
{t(`${i18nPrefix}.delete`, { ns: 'plugin' })}
|
||||
</AlertDialogTitle>
|
||||
<AlertDialogDescription className="w-full whitespace-pre-wrap break-words text-text-tertiary system-md-regular">
|
||||
{t(`${i18nPrefix}.deleteContentLeft`, { ns: 'plugin' })}
|
||||
<span className="system-md-semibold">{label[locale]}</span>
|
||||
<span className="text-text-secondary system-md-semibold">{label[locale]}</span>
|
||||
{t(`${i18nPrefix}.deleteContentRight`, { ns: 'plugin' })}
|
||||
<br />
|
||||
</div>
|
||||
)}
|
||||
onCancel={hideDeleteConfirm}
|
||||
onConfirm={onDelete}
|
||||
isLoading={deleting}
|
||||
isDisabled={deleting}
|
||||
/>
|
||||
)}
|
||||
</AlertDialogDescription>
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton disabled={deleting}>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton loading={deleting} disabled={deleting} onClick={onDelete}>
|
||||
{t('operation.confirm', { ns: 'common' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
|
||||
{/* Update from Marketplace Modal */}
|
||||
{isShowUpdateModal && (
|
||||
<UpdateFromMarketplace
|
||||
pluginId={detail.plugin_id}
|
||||
|
||||
@@ -15,6 +15,7 @@ type VersionPickerMock = {
|
||||
const {
|
||||
mockSetShowUpdatePluginModal,
|
||||
mockRefreshModelProviders,
|
||||
mockInvalidateCheckInstalled,
|
||||
mockInvalidateAllToolProviders,
|
||||
mockUninstallPlugin,
|
||||
mockFetchReleases,
|
||||
@@ -23,6 +24,7 @@ const {
|
||||
return {
|
||||
mockSetShowUpdatePluginModal: vi.fn(),
|
||||
mockRefreshModelProviders: vi.fn(),
|
||||
mockInvalidateCheckInstalled: vi.fn(),
|
||||
mockInvalidateAllToolProviders: vi.fn(),
|
||||
mockUninstallPlugin: vi.fn(() => Promise.resolve({ success: true })),
|
||||
mockFetchReleases: vi.fn(() => Promise.resolve([{ tag_name: 'v2.0.0' }])),
|
||||
@@ -46,6 +48,10 @@ vi.mock('@/service/plugins', () => ({
|
||||
uninstallPlugin: mockUninstallPlugin,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useInvalidateCheckInstalled: () => mockInvalidateCheckInstalled,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders,
|
||||
}))
|
||||
@@ -178,6 +184,7 @@ describe('usePluginOperations', () => {
|
||||
result.current.handleUpdatedFromMarketplace()
|
||||
})
|
||||
|
||||
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
|
||||
expect(mockOnUpdate).toHaveBeenCalled()
|
||||
expect(modalStates.hideUpdateModal).toHaveBeenCalled()
|
||||
})
|
||||
@@ -251,6 +258,32 @@ describe('usePluginOperations', () => {
|
||||
expect(mockSetShowUpdatePluginModal).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should invalidate checkInstalled when GitHub update save callback fires', async () => {
|
||||
const detail = createPluginDetail({
|
||||
source: PluginSource.github,
|
||||
meta: { repo: 'owner/repo', version: 'v1.0.0', package: 'pkg' },
|
||||
})
|
||||
const { result } = renderHook(() =>
|
||||
usePluginOperations({
|
||||
detail,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
isFromMarketplace: false,
|
||||
onUpdate: mockOnUpdate,
|
||||
}),
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate()
|
||||
})
|
||||
|
||||
const firstCall = mockSetShowUpdatePluginModal.mock.calls.at(0)?.[0]
|
||||
firstCall?.onSaveCallback()
|
||||
|
||||
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
|
||||
expect(mockOnUpdate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not show modal when no releases found', async () => {
|
||||
mockFetchReleases.mockResolvedValueOnce([])
|
||||
const detail = createPluginDetail({
|
||||
@@ -388,6 +421,7 @@ describe('usePluginOperations', () => {
|
||||
await result.current.handleDelete()
|
||||
})
|
||||
|
||||
expect(mockInvalidateCheckInstalled).toHaveBeenCalled()
|
||||
expect(mockOnUpdate).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
import type { PluginDetail } from '../../../types'
|
||||
import type { ModalStates, VersionTarget } from './use-detail-header-state'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { uninstallPlugin } from '@/service/plugins'
|
||||
import { useInvalidateCheckInstalled } from '@/service/use-plugins'
|
||||
import { useInvalidateAllToolProviders } from '@/service/use-tools'
|
||||
import { useGitHubReleases } from '../../../install-plugin/hooks'
|
||||
import { PluginCategoryEnum, PluginSource } from '../../../types'
|
||||
@@ -36,13 +38,19 @@ export const usePluginOperations = ({
|
||||
isFromMarketplace,
|
||||
onUpdate,
|
||||
}: UsePluginOperationsParams): UsePluginOperationsReturn => {
|
||||
const { t } = useTranslation()
|
||||
const { checkForUpdates, fetchReleases } = useGitHubReleases()
|
||||
const { setShowUpdatePluginModal } = useModalContext()
|
||||
const { refreshModelProviders } = useProviderContext()
|
||||
const invalidateCheckInstalled = useInvalidateCheckInstalled()
|
||||
const invalidateAllToolProviders = useInvalidateAllToolProviders()
|
||||
|
||||
const { id, meta, plugin_id } = detail
|
||||
const { author, category, name } = detail.declaration || detail
|
||||
const handlePluginUpdated = useCallback((isDelete?: boolean) => {
|
||||
invalidateCheckInstalled()
|
||||
onUpdate?.(isDelete)
|
||||
}, [invalidateCheckInstalled, onUpdate])
|
||||
|
||||
const handleUpdate = useCallback(async (isDowngrade?: boolean) => {
|
||||
if (isFromMarketplace) {
|
||||
@@ -71,7 +79,7 @@ export const usePluginOperations = ({
|
||||
if (needUpdate) {
|
||||
setShowUpdatePluginModal({
|
||||
onSaveCallback: () => {
|
||||
onUpdate?.()
|
||||
handlePluginUpdated()
|
||||
},
|
||||
payload: {
|
||||
type: PluginSource.github,
|
||||
@@ -97,15 +105,15 @@ export const usePluginOperations = ({
|
||||
checkForUpdates,
|
||||
setShowUpdatePluginModal,
|
||||
detail,
|
||||
onUpdate,
|
||||
handlePluginUpdated,
|
||||
modalStates,
|
||||
versionPicker,
|
||||
])
|
||||
|
||||
const handleUpdatedFromMarketplace = useCallback(() => {
|
||||
onUpdate?.()
|
||||
handlePluginUpdated()
|
||||
modalStates.hideUpdateModal()
|
||||
}, [onUpdate, modalStates])
|
||||
}, [handlePluginUpdated, modalStates])
|
||||
|
||||
const handleDelete = useCallback(async () => {
|
||||
modalStates.showDeleting()
|
||||
@@ -114,7 +122,11 @@ export const usePluginOperations = ({
|
||||
|
||||
if (res.success) {
|
||||
modalStates.hideDeleteConfirm()
|
||||
onUpdate?.(true)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('action.deleteSuccess', { ns: 'plugin' }),
|
||||
})
|
||||
handlePluginUpdated(true)
|
||||
|
||||
if (PluginCategoryEnum.model.includes(category))
|
||||
refreshModelProviders()
|
||||
@@ -130,7 +142,7 @@ export const usePluginOperations = ({
|
||||
plugin_id,
|
||||
name,
|
||||
modalStates,
|
||||
onUpdate,
|
||||
handlePluginUpdated,
|
||||
refreshModelProviders,
|
||||
invalidateAllToolProviders,
|
||||
])
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { RiArrowRightUpLine, RiMoreFill } from '@remixicon/react'
|
||||
import type { Placement } from '@/app/components/base/ui/placement'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
// import Button from '@/app/components/base/button'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/app/components/base/ui/dropdown-menu'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { PluginSource } from '../types'
|
||||
@@ -21,6 +20,10 @@ type Props = {
|
||||
onCheckVersion: () => void
|
||||
onRemove: () => void
|
||||
detailUrl: string
|
||||
placement?: Placement
|
||||
sideOffset?: number
|
||||
alignOffset?: number
|
||||
popupClassName?: string
|
||||
}
|
||||
|
||||
const OperationDropdown: FC<Props> = ({
|
||||
@@ -29,83 +32,52 @@ const OperationDropdown: FC<Props> = ({
|
||||
onInfo,
|
||||
onCheckVersion,
|
||||
onRemove,
|
||||
placement = 'bottom-end',
|
||||
sideOffset = 4,
|
||||
alignOffset = 0,
|
||||
popupClassName,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, doSetOpen] = useState(false)
|
||||
const openRef = useRef(open)
|
||||
const setOpen = useCallback((v: boolean) => {
|
||||
doSetOpen(v)
|
||||
openRef.current = v
|
||||
}, [doSetOpen])
|
||||
|
||||
const handleTrigger = useCallback(() => {
|
||||
setOpen(!openRef.current)
|
||||
}, [setOpen])
|
||||
|
||||
const [open, setOpen] = React.useState(false)
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement="bottom-end"
|
||||
offset={{
|
||||
mainAxis: -12,
|
||||
crossAxis: 36,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={handleTrigger}>
|
||||
<div>
|
||||
<ActionButton className={cn(open && 'bg-state-base-hover')}>
|
||||
<RiMoreFill className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className="z-50">
|
||||
<div className="w-[160px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg">
|
||||
{source === PluginSource.github && (
|
||||
<div
|
||||
onClick={() => {
|
||||
onInfo()
|
||||
handleTrigger()
|
||||
}}
|
||||
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
|
||||
>
|
||||
{t('detailPanel.operation.info', { ns: 'plugin' })}
|
||||
</div>
|
||||
)}
|
||||
{source === PluginSource.github && (
|
||||
<div
|
||||
onClick={() => {
|
||||
onCheckVersion()
|
||||
handleTrigger()
|
||||
}}
|
||||
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover"
|
||||
>
|
||||
{t('detailPanel.operation.checkUpdate', { ns: 'plugin' })}
|
||||
</div>
|
||||
)}
|
||||
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
|
||||
<a href={detailUrl} target="_blank" className="system-md-regular flex cursor-pointer items-center rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-base-hover">
|
||||
<span className="grow">{t('detailPanel.operation.viewDetail', { ns: 'plugin' })}</span>
|
||||
<RiArrowRightUpLine className="h-3.5 w-3.5 shrink-0 text-text-tertiary" />
|
||||
</a>
|
||||
)}
|
||||
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
|
||||
<div className="my-1 h-px bg-divider-subtle"></div>
|
||||
)}
|
||||
<div
|
||||
onClick={() => {
|
||||
onRemove()
|
||||
handleTrigger()
|
||||
}}
|
||||
className="system-md-regular cursor-pointer rounded-lg px-3 py-1.5 text-text-secondary hover:bg-state-destructive-hover hover:text-text-destructive"
|
||||
>
|
||||
{t('detailPanel.operation.remove', { ns: 'plugin' })}
|
||||
</div>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
<DropdownMenu open={open} onOpenChange={setOpen}>
|
||||
<DropdownMenuTrigger
|
||||
className={cn('action-btn action-btn-m', open && 'bg-state-base-hover')}
|
||||
>
|
||||
<span className="i-ri-more-fill h-4 w-4" />
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
placement={placement}
|
||||
sideOffset={sideOffset}
|
||||
alignOffset={alignOffset}
|
||||
popupClassName={cn('w-[160px]', popupClassName)}
|
||||
>
|
||||
{source === PluginSource.github && (
|
||||
<DropdownMenuItem onClick={onInfo}>
|
||||
{t('detailPanel.operation.info', { ns: 'plugin' })}
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{source === PluginSource.github && (
|
||||
<DropdownMenuItem onClick={onCheckVersion}>
|
||||
{t('detailPanel.operation.checkUpdate', { ns: 'plugin' })}
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
|
||||
<DropdownMenuItem render={<a href={detailUrl} target="_blank" rel="noopener noreferrer" />}>
|
||||
<span className="grow">{t('detailPanel.operation.viewDetail', { ns: 'plugin' })}</span>
|
||||
<span className="i-ri-arrow-right-up-line h-3.5 w-3.5 shrink-0 text-text-tertiary" />
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
{(source === PluginSource.marketplace || source === PluginSource.github) && enable_marketplace && (
|
||||
<DropdownMenuSeparator />
|
||||
)}
|
||||
<DropdownMenuItem destructive onClick={onRemove}>
|
||||
{t('detailPanel.operation.remove', { ns: 'plugin' })}
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
}
|
||||
export default React.memo(OperationDropdown)
|
||||
|
||||
@@ -104,36 +104,6 @@ vi.mock('../../install-plugin/install-from-github', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
// Mock Portal components for PluginVersionPicker
|
||||
let mockPortalOpen = false
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
|
||||
PortalToFollowElem: ({ children, open, onOpenChange: _onOpenChange }: {
|
||||
children: React.ReactNode
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
}) => {
|
||||
mockPortalOpen = open
|
||||
return <div data-testid="portal-elem" data-open={open}>{children}</div>
|
||||
},
|
||||
PortalToFollowElemTrigger: ({ children, onClick, className }: {
|
||||
children: React.ReactNode
|
||||
onClick: () => void
|
||||
className?: string
|
||||
}) => (
|
||||
<div data-testid="portal-trigger" onClick={onClick} className={className}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
PortalToFollowElemContent: ({ children, className }: {
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
}) => {
|
||||
if (!mockPortalOpen)
|
||||
return null
|
||||
return <div data-testid="portal-content" className={className}>{children}</div>
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock semver
|
||||
vi.mock('semver', () => ({
|
||||
lt: (v1: string, v2: string) => {
|
||||
@@ -247,7 +217,6 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
describe('update-plugin', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpen = false
|
||||
mockCheck.mockResolvedValue({ status: TaskStatus.success })
|
||||
})
|
||||
|
||||
@@ -946,7 +915,7 @@ describe('update-plugin', () => {
|
||||
onShowChange: vi.fn(),
|
||||
pluginID: 'test-plugin-id',
|
||||
currentVersion: '1.0.0',
|
||||
trigger: <button>Select Version</button>,
|
||||
trigger: <span>Select Version</span>,
|
||||
onSelect: vi.fn(),
|
||||
}
|
||||
|
||||
@@ -964,7 +933,7 @@ describe('update-plugin', () => {
|
||||
render(<PluginVersionPicker {...defaultProps} isShow={false} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('plugin.detailPanel.switchVersion')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render version list when isShow is true', () => {
|
||||
@@ -972,7 +941,6 @@ describe('update-plugin', () => {
|
||||
render(<PluginVersionPicker {...defaultProps} isShow={true} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('portal-content')).toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -1002,7 +970,7 @@ describe('update-plugin', () => {
|
||||
|
||||
// Act
|
||||
render(<PluginVersionPicker {...defaultProps} onShowChange={onShowChange} />)
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
fireEvent.click(screen.getByText('Select Version'))
|
||||
|
||||
// Assert
|
||||
expect(onShowChange).toHaveBeenCalledWith(true)
|
||||
@@ -1014,7 +982,7 @@ describe('update-plugin', () => {
|
||||
|
||||
// Act
|
||||
render(<PluginVersionPicker {...defaultProps} disabled={true} onShowChange={onShowChange} />)
|
||||
fireEvent.click(screen.getByTestId('portal-trigger'))
|
||||
fireEvent.click(screen.getByText('Select Version'))
|
||||
|
||||
// Assert
|
||||
expect(onShowChange).not.toHaveBeenCalled()
|
||||
@@ -1116,7 +1084,7 @@ describe('update-plugin', () => {
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should support custom offset', () => {
|
||||
@@ -1125,12 +1093,13 @@ describe('update-plugin', () => {
|
||||
<PluginVersionPicker
|
||||
{...defaultProps}
|
||||
isShow={true}
|
||||
offset={{ mainAxis: 10, crossAxis: 20 }}
|
||||
sideOffset={10}
|
||||
alignOffset={20}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
expect(screen.getByText('plugin.detailPanel.switchVersion')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1190,7 +1159,7 @@ describe('update-plugin', () => {
|
||||
onShowChange: vi.fn(),
|
||||
pluginID: 'test',
|
||||
currentVersion: '1.0.0',
|
||||
trigger: <button>Select</button>,
|
||||
trigger: <span>Select</span>,
|
||||
onSelect: vi.fn(),
|
||||
}}
|
||||
/>,
|
||||
|
||||
@@ -18,8 +18,8 @@ const DowngradeWarningModal = ({
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col items-start gap-2 self-stretch">
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t(`${i18nPrefix}.title`, { ns: 'plugin' })}</div>
|
||||
<div className="system-md-regular text-text-secondary">
|
||||
<div className="text-text-primary title-2xl-semi-bold">{t(`${i18nPrefix}.title`, { ns: 'plugin' })}</div>
|
||||
<div className="text-text-secondary system-md-regular">
|
||||
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -6,7 +6,12 @@ import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge, { BadgeState } from '@/app/components/base/badge/index'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
DialogTitle,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import Card from '@/app/components/plugins/card'
|
||||
import checkTaskStatus from '@/app/components/plugins/install-plugin/base/check-task-status'
|
||||
import { pluginManifestToCardPluginProps } from '@/app/components/plugins/install-plugin/utils'
|
||||
@@ -125,63 +130,65 @@ const UpdatePluginModal: FC<Props> = ({
|
||||
const doShowDowngradeWarningModal = isShowDowngradeWarningModal && uploadStep === UploadStep.notStarted
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isShow={true}
|
||||
onClose={onCancel}
|
||||
className={cn('min-w-[560px]', doShowDowngradeWarningModal && 'min-w-[640px]')}
|
||||
closable
|
||||
title={!doShowDowngradeWarningModal && t(`${i18nPrefix}.${uploadStep === UploadStep.installed ? 'successfulTitle' : 'title'}`, { ns: 'plugin' })}
|
||||
>
|
||||
{doShowDowngradeWarningModal && (
|
||||
<DowngradeWarningModal
|
||||
onCancel={onCancel}
|
||||
onJustDowngrade={handleConfirm}
|
||||
onExcludeAndDowngrade={handleExcludeAndDownload}
|
||||
/>
|
||||
)}
|
||||
{!doShowDowngradeWarningModal && (
|
||||
<>
|
||||
<div className="system-md-regular mb-2 mt-3 text-text-secondary">
|
||||
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
|
||||
</div>
|
||||
<div className="flex flex-wrap content-start items-start gap-1 self-stretch rounded-2xl bg-background-section-burn p-2">
|
||||
<Card
|
||||
installed={uploadStep === UploadStep.installed}
|
||||
payload={pluginManifestToCardPluginProps({
|
||||
...originalPackageInfo.payload,
|
||||
icon: icon!,
|
||||
})}
|
||||
className="w-full"
|
||||
titleLeft={(
|
||||
<>
|
||||
<Badge className="mx-1" size="s" state={BadgeState.Warning}>
|
||||
{`${originalPackageInfo.payload.version} -> ${targetPackageInfo.version}`}
|
||||
</Badge>
|
||||
</>
|
||||
<Dialog open onOpenChange={() => onCancel()}>
|
||||
<DialogContent
|
||||
backdropProps={{ forceRender: true }}
|
||||
className={cn('min-w-[560px]', doShowDowngradeWarningModal && 'min-w-[640px]')}
|
||||
>
|
||||
<DialogCloseButton />
|
||||
{doShowDowngradeWarningModal && (
|
||||
<DowngradeWarningModal
|
||||
onCancel={onCancel}
|
||||
onJustDowngrade={handleConfirm}
|
||||
onExcludeAndDowngrade={handleExcludeAndDownload}
|
||||
/>
|
||||
)}
|
||||
{!doShowDowngradeWarningModal && (
|
||||
<>
|
||||
<DialogTitle className="text-text-primary title-2xl-semi-bold">
|
||||
{t(`${i18nPrefix}.${uploadStep === UploadStep.installed ? 'successfulTitle' : 'title'}`, { ns: 'plugin' })}
|
||||
</DialogTitle>
|
||||
<div className="mb-2 mt-3 text-text-secondary system-md-regular">
|
||||
{t(`${i18nPrefix}.description`, { ns: 'plugin' })}
|
||||
</div>
|
||||
<div className="flex flex-wrap content-start items-start gap-1 self-stretch rounded-2xl bg-background-section-burn p-2">
|
||||
<Card
|
||||
installed={uploadStep === UploadStep.installed}
|
||||
payload={pluginManifestToCardPluginProps({
|
||||
...originalPackageInfo.payload,
|
||||
icon: icon!,
|
||||
})}
|
||||
className="w-full"
|
||||
titleLeft={(
|
||||
<>
|
||||
<Badge className="mx-1" size="s" state={BadgeState.Warning}>
|
||||
{`${originalPackageInfo.payload.version} -> ${targetPackageInfo.version}`}
|
||||
</Badge>
|
||||
</>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2 self-stretch pt-5">
|
||||
{uploadStep === UploadStep.notStarted && (
|
||||
<Button
|
||||
onClick={handleCancel}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2 self-stretch pt-5">
|
||||
{uploadStep === UploadStep.notStarted && (
|
||||
<Button
|
||||
onClick={handleCancel}
|
||||
variant="primary"
|
||||
loading={uploadStep === UploadStep.upgrading}
|
||||
onClick={handleConfirm}
|
||||
disabled={uploadStep === UploadStep.upgrading}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
{configBtnText}
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
variant="primary"
|
||||
loading={uploadStep === UploadStep.upgrading}
|
||||
onClick={handleConfirm}
|
||||
disabled={uploadStep === UploadStep.upgrading}
|
||||
>
|
||||
{configBtnText}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
</Modal>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
export default React.memo(UpdatePluginModal)
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
'use client'
|
||||
import type {
|
||||
OffsetOptions,
|
||||
Placement,
|
||||
} from '@floating-ui/react'
|
||||
import type { FC } from 'react'
|
||||
import type { Placement } from '@/app/components/base/ui/placement'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { lt } from 'semver'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/app/components/base/ui/popover'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { useVersionListOfPlugin } from '@/service/use-plugins'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@@ -26,7 +23,8 @@ type Props = {
|
||||
currentVersion: string
|
||||
trigger: React.ReactNode
|
||||
placement?: Placement
|
||||
offset?: OffsetOptions
|
||||
sideOffset?: number
|
||||
alignOffset?: number
|
||||
onSelect: ({
|
||||
version,
|
||||
unique_identifier,
|
||||
@@ -46,22 +44,14 @@ const PluginVersionPicker: FC<Props> = ({
|
||||
currentVersion,
|
||||
trigger,
|
||||
placement = 'bottom-start',
|
||||
offset = {
|
||||
mainAxis: 4,
|
||||
crossAxis: -16,
|
||||
},
|
||||
sideOffset = 4,
|
||||
alignOffset = -16,
|
||||
onSelect,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const format = t('dateTimeFormat', { ns: 'appLog' }).split(' ')[0]
|
||||
const { formatDate } = useTimestamp()
|
||||
|
||||
const handleTriggerClick = () => {
|
||||
if (disabled)
|
||||
return
|
||||
onShowChange(true)
|
||||
}
|
||||
|
||||
const { data: res } = useVersionListOfPlugin(pluginID)
|
||||
|
||||
const handleSelect = useCallback(({ version, unique_identifier, isDowngrade }: {
|
||||
@@ -76,49 +66,52 @@ const PluginVersionPicker: FC<Props> = ({
|
||||
}, [currentVersion, onSelect, onShowChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
placement={placement}
|
||||
offset={offset}
|
||||
<Popover
|
||||
open={isShow}
|
||||
onOpenChange={onShowChange}
|
||||
onOpenChange={(open) => {
|
||||
if (!disabled)
|
||||
onShowChange(open)
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
<PopoverTrigger
|
||||
className={cn('inline-flex cursor-pointer items-center', disabled && 'cursor-default')}
|
||||
onClick={handleTriggerClick}
|
||||
>
|
||||
{trigger}
|
||||
</PortalToFollowElemTrigger>
|
||||
</PopoverTrigger>
|
||||
|
||||
<PortalToFollowElemContent className="z-[1000]">
|
||||
<div className="relative w-[209px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg backdrop-blur-sm">
|
||||
<div className="system-xs-medium-uppercase px-3 pb-0.5 pt-1 text-text-tertiary">
|
||||
{t('detailPanel.switchVersion', { ns: 'plugin' })}
|
||||
</div>
|
||||
<div className="relative">
|
||||
{res?.data.versions.map(version => (
|
||||
<div
|
||||
key={version.unique_identifier}
|
||||
className={cn(
|
||||
'flex h-7 cursor-pointer items-center gap-1 rounded-lg px-3 py-1 hover:bg-state-base-hover',
|
||||
currentVersion === version.version && 'cursor-default opacity-30 hover:bg-transparent',
|
||||
)}
|
||||
onClick={() => handleSelect({
|
||||
version: version.version,
|
||||
unique_identifier: version.unique_identifier,
|
||||
isDowngrade: lt(version.version, currentVersion),
|
||||
})}
|
||||
>
|
||||
<div className="flex grow items-center">
|
||||
<div className="system-sm-medium text-text-secondary">{version.version}</div>
|
||||
{currentVersion === version.version && <Badge className="ml-1" text="CURRENT" />}
|
||||
</div>
|
||||
<div className="system-xs-regular shrink-0 text-text-tertiary">{formatDate(version.created_at, format)}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<PopoverContent
|
||||
placement={placement}
|
||||
sideOffset={sideOffset}
|
||||
alignOffset={alignOffset}
|
||||
popupClassName="relative w-[209px] bg-components-panel-bg-blur p-1 backdrop-blur-sm"
|
||||
>
|
||||
<div className="px-3 pb-0.5 pt-1 text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('detailPanel.switchVersion', { ns: 'plugin' })}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
<div className="relative max-h-[224px] overflow-y-auto">
|
||||
{res?.data.versions.map(version => (
|
||||
<div
|
||||
key={version.unique_identifier}
|
||||
className={cn(
|
||||
'flex h-7 cursor-pointer items-center gap-1 rounded-lg px-3 py-1 hover:bg-state-base-hover',
|
||||
currentVersion === version.version && 'cursor-default opacity-30 hover:bg-transparent',
|
||||
)}
|
||||
onClick={() => handleSelect({
|
||||
version: version.version,
|
||||
unique_identifier: version.unique_identifier,
|
||||
isDowngrade: lt(version.version, currentVersion),
|
||||
})}
|
||||
>
|
||||
<div className="flex grow items-center">
|
||||
<div className="text-text-secondary system-sm-medium">{version.version}</div>
|
||||
{currentVersion === version.version && <Badge className="ml-1" text="CURRENT" />}
|
||||
</div>
|
||||
<div className="shrink-0 text-text-tertiary system-xs-regular">{formatDate(version.created_at, format)}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,18 @@
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { useCallback, useEffect, useMemo } from 'react'
|
||||
import { createContext, useContext, useContextSelector } from 'use-context-selector'
|
||||
import { setUserId, setUserProperties } from '@/app/components/base/amplitude'
|
||||
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
|
||||
import MaintenanceNotice from '@/app/components/header/maintenance-notice'
|
||||
import { ZENDESK_FIELD_IDS } from '@/config'
|
||||
import {
|
||||
AppContext,
|
||||
initialLangGeniusVersionInfo,
|
||||
initialWorkspaceInfo,
|
||||
userProfilePlaceholder,
|
||||
useSelector,
|
||||
} from '@/context/app-context'
|
||||
import { env } from '@/env'
|
||||
import {
|
||||
useCurrentWorkspace,
|
||||
@@ -18,72 +23,6 @@ import {
|
||||
} from '@/service/use-common'
|
||||
import { useGlobalPublicStore } from './global-public-context'
|
||||
|
||||
export type AppContextValue = {
|
||||
userProfile: UserProfileResponse
|
||||
mutateUserProfile: VoidFunction
|
||||
currentWorkspace: ICurrentWorkspace
|
||||
isCurrentWorkspaceManager: boolean
|
||||
isCurrentWorkspaceOwner: boolean
|
||||
isCurrentWorkspaceEditor: boolean
|
||||
isCurrentWorkspaceDatasetOperator: boolean
|
||||
mutateCurrentWorkspace: VoidFunction
|
||||
langGeniusVersionInfo: LangGeniusVersionResponse
|
||||
useSelector: typeof useSelector
|
||||
isLoadingCurrentWorkspace: boolean
|
||||
isValidatingCurrentWorkspace: boolean
|
||||
}
|
||||
|
||||
const userProfilePlaceholder = {
|
||||
id: '',
|
||||
name: '',
|
||||
email: '',
|
||||
avatar: '',
|
||||
avatar_url: '',
|
||||
is_password_set: false,
|
||||
}
|
||||
|
||||
const initialLangGeniusVersionInfo = {
|
||||
current_env: '',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
release_date: '',
|
||||
release_notes: '',
|
||||
version: '',
|
||||
can_auto_update: false,
|
||||
}
|
||||
|
||||
const initialWorkspaceInfo: ICurrentWorkspace = {
|
||||
id: '',
|
||||
name: '',
|
||||
plan: '',
|
||||
status: '',
|
||||
created_at: 0,
|
||||
role: 'normal',
|
||||
providers: [],
|
||||
trial_credits: 200,
|
||||
trial_credits_used: 0,
|
||||
next_credit_reset_date: 0,
|
||||
}
|
||||
|
||||
const AppContext = createContext<AppContextValue>({
|
||||
userProfile: userProfilePlaceholder,
|
||||
currentWorkspace: initialWorkspaceInfo,
|
||||
isCurrentWorkspaceManager: false,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
isCurrentWorkspaceEditor: false,
|
||||
isCurrentWorkspaceDatasetOperator: false,
|
||||
mutateUserProfile: noop,
|
||||
mutateCurrentWorkspace: noop,
|
||||
langGeniusVersionInfo: initialLangGeniusVersionInfo,
|
||||
useSelector,
|
||||
isLoadingCurrentWorkspace: false,
|
||||
isValidatingCurrentWorkspace: false,
|
||||
})
|
||||
|
||||
export function useSelector<T>(selector: (value: AppContextValue) => T): T {
|
||||
return useContextSelector(AppContext, selector)
|
||||
}
|
||||
|
||||
export type AppContextProviderProps = {
|
||||
children: ReactNode
|
||||
}
|
||||
@@ -170,7 +109,7 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
||||
// Report user and workspace info to Amplitude when loaded
|
||||
if (userProfile?.id) {
|
||||
setUserId(userProfile.email)
|
||||
const properties: Record<string, any> = {
|
||||
const properties: Record<string, string | number | boolean> = {
|
||||
email: userProfile.email,
|
||||
name: userProfile.name,
|
||||
has_password: userProfile.is_password_set,
|
||||
@@ -213,7 +152,3 @@ export const AppContextProvider: FC<AppContextProviderProps> = ({ children }) =>
|
||||
</AppContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export const useAppContext = () => useContext(AppContext)
|
||||
|
||||
export default AppContext
|
||||
73
web/context/app-context.ts
Normal file
73
web/context/app-context.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
'use client'
|
||||
|
||||
import type { ICurrentWorkspace, LangGeniusVersionResponse, UserProfileResponse } from '@/models/common'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { createContext, useContext, useContextSelector } from 'use-context-selector'
|
||||
|
||||
export type AppContextValue = {
|
||||
userProfile: UserProfileResponse
|
||||
mutateUserProfile: VoidFunction
|
||||
currentWorkspace: ICurrentWorkspace
|
||||
isCurrentWorkspaceManager: boolean
|
||||
isCurrentWorkspaceOwner: boolean
|
||||
isCurrentWorkspaceEditor: boolean
|
||||
isCurrentWorkspaceDatasetOperator: boolean
|
||||
mutateCurrentWorkspace: VoidFunction
|
||||
langGeniusVersionInfo: LangGeniusVersionResponse
|
||||
useSelector: typeof useSelector
|
||||
isLoadingCurrentWorkspace: boolean
|
||||
isValidatingCurrentWorkspace: boolean
|
||||
}
|
||||
|
||||
export const userProfilePlaceholder = {
|
||||
id: '',
|
||||
name: '',
|
||||
email: '',
|
||||
avatar: '',
|
||||
avatar_url: '',
|
||||
is_password_set: false,
|
||||
}
|
||||
|
||||
export const initialLangGeniusVersionInfo = {
|
||||
current_env: '',
|
||||
current_version: '',
|
||||
latest_version: '',
|
||||
release_date: '',
|
||||
release_notes: '',
|
||||
version: '',
|
||||
can_auto_update: false,
|
||||
}
|
||||
|
||||
export const initialWorkspaceInfo: ICurrentWorkspace = {
|
||||
id: '',
|
||||
name: '',
|
||||
plan: '',
|
||||
status: '',
|
||||
created_at: 0,
|
||||
role: 'normal',
|
||||
providers: [],
|
||||
trial_credits: 200,
|
||||
trial_credits_used: 0,
|
||||
next_credit_reset_date: 0,
|
||||
}
|
||||
|
||||
export const AppContext = createContext<AppContextValue>({
|
||||
userProfile: userProfilePlaceholder,
|
||||
currentWorkspace: initialWorkspaceInfo,
|
||||
isCurrentWorkspaceManager: false,
|
||||
isCurrentWorkspaceOwner: false,
|
||||
isCurrentWorkspaceEditor: false,
|
||||
isCurrentWorkspaceDatasetOperator: false,
|
||||
mutateUserProfile: noop,
|
||||
mutateCurrentWorkspace: noop,
|
||||
langGeniusVersionInfo: initialLangGeniusVersionInfo,
|
||||
useSelector,
|
||||
isLoadingCurrentWorkspace: false,
|
||||
isValidatingCurrentWorkspace: false,
|
||||
})
|
||||
|
||||
export function useSelector<T>(selector: (value: AppContextValue) => T): T {
|
||||
return useContextSelector(AppContext, selector)
|
||||
}
|
||||
|
||||
export const useAppContext = () => useContext(AppContext)
|
||||
@@ -343,8 +343,8 @@ export const ModalContextProvider = ({
|
||||
accountSettingTab && (
|
||||
<AccountSetting
|
||||
activeTab={accountSettingTab}
|
||||
onCancel={handleCancelAccountSettingModal}
|
||||
onTabChange={handleAccountSettingTabChange}
|
||||
onCancelAction={handleCancelAccountSettingModal}
|
||||
onTabChangeAction={handleAccountSettingTabChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
33
web/contract/console/model-providers.ts
Normal file
33
web/contract/console/model-providers.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import type { ModelItem, PreferredProviderTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { CommonResponse } from '@/models/common'
|
||||
import { type } from '@orpc/contract'
|
||||
import { base } from '../base'
|
||||
|
||||
export const modelProvidersModelsContract = base
|
||||
.route({
|
||||
path: '/workspaces/current/model-providers/{provider}/models',
|
||||
method: 'GET',
|
||||
})
|
||||
.input(type<{
|
||||
params: {
|
||||
provider: string
|
||||
}
|
||||
}>())
|
||||
.output(type<{
|
||||
data: ModelItem[]
|
||||
}>())
|
||||
|
||||
export const changePreferredProviderTypeContract = base
|
||||
.route({
|
||||
path: '/workspaces/current/model-providers/{provider}/preferred-provider-type',
|
||||
method: 'POST',
|
||||
})
|
||||
.input(type<{
|
||||
params: {
|
||||
provider: string
|
||||
}
|
||||
body: {
|
||||
preferred_provider_type: PreferredProviderTypeEnum
|
||||
}
|
||||
}>())
|
||||
.output(type<CommonResponse>())
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
exploreInstalledAppsContract,
|
||||
exploreInstalledAppUninstallContract,
|
||||
} from './console/explore'
|
||||
import { changePreferredProviderTypeContract, modelProvidersModelsContract } from './console/model-providers'
|
||||
import { systemFeaturesContract } from './console/system'
|
||||
import {
|
||||
triggerOAuthConfigContract,
|
||||
@@ -63,6 +64,10 @@ export const consoleRouterContract = {
|
||||
parameters: trialAppParametersContract,
|
||||
workflows: trialAppWorkflowsContract,
|
||||
},
|
||||
modelProviders: {
|
||||
models: modelProvidersModelsContract,
|
||||
changePreferredProviderType: changePreferredProviderTypeContract,
|
||||
},
|
||||
billing: {
|
||||
invoices: invoicesContract,
|
||||
bindPartnerStack: bindPartnerStackContract,
|
||||
|
||||
@@ -10,6 +10,9 @@ This document tracks the migration away from legacy overlay APIs.
|
||||
- `@/app/components/base/modal`
|
||||
- `@/app/components/base/confirm`
|
||||
- `@/app/components/base/select` (including `custom` / `pure`)
|
||||
- `@/app/components/base/popover`
|
||||
- `@/app/components/base/dropdown`
|
||||
- `@/app/components/base/dialog`
|
||||
- Replacement primitives:
|
||||
- `@/app/components/base/ui/tooltip`
|
||||
- `@/app/components/base/ui/dropdown-menu`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user