mirror of
https://github.com/langgenius/dify.git
synced 2026-03-09 17:25:10 +00:00
Compare commits
104 Commits
main
...
feat/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3c98e417d | ||
|
|
dfe389c017 | ||
|
|
b364b06e51 | ||
|
|
ce0197b107 | ||
|
|
164cefc65c | ||
|
|
f6d80b9fa7 | ||
|
|
e845fa7e6a | ||
|
|
bab7bd5ecc | ||
|
|
cfb02bceaf | ||
|
|
694ca840e1 | ||
|
|
2d979e2cec | ||
|
|
5cee7cf8ce | ||
|
|
0c17823c8b | ||
|
|
49c6696d08 | ||
|
|
292c98a8f3 | ||
|
|
0e0a6ad043 | ||
|
|
456c95adb1 | ||
|
|
1abbaf9fd5 | ||
|
|
1a26e1669b | ||
|
|
02444af2e3 | ||
|
|
56038e3684 | ||
|
|
eb9341e7ec | ||
|
|
e40b31b9c4 | ||
|
|
b89ee4807f | ||
|
|
9907cf9e06 | ||
|
|
208a31719f | ||
|
|
3d1ef1f7f5 | ||
|
|
24b14e2c1a | ||
|
|
53f122f717 | ||
|
|
fced2f9e65 | ||
|
|
0c08c4016d | ||
|
|
ff4e4a8d64 | ||
|
|
948efa129f | ||
|
|
e371bfd676 | ||
|
|
6d612c0909 | ||
|
|
56e0dc0ae6 | ||
|
|
975eca00c3 | ||
|
|
f049bafcc3 | ||
|
|
dd9c526447 | ||
|
|
922dc71e36 | ||
|
|
f03ec7f671 | ||
|
|
29f275442d | ||
|
|
c9532ffd43 | ||
|
|
840dc33b8b | ||
|
|
cae58a0649 | ||
|
|
1752edc047 | ||
|
|
7471c32612 | ||
|
|
2d333bbbe5 | ||
|
|
4af6788ce0 | ||
|
|
24b072def9 | ||
|
|
909c8c3350 | ||
|
|
80e9c8bee0 | ||
|
|
15b7b304d2 | ||
|
|
61e2672b59 | ||
|
|
5f4ed4c6f6 | ||
|
|
4a1032c628 | ||
|
|
423c97a47e | ||
|
|
a7e3fb2e33 | ||
|
|
ce34937a1c | ||
|
|
ad9ac6978e | ||
|
|
57c1ba3543 | ||
|
|
d7a5af2b9a | ||
|
|
d45edffaa3 | ||
|
|
530515b6ef | ||
|
|
f13f0d1f9a | ||
|
|
b597d52c11 | ||
|
|
34c42fe666 | ||
|
|
dc109c99f0 | ||
|
|
223b9d89c1 | ||
|
|
dd119eb44f | ||
|
|
970493fa85 | ||
|
|
ab87ac333a | ||
|
|
b8b70da9ad | ||
|
|
77d81aebe8 | ||
|
|
deb4cd3ece | ||
|
|
648d9ef1f9 | ||
|
|
5ed4797078 | ||
|
|
62631658e9 | ||
|
|
22a4100dd7 | ||
|
|
0f7ed6f67e | ||
|
|
4d9fcbec57 | ||
|
|
4d7a9bc798 | ||
|
|
d6d04ed657 | ||
|
|
f594a71dae | ||
|
|
04e0ab7eda | ||
|
|
784bda9c86 | ||
|
|
1af1fb6913 | ||
|
|
1f0c36e9f7 | ||
|
|
455ae65025 | ||
|
|
d44682e957 | ||
|
|
8c4afc0c18 | ||
|
|
539cbcae6a | ||
|
|
8d257fea7c | ||
|
|
c3364ac350 | ||
|
|
f991644989 | ||
|
|
29e344ac8b | ||
|
|
1ad9305732 | ||
|
|
17f38f171d | ||
|
|
802088c8eb | ||
|
|
cad6d94491 | ||
|
|
621d0fb2c9 | ||
|
|
a92fb3244b | ||
|
|
97508f8d7b | ||
|
|
70e677a6ac |
@@ -1,92 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.console.app import annotation as annotation_module
|
||||
|
||||
|
||||
def test_annotation_reply_payload_valid():
|
||||
"""Test AnnotationReplyPayload with valid data."""
|
||||
payload = annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5,
|
||||
embedding_provider_name="openai",
|
||||
embedding_model_name="text-embedding-3-small",
|
||||
)
|
||||
assert payload.score_threshold == 0.5
|
||||
assert payload.embedding_provider_name == "openai"
|
||||
assert payload.embedding_model_name == "text-embedding-3-small"
|
||||
|
||||
|
||||
def test_annotation_setting_update_payload_valid():
|
||||
"""Test AnnotationSettingUpdatePayload with valid data."""
|
||||
payload = annotation_module.AnnotationSettingUpdatePayload(
|
||||
score_threshold=0.75,
|
||||
)
|
||||
assert payload.score_threshold == 0.75
|
||||
|
||||
|
||||
def test_annotation_list_query_defaults():
|
||||
"""Test AnnotationListQuery with default parameters."""
|
||||
query = annotation_module.AnnotationListQuery()
|
||||
assert query.page == 1
|
||||
assert query.limit == 20
|
||||
assert query.keyword == ""
|
||||
|
||||
|
||||
def test_annotation_list_query_custom_page():
|
||||
"""Test AnnotationListQuery with custom page."""
|
||||
query = annotation_module.AnnotationListQuery(page=3, limit=50)
|
||||
assert query.page == 3
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_annotation_list_query_with_keyword():
|
||||
"""Test AnnotationListQuery with keyword."""
|
||||
query = annotation_module.AnnotationListQuery(keyword="test")
|
||||
assert query.keyword == "test"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_message_id():
|
||||
"""Test CreateAnnotationPayload with message ID."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
question="What is AI?",
|
||||
)
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
assert payload.question == "What is AI?"
|
||||
|
||||
|
||||
def test_create_annotation_payload_with_text():
|
||||
"""Test CreateAnnotationPayload with text content."""
|
||||
payload = annotation_module.CreateAnnotationPayload(
|
||||
question="What is ML?",
|
||||
answer="Machine learning is...",
|
||||
)
|
||||
assert payload.question == "What is ML?"
|
||||
assert payload.answer == "Machine learning is..."
|
||||
|
||||
|
||||
def test_update_annotation_payload():
|
||||
"""Test UpdateAnnotationPayload."""
|
||||
payload = annotation_module.UpdateAnnotationPayload(
|
||||
question="Updated question",
|
||||
answer="Updated answer",
|
||||
)
|
||||
assert payload.question == "Updated question"
|
||||
assert payload.answer == "Updated answer"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_enable():
|
||||
"""Test AnnotationReplyStatusQuery with enable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
|
||||
assert query.action == "enable"
|
||||
|
||||
|
||||
def test_annotation_reply_status_query_disable():
|
||||
"""Test AnnotationReplyStatusQuery with disable action."""
|
||||
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
|
||||
assert query.action == "disable"
|
||||
|
||||
|
||||
def test_annotation_file_payload_valid():
|
||||
"""Test AnnotationFilePayload with valid message ID."""
|
||||
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
@@ -13,9 +13,6 @@ from pandas.errors import ParserError
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
|
||||
class TestAnnotationImportRateLimiting:
|
||||
@@ -36,6 +33,8 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-minute rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-minute limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
|
||||
@@ -55,6 +54,7 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that per-hour rate limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate exceeding per-hour limit
|
||||
mock_redis.zcard.side_effect = [
|
||||
@@ -74,6 +74,7 @@ class TestAnnotationImportRateLimiting:
|
||||
|
||||
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_rate_limit
|
||||
|
||||
# Simulate being under both limits
|
||||
mock_redis.zcard.return_value = 2
|
||||
@@ -109,6 +110,7 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
|
||||
"""Test that concurrent task limit is enforced."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate max concurrent tasks already running
|
||||
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
|
||||
@@ -125,6 +127,7 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
|
||||
"""Test that requests within concurrency limits are allowed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
# Simulate being under concurrent task limit
|
||||
mock_redis.zcard.return_value = 1
|
||||
@@ -139,6 +142,7 @@ class TestAnnotationImportConcurrencyControl:
|
||||
|
||||
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
|
||||
"""Test that old/stale job entries are removed."""
|
||||
from controllers.console.wraps import annotation_import_concurrency_limit
|
||||
|
||||
mock_redis.zcard.return_value = 0
|
||||
|
||||
@@ -199,6 +203,7 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too many records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with too many records
|
||||
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
|
||||
@@ -224,6 +229,7 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
|
||||
"""Test that files with too few valid records are rejected."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create CSV with only header (no data rows)
|
||||
csv_content = "question,answer\n"
|
||||
@@ -243,6 +249,7 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
|
||||
"""Test that invalid CSV format is handled gracefully."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Any content is fine once we force ParserError
|
||||
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
|
||||
@@ -263,6 +270,7 @@ class TestAnnotationImportServiceValidation:
|
||||
|
||||
def test_valid_import_succeeds(self, mock_app, mock_db_session):
|
||||
"""Test that valid import request succeeds."""
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Create valid CSV
|
||||
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
|
||||
@@ -292,10 +300,18 @@ class TestAnnotationImportServiceValidation:
|
||||
class TestAnnotationImportTaskOptimization:
|
||||
"""Test optimizations in batch import task."""
|
||||
|
||||
def test_task_is_registered_with_queue(self):
|
||||
"""Test that task is registered with the correct queue."""
|
||||
assert hasattr(batch_import_annotations_task, "apply_async")
|
||||
assert hasattr(batch_import_annotations_task, "delay")
|
||||
def test_task_has_timeout_configured(self):
|
||||
"""Test that task has proper timeout configuration."""
|
||||
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
|
||||
|
||||
# Verify task configuration
|
||||
assert hasattr(batch_import_annotations_task, "time_limit")
|
||||
assert hasattr(batch_import_annotations_task, "soft_time_limit")
|
||||
|
||||
# Check timeout values are reasonable
|
||||
# Hard limit should be 6 minutes (360s)
|
||||
# Soft limit should be 5 minutes (300s)
|
||||
# Note: actual values depend on Celery configuration
|
||||
|
||||
|
||||
class TestConfigurationValues:
|
||||
|
||||
@@ -1,585 +0,0 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import (
|
||||
annotation as annotation_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
completion as completion_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
message as message_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
ops_trace as ops_trace_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
site as site_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
statistic as statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_app_log as workflow_app_log_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_draft_variable as workflow_draft_variable_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_statistic as workflow_statistic_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
workflow_trigger as workflow_trigger_module,
|
||||
)
|
||||
from controllers.console.app import (
|
||||
wraps as wraps_module,
|
||||
)
|
||||
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
|
||||
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
|
||||
from controllers.console.app.site import AppSiteUpdatePayload
|
||||
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
|
||||
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
|
||||
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
|
||||
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
|
||||
from controllers.console.app.workflow_trigger import Parser, ParserEnable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
def test_chat_message_payload_uuid_validation(self):
|
||||
payload = ChatMessagePayload(
|
||||
inputs={},
|
||||
model_config={},
|
||||
query="hi",
|
||||
conversation_id=str(uuid.uuid4()),
|
||||
parent_message_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert payload.query == "hi"
|
||||
|
||||
def test_completion_api_success(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: {"text": "ok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
completion_module.helper,
|
||||
"compact_generate_response",
|
||||
lambda response: {"result": response},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(
|
||||
completion_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app, monkeypatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
def test_ops_trace_config_payload(self):
|
||||
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
|
||||
assert payload.tracing_config["api_key"] == "k"
|
||||
|
||||
def test_trace_app_config_get_empty(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"get_tracing_app_config",
|
||||
lambda **_kwargs: None,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
result = method(app_id="app-1")
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"create_tracing_app_config",
|
||||
lambda **_kwargs: {"error": True},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/",
|
||||
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
"delete_tracing_app_config",
|
||||
lambda **_kwargs: False,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
def test_site_default_language_validation(self):
|
||||
payload = AppSiteUpdatePayload(default_language="en-US")
|
||||
assert payload.default_language == "en-US"
|
||||
|
||||
def test_app_site_update_post(self, app, monkeypatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
def test_app_site_access_token_reset(self, app, monkeypatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = site
|
||||
monkeypatch.setattr(
|
||||
site_module.db,
|
||||
"session",
|
||||
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
|
||||
)
|
||||
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
|
||||
monkeypatch.setattr(
|
||||
site_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
def test_workflow_app_log_query_detail_bool(self):
|
||||
query = WorkflowAppLogQuery(detail="true")
|
||||
assert query.detail is True
|
||||
|
||||
def test_workflow_app_log_api_get(self, app, monkeypatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
def fake_get_paginate(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_log_module.WorkflowAppService,
|
||||
"get_paginate_workflow_app_logs",
|
||||
fake_get_paginate,
|
||||
)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
def test_workflow_variable_collection_get(self, app, monkeypatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return "session"
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyDraftService:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
def list_variables_without_values(self, **_kwargs):
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
|
||||
|
||||
class DummyWorkflowService:
|
||||
def is_workflow_exist(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
def test_workflow_statistic_blank_to_none(self):
|
||||
query = WorkflowStatisticQuery(start="", end="")
|
||||
assert query.start is None
|
||||
assert query.end is None
|
||||
|
||||
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
|
||||
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
|
||||
assert enable_payload.enable_trigger is True
|
||||
|
||||
def test_webhook_trigger_api_get(self, app, monkeypatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
trigger = MagicMock()
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = trigger
|
||||
|
||||
class DummySession:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
|
||||
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
def test_mcp_server_update_payload(self):
|
||||
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
),
|
||||
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
|
||||
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
|
||||
]
|
||||
assert len(payloads) == 3
|
||||
assert all(p is not None for p in payloads)
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
@@ -1,292 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
ProviderNotSupportTextToSpeechLanageServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exc", "expected"),
|
||||
[
|
||||
(AppModelConfigBrokenError(), AppUnavailableError),
|
||||
(NoAudioUploadedServiceError(), NoAudioUploadedError),
|
||||
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
|
||||
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
|
||||
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
|
||||
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
|
||||
(QuotaExceededError(), ProviderQuotaExceededError),
|
||||
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
response = handler(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
|
||||
)
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(app_model=app_model)
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
@@ -1,156 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import audio as audio_module
|
||||
from controllers.console.app.error import AudioTooLargeError
|
||||
from services.errors.audio import AudioTooLargeServiceError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_asr",
|
||||
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices",
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/audio-to-text",
|
||||
method="POST",
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio",
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
"transcript_tts_voices",
|
||||
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
@@ -1,130 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import conversation as conversation_module
|
||||
from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account():
|
||||
return SimpleNamespace(timezone="UTC", id="u1")
|
||||
|
||||
|
||||
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-1/completion-conversations",
|
||||
method="GET",
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.ChatConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
|
||||
paginate_result = MagicMock()
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response is paginate_result
|
||||
|
||||
|
||||
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
conversation = SimpleNamespace(id="c1", app_id="app-1")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = conversation
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
|
||||
|
||||
assert result is conversation
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(conversation)
|
||||
|
||||
|
||||
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(conversation_module.db, "session", session)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
|
||||
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
"delete",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
@@ -1,260 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import generator as generator_module
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _model_config_payload():
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
|
||||
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
class _Service:
|
||||
def get_draft_workflow(self, app_model):
|
||||
return workflow
|
||||
|
||||
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
|
||||
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/rule-code-generate",
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method()
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
_install_workflow_service(monkeypatch, workflow=None)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(graph_dict={"nodes": []})
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{"id": "node-1", "data": {"type": "code"}},
|
||||
]
|
||||
}
|
||||
)
|
||||
_install_workflow_service(monkeypatch, workflow=workflow)
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "node-1",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"code": "x"}
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
generator_module.LLMGenerator,
|
||||
"instruction_modify_legacy",
|
||||
lambda **_kwargs: {"instruction": "ok"},
|
||||
)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "old",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate",
|
||||
method="POST",
|
||||
json={
|
||||
"flow_id": "app-1",
|
||||
"node_id": "",
|
||||
"current": "",
|
||||
"instruction": "do",
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
|
||||
def test_instruction_template_prompt(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "prompt"},
|
||||
):
|
||||
response = method()
|
||||
|
||||
assert "data" in response
|
||||
|
||||
|
||||
def test_instruction_template_invalid_type(app) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "unknown"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
@@ -1,122 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
limit=50,
|
||||
)
|
||||
assert query.limit == 50
|
||||
|
||||
|
||||
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery with defaults."""
|
||||
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert query.first_id is None
|
||||
assert query.limit == 20
|
||||
|
||||
|
||||
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test ChatMessagesQuery converts empty first_id to None."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
conversation_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
first_id="",
|
||||
)
|
||||
assert query.first_id is None
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with like rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="like",
|
||||
content="Good answer",
|
||||
)
|
||||
assert payload.rating == "like"
|
||||
assert payload.content == "Good answer"
|
||||
|
||||
|
||||
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload with dislike rating."""
|
||||
payload = message_module.MessageFeedbackPayload(
|
||||
message_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
rating="dislike",
|
||||
)
|
||||
assert payload.rating == "dislike"
|
||||
|
||||
|
||||
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test MessageFeedbackPayload without rating."""
|
||||
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
|
||||
assert payload.rating is None
|
||||
|
||||
|
||||
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with default format."""
|
||||
query = message_module.FeedbackExportQuery()
|
||||
assert query.format == "csv"
|
||||
assert query.from_source is None
|
||||
|
||||
|
||||
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with JSON format."""
|
||||
query = message_module.FeedbackExportQuery(format="json")
|
||||
assert query.format == "json"
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as true string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="true")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as false string."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="false")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 1."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="1")
|
||||
assert query.has_comment is True
|
||||
|
||||
|
||||
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with has_comment as 0."""
|
||||
query = message_module.FeedbackExportQuery(has_comment="0")
|
||||
assert query.has_comment is False
|
||||
|
||||
|
||||
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test FeedbackExportQuery with rating filter."""
|
||||
query = message_module.FeedbackExportQuery(rating="like")
|
||||
assert query.rating == "like"
|
||||
|
||||
|
||||
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test AnnotationCountResponse creation."""
|
||||
response = message_module.AnnotationCountResponse(count=10)
|
||||
assert response.count == 10
|
||||
|
||||
|
||||
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test SuggestedQuestionsResponse creation."""
|
||||
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
|
||||
assert len(response.data) == 2
|
||||
assert response.data[0] == "What is AI?"
|
||||
@@ -1,151 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.CHAT.value,
|
||||
is_agent=False,
|
||||
app_model_config_id=None,
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {"pre_prompt": "hi"},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
def _from_model_config_dict(self, model_config):
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.id = "config-1"
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
send_mock.assert_called_once()
|
||||
assert app_model.app_model_config_id == "config-1"
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
mode=AppMode.AGENT_CHAT.value,
|
||||
is_agent=True,
|
||||
app_model_config_id="config-0",
|
||||
updated_by=None,
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
|
||||
original_config.agent_mode = json.dumps(
|
||||
{
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
}
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = original_config
|
||||
session.query.return_value = query
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_config_module.AppModelConfigService,
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {
|
||||
"pre_prompt": "hi",
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": "function-calling",
|
||||
"tools": [
|
||||
{
|
||||
"provider_id": "provider",
|
||||
"provider_type": "builtin",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {"secret": "masked"},
|
||||
}
|
||||
],
|
||||
"prompt": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
|
||||
|
||||
class _ParamManager:
|
||||
def __init__(self, **_kwargs):
|
||||
self.delete_called = False
|
||||
|
||||
def decrypt_tool_parameters(self, _value):
|
||||
return {"secret": "decrypted"}
|
||||
|
||||
def mask_tool_parameters(self, _value):
|
||||
return {"secret": "masked"}
|
||||
|
||||
def encrypt_tool_parameters(self, _value):
|
||||
return {"secret": "encrypted"}
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
self.delete_called = True
|
||||
|
||||
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
|
||||
send_mock = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
|
||||
assert response["result"] == "success"
|
||||
@@ -1,215 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
|
||||
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
|
||||
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
|
||||
|
||||
|
||||
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: (None, None),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["date"] == "2024-01-03"
|
||||
assert data["data"][0]["token_count"] == 10
|
||||
assert data["data"][0]["total_price"] == 0.25
|
||||
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
|
||||
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
api = statistic_module.AverageSessionInteractionStatistic()
|
||||
method = _unwrap(api.get)
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
def mock_parse(*args, **kwargs):
|
||||
raise ValueError("Invalid time range")
|
||||
|
||||
_install_db(monkeypatch, [])
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", message_count=10),
|
||||
SimpleNamespace(date="2024-01-02", message_count=15),
|
||||
SimpleNamespace(date="2024-01-03", message_count=12),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
statistic_module,
|
||||
"parse_time_range",
|
||||
lambda *_args, **_kwargs: ("s", "e"),
|
||||
)
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
|
||||
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
|
||||
]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
@@ -1,163 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.file.models import File
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
|
||||
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
|
||||
|
||||
|
||||
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config = object()
|
||||
file_list = [
|
||||
File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="http://u",
|
||||
)
|
||||
]
|
||||
build_mock = Mock(return_value=file_list)
|
||||
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
|
||||
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
|
||||
|
||||
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
|
||||
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
|
||||
|
||||
assert result == file_list
|
||||
build_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert exc.value.code == 415
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
data="[]",
|
||||
content_type="application/json",
|
||||
):
|
||||
response, status = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert status == 400
|
||||
assert response["message"] == "Invalid JSON data"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="h",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
|
||||
)
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise workflow_module.WorkflowHashNotEqualError()
|
||||
|
||||
service = SimpleNamespace(sync_draft_workflow=_raise)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module.AppGenerateService,
|
||||
"generate",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
workflow_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/advanced-chat/workflows/draft/run",
|
||||
method="POST",
|
||||
json={"inputs": {}},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
@@ -1,47 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import wraps as wraps_module
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
assert handler(app_id="app-1") == "app-1"
|
||||
|
||||
|
||||
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
|
||||
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
|
||||
|
||||
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(AppNotFoundError):
|
||||
handler(app_id="app-1")
|
||||
|
||||
|
||||
def test_get_app_model_requires_app_id() -> None:
|
||||
@wraps_module.get_app_model
|
||||
def handler(app_model):
|
||||
return app_model.id
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler()
|
||||
@@ -1,483 +1,13 @@
|
||||
"""Final working unit tests for admin endpoints - tests business logic directly."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.console.admin import (
|
||||
DeleteExploreBannerApi,
|
||||
InsertExploreAppApi,
|
||||
InsertExploreAppListApi,
|
||||
InsertExploreAppPayload,
|
||||
InsertExploreBannerApi,
|
||||
InsertExploreBannerPayload,
|
||||
)
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_only_edition_cloud(mocker):
|
||||
"""
|
||||
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.wraps.dify_config.EDITION",
|
||||
new="CLOUD",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_auth(mocker):
|
||||
"""
|
||||
Provide valid admin authentication for controller tests.
|
||||
"""
|
||||
mocker.patch(
|
||||
"controllers.console.admin.dify_config.ADMIN_API_KEY",
|
||||
"test-admin-key",
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.admin.extract_access_token",
|
||||
return_value="test-admin-key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_console_payload(mocker):
|
||||
payload = {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
"language": "en-US",
|
||||
"category": "Productivity",
|
||||
"position": 1,
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_banner_payload(mocker):
|
||||
mocker.patch(
|
||||
"flask_restx.namespace.Namespace.payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value={
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory(mocker):
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.add = Mock()
|
||||
mock_session.commit = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = DeleteExploreBannerApi()
|
||||
|
||||
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.delete(uuid.uuid4())
|
||||
|
||||
def test_delete_banner_success(self, mocker, mock_admin_auth):
|
||||
mock_banner = Mock()
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
|
||||
)
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
class TestInsertExploreBannerApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreBannerApi()
|
||||
|
||||
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_banner_payload_valid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
model = InsertExploreBannerPayload.model_validate(payload)
|
||||
assert model.language == "en-US"
|
||||
|
||||
def test_banner_payload_invalid_language(self):
|
||||
payload = {
|
||||
"title": "Test Banner",
|
||||
"description": "Banner description",
|
||||
"img-src": "https://example.com/banner.png",
|
||||
"link": "https://example.com",
|
||||
"sort": 1,
|
||||
"category": "homepage",
|
||||
"language": "invalid-lang",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
|
||||
InsertExploreBannerPayload.model_validate(payload)
|
||||
|
||||
|
||||
class TestInsertExploreAppApiDelete:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppApi()
|
||||
|
||||
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=Mock(return_value=False),
|
||||
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
|
||||
),
|
||||
)
|
||||
|
||||
response, status = self.api.delete(uuid.uuid4())
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app from explore that has a trial app."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_trial = Mock()
|
||||
|
||||
# Mock session context manager and its execute
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
# Set up side effects for execute calls
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
|
||||
Mock(scalar_one_or_none=lambda: mock_trial),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is False
|
||||
|
||||
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
|
||||
"""Test deleting an app that has installed apps in other tenants."""
|
||||
app_id = uuid.uuid4()
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
mock_recommended.app_id = "app-123"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.is_public = True
|
||||
|
||||
mock_installed_app = Mock(spec=InstalledApp)
|
||||
|
||||
# Mock session
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock()
|
||||
mock_session.delete = Mock()
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.delete")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.delete(app_id)
|
||||
|
||||
assert status == 204
|
||||
assert mock_session.delete.called
|
||||
|
||||
|
||||
class TestInsertExploreAppListApi:
|
||||
def setup_method(self):
|
||||
self.api = InsertExploreAppListApi()
|
||||
|
||||
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: None),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound, match="is not found"):
|
||||
self.api.post()
|
||||
|
||||
def test_create_recommended_app(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
# db.session.execute → fetch App
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
return_value=Mock(scalar_one_or_none=lambda: mock_app),
|
||||
)
|
||||
|
||||
# session_factory.create_session → recommended_app lookup
|
||||
mock_session = Mock()
|
||||
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.session_factory.create_session",
|
||||
return_value=Mock(
|
||||
__enter__=lambda s: mock_session,
|
||||
__exit__=Mock(return_value=False),
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 201
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_site_data_overrides_payload(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
site = Mock()
|
||||
site.description = "Site Desc"
|
||||
site.copyright = "Site Copyright"
|
||||
site.privacy_policy = "Site Privacy"
|
||||
site.custom_disclaimer = "Site Disclaimer"
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = site
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
commit_spy.assert_called_once()
|
||||
|
||||
def test_create_trial_app_when_can_trial_enabled(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 5
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.tenant_id = "tenant"
|
||||
mock_app.is_public = False
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
self.api.post()
|
||||
|
||||
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
|
||||
|
||||
def test_update_recommended_app_with_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app when trial is enabled."""
|
||||
mock_console_payload["can_trial"] = True
|
||||
mock_console_payload["trial_limit"] = 10
|
||||
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
mock_app.tenant_id = "tenant-123"
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
Mock(scalar_one_or_none=lambda: None),
|
||||
],
|
||||
)
|
||||
|
||||
add_spy = mocker.patch("controllers.console.admin.db.session.add")
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
|
||||
def test_update_recommended_app_without_trial(
|
||||
self,
|
||||
mocker,
|
||||
mock_admin_auth,
|
||||
mock_console_payload,
|
||||
mock_session_factory,
|
||||
):
|
||||
"""Test updating a recommended app without trial enabled."""
|
||||
mock_app = Mock(spec=App)
|
||||
mock_app.id = "app-id"
|
||||
mock_app.site = None
|
||||
mock_app.is_public = False
|
||||
|
||||
mock_recommended = Mock(spec=RecommendedApp)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.admin.db.session.execute",
|
||||
side_effect=[
|
||||
Mock(scalar_one_or_none=lambda: mock_app),
|
||||
Mock(scalar_one_or_none=lambda: mock_recommended),
|
||||
],
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.admin.db.session.commit")
|
||||
|
||||
response, status = self.api.post()
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
assert mock_app.is_public is True
|
||||
from controllers.console.admin import InsertExploreAppPayload
|
||||
from models.model import App, RecommendedApp
|
||||
|
||||
|
||||
class TestInsertExploreAppPayload:
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.apikey import (
|
||||
BaseApiKeyListResource,
|
||||
BaseApiKeyResource,
|
||||
_get_resource,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = True
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context_non_admin():
|
||||
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
|
||||
user = MagicMock()
|
||||
user.is_admin_or_owner = False
|
||||
mock.return_value = (user, "tenant-123")
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_mock():
|
||||
with patch("controllers.console.apikey.db") as mock_db:
|
||||
mock_db.session = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_permissions():
|
||||
with patch(
|
||||
"controllers.console.apikey.edit_permission_required",
|
||||
lambda f: f,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class DummyApiKeyListResource(BaseApiKeyListResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
|
||||
class DummyApiKeyResource(BaseApiKeyResource):
|
||||
resource_type = "app"
|
||||
resource_model = MagicMock()
|
||||
resource_id_field = "app_id"
|
||||
|
||||
|
||||
class TestGetResource:
|
||||
def test_get_resource_success(self):
|
||||
fake_resource = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
|
||||
|
||||
result = _get_resource("rid", "tid", MagicMock)
|
||||
assert result == fake_resource
|
||||
|
||||
def test_get_resource_not_found(self):
|
||||
with (
|
||||
patch("controllers.console.apikey.select") as mock_select,
|
||||
patch("controllers.console.apikey.Session") as mock_session,
|
||||
patch("controllers.console.apikey.db") as mock_db,
|
||||
patch("controllers.console.apikey.flask_restx.abort") as abort,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_select.return_value.filter_by.return_value = MagicMock()
|
||||
|
||||
session = mock_session.return_value.__enter__.return_value
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
_get_resource("rid", "tid", MagicMock)
|
||||
|
||||
abort.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseApiKeyListResource:
|
||||
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyListResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
|
||||
assert "items" in result
|
||||
|
||||
|
||||
class TestBaseApiKeyResource:
|
||||
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Forbidden):
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with patch("controllers.console.apikey._get_resource"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
# flask_restx.abort raises HTTPException with message in data attribute
|
||||
assert exc_info.value.data["message"] == "API key not found"
|
||||
|
||||
def test_delete_success(self, tenant_context_admin, db_mock):
|
||||
resource = DummyApiKeyResource()
|
||||
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch("controllers.console.apikey._get_resource"),
|
||||
patch("controllers.console.apikey.ApiTokenCache.delete"),
|
||||
):
|
||||
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
|
||||
|
||||
assert status == 204
|
||||
assert result == {"result": "success"}
|
||||
db_mock.session.commit.assert_called_once()
|
||||
@@ -0,0 +1,46 @@
|
||||
import builtins
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from extensions import ext_fastopenapi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.secret_key = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
|
||||
client = app.test_client()
|
||||
response = client.get("/console/api/init")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"status": "finished"}
|
||||
|
||||
|
||||
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
ext_fastopenapi.init_app(app)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
|
||||
|
||||
with (
|
||||
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
|
||||
):
|
||||
client = app.test_client()
|
||||
response = client.post("/console/api/init", json={"password": "test-init-password"})
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_json() == {"result": "success"}
|
||||
@@ -0,0 +1,286 @@
|
||||
"""Tests for remote file upload API endpoints using Flask-RESTX."""
|
||||
|
||||
import contextlib
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Create Flask app for testing."""
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client with console blueprint registered."""
|
||||
from controllers.console import bp
|
||||
|
||||
app.register_blueprint(bp)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Create a mock account for testing."""
|
||||
from models import Account
|
||||
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-account-id"
|
||||
account.current_tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_ctx(app, mock_account):
|
||||
"""Context manager to set auth/tenant context in flask.g for a request."""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _ctx():
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
yield
|
||||
|
||||
return _ctx
|
||||
|
||||
|
||||
class TestGetRemoteFileInfo:
|
||||
"""Test GET /console/api/remote-files/<path:url> endpoint."""
|
||||
|
||||
def test_get_remote_file_info_success(self, app, client, mock_account):
|
||||
"""Test successful retrieval of remote file info."""
|
||||
response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.txt"),
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "text/plain"
|
||||
assert data["file_length"] == 1024
|
||||
|
||||
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
|
||||
"""Test fallback to GET when HEAD returns non-200 status."""
|
||||
head_response = httpx.Response(
|
||||
404,
|
||||
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
|
||||
)
|
||||
get_response = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", "http://example.com/file.pdf"),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_account
|
||||
g._current_tenant = mock_account.current_tenant_id
|
||||
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
|
||||
resp = client.get(f"/console/api/remote-files/{encoded_url}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["file_type"] == "application/pdf"
|
||||
assert data["file_length"] == 2048
|
||||
|
||||
|
||||
class TestRemoteFileUpload:
|
||||
"""Test POST /console/api/remote-files/upload endpoint."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("head_status", "use_get"),
|
||||
[
|
||||
(200, False), # HEAD succeeds
|
||||
(405, True), # HEAD fails -> fallback GET
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
|
||||
url = "http://example.com/file.pdf"
|
||||
head_resp = httpx.Response(
|
||||
head_status,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
)
|
||||
get_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
|
||||
content=b"file content",
|
||||
)
|
||||
|
||||
file_info = SimpleNamespace(
|
||||
extension="pdf",
|
||||
size=1024,
|
||||
filename="file.pdf",
|
||||
mimetype="application/pdf",
|
||||
)
|
||||
uploaded_file = SimpleNamespace(
|
||||
id="uploaded-file-id",
|
||||
name="file.pdf",
|
||||
size=1024,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="test-account-id",
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
|
||||
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=True,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("controllers.console.remote_files.FileService") as mock_file_service,
|
||||
patch(
|
||||
"controllers.console.remote_files.file_helpers.get_signed_file_url",
|
||||
return_value="http://example.com/signed-url",
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
mock_file_service.return_value.upload_file.return_value = uploaded_file
|
||||
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
p_head.assert_called_once()
|
||||
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
|
||||
p_get.assert_called_once()
|
||||
mock_file_service.return_value.upload_file.assert_called_once()
|
||||
|
||||
data = resp.get_json()
|
||||
assert data["id"] == "uploaded-file-id"
|
||||
assert data["name"] == "file.pdf"
|
||||
assert data["size"] == 1024
|
||||
assert data["extension"] == "pdf"
|
||||
assert data["url"] == "http://example.com/signed-url"
|
||||
assert data["mime_type"] == "application/pdf"
|
||||
assert data["created_by"] == "test-account-id"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("size_ok", "raises", "expected_status", "expected_msg"),
|
||||
[
|
||||
# When size check fails in controller, API returns 413 with message "File size exceeded..."
|
||||
(False, None, 413, "file size exceeded"),
|
||||
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
|
||||
(True, "unsupported", 415, "file type not allowed"),
|
||||
],
|
||||
)
|
||||
def test_upload_remote_file_errors(
|
||||
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
|
||||
):
|
||||
url = "http://example.com/x.pdf"
|
||||
head_resp = httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("HEAD", url),
|
||||
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
|
||||
)
|
||||
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
|
||||
patch(
|
||||
"controllers.console.remote_files.helpers.guess_file_info_from_response",
|
||||
return_value=file_info,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.FileService.is_file_size_within_limit",
|
||||
return_value=size_ok,
|
||||
),
|
||||
patch("controllers.console.remote_files.db", spec=["engine"]),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
if raises == "unsupported":
|
||||
from services.errors.file import UnsupportedFileTypeError
|
||||
|
||||
with patch("controllers.console.remote_files.FileService") as mock_file_service:
|
||||
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
else:
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": url},
|
||||
)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert expected_msg in msg.lower()
|
||||
|
||||
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
|
||||
"""Test upload when fetching of remote file fails."""
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.remote_files.current_account_with_tenant",
|
||||
return_value=(mock_account, "test-tenant-id"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.remote_files.ssrf_proxy.head",
|
||||
side_effect=httpx.RequestError("Connection failed"),
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
with auth_ctx():
|
||||
resp = client.post(
|
||||
"/console/api/remote-files/upload",
|
||||
json={"url": "http://unreachable.com/file.pdf"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
data = resp.get_json()
|
||||
msg = (data.get("error") or {}).get("message") or data.get("message", "")
|
||||
assert "failed to fetch" in msg.lower()
|
||||
@@ -1,81 +0,0 @@
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestFeatureApi:
|
||||
def test_get_tenant_features_success(self, mocker):
|
||||
from controllers.console.feature import FeatureApi
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_account_with_tenant",
|
||||
return_value=("account_id", "tenant_123"),
|
||||
)
|
||||
|
||||
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
|
||||
"features": {"feature_a": True}
|
||||
}
|
||||
|
||||
api = FeatureApi()
|
||||
|
||||
raw_get = unwrap(FeatureApi.get)
|
||||
result = raw_get(api)
|
||||
|
||||
assert result == {"features": {"feature_a": True}}
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
def test_get_system_features_authenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated == True
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
fake_user.is_authenticated = True
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": True}}
|
||||
|
||||
def test_get_system_features_unauthenticated(self, mocker):
|
||||
"""
|
||||
current_user.is_authenticated raises Unauthorized
|
||||
"""
|
||||
|
||||
from controllers.console.feature import SystemFeatureApi
|
||||
|
||||
fake_user = mocker.Mock()
|
||||
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.current_user",
|
||||
fake_user,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.feature.FeatureService.get_system_features"
|
||||
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
|
||||
|
||||
api = SystemFeatureApi()
|
||||
result = api.get()
|
||||
|
||||
assert result == {"features": {"sys_feature": False}}
|
||||
@@ -1,300 +0,0 @@
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.files import (
|
||||
FileApi,
|
||||
FilePreviewApi,
|
||||
FileSupportTypeApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""
|
||||
Recursively unwrap decorated functions.
|
||||
"""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
app.testing = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_decorators():
|
||||
"""
|
||||
Make decorators no-ops so logic is directly testable
|
||||
"""
|
||||
with (
|
||||
patch("controllers.console.files.setup_required", new=lambda f: f),
|
||||
patch("controllers.console.files.login_required", new=lambda f: f),
|
||||
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
|
||||
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user():
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account_context(mock_current_user):
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
with patch("controllers.console.files.db") as db_mock:
|
||||
db_mock.engine = MagicMock()
|
||||
yield db_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_service(mock_db):
|
||||
with patch("controllers.console.files.FileService") as fs:
|
||||
instance = fs.return_value
|
||||
yield instance
|
||||
|
||||
|
||||
class TestFileApiGet:
|
||||
def test_get_upload_config(self, app):
|
||||
api = FileApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
data, status = get_method(api)
|
||||
|
||||
assert status == 200
|
||||
assert "file_size_limit" in data
|
||||
assert "batch_count_limit" in data
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api)
|
||||
|
||||
def test_too_many_files(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST"):
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
with patch("controllers.console.files.request") as mock_request:
|
||||
mock_request.files = MagicMock()
|
||||
mock_request.files.__len__.return_value = 2
|
||||
mock_request.files.__contains__.return_value = True
|
||||
mock_request.form = MagicMock()
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api)
|
||||
|
||||
def test_filename_missing(self, app, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), ""),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
|
||||
with patch(
|
||||
"controllers.console.files.current_account_with_tenant",
|
||||
return_value=(mock_current_user, None),
|
||||
):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"abc"), "test.txt"),
|
||||
"source": "datasets",
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api)
|
||||
|
||||
def test_successful_upload(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-123"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 1024
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-123"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = "http://example.com/preview/file-id-123"
|
||||
mock_file.source_url = "http://example.com/source/file-id-123"
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-123"
|
||||
mock_file.tenant_id = "tenant-123"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-123"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"hello"), "test.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
# Create a properly structured mock file object
|
||||
mock_file = MagicMock()
|
||||
mock_file.id = "file-id-456"
|
||||
mock_file.filename = "test.txt"
|
||||
mock_file.name = "test.txt"
|
||||
mock_file.size = 512
|
||||
mock_file.extension = "txt"
|
||||
mock_file.mime_type = "text/plain"
|
||||
mock_file.created_by = "user-456"
|
||||
mock_file.created_at = 1234567890
|
||||
mock_file.preview_url = None
|
||||
mock_file.source_url = None
|
||||
mock_file.original_url = None
|
||||
mock_file.user_id = "user-456"
|
||||
mock_file.tenant_id = "tenant-456"
|
||||
mock_file.conversation_id = None
|
||||
mock_file.file_key = "file-key-456"
|
||||
|
||||
mock_file_service.upload_file.return_value = mock_file
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"content"), "test.txt"),
|
||||
"source": "invalid_source", # Should be normalized to None
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
# Verify that FileService was called with source=None
|
||||
mock_file_service.upload_file.assert_called_once()
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
|
||||
error = ServiceFileTooLargeError("File is too large")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api)
|
||||
|
||||
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
error = ServiceUnsupportedFileTypeError()
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "bad.exe"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api)
|
||||
|
||||
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
|
||||
|
||||
error = ServiceBlockedFileExtensionError("File extension is blocked")
|
||||
mock_file_service.upload_file.side_effect = error
|
||||
|
||||
data = {
|
||||
"file": (io.BytesIO(b"x"), "blocked.txt"),
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
class TestFileSupportTypeApi:
|
||||
def test_get_supported_types(self, app):
|
||||
api = FileSupportTypeApi()
|
||||
get_method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api)
|
||||
|
||||
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
@@ -1,293 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console.human_input_form import (
|
||||
ConsoleHumanInputFormApi,
|
||||
ConsoleWorkflowEventsApi,
|
||||
DifyAPIRepositoryFactory,
|
||||
WorkflowResponseConverter,
|
||||
_jsonify_form_definition,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_jsonify_form_definition() -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
|
||||
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
response = _jsonify_form_definition(form)
|
||||
|
||||
assert isinstance(response, Response)
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["expiration_time"] == int(expiration.timestamp())
|
||||
|
||||
|
||||
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1")
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form)
|
||||
|
||||
|
||||
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
|
||||
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["fields"] == ["a"]
|
||||
|
||||
|
||||
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_definition_by_token_for_console(self, _token):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
|
||||
|
||||
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock = Mock()
|
||||
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
|
||||
|
||||
class _ServiceStub:
|
||||
def __init__(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
def get_form_by_token(self, _token):
|
||||
return form
|
||||
|
||||
def submit_form_by_token(self, **kwargs):
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response = handler(api, form_token="token")
|
||||
|
||||
assert response.get_json() == {}
|
||||
submit_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-2",
|
||||
tenant_id="t1",
|
||||
)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow_run = SimpleNamespace(
|
||||
id="run-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="user-1",
|
||||
tenant_id="t1",
|
||||
app_id="app-1",
|
||||
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
|
||||
class _RepoStub:
|
||||
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
|
||||
return workflow_run
|
||||
|
||||
response_obj = SimpleNamespace(
|
||||
event=SimpleNamespace(value="finished"),
|
||||
model_dump=lambda mode="json": {"status": "done"},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
|
||||
lambda *_args, **_kwargs: app_model,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
WorkflowResponseConverter,
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: response_obj,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
response = handler(api, workflow_run_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
assert "data" in response.get_data(as_text=True)
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import init_validate
|
||||
from controllers.console.error import AlreadySetupError, InitValidateFailedError
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, has_setup: bool):
|
||||
self._has_setup = has_setup
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
|
||||
|
||||
|
||||
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "finished"
|
||||
|
||||
|
||||
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
|
||||
result = init_validate.get_init_status()
|
||||
assert result.status == "not_started"
|
||||
|
||||
|
||||
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(AlreadySetupError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
|
||||
|
||||
|
||||
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
with pytest.raises(InitValidateFailedError):
|
||||
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
|
||||
assert init_validate.session.get("is_init_validated") is False
|
||||
|
||||
|
||||
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="POST"):
|
||||
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
|
||||
assert result.result == "success"
|
||||
assert init_validate.session.get("is_init_validated") is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session["is_init_validated"] = True
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is True
|
||||
|
||||
|
||||
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
|
||||
monkeypatch.setenv("INIT_PASSWORD", "expected")
|
||||
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
|
||||
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
|
||||
app.secret_key = "test-secret"
|
||||
|
||||
with app.test_request_context("/console/api/init", method="GET"):
|
||||
init_validate.session.pop("is_init_validated", None)
|
||||
assert init_validate.get_init_validate_status() is False
|
||||
@@ -1,281 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
method: str = "GET",
|
||||
content: bytes = b"",
|
||||
text: str = "",
|
||||
error: Exception | None = None,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.request = SimpleNamespace(method=method)
|
||||
self.content = content
|
||||
self.text = text
|
||||
self._error = error
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
if self._error:
|
||||
raise self._error
|
||||
|
||||
|
||||
def _mock_upload_dependencies(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
mimetype="text/plain",
|
||||
size=3,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.helpers,
|
||||
"guess_file_info_from_response",
|
||||
MagicMock(return_value=file_info),
|
||||
)
|
||||
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
"get_signed_file_url",
|
||||
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
|
||||
)
|
||||
|
||||
return file_service_cls
|
||||
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
head_resp = _FakeResponse(
|
||||
status_code=200,
|
||||
headers={"Content-Type": "text/plain", "Content-Length": "128"},
|
||||
method="HEAD",
|
||||
)
|
||||
head_mock = MagicMock(return_value=head_resp)
|
||||
get_mock = MagicMock()
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "text/plain", "file_length": 128}
|
||||
head_mock.assert_called_once_with(decoded_url)
|
||||
get_mock.assert_not_called()
|
||||
|
||||
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
|
||||
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
with app.test_request_context(method="GET"):
|
||||
payload = handler(api, url=encoded_url)
|
||||
|
||||
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
|
||||
get_mock.assert_called_once_with(decoded_url, timeout=3)
|
||||
|
||||
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
|
||||
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
|
||||
get_mock = MagicMock(return_value=get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="report.txt",
|
||||
size=16,
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-1"
|
||||
assert payload["url"] == "https://signed.example/file-1"
|
||||
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
|
||||
file_service_cls.return_value.upload_file.assert_called_once_with(
|
||||
filename="report.txt",
|
||||
content=b"fallback-content",
|
||||
mimetype="text/plain",
|
||||
user=SimpleNamespace(id="u1"),
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
|
||||
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/photo.jpg"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
|
||||
)
|
||||
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
|
||||
get_mock = MagicMock(return_value=extra_get_resp)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
|
||||
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-2",
|
||||
name="photo.jpg",
|
||||
size=18,
|
||||
extension=".jpg",
|
||||
mime_type="image/jpeg",
|
||||
created_by="u1",
|
||||
created_at=datetime(2024, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
file_service_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
payload, status = handler(api)
|
||||
|
||||
assert status == 201
|
||||
assert payload["id"] == "file-2"
|
||||
get_mock.assert_called_once_with(url)
|
||||
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"get",
|
||||
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
request = httpx.Request("HEAD", url)
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
|
||||
)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
|
||||
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(FileTooLargeError, match="size exceeded"):
|
||||
handler(api)
|
||||
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.ssrf_proxy,
|
||||
"head",
|
||||
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
|
||||
)
|
||||
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
|
||||
file_service_cls = _mock_upload_dependencies(monkeypatch)
|
||||
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
handler(api)
|
||||
@@ -1,49 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import controllers.console.spec as spec_module
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestSpecSchemaDefinitionsApi:
|
||||
def test_get_success(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
schema_definitions = [{"type": "string"}]
|
||||
|
||||
with patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
) as schema_manager_cls:
|
||||
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
|
||||
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == schema_definitions
|
||||
|
||||
def test_get_exception_returns_empty_list(self):
|
||||
api = spec_module.SpecSchemaDefinitionsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
spec_module,
|
||||
"SchemaManager",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
spec_module.logger,
|
||||
"exception",
|
||||
) as log_exception,
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp == []
|
||||
log_exception.assert_called_once()
|
||||
@@ -1,162 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import controllers.console.version as version_module
|
||||
|
||||
|
||||
class TestHasNewVersion:
|
||||
def test_has_new_version_true(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.2.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_has_new_version_false(self):
|
||||
result = version_module._has_new_version(
|
||||
latest_version="1.0.0",
|
||||
current_version="1.1.0",
|
||||
)
|
||||
assert result is False
|
||||
|
||||
def test_has_new_version_invalid_version(self):
|
||||
with patch.object(version_module.logger, "warning") as log_warning:
|
||||
result = version_module._has_new_version(
|
||||
latest_version="invalid",
|
||||
current_version="1.0.0",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
log_warning.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckVersionUpdate:
|
||||
def test_no_check_update_url(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
True,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
False,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
assert result.can_auto_update is False
|
||||
assert result.features.can_replace_logo is True
|
||||
assert result.features.model_load_balancing_enabled is False
|
||||
|
||||
def test_http_error_fallback(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch.object(
|
||||
version_module.logger,
|
||||
"warning",
|
||||
) as log_warning,
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.0.0"
|
||||
log_warning.assert_called_once()
|
||||
|
||||
def test_new_version_available(self):
|
||||
query = version_module.VersionQuery(current_version="1.0.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.2.0",
|
||||
"releaseDate": "2024-01-01",
|
||||
"releaseNotes": "New features",
|
||||
"canAutoUpdate": True,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.0.0",
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CAN_REPLACE_LOGO",
|
||||
False,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"MODEL_LB_ENABLED",
|
||||
True,
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.release_date == "2024-01-01"
|
||||
assert result.release_notes == "New features"
|
||||
assert result.can_auto_update is True
|
||||
|
||||
def test_no_new_version(self):
|
||||
query = version_module.VersionQuery(current_version="1.2.0")
|
||||
|
||||
response = MagicMock()
|
||||
response.json.return_value = {
|
||||
"version": "1.1.0",
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
version_module.dify_config,
|
||||
"CHECK_UPDATE_URL",
|
||||
"http://example.com",
|
||||
),
|
||||
patch.object(
|
||||
version_module.httpx,
|
||||
"get",
|
||||
return_value=response,
|
||||
),
|
||||
patch.object(
|
||||
version_module.dify_config.project,
|
||||
"version",
|
||||
"1.2.0",
|
||||
),
|
||||
):
|
||||
result = version_module.check_version_update(query)
|
||||
|
||||
assert result.version == "1.2.0"
|
||||
assert result.can_auto_update is False
|
||||
@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 6. Close Handling ───────────────────────────────────────────────────
|
||||
describe('Close handling', () => {
|
||||
it('should call onCancel when pressing ESC key', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
// ahooks useKeyPress listens on document for keydown events
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
code: 'Escape',
|
||||
keyCode: 27,
|
||||
bubbles: true,
|
||||
}))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
|
||||
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
|
||||
describe('Pricing page URL', () => {
|
||||
it('should render pricing link with correct URL', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
@@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
||||
isShow={isShowDeleteConfirm}
|
||||
onClose={() => setIsShowDeleteConfirm(false)}
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-3 text-text-primary">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<div className="mb-3 text-text-primary title-2xl-semi-bold">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<p className="mb-8 text-text-secondary">{t('avatar.deleteDescription', { ns: 'common' })}</p>
|
||||
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
|
||||
@@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
</div>
|
||||
{step === STEP.start && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-medium text-text-warning">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-warning body-md-medium">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content1"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyOrigin && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content2"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -278,25 +278,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToOriginEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToOriginEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.newEmail && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
<div className="text-text-secondary body-md-regular">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.emailPlaceholder', { ns: 'common' })}
|
||||
@@ -305,10 +305,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
destructive={newEmailExited || unAvailableEmail}
|
||||
/>
|
||||
{newEmailExited && (
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
{unAvailableEmail && (
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-3 space-y-2">
|
||||
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyNew && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content4"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email: mail }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -368,13 +368,13 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToNewEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToNewEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -138,7 +138,7 @@ export default function AccountPage() {
|
||||
imageUrl={icon_url}
|
||||
/>
|
||||
</div>
|
||||
<div className="system-sm-medium mt-[3px] text-text-secondary">{item.name}</div>
|
||||
<div className="mt-[3px] text-text-secondary system-sm-medium">{item.name}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -146,12 +146,12 @@ export default function AccountPage() {
|
||||
return (
|
||||
<>
|
||||
<div className="pb-3 pt-2">
|
||||
<h4 className="title-2xl-semi-bold text-text-primary">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
</div>
|
||||
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
|
||||
<div className="ml-4">
|
||||
<p className="system-xl-semibold text-text-primary">
|
||||
<p className="text-text-primary system-xl-semibold">
|
||||
{userProfile.name}
|
||||
{isEducationAccount && (
|
||||
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
|
||||
@@ -160,16 +160,16 @@ export default function AccountPage() {
|
||||
</PremiumBadge>
|
||||
)}
|
||||
</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{userProfile.email}</p>
|
||||
<p className="text-text-tertiary system-xs-regular">{userProfile.email}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<span className="pl-1">{userProfile.name}</span>
|
||||
</div>
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={handleEditName}>
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={handleEditName}>
|
||||
{t('operation.edit', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -177,11 +177,11 @@ export default function AccountPage() {
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.email', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<span className="pl-1">{userProfile.email}</span>
|
||||
</div>
|
||||
{systemFeatures.enable_change_email && (
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={() => setShowUpdateEmail(true)}>
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={() => setShowUpdateEmail(true)}>
|
||||
{t('operation.change', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
@@ -191,8 +191,8 @@ export default function AccountPage() {
|
||||
systemFeatures.enable_email_password_login && (
|
||||
<div className="mb-8 flex justify-between gap-2">
|
||||
<div>
|
||||
<div className="system-sm-semibold mb-1 text-text-secondary">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="body-xs-regular mb-2 text-text-tertiary">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
<div className="mb-1 text-text-secondary system-sm-semibold">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="mb-2 text-text-tertiary body-xs-regular">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<Button onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</Button>
|
||||
</div>
|
||||
@@ -219,7 +219,7 @@ export default function AccountPage() {
|
||||
onClose={() => setEditNameModalVisible(false)}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="mt-2"
|
||||
@@ -249,7 +249,7 @@ export default function AccountPage() {
|
||||
}}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
{userProfile.is_password_set && (
|
||||
<>
|
||||
<div className={titleClassName}>{t('account.currentPassword', { ns: 'common' })}</div>
|
||||
@@ -272,7 +272,7 @@ export default function AccountPage() {
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">
|
||||
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="relative mt-2">
|
||||
@@ -291,7 +291,7 @@ export default function AccountPage() {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="relative mt-2">
|
||||
<Input
|
||||
type={showConfirmPassword ? 'text' : 'password'}
|
||||
|
||||
@@ -94,7 +94,7 @@ const CSVUploader: FC<Props> = ({
|
||||
/>
|
||||
<div ref={dropRef}>
|
||||
{!file && (
|
||||
<div className={cn('system-sm-regular flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
|
||||
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg system-sm-regular', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
|
||||
<div className="flex w-full items-center justify-center space-x-2">
|
||||
<CSVIcon className="shrink-0" />
|
||||
<div className="text-text-tertiary">
|
||||
|
||||
@@ -178,7 +178,7 @@ const Prompt: FC<ISimplePromptInput> = ({
|
||||
{!noTitle && (
|
||||
<div className="flex h-11 items-center justify-between pl-3 pr-2.5">
|
||||
<div className="flex items-center space-x-1">
|
||||
<div className="h2 system-sm-semibold-uppercase text-text-secondary">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
|
||||
<div className="h2 text-text-secondary system-sm-semibold-uppercase">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
|
||||
{!readonly && (
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
|
||||
@@ -96,7 +96,7 @@ const Editor: FC<Props> = ({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className={cn(editorHeight, ' min-h-[102px] overflow-y-auto px-4 text-sm text-gray-700')}>
|
||||
<div className={cn(editorHeight, 'min-h-[102px] overflow-y-auto px-4 text-sm text-gray-700')}>
|
||||
<PromptEditor
|
||||
className={editorHeight}
|
||||
value={value}
|
||||
|
||||
@@ -3,8 +3,10 @@ import type { FormValue } from '@/app/components/header/account-setting/model-pr
|
||||
import type { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
import type { GenRes } from '@/service/debug'
|
||||
import type { AppModeEnum, CompletionParams, Model, ModelModeType } from '@/types/app'
|
||||
import { useSessionStorageState } from 'ahooks'
|
||||
import useBoolean from 'ahooks/lib/useBoolean'
|
||||
import {
|
||||
useBoolean,
|
||||
useSessionStorageState,
|
||||
} from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -224,7 +226,7 @@ export const GetCodeGeneratorResModal: FC<IGetCodeGeneratorResProps> = (
|
||||
</div>
|
||||
<div>
|
||||
<div className="text-[0px]">
|
||||
<div className="system-sm-semibold-uppercase mb-1.5 text-text-secondary">{t('codegen.instruction', { ns: 'appDebug' })}</div>
|
||||
<div className="mb-1.5 text-text-secondary system-sm-semibold-uppercase">{t('codegen.instruction', { ns: 'appDebug' })}</div>
|
||||
<InstructionEditor
|
||||
editorKey={editorKey}
|
||||
value={instruction}
|
||||
@@ -248,7 +250,7 @@ export const GetCodeGeneratorResModal: FC<IGetCodeGeneratorResProps> = (
|
||||
disabled={isLoading}
|
||||
>
|
||||
<Generator className="h-4 w-4" />
|
||||
<span className="text-xs font-semibold ">{t('codegen.generate', { ns: 'appDebug' })}</span>
|
||||
<span className="text-xs font-semibold">{t('codegen.generate', { ns: 'appDebug' })}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -210,7 +210,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
<div className="overflow-y-auto border-b border-divider-regular p-6 pb-[68px] pt-5">
|
||||
<div className={cn(rowClass, 'items-center')}>
|
||||
<div className={labelClass}>
|
||||
<div className="system-sm-semibold text-text-secondary">{t('form.name', { ns: 'datasetSettings' })}</div>
|
||||
<div className="text-text-secondary system-sm-semibold">{t('form.name', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
<Input
|
||||
value={localeCurrentDataset.name}
|
||||
@@ -221,7 +221,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
</div>
|
||||
<div className={cn(rowClass)}>
|
||||
<div className={labelClass}>
|
||||
<div className="system-sm-semibold text-text-secondary">{t('form.desc', { ns: 'datasetSettings' })}</div>
|
||||
<div className="text-text-secondary system-sm-semibold">{t('form.desc', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<Textarea
|
||||
@@ -234,7 +234,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
</div>
|
||||
<div className={rowClass}>
|
||||
<div className={labelClass}>
|
||||
<div className="system-sm-semibold text-text-secondary">{t('form.permissions', { ns: 'datasetSettings' })}</div>
|
||||
<div className="text-text-secondary system-sm-semibold">{t('form.permissions', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<PermissionSelector
|
||||
@@ -250,7 +250,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
{!!(currentDataset && currentDataset.indexing_technique) && (
|
||||
<div className={cn(rowClass)}>
|
||||
<div className={labelClass}>
|
||||
<div className="system-sm-semibold text-text-secondary">{t('form.indexMethod', { ns: 'datasetSettings' })}</div>
|
||||
<div className="text-text-secondary system-sm-semibold">{t('form.indexMethod', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
<div className="grow">
|
||||
<IndexMethod
|
||||
@@ -267,7 +267,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
{indexMethod === IndexingType.QUALIFIED && (
|
||||
<div className={cn(rowClass)}>
|
||||
<div className={labelClass}>
|
||||
<div className="system-sm-semibold text-text-secondary">{t('form.embeddingModel', { ns: 'datasetSettings' })}</div>
|
||||
<div className="text-text-secondary system-sm-semibold">{t('form.embeddingModel', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<div className="h-8 w-full rounded-lg bg-components-input-bg-normal opacity-60">
|
||||
|
||||
@@ -394,7 +394,7 @@ const Debug: FC<IDebug> = ({
|
||||
<>
|
||||
<div className="shrink-0">
|
||||
<div className="flex items-center justify-between px-4 pb-2 pt-3">
|
||||
<div className="system-xl-semibold text-text-primary">{t('inputs.title', { ns: 'appDebug' })}</div>
|
||||
<div className="text-text-primary system-xl-semibold">{t('inputs.title', { ns: 'appDebug' })}</div>
|
||||
<div className="flex items-center">
|
||||
{
|
||||
debugWithMultipleModel
|
||||
@@ -539,7 +539,7 @@ const Debug: FC<IDebug> = ({
|
||||
{!completionRes && !isResponding && (
|
||||
<div className="flex grow flex-col items-center justify-center gap-2">
|
||||
<RiSparklingFill className="h-12 w-12 text-text-empty-state-icon" />
|
||||
<div className="system-sm-regular text-text-quaternary">{t('noResult', { ns: 'appDebug' })}</div>
|
||||
<div className="text-text-quaternary system-sm-regular">{t('noResult', { ns: 'appDebug' })}</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -966,10 +966,10 @@ const Configuration: FC = () => {
|
||||
<div className="bg-default-subtle absolute left-0 top-0 h-14 w-full">
|
||||
<div className="flex h-14 items-center justify-between px-6">
|
||||
<div className="flex items-center">
|
||||
<div className="system-xl-semibold text-text-primary">{t('orchestrate', { ns: 'appDebug' })}</div>
|
||||
<div className="text-text-primary system-xl-semibold">{t('orchestrate', { ns: 'appDebug' })}</div>
|
||||
<div className="flex h-[14px] items-center space-x-1 text-xs">
|
||||
{isAdvancedMode && (
|
||||
<div className="system-xs-medium-uppercase ml-1 flex h-5 items-center rounded-md border border-components-button-secondary-border px-1.5 uppercase text-text-tertiary">{t('promptMode.advanced', { ns: 'appDebug' })}</div>
|
||||
<div className="ml-1 flex h-5 items-center rounded-md border border-components-button-secondary-border px-1.5 uppercase text-text-tertiary system-xs-medium-uppercase">{t('promptMode.advanced', { ns: 'appDebug' })}</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
@@ -1030,8 +1030,8 @@ const Configuration: FC = () => {
|
||||
<Config />
|
||||
</div>
|
||||
{!isMobile && (
|
||||
<div className="relative flex h-full w-1/2 grow flex-col overflow-y-auto " style={{ borderColor: 'rgba(0, 0, 0, 0.02)' }}>
|
||||
<div className="flex grow flex-col rounded-tl-2xl border-l-[0.5px] border-t-[0.5px] border-components-panel-border bg-chatbot-bg ">
|
||||
<div className="relative flex h-full w-1/2 grow flex-col overflow-y-auto" style={{ borderColor: 'rgba(0, 0, 0, 0.02)' }}>
|
||||
<div className="flex grow flex-col rounded-tl-2xl border-l-[0.5px] border-t-[0.5px] border-components-panel-border bg-chatbot-bg">
|
||||
<Debug
|
||||
isAPIKeySet={isAPIKeySet}
|
||||
onSetting={() => setShowAccountSettingModal({ payload: ACCOUNT_SETTING_TAB.PROVIDER })}
|
||||
|
||||
@@ -217,7 +217,7 @@ const ExternalDataToolModal: FC<ExternalDataToolModalProps> = ({
|
||||
<AppIcon
|
||||
size="large"
|
||||
onClick={() => { setShowEmojiPicker(true) }}
|
||||
className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border "
|
||||
className="!h-9 !w-9 cursor-pointer rounded-lg border-[0.5px] border-components-panel-border"
|
||||
icon={localeData.icon}
|
||||
background={localeData.icon_background}
|
||||
/>
|
||||
|
||||
@@ -117,10 +117,10 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
<div className="px-10">
|
||||
<div className="h-6 w-full 2xl:h-[139px]" />
|
||||
<div className="pb-6 pt-1">
|
||||
<span className="title-2xl-semi-bold text-text-primary">{t('newApp.startFromBlank', { ns: 'app' })}</span>
|
||||
<span className="text-text-primary title-2xl-semi-bold">{t('newApp.startFromBlank', { ns: 'app' })}</span>
|
||||
</div>
|
||||
<div className="mb-2 leading-6">
|
||||
<span className="system-sm-semibold text-text-secondary">{t('newApp.chooseAppType', { ns: 'app' })}</span>
|
||||
<span className="text-text-secondary system-sm-semibold">{t('newApp.chooseAppType', { ns: 'app' })}</span>
|
||||
</div>
|
||||
<div className="flex w-[660px] flex-col gap-4">
|
||||
<div>
|
||||
@@ -160,7 +160,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
className="flex cursor-pointer items-center border-0 bg-transparent p-0"
|
||||
onClick={() => setIsAppTypeExpanded(!isAppTypeExpanded)}
|
||||
>
|
||||
<span className="system-2xs-medium-uppercase text-text-tertiary">{t('newApp.forBeginners', { ns: 'app' })}</span>
|
||||
<span className="text-text-tertiary system-2xs-medium-uppercase">{t('newApp.forBeginners', { ns: 'app' })}</span>
|
||||
<RiArrowRightSLine className={`ml-1 h-4 w-4 text-text-tertiary transition-transform ${isAppTypeExpanded ? 'rotate-90' : ''}`} />
|
||||
</button>
|
||||
</div>
|
||||
@@ -212,7 +212,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="flex-1">
|
||||
<div className="mb-1 flex h-6 items-center">
|
||||
<label className="system-sm-semibold text-text-secondary">{t('newApp.captionName', { ns: 'app' })}</label>
|
||||
<label className="text-text-secondary system-sm-semibold">{t('newApp.captionName', { ns: 'app' })}</label>
|
||||
</div>
|
||||
<Input
|
||||
value={name}
|
||||
@@ -243,8 +243,8 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
</div>
|
||||
<div>
|
||||
<div className="mb-1 flex h-6 items-center">
|
||||
<label className="system-sm-semibold text-text-secondary">{t('newApp.captionDescription', { ns: 'app' })}</label>
|
||||
<span className="system-xs-regular ml-1 text-text-tertiary">
|
||||
<label className="text-text-secondary system-sm-semibold">{t('newApp.captionDescription', { ns: 'app' })}</label>
|
||||
<span className="ml-1 text-text-tertiary system-xs-regular">
|
||||
(
|
||||
{t('newApp.optional', { ns: 'app' })}
|
||||
)
|
||||
@@ -260,7 +260,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
</div>
|
||||
{isAppsFull && <AppsFull className="mt-4" loc="app-create" />}
|
||||
<div className="flex items-center justify-between pb-10 pt-5">
|
||||
<div className="system-xs-regular flex cursor-pointer items-center gap-1 text-text-tertiary" onClick={onCreateFromTemplate}>
|
||||
<div className="flex cursor-pointer items-center gap-1 text-text-tertiary system-xs-regular" onClick={onCreateFromTemplate}>
|
||||
<span>{t('newApp.noIdeaTip', { ns: 'app' })}</span>
|
||||
<div className="p-[1px]">
|
||||
<RiArrowRightLine className="h-3.5 w-3.5" />
|
||||
@@ -334,8 +334,8 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP
|
||||
onClick={onClick}
|
||||
>
|
||||
{icon}
|
||||
<div className="system-sm-semibold mb-0.5 mt-2 text-text-secondary">{title}</div>
|
||||
<div className="system-xs-regular line-clamp-2 text-text-tertiary" title={description}>{description}</div>
|
||||
<div className="mb-0.5 mt-2 text-text-secondary system-sm-semibold">{title}</div>
|
||||
<div className="line-clamp-2 text-text-tertiary system-xs-regular" title={description}>{description}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -367,8 +367,8 @@ function AppPreview({ mode }: { mode: AppModeEnum }) {
|
||||
const previewInfo = modeToPreviewInfoMap[mode]
|
||||
return (
|
||||
<div className="px-8 py-4">
|
||||
<h4 className="system-sm-semibold-uppercase text-text-secondary">{previewInfo.title}</h4>
|
||||
<div className="system-xs-regular mt-1 min-h-8 max-w-96 text-text-tertiary">
|
||||
<h4 className="text-text-secondary system-sm-semibold-uppercase">{previewInfo.title}</h4>
|
||||
<div className="mt-1 min-h-8 max-w-96 text-text-tertiary system-xs-regular">
|
||||
<span>{previewInfo.description}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -232,7 +232,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
isShow={show}
|
||||
onClose={noop}
|
||||
>
|
||||
<div className="title-2xl-semi-bold flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary">
|
||||
<div className="flex items-center justify-between pb-3 pl-6 pr-5 pt-6 text-text-primary title-2xl-semi-bold">
|
||||
{t('importFromDSL', { ns: 'app' })}
|
||||
<div
|
||||
className="flex h-8 w-8 cursor-pointer items-center"
|
||||
@@ -241,7 +241,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
<RiCloseLine className="h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="system-md-semibold flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary">
|
||||
<div className="flex h-9 items-center space-x-6 border-b border-divider-subtle px-6 text-text-tertiary system-md-semibold">
|
||||
{
|
||||
tabs.map(tab => (
|
||||
<div
|
||||
@@ -275,7 +275,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
{
|
||||
currentTab === CreateFromDSLModalTab.FROM_URL && (
|
||||
<div>
|
||||
<div className="system-md-semibold mb-1 text-text-secondary">DSL URL</div>
|
||||
<div className="mb-1 text-text-secondary system-md-semibold">DSL URL</div>
|
||||
<Input
|
||||
placeholder={t('importFromDSLUrlPlaceholder', { ns: 'app' }) || ''}
|
||||
value={dslUrlValue}
|
||||
@@ -309,8 +309,8 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
||||
className="w-[480px]"
|
||||
>
|
||||
<div className="flex flex-col items-start gap-2 self-stretch pb-4">
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
|
||||
<div className="system-md-regular flex grow flex-col text-text-secondary">
|
||||
<div className="text-text-primary title-2xl-semi-bold">{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}</div>
|
||||
<div className="flex grow flex-col text-text-secondary system-md-regular">
|
||||
<div>{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}</div>
|
||||
<div>{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}</div>
|
||||
<br />
|
||||
|
||||
@@ -121,7 +121,7 @@ const Uploader: FC<Props> = ({
|
||||
</div>
|
||||
)}
|
||||
{file && (
|
||||
<div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', ' hover:bg-components-panel-on-panel-item-bg-hover')}>
|
||||
<div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', 'hover:bg-components-panel-on-panel-item-bg-hover')}>
|
||||
<div className="flex items-center justify-center p-3">
|
||||
<YamlIcon className="h-6 w-6 shrink-0" />
|
||||
</div>
|
||||
|
||||
@@ -96,7 +96,7 @@ const statusTdRender = (statusCount: StatusCount) => {
|
||||
|
||||
if (statusCount.paused > 0) {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
|
||||
<Indicator color="yellow" />
|
||||
<span className="text-util-colors-warning-warning-600">Pending</span>
|
||||
</div>
|
||||
@@ -104,7 +104,7 @@ const statusTdRender = (statusCount: StatusCount) => {
|
||||
}
|
||||
else if (statusCount.partial_success + statusCount.failed === 0) {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
|
||||
<Indicator color="green" />
|
||||
<span className="text-util-colors-green-green-600">Success</span>
|
||||
</div>
|
||||
@@ -112,7 +112,7 @@ const statusTdRender = (statusCount: StatusCount) => {
|
||||
}
|
||||
else if (statusCount.failed === 0) {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
|
||||
<Indicator color="green" />
|
||||
<span className="text-util-colors-green-green-600">Partial Success</span>
|
||||
</div>
|
||||
@@ -120,7 +120,7 @@ const statusTdRender = (statusCount: StatusCount) => {
|
||||
}
|
||||
else {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<div className="inline-flex items-center gap-1 system-xs-semibold-uppercase">
|
||||
<Indicator color="red" />
|
||||
<span className="text-util-colors-red-red-600">
|
||||
{statusCount.failed}
|
||||
@@ -562,9 +562,9 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
{/* Panel Header */}
|
||||
<div className="flex shrink-0 items-center gap-2 rounded-t-xl bg-components-panel-bg pb-2 pl-4 pr-3 pt-3">
|
||||
<div className="shrink-0">
|
||||
<div className="system-xs-semibold-uppercase mb-0.5 text-text-primary">{isChatMode ? t('detail.conversationId', { ns: 'appLog' }) : t('detail.time', { ns: 'appLog' })}</div>
|
||||
<div className="mb-0.5 text-text-primary system-xs-semibold-uppercase">{isChatMode ? t('detail.conversationId', { ns: 'appLog' }) : t('detail.time', { ns: 'appLog' })}</div>
|
||||
{isChatMode && (
|
||||
<div className="system-2xs-regular-uppercase flex items-center text-text-secondary">
|
||||
<div className="flex items-center text-text-secondary system-2xs-regular-uppercase">
|
||||
<Tooltip
|
||||
popupContent={detail.id}
|
||||
>
|
||||
@@ -574,7 +574,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
</div>
|
||||
)}
|
||||
{!isChatMode && (
|
||||
<div className="system-2xs-regular-uppercase text-text-secondary">{formatTime(detail.created_at, t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
|
||||
<div className="text-text-secondary system-2xs-regular-uppercase">{formatTime(detail.created_at, t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex grow flex-wrap items-center justify-end gap-y-1">
|
||||
@@ -600,7 +600,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
? (
|
||||
<div className="px-6 py-4">
|
||||
<div className="flex h-[18px] items-center space-x-3">
|
||||
<div className="system-xs-semibold-uppercase text-text-tertiary">{t('table.header.output', { ns: 'appLog' })}</div>
|
||||
<div className="text-text-tertiary system-xs-semibold-uppercase">{t('table.header.output', { ns: 'appLog' })}</div>
|
||||
<div
|
||||
className="h-px grow"
|
||||
style={{
|
||||
@@ -692,7 +692,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
</div>
|
||||
{hasMore && (
|
||||
<div className="py-3 text-center">
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
<div className="text-text-tertiary system-xs-regular">
|
||||
{t('detail.loading', { ns: 'appLog' })}
|
||||
...
|
||||
</div>
|
||||
@@ -950,7 +950,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
)}
|
||||
popupClassName={(isHighlight && !isChatMode) ? '' : '!hidden'}
|
||||
>
|
||||
<div className={cn(isEmptyStyle ? 'text-text-quaternary' : 'text-text-secondary', !isHighlight ? '' : 'bg-orange-100', 'system-sm-regular overflow-hidden text-ellipsis whitespace-nowrap')}>
|
||||
<div className={cn(isEmptyStyle ? 'text-text-quaternary' : 'text-text-secondary', !isHighlight ? '' : 'bg-orange-100', 'overflow-hidden text-ellipsis whitespace-nowrap system-sm-regular')}>
|
||||
{value || '-'}
|
||||
</div>
|
||||
</Tooltip>
|
||||
@@ -963,7 +963,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
return (
|
||||
<div className="relative mt-2 grow overflow-x-auto">
|
||||
<table className={cn('w-full min-w-[440px] border-collapse border-0')}>
|
||||
<thead className="system-xs-medium-uppercase text-text-tertiary">
|
||||
<thead className="text-text-tertiary system-xs-medium-uppercase">
|
||||
<tr>
|
||||
<td className="w-5 whitespace-nowrap rounded-l-lg bg-background-section-burn pl-2 pr-1"></td>
|
||||
<td className="whitespace-nowrap bg-background-section-burn py-1.5 pl-3">{isChatMode ? t('table.header.summary', { ns: 'appLog' }) : t('table.header.input', { ns: 'appLog' })}</td>
|
||||
@@ -976,7 +976,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
<td className="whitespace-nowrap rounded-r-lg bg-background-section-burn py-1.5 pl-3">{t('table.header.time', { ns: 'appLog' })}</td>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="system-sm-regular text-text-secondary">
|
||||
<tbody className="text-text-secondary system-sm-regular">
|
||||
{logs.data.map((log: any) => {
|
||||
const endUser = log.from_end_user_session_id || log.from_account_name
|
||||
const leftValue = get(log, isChatMode ? 'name' : 'message.inputs.query') || (!isChatMode ? (get(log, 'message.query') || get(log, 'message.inputs.default_input')) : '') || ''
|
||||
|
||||
@@ -231,12 +231,12 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
{/* header */}
|
||||
<div className="pb-3 pl-6 pr-5 pt-5">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="title-2xl-semi-bold grow text-text-primary">{t(`${prefixSettings}.title`, { ns: 'appOverview' })}</div>
|
||||
<div className="grow text-text-primary title-2xl-semi-bold">{t(`${prefixSettings}.title`, { ns: 'appOverview' })}</div>
|
||||
<ActionButton className="shrink-0" onClick={onHide}>
|
||||
<RiCloseLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</div>
|
||||
<div className="system-xs-regular mt-0.5 text-text-tertiary">
|
||||
<div className="mt-0.5 text-text-tertiary system-xs-regular">
|
||||
<span>{t(`${prefixSettings}.modalTip`, { ns: 'appOverview' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
@@ -245,7 +245,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
{/* name & icon */}
|
||||
<div className="flex gap-4">
|
||||
<div className="grow">
|
||||
<div className={cn('system-sm-semibold mb-1 py-1 text-text-secondary')}>{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('mb-1 py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.webName`, { ns: 'appOverview' })}</div>
|
||||
<Input
|
||||
className="w-full"
|
||||
value={inputInfo.title}
|
||||
@@ -265,32 +265,32 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
</div>
|
||||
{/* description */}
|
||||
<div className="relative">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.webDesc`, { ns: 'appOverview' })}</div>
|
||||
<Textarea
|
||||
className="mt-1"
|
||||
value={inputInfo.desc}
|
||||
onChange={e => onDesChange(e.target.value)}
|
||||
placeholder={t(`${prefixSettings}.webDescPlaceholder`, { ns: 'appOverview' }) as string}
|
||||
/>
|
||||
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>{t(`${prefixSettings}.webDescTip`, { ns: 'appOverview' })}</p>
|
||||
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.webDescTip`, { ns: 'appOverview' })}</p>
|
||||
</div>
|
||||
<Divider className="my-0 h-px" />
|
||||
{/* answer icon */}
|
||||
{isChat && (
|
||||
<div className="w-full">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t('answerIcon.title', { ns: 'app' })}</div>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t('answerIcon.title', { ns: 'app' })}</div>
|
||||
<Switch
|
||||
value={inputInfo.use_icon_as_answer_icon}
|
||||
onChange={v => setInputInfo({ ...inputInfo, use_icon_as_answer_icon: v })}
|
||||
/>
|
||||
</div>
|
||||
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t('answerIcon.description', { ns: 'app' })}</p>
|
||||
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t('answerIcon.description', { ns: 'app' })}</p>
|
||||
</div>
|
||||
)}
|
||||
{/* language */}
|
||||
<div className="flex items-center">
|
||||
<div className={cn('system-sm-semibold grow py-1 text-text-secondary')}>{t(`${prefixSettings}.language`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('grow py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.language`, { ns: 'appOverview' })}</div>
|
||||
<SimpleSelect
|
||||
wrapperClassName="w-[200px]"
|
||||
items={languages.filter(item => item.supported)}
|
||||
@@ -303,8 +303,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
{isChat && (
|
||||
<div className="flex items-center">
|
||||
<div className="grow">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.chatColorTheme`, { ns: 'appOverview' })}</div>
|
||||
<div className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.chatColorThemeDesc`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.chatColorTheme`, { ns: 'appOverview' })}</div>
|
||||
<div className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.chatColorThemeDesc`, { ns: 'appOverview' })}</div>
|
||||
</div>
|
||||
<div className="shrink-0">
|
||||
<Input
|
||||
@@ -314,7 +314,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
placeholder="E.g #A020F0"
|
||||
/>
|
||||
<div className="flex items-center justify-between">
|
||||
<p className={cn('body-xs-regular text-text-tertiary')}>{t(`${prefixSettings}.chatColorThemeInverted`, { ns: 'appOverview' })}</p>
|
||||
<p className={cn('text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.chatColorThemeInverted`, { ns: 'appOverview' })}</p>
|
||||
<Switch value={inputInfo.chatColorThemeInverted} onChange={v => setInputInfo({ ...inputInfo, chatColorThemeInverted: v })}></Switch>
|
||||
</div>
|
||||
</div>
|
||||
@@ -323,22 +323,22 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
{/* workflow detail */}
|
||||
<div className="w-full">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.workflow.subTitle`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.workflow.subTitle`, { ns: 'appOverview' })}</div>
|
||||
<Switch
|
||||
disabled={!(appInfo.mode === AppModeEnum.WORKFLOW || appInfo.mode === AppModeEnum.ADVANCED_CHAT)}
|
||||
value={inputInfo.show_workflow_steps}
|
||||
onChange={v => setInputInfo({ ...inputInfo, show_workflow_steps: v })}
|
||||
/>
|
||||
</div>
|
||||
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.workflow.showDesc`, { ns: 'appOverview' })}</p>
|
||||
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.workflow.showDesc`, { ns: 'appOverview' })}</p>
|
||||
</div>
|
||||
{/* more settings switch */}
|
||||
<Divider className="my-0 h-px" />
|
||||
{!isShowMore && (
|
||||
<div className="flex cursor-pointer items-center" onClick={() => setIsShowMore(true)}>
|
||||
<div className="grow">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.entry`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.entry`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>
|
||||
{t(`${prefixSettings}.more.copyRightPlaceholder`, { ns: 'appOverview' })}
|
||||
{' '}
|
||||
&
|
||||
@@ -356,7 +356,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
<div className="w-full">
|
||||
<div className="flex items-center">
|
||||
<div className="flex grow items-center">
|
||||
<div className={cn('system-sm-semibold mr-1 py-1 text-text-secondary')}>{t(`${prefixSettings}.more.copyright`, { ns: 'appOverview' })}</div>
|
||||
<div className={cn('mr-1 py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.copyright`, { ns: 'appOverview' })}</div>
|
||||
{/* upgrade button */}
|
||||
{enableBilling && isFreePlan && (
|
||||
<div className="h-[18px] select-none">
|
||||
@@ -385,7 +385,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
/>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<p className="body-xs-regular pb-0.5 text-text-tertiary">{t(`${prefixSettings}.more.copyrightTip`, { ns: 'appOverview' })}</p>
|
||||
<p className="pb-0.5 text-text-tertiary body-xs-regular">{t(`${prefixSettings}.more.copyrightTip`, { ns: 'appOverview' })}</p>
|
||||
{inputInfo.copyrightSwitchValue && (
|
||||
<Input
|
||||
className="mt-2 h-10"
|
||||
@@ -397,8 +397,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
</div>
|
||||
{/* privacy policy */}
|
||||
<div className="w-full">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.privacyPolicy`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.privacyPolicy`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>
|
||||
<Trans
|
||||
i18nKey={`${prefixSettings}.more.privacyPolicyTip`}
|
||||
ns="appOverview"
|
||||
@@ -414,8 +414,8 @@ const SettingsModal: FC<ISettingsModalProps> = ({
|
||||
</div>
|
||||
{/* custom disclaimer */}
|
||||
<div className="w-full">
|
||||
<div className={cn('system-sm-semibold py-1 text-text-secondary')}>{t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>{t(`${prefixSettings}.more.customDisclaimerTip`, { ns: 'appOverview' })}</p>
|
||||
<div className={cn('py-1 text-text-secondary system-sm-semibold')}>{t(`${prefixSettings}.more.customDisclaimer`, { ns: 'appOverview' })}</div>
|
||||
<p className={cn('pb-0.5 text-text-tertiary body-xs-regular')}>{t(`${prefixSettings}.more.customDisclaimerTip`, { ns: 'appOverview' })}</p>
|
||||
<Textarea
|
||||
className="mt-1"
|
||||
value={inputInfo.customDisclaimer}
|
||||
|
||||
@@ -200,14 +200,14 @@ const ChatInputArea = ({
|
||||
<div className="relative flex w-full grow items-center">
|
||||
<div
|
||||
ref={textValueRef}
|
||||
className="body-lg-regular pointer-events-none invisible absolute h-auto w-auto whitespace-pre p-1 leading-6"
|
||||
className="pointer-events-none invisible absolute h-auto w-auto whitespace-pre p-1 leading-6 body-lg-regular"
|
||||
>
|
||||
{query}
|
||||
</div>
|
||||
<Textarea
|
||||
ref={ref => textareaRef.current = ref as any}
|
||||
className={cn(
|
||||
'body-lg-regular w-full resize-none bg-transparent p-1 leading-6 text-text-primary outline-none',
|
||||
'w-full resize-none bg-transparent p-1 leading-6 text-text-primary outline-none body-lg-regular',
|
||||
)}
|
||||
placeholder={decode(t(readonly ? 'chat.inputDisabledPlaceholder' : 'chat.inputPlaceholder', { ns: 'common', botName }) || '')}
|
||||
autoFocus
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
<svg width="10" height="10" viewBox="0 0 10 10" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z" fill="#676F83"/>
|
||||
<path d="M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z" fill="#676F83"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
@@ -1,5 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="arrow-down-round-fill">
|
||||
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="#101828"/>
|
||||
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="currentColor"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 380 B After Width: | Height: | Size: 385 B |
@@ -1,3 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path id="Solid" fill-rule="evenodd" clip-rule="evenodd" d="M8.00008 0.666016C3.94999 0.666016 0.666748 3.94926 0.666748 7.99935C0.666748 12.0494 3.94999 15.3327 8.00008 15.3327C12.0502 15.3327 15.3334 12.0494 15.3334 7.99935C15.3334 3.94926 12.0502 0.666016 8.00008 0.666016ZM10.4715 5.52794C10.7318 5.78829 10.7318 6.2104 10.4715 6.47075L8.94289 7.99935L10.4715 9.52794C10.7318 9.78829 10.7318 10.2104 10.4715 10.4708C10.2111 10.7311 9.78903 10.7311 9.52868 10.4708L8.00008 8.94216L6.47149 10.4708C6.21114 10.7311 5.78903 10.7311 5.52868 10.4708C5.26833 10.2104 5.26833 9.78829 5.52868 9.52794L7.05727 7.99935L5.52868 6.47075C5.26833 6.2104 5.26833 5.78829 5.52868 5.52794C5.78903 5.26759 6.21114 5.26759 6.47149 5.52794L8.00008 7.05654L9.52868 5.52794C9.78903 5.26759 10.2111 5.26759 10.4715 5.52794Z" fill="#98A2B3"/>
|
||||
<path id="Solid" fill-rule="evenodd" clip-rule="evenodd" d="M8.00008 0.666016C3.94999 0.666016 0.666748 3.94926 0.666748 7.99935C0.666748 12.0494 3.94999 15.3327 8.00008 15.3327C12.0502 15.3327 15.3334 12.0494 15.3334 7.99935C15.3334 3.94926 12.0502 0.666016 8.00008 0.666016ZM10.4715 5.52794C10.7318 5.78829 10.7318 6.2104 10.4715 6.47075L8.94289 7.99935L10.4715 9.52794C10.7318 9.78829 10.7318 10.2104 10.4715 10.4708C10.2111 10.7311 9.78903 10.7311 9.52868 10.4708L8.00008 8.94216L6.47149 10.4708C6.21114 10.7311 5.78903 10.7311 5.52868 10.4708C5.26833 10.2104 5.26833 9.78829 5.52868 9.52794L7.05727 7.99935L5.52868 6.47075C5.26833 6.2104 5.26833 5.78829 5.52868 5.52794C5.78903 5.26759 6.21114 5.26759 6.47149 5.52794L8.00008 7.05654L9.52868 5.52794C9.78903 5.26759 10.2111 5.26759 10.4715 5.52794Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 925 B After Width: | Height: | Size: 930 B |
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "10",
|
||||
"height": "10",
|
||||
"viewBox": "0 0 10 10",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M2 5C2 3.44487 2.58482 1.98537 3.54004 1.04932C2.17681 1.34034 1 2.90001 1 5C1 7.09996 2.17685 8.65912 3.54004 8.9502C2.58496 8.01413 2 6.55501 2 5ZM3 5C3 7.33338 4.4528 9 6 9C7.5472 9 9 7.33338 9 5C9 2.66664 7.5472 1 6 1C4.4528 1 3 2.66664 3 5ZM10 5C10 7.63722 8.3188 10 6 10H4C1.6812 10 0 7.63722 0 5C0 2.3628 1.6812 0 4 0H6C8.3188 0 10 2.3628 10 5Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M6.71519 4.09259L6.45385 3.18667C6.42141 3.07421 6.34037 3 6.25 3C6.15963 3 6.07859 3.07421 6.04615 3.18667L5.78481 4.09259C5.74675 4.22464 5.66849 4.32899 5.56945 4.37978L4.88999 4.7282C4.80565 4.77146 4.75 4.87951 4.75 5C4.75 5.12049 4.80565 5.22854 4.88999 5.2718L5.56945 5.62022C5.66849 5.67101 5.74675 5.77536 5.78481 5.90741L6.04615 6.81333C6.07859 6.92579 6.15963 7 6.25 7C6.34037 7 6.42141 6.92579 6.45385 6.81333L6.71519 5.90741C6.75325 5.77536 6.83151 5.67101 6.93055 5.62022L7.61001 5.2718C7.69435 5.22854 7.75 5.12049 7.75 5C7.75 4.87951 7.69435 4.77146 7.61001 4.7282L6.93055 4.37978C6.83151 4.32899 6.75325 4.22464 6.71519 4.09259Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "CreditsCoin"
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './CreditsCoin.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'CreditsCoin'
|
||||
|
||||
export default Icon
|
||||
@@ -1,5 +1,6 @@
|
||||
export { default as Balance } from './Balance'
|
||||
export { default as CoinsStacked01 } from './CoinsStacked01'
|
||||
export { default as CreditsCoin } from './CreditsCoin'
|
||||
export { default as GoldCoin } from './GoldCoin'
|
||||
export { default as ReceiptList } from './ReceiptList'
|
||||
export { default as Tag01 } from './Tag01'
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useReactFlow, useStoreApi } from 'reactflow'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { isConversationVar, isENV, isGlobalVar, isRagVariableVar, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
|
||||
import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel'
|
||||
import {
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
UPDATE_WORKFLOW_NODES_MAP,
|
||||
} from './index'
|
||||
import { WorkflowVariableBlockNode } from './node'
|
||||
import { useLlmModelPluginInstalled } from './use-llm-model-plugin-installed'
|
||||
|
||||
type WorkflowVariableBlockComponentProps = {
|
||||
nodeKey: string
|
||||
@@ -68,6 +69,8 @@ const WorkflowVariableBlockComponent = ({
|
||||
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
|
||||
|
||||
const isException = isExceptionVariable(varName, node?.type)
|
||||
const sourceNodeId = variables[isRagVar ? 1 : 0]
|
||||
const isLlmModelInstalled = useLlmModelPluginInstalled(sourceNodeId, localWorkflowNodesMap)
|
||||
const variableValid = useMemo(() => {
|
||||
let variableValid = true
|
||||
const isEnv = isENV(variables)
|
||||
@@ -144,7 +147,13 @@ const WorkflowVariableBlockComponent = ({
|
||||
handleVariableJump()
|
||||
}}
|
||||
isExceptionVariable={isException}
|
||||
errorMsg={!variableValid ? t('errorMsg.invalidVariable', { ns: 'workflow' }) : undefined}
|
||||
errorMsg={
|
||||
!variableValid
|
||||
? t('errorMsg.invalidVariable', { ns: 'workflow' })
|
||||
: !isLlmModelInstalled
|
||||
? t('errorMsg.modelPluginNotInstalled', { ns: 'workflow' })
|
||||
: undefined
|
||||
}
|
||||
isSelected={isSelected}
|
||||
ref={ref}
|
||||
notShowFullPath={isShowAPart}
|
||||
@@ -155,9 +164,9 @@ const WorkflowVariableBlockComponent = ({
|
||||
return Item
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
noDecoration
|
||||
popupContent={(
|
||||
<Tooltip>
|
||||
<TooltipTrigger disabled={!isShowAPart} render={<div>{Item}</div>} />
|
||||
<TooltipContent variant="plain">
|
||||
<VarFullPathPanel
|
||||
nodeName={node.title}
|
||||
path={variables.slice(1)}
|
||||
@@ -169,10 +178,7 @@ const WorkflowVariableBlockComponent = ({
|
||||
: Type.string}
|
||||
nodeType={node?.type}
|
||||
/>
|
||||
)}
|
||||
disabled={!isShowAPart}
|
||||
>
|
||||
<div>{Item}</div>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import type { WorkflowNodesMap } from '@/app/components/base/prompt-editor/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { extractPluginId } from '@/app/components/workflow/utils/plugin'
|
||||
import { useProviderContextSelector } from '@/context/provider-context'
|
||||
|
||||
export function useLlmModelPluginInstalled(
|
||||
nodeId: string,
|
||||
workflowNodesMap: WorkflowNodesMap | undefined,
|
||||
): boolean {
|
||||
const node = workflowNodesMap?.[nodeId]
|
||||
const modelProvider = node?.type === BlockEnum.LLM
|
||||
? node.modelProvider
|
||||
: undefined
|
||||
const modelPluginId = modelProvider ? extractPluginId(modelProvider) : undefined
|
||||
|
||||
return useProviderContextSelector((state) => {
|
||||
if (!modelPluginId)
|
||||
return true
|
||||
return state.modelProviders.some(p =>
|
||||
extractPluginId(p.provider) === modelPluginId,
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -73,7 +73,7 @@ export type GetVarType = (payload: {
|
||||
export type WorkflowVariableBlockType = {
|
||||
show?: boolean
|
||||
variables?: NodeOutPutVar[]
|
||||
workflowNodesMap?: Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'>>
|
||||
workflowNodesMap?: WorkflowNodesMap
|
||||
onInsert?: () => void
|
||||
onDelete?: () => void
|
||||
getVarType?: GetVarType
|
||||
@@ -81,12 +81,14 @@ export type WorkflowVariableBlockType = {
|
||||
onManageInputField?: () => void
|
||||
}
|
||||
|
||||
export type WorkflowNodesMap = Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'> & { modelProvider?: string }>
|
||||
|
||||
export type HITLInputBlockType = {
|
||||
show?: boolean
|
||||
nodeId: string
|
||||
formInputs?: FormInputItem[]
|
||||
variables?: NodeOutPutVar[]
|
||||
workflowNodesMap?: Record<string, Pick<Node['data'], 'title' | 'type' | 'height' | 'width' | 'position'>>
|
||||
workflowNodesMap?: WorkflowNodesMap
|
||||
getVarType?: GetVarType
|
||||
onFormInputsChange?: (inputs: FormInputItem[]) => void
|
||||
onFormInputItemRemove: (varName: string) => void
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import type { ChangeEvent, FC, KeyboardEvent } from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import AutosizeInput from 'react-18-input-autosize'
|
||||
import _AutosizeInput from 'react-18-input-autosize'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
// CJS/ESM interop: Turbopack may resolve the module namespace object instead of the default export
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
const AutosizeInput = ('default' in (_AutosizeInput as any) ? (_AutosizeInput as any).default : _AutosizeInput) as typeof _AutosizeInput
|
||||
|
||||
type TagInputProps = {
|
||||
items: string[]
|
||||
onChange: (items: string[]) => void
|
||||
|
||||
@@ -43,20 +43,24 @@ type DialogContentProps = {
|
||||
children: React.ReactNode
|
||||
className?: string
|
||||
overlayClassName?: string
|
||||
backdropProps?: React.ComponentPropsWithoutRef<typeof BaseDialog.Backdrop>
|
||||
}
|
||||
|
||||
export function DialogContent({
|
||||
children,
|
||||
className,
|
||||
overlayClassName,
|
||||
backdropProps,
|
||||
}: DialogContentProps) {
|
||||
return (
|
||||
<DialogPortal>
|
||||
<BaseDialog.Backdrop
|
||||
{...backdropProps}
|
||||
className={cn(
|
||||
'fixed inset-0 z-50 bg-background-overlay',
|
||||
'transition-opacity duration-150 data-[ending-style]:opacity-0 data-[starting-style]:opacity-0 motion-reduce:transition-none',
|
||||
overlayClassName,
|
||||
backdropProps?.className,
|
||||
)}
|
||||
/>
|
||||
<BaseDialog.Popup
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { CategoryEnum } from '..'
|
||||
import Footer from '../footer'
|
||||
import { CategoryEnum } from '../types'
|
||||
|
||||
vi.mock('next/link', () => ({
|
||||
default: ({ children, href, className, target }: { children: React.ReactNode, href: string, className?: string, target?: string }) => (
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { Dialog } from '@/app/components/base/ui/dialog'
|
||||
import Header from '../header'
|
||||
|
||||
function renderHeader(onClose: () => void) {
|
||||
return render(
|
||||
<Dialog open>
|
||||
<Header onClose={onClose} />
|
||||
</Dialog>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Header', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -11,7 +20,7 @@ describe('Header', () => {
|
||||
it('should render title and description translations', () => {
|
||||
const handleClose = vi.fn()
|
||||
|
||||
render(<Header onClose={handleClose} />)
|
||||
renderHeader(handleClose)
|
||||
|
||||
expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument()
|
||||
@@ -22,7 +31,7 @@ describe('Header', () => {
|
||||
describe('Props', () => {
|
||||
it('should invoke onClose when close button is clicked', () => {
|
||||
const handleClose = vi.fn()
|
||||
render(<Header onClose={handleClose} />)
|
||||
renderHeader(handleClose)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
@@ -32,7 +41,7 @@ describe('Header', () => {
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should render structural elements with translation keys', () => {
|
||||
const { container } = render(<Header onClose={vi.fn()} />)
|
||||
const { container } = renderHeader(vi.fn())
|
||||
|
||||
expect(container.querySelector('span')).toBeInTheDocument()
|
||||
expect(container.querySelector('p')).toBeInTheDocument()
|
||||
|
||||
@@ -74,15 +74,11 @@ describe('Pricing', () => {
|
||||
})
|
||||
|
||||
describe('Props', () => {
|
||||
it('should allow switching categories and handle esc key', () => {
|
||||
const handleCancel = vi.fn()
|
||||
render(<Pricing onCancel={handleCancel} />)
|
||||
it('should allow switching categories', () => {
|
||||
render(<Pricing onCancel={vi.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByText('billing.plansCommon.self'))
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.keyDown(window, { key: 'Escape', keyCode: 27 })
|
||||
expect(handleCancel).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import type { Category } from '.'
|
||||
import { RiArrowRightUpLine } from '@remixicon/react'
|
||||
import type { Category } from './types'
|
||||
import Link from 'next/link'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { CategoryEnum } from '.'
|
||||
import { CategoryEnum } from './types'
|
||||
|
||||
type FooterProps = {
|
||||
pricingPageURL: string
|
||||
@@ -34,7 +33,7 @@ const Footer = ({
|
||||
>
|
||||
{t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })}
|
||||
</Link>
|
||||
<RiArrowRightUpLine className="size-4" />
|
||||
<span aria-hidden="true" className="i-ri-arrow-right-up-line size-4" />
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog'
|
||||
import Button from '../../base/button'
|
||||
import DifyLogo from '../../base/logo/dify-logo'
|
||||
|
||||
@@ -20,19 +20,19 @@ const Header = ({
|
||||
<div className="py-[5px]">
|
||||
<DifyLogo className="h-[27px] w-[60px]" />
|
||||
</div>
|
||||
<span className="bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
|
||||
<DialogTitle className="m-0 bg-billing-plan-title-bg bg-clip-text px-1.5 font-instrument text-[37px] italic leading-[1.2] text-transparent">
|
||||
{t('plansCommon.title.plans', { ns: 'billing' })}
|
||||
</span>
|
||||
</DialogTitle>
|
||||
</div>
|
||||
<p className="system-sm-regular text-text-tertiary">
|
||||
<DialogDescription className="m-0 text-text-tertiary system-sm-regular">
|
||||
{t('plansCommon.title.description', { ns: 'billing' })}
|
||||
</p>
|
||||
</DialogDescription>
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="absolute bottom-[40.5px] right-[-18px] z-10 size-9 rounded-full p-2"
|
||||
onClick={onClose}
|
||||
>
|
||||
<RiCloseLine className="size-5" />
|
||||
<span aria-hidden="true" className="i-ri-close-line size-5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { useKeyPress } from 'ahooks'
|
||||
import type { Category } from './types'
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { createPortal } from 'react-dom'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGetPricingPageLanguage } from '@/context/i18n'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
@@ -13,13 +13,7 @@ import Header from './header'
|
||||
import PlanSwitcher from './plan-switcher'
|
||||
import { PlanRange } from './plan-switcher/plan-range-switcher'
|
||||
import Plans from './plans'
|
||||
|
||||
export enum CategoryEnum {
|
||||
CLOUD = 'cloud',
|
||||
SELF = 'self',
|
||||
}
|
||||
|
||||
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
|
||||
import { CategoryEnum } from './types'
|
||||
|
||||
type PricingProps = {
|
||||
onCancel: () => void
|
||||
@@ -33,42 +27,47 @@ const Pricing: FC<PricingProps> = ({
|
||||
const [planRange, setPlanRange] = React.useState<PlanRange>(PlanRange.monthly)
|
||||
const [currentCategory, setCurrentCategory] = useState<Category>(CategoryEnum.CLOUD)
|
||||
const canPay = isCurrentWorkspaceManager
|
||||
useKeyPress(['esc'], onCancel)
|
||||
|
||||
const pricingPageLanguage = useGetPricingPageLanguage()
|
||||
const pricingPageURL = pricingPageLanguage
|
||||
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
|
||||
: 'https://dify.ai/pricing#plans-and-features'
|
||||
|
||||
return createPortal(
|
||||
<div
|
||||
className="fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] overflow-auto bg-saas-background"
|
||||
onClick={e => e.stopPropagation()}
|
||||
return (
|
||||
<Dialog
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}}
|
||||
>
|
||||
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
|
||||
<div className="absolute -top-12 left-0 right-0 -z-10">
|
||||
<NoiseTop />
|
||||
<DialogContent
|
||||
className="inset-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-auto rounded-none border-none bg-saas-background p-0 shadow-none"
|
||||
>
|
||||
<div className="relative grid min-h-full min-w-[1200px] grid-rows-[1fr_auto_auto_1fr] overflow-hidden">
|
||||
<div className="absolute -top-12 left-0 right-0 -z-10">
|
||||
<NoiseTop />
|
||||
</div>
|
||||
<Header onClose={onCancel} />
|
||||
<PlanSwitcher
|
||||
currentCategory={currentCategory}
|
||||
onChangeCategory={setCurrentCategory}
|
||||
currentPlanRange={planRange}
|
||||
onChangePlanRange={setPlanRange}
|
||||
/>
|
||||
<Plans
|
||||
plan={plan}
|
||||
currentPlan={currentCategory}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
|
||||
<div className="absolute -bottom-12 left-0 right-0 -z-10">
|
||||
<NoiseBottom />
|
||||
</div>
|
||||
</div>
|
||||
<Header onClose={onCancel} />
|
||||
<PlanSwitcher
|
||||
currentCategory={currentCategory}
|
||||
onChangeCategory={setCurrentCategory}
|
||||
currentPlanRange={planRange}
|
||||
onChangePlanRange={setPlanRange}
|
||||
/>
|
||||
<Plans
|
||||
plan={plan}
|
||||
currentPlan={currentCategory}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
<Footer pricingPageURL={pricingPageURL} currentCategory={currentCategory} />
|
||||
<div className="absolute -bottom-12 left-0 right-0 -z-10">
|
||||
<NoiseBottom />
|
||||
</div>
|
||||
</div>
|
||||
</div>,
|
||||
document.body,
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
export default React.memo(Pricing)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { CategoryEnum } from '../../index'
|
||||
import { CategoryEnum } from '../../types'
|
||||
import PlanSwitcher from '../index'
|
||||
import { PlanRange } from '../plan-range-switcher'
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Category } from '../index'
|
||||
import type { Category } from '../types'
|
||||
import type { PlanRange } from './plan-range-switcher'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
6
web/app/components/billing/pricing/types.ts
Normal file
6
web/app/components/billing/pricing/types.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
export enum CategoryEnum {
|
||||
CLOUD = 'cloud',
|
||||
SELF = 'self',
|
||||
}
|
||||
|
||||
export type Category = CategoryEnum.CLOUD | CategoryEnum.SELF
|
||||
@@ -204,7 +204,7 @@ const CSVUploader: FC<Props> = ({
|
||||
/>
|
||||
<div ref={dropRef}>
|
||||
{!file && (
|
||||
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-panel-border bg-components-panel-bg-blur text-sm font-normal', dragging && 'border border-divider-subtle bg-components-panel-on-panel-item-bg-hover')}>
|
||||
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-panel-border bg-components-panel-bg-blur text-sm font-normal', dragging && 'border border-divider-subtle bg-components-panel-on-panel-item-bg-hover')}>
|
||||
<div className="flex w-full items-center justify-center space-x-2">
|
||||
<CSVIcon className="shrink-0" />
|
||||
<div className="text-text-secondary">
|
||||
|
||||
@@ -58,7 +58,7 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="system-xs-semibold text-text-accent"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChildChunk?.()
|
||||
@@ -120,11 +120,11 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
<div className="flex h-full flex-col">
|
||||
<div className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
|
||||
<div className="flex flex-col">
|
||||
<div className="system-xl-semibold text-text-primary">{t('segment.addChildChunk', { ns: 'datasetDocuments' })}</div>
|
||||
<div className="text-text-primary system-xl-semibold">{t('segment.addChildChunk', { ns: 'datasetDocuments' })}</div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<SegmentIndexTag label={t('segment.newChildChunk', { ns: 'datasetDocuments' }) as string} />
|
||||
<Dot />
|
||||
<span className="system-xs-medium text-text-tertiary">{wordCountText}</span>
|
||||
<span className="text-text-tertiary system-xs-medium">{wordCountText}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
|
||||
@@ -61,7 +61,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="system-xs-semibold text-text-accent"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChunk()
|
||||
@@ -158,13 +158,13 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}
|
||||
>
|
||||
<div className="flex flex-col">
|
||||
<div className="system-xl-semibold text-text-primary">
|
||||
<div className="text-text-primary system-xl-semibold">
|
||||
{t('segment.addChunk', { ns: 'datasetDocuments' })}
|
||||
</div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<SegmentIndexTag label={t('segment.newChunk', { ns: 'datasetDocuments' })!} />
|
||||
<Dot />
|
||||
<span className="system-xs-medium text-text-tertiary">{wordCountText}</span>
|
||||
<span className="text-text-tertiary system-xs-medium">{wordCountText}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
|
||||
@@ -100,10 +100,10 @@ vi.mock('@/app/components/datasets/create/step-two', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting', () => ({
|
||||
default: ({ activeTab, onCancel }: { activeTab?: string, onCancel?: () => void }) => (
|
||||
default: ({ activeTab, onCancelAction }: { activeTab?: string, onCancelAction?: () => void }) => (
|
||||
<div data-testid="account-setting">
|
||||
<span data-testid="active-tab">{activeTab}</span>
|
||||
<button onClick={onCancel} data-testid="close-setting">Close</button>
|
||||
<button onClick={onCancelAction} data-testid="close-setting">Close</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import type { DataSourceProvider, NotionPage } from '@/models/common'
|
||||
import type {
|
||||
CrawlOptions,
|
||||
@@ -19,6 +20,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import StepTwo from '@/app/components/datasets/create/step-two'
|
||||
import AccountSetting from '@/app/components/header/account-setting'
|
||||
import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import DatasetDetailContext from '@/context/dataset-detail'
|
||||
@@ -33,8 +35,13 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean()
|
||||
const [accountSettingTab, setAccountSettingTab] = React.useState<AccountSettingTab>(ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
const { indexingTechnique, dataset } = useContext(DatasetDetailContext)
|
||||
const { data: embeddingsDefaultModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
const handleOpenAccountSetting = React.useCallback(() => {
|
||||
setAccountSettingTab(ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
showSetAPIKey()
|
||||
}, [showSetAPIKey])
|
||||
|
||||
const invalidDocumentList = useInvalidDocumentList(datasetId)
|
||||
const invalidDocumentDetail = useInvalidDocumentDetail()
|
||||
@@ -135,7 +142,7 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
{dataset && documentDetail && (
|
||||
<StepTwo
|
||||
isAPIKeySet={!!embeddingsDefaultModel}
|
||||
onSetting={showSetAPIKey}
|
||||
onSetting={handleOpenAccountSetting}
|
||||
datasetId={datasetId}
|
||||
dataSourceType={documentDetail.data_source_type as DataSourceType}
|
||||
notionPages={currentPage ? [currentPage as unknown as NotionPage] : []}
|
||||
@@ -155,8 +162,9 @@ const DocumentSettings = ({ datasetId, documentId }: DocumentSettingsProps) => {
|
||||
</div>
|
||||
{isShowSetAPIKey && (
|
||||
<AccountSetting
|
||||
activeTab="provider"
|
||||
onCancel={async () => {
|
||||
activeTab={accountSettingTab}
|
||||
onTabChangeAction={setAccountSettingTab}
|
||||
onCancelAction={async () => {
|
||||
hideSetAPIkey()
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -120,13 +120,13 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
|
||||
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
|
||||
<div className="shadows-shadow-xl relative flex w-[480px] flex-col items-start rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg">
|
||||
<div className="flex flex-col items-start gap-2 self-stretch pb-3 pl-6 pr-14 pt-6">
|
||||
<div className="title-2xl-semi-bold grow self-stretch text-text-primary">
|
||||
<div className="grow self-stretch text-text-primary title-2xl-semi-bold">
|
||||
{
|
||||
isEditMode ? t('editExternalAPIFormTitle', { ns: 'dataset' }) : t('createExternalAPI', { ns: 'dataset' })
|
||||
}
|
||||
</div>
|
||||
{isEditMode && (datasetBindings?.length ?? 0) > 0 && (
|
||||
<div className="system-xs-regular flex items-center text-text-tertiary">
|
||||
<div className="flex items-center text-text-tertiary system-xs-regular">
|
||||
{t('editExternalAPIFormWarning.front', { ns: 'dataset' })}
|
||||
<span className="flex cursor-pointer items-center text-text-accent">
|
||||
|
||||
@@ -139,12 +139,12 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
|
||||
popupContent={(
|
||||
<div className="p-1">
|
||||
<div className="flex items-start self-stretch pb-0.5 pl-2 pr-3 pt-1">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{`${datasetBindings?.length} ${t('editExternalAPITooltipTitle', { ns: 'dataset' })}`}</div>
|
||||
<div className="text-text-tertiary system-xs-medium-uppercase">{`${datasetBindings?.length} ${t('editExternalAPITooltipTitle', { ns: 'dataset' })}`}</div>
|
||||
</div>
|
||||
{datasetBindings?.map(binding => (
|
||||
<div key={binding.id} className="flex items-center gap-1 self-stretch px-2 py-1">
|
||||
<RiBook2Line className="h-4 w-4 text-text-secondary" />
|
||||
<div className="system-sm-medium text-text-secondary">{binding.name}</div>
|
||||
<div className="text-text-secondary system-sm-medium">{binding.name}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -188,8 +188,8 @@ const AddExternalAPIModal: FC<AddExternalAPIModalProps> = ({ data, onSave, onCan
|
||||
{t('externalAPIForm.save', { ns: 'dataset' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="system-xs-regular flex items-center justify-center gap-1 self-stretch rounded-b-2xl border-t-[0.5px]
|
||||
border-divider-subtle bg-background-soft px-2 py-3 text-text-tertiary"
|
||||
<div className="flex items-center justify-center gap-1 self-stretch rounded-b-2xl border-t-[0.5px] border-divider-subtle
|
||||
bg-background-soft px-2 py-3 text-text-tertiary system-xs-regular"
|
||||
>
|
||||
<RiLock2Fill className="h-3 w-3 text-text-quaternary" />
|
||||
{t('externalAPIForm.encrypted.front', { ns: 'dataset' })}
|
||||
|
||||
@@ -63,7 +63,7 @@ const SummaryIndexSetting = ({
|
||||
return (
|
||||
<div>
|
||||
<div className="flex h-6 items-center justify-between">
|
||||
<div className="system-sm-semibold-uppercase flex items-center text-text-secondary">
|
||||
<div className="flex items-center text-text-secondary system-sm-semibold-uppercase">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
<Tooltip
|
||||
triggerClassName="ml-1 h-4 w-4 shrink-0"
|
||||
@@ -80,7 +80,7 @@ const SummaryIndexSetting = ({
|
||||
{
|
||||
summaryIndexSetting?.enable && (
|
||||
<div>
|
||||
<div className="system-xs-medium-uppercase mb-1.5 mt-2 flex h-6 items-center text-text-tertiary">
|
||||
<div className="mb-1.5 mt-2 flex h-6 items-center text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<ModelSelector
|
||||
@@ -90,7 +90,7 @@ const SummaryIndexSetting = ({
|
||||
readonly={readonly}
|
||||
showDeprecatedWarnIcon
|
||||
/>
|
||||
<div className="system-xs-medium-uppercase mt-3 flex h-6 items-center text-text-tertiary">
|
||||
<div className="mt-3 flex h-6 items-center text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<Textarea
|
||||
@@ -111,12 +111,12 @@ const SummaryIndexSetting = ({
|
||||
<div className="space-y-4">
|
||||
<div className="flex gap-x-1">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-semibold text-text-secondary">
|
||||
<div className="text-text-secondary system-sm-semibold">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="py-1.5">
|
||||
<div className="system-sm-semibold flex items-center text-text-secondary">
|
||||
<div className="flex items-center text-text-secondary system-sm-semibold">
|
||||
<Switch
|
||||
className="mr-2"
|
||||
value={summaryIndexSetting?.enable ?? false}
|
||||
@@ -127,7 +127,7 @@ const SummaryIndexSetting = ({
|
||||
summaryIndexSetting?.enable ? t('list.status.enabled', { ns: 'datasetDocuments' }) : t('list.status.disabled', { ns: 'datasetDocuments' })
|
||||
}
|
||||
</div>
|
||||
<div className="system-sm-regular mt-2 text-text-tertiary">
|
||||
<div className="mt-2 text-text-tertiary system-sm-regular">
|
||||
{
|
||||
summaryIndexSetting?.enable && t('form.summaryAutoGenTip', { ns: 'datasetSettings' })
|
||||
}
|
||||
@@ -142,7 +142,7 @@ const SummaryIndexSetting = ({
|
||||
<>
|
||||
<div className="flex gap-x-1">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-medium text-text-tertiary">
|
||||
<div className="text-text-tertiary system-sm-medium">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -159,7 +159,7 @@ const SummaryIndexSetting = ({
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex h-7 w-[180px] shrink-0 items-center pt-1">
|
||||
<div className="system-sm-medium text-text-tertiary">
|
||||
<div className="text-text-tertiary system-sm-medium">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -188,7 +188,7 @@ const SummaryIndexSetting = ({
|
||||
onChange={handleSummaryIndexEnableChange}
|
||||
size="md"
|
||||
/>
|
||||
<div className="system-sm-semibold text-text-secondary">
|
||||
<div className="text-text-secondary system-sm-semibold">
|
||||
{t('form.summaryAutoGen', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -196,7 +196,7 @@ const SummaryIndexSetting = ({
|
||||
summaryIndexSetting?.enable && (
|
||||
<>
|
||||
<div>
|
||||
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
|
||||
<div className="mb-1.5 flex h-6 items-center text-text-secondary system-sm-medium">
|
||||
{t('form.summaryModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<ModelSelector
|
||||
@@ -209,7 +209,7 @@ const SummaryIndexSetting = ({
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="system-sm-medium mb-1.5 flex h-6 items-center text-text-secondary">
|
||||
<div className="mb-1.5 flex h-6 items-center text-text-secondary system-sm-medium">
|
||||
{t('form.summaryInstructions', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<Textarea
|
||||
|
||||
@@ -46,7 +46,7 @@ const WorkplaceSelector = () => {
|
||||
<span className="h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90">{currentWorkspace?.name[0]?.toLocaleUpperCase()}</span>
|
||||
</div>
|
||||
<div className="flex min-w-0 items-center">
|
||||
<div className="system-sm-medium min-w-0 max-w-[149px] truncate text-text-secondary max-[800px]:hidden">{currentWorkspace?.name}</div>
|
||||
<div className="min-w-0 max-w-[149px] truncate text-text-secondary system-sm-medium max-[800px]:hidden">{currentWorkspace?.name}</div>
|
||||
<RiArrowDownSLine className="h-4 w-4 shrink-0 text-text-secondary" />
|
||||
</div>
|
||||
</MenuButton>
|
||||
@@ -68,9 +68,9 @@ const WorkplaceSelector = () => {
|
||||
`,
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full flex-col items-start self-stretch rounded-xl border-[0.5px] border-components-panel-border p-1 pb-2 shadow-lg ">
|
||||
<div className="flex w-full flex-col items-start self-stretch rounded-xl border-[0.5px] border-components-panel-border p-1 pb-2 shadow-lg">
|
||||
<div className="flex items-start self-stretch px-3 pb-0.5 pt-1">
|
||||
<span className="system-xs-medium-uppercase flex-1 text-text-tertiary">{t('userProfile.workspace', { ns: 'common' })}</span>
|
||||
<span className="flex-1 text-text-tertiary system-xs-medium-uppercase">{t('userProfile.workspace', { ns: 'common' })}</span>
|
||||
</div>
|
||||
{
|
||||
workspaces.map(workspace => (
|
||||
@@ -78,7 +78,7 @@ const WorkplaceSelector = () => {
|
||||
<div className="flex h-6 w-6 shrink-0 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px]">
|
||||
<span className="h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90">{workspace?.name[0]?.toLocaleUpperCase()}</span>
|
||||
</div>
|
||||
<div className="system-md-regular line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary">{workspace.name}</div>
|
||||
<div className="line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary system-md-regular">{workspace.name}</div>
|
||||
<PlanBadge plan={workspace.plan as Plan} />
|
||||
</div>
|
||||
))
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import type { AccountSettingTab } from './constants'
|
||||
import type { AppContextValue } from '@/context/app-context'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { useState } from 'react'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { baseProviderContextValue, useProviderContext } from '@/context/provider-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { ACCOUNT_SETTING_TAB } from './constants'
|
||||
import AccountSetting from './index'
|
||||
|
||||
const mockResetModelProviderListExpanded = vi.fn()
|
||||
|
||||
vi.mock('@/context/provider-context', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/context/provider-context')>()
|
||||
return {
|
||||
@@ -47,10 +51,15 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', ()
|
||||
useDefaultModel: vi.fn(() => ({ data: null, isLoading: false })),
|
||||
useUpdateDefaultModel: vi.fn(() => ({ trigger: vi.fn() })),
|
||||
useUpdateModelList: vi.fn(() => vi.fn()),
|
||||
useInvalidateDefaultModel: vi.fn(() => vi.fn()),
|
||||
useModelList: vi.fn(() => ({ data: [], isLoading: false })),
|
||||
useSystemDefaultModelAndModelList: vi.fn(() => [null, vi.fn()]),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/account-setting/model-provider-page/atoms', () => ({
|
||||
useResetModelProviderListExpanded: () => mockResetModelProviderListExpanded,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-datasource', () => ({
|
||||
useGetDataSourceListAuth: vi.fn(() => ({ data: { result: [] } })),
|
||||
}))
|
||||
@@ -105,6 +114,38 @@ const baseAppContextValue: AppContextValue = {
|
||||
describe('AccountSetting', () => {
|
||||
const mockOnCancel = vi.fn()
|
||||
const mockOnTabChange = vi.fn()
|
||||
const renderAccountSetting = (props?: {
|
||||
initialTab?: AccountSettingTab
|
||||
onCancel?: () => void
|
||||
onTabChange?: (tab: AccountSettingTab) => void
|
||||
}) => {
|
||||
const {
|
||||
initialTab = ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
onCancel = mockOnCancel,
|
||||
onTabChange = mockOnTabChange,
|
||||
} = props ?? {}
|
||||
|
||||
const StatefulAccountSetting = () => {
|
||||
const [activeTab, setActiveTab] = useState<AccountSettingTab>(initialTab)
|
||||
|
||||
return (
|
||||
<AccountSetting
|
||||
onCancelAction={onCancel}
|
||||
activeTab={activeTab}
|
||||
onTabChangeAction={(tab) => {
|
||||
setActiveTab(tab)
|
||||
onTabChange(tab)
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<StatefulAccountSetting />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -120,11 +161,7 @@ describe('AccountSetting', () => {
|
||||
describe('Rendering', () => {
|
||||
it('should render the sidebar with correct menu items', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.userProfile.settings')).toBeInTheDocument()
|
||||
@@ -137,13 +174,9 @@ describe('AccountSetting', () => {
|
||||
expect(screen.getAllByText('common.settings.language').length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should respect the activeTab prop', () => {
|
||||
it('should respect the initial tab', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} activeTab={ACCOUNT_SETTING_TAB.DATA_SOURCE} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.DATA_SOURCE })
|
||||
|
||||
// Assert
|
||||
// Check that the active item title is Data Source
|
||||
@@ -157,11 +190,7 @@ describe('AccountSetting', () => {
|
||||
vi.mocked(useBreakpoints).mockReturnValue(MediaType.mobile)
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
// On mobile, the labels should not be rendered as per the implementation
|
||||
@@ -176,11 +205,7 @@ describe('AccountSetting', () => {
|
||||
})
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.settings.provider')).not.toBeInTheDocument()
|
||||
@@ -197,11 +222,7 @@ describe('AccountSetting', () => {
|
||||
})
|
||||
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.settings.billing')).not.toBeInTheDocument()
|
||||
@@ -212,11 +233,7 @@ describe('AccountSetting', () => {
|
||||
describe('Tab Navigation', () => {
|
||||
it('should change active tab when clicking on menu item', () => {
|
||||
// Arrange
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} onTabChange={mockOnTabChange} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting({ onTabChange: mockOnTabChange })
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByText('common.settings.provider'))
|
||||
@@ -229,11 +246,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should navigate through various tabs and show correct details', () => {
|
||||
// Act & Assert
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
|
||||
// Billing
|
||||
fireEvent.click(screen.getByText('common.settings.billing'))
|
||||
@@ -267,13 +280,11 @@ describe('AccountSetting', () => {
|
||||
describe('Interactions', () => {
|
||||
it('should call onCancel when clicking close button', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
const buttons = screen.getAllByRole('button')
|
||||
fireEvent.click(buttons[0])
|
||||
renderAccountSetting()
|
||||
const closeIcon = document.querySelector('.i-ri-close-line')
|
||||
const closeButton = closeIcon?.closest('button')
|
||||
expect(closeButton).not.toBeNull()
|
||||
fireEvent.click(closeButton!)
|
||||
|
||||
// Assert
|
||||
expect(mockOnCancel).toHaveBeenCalled()
|
||||
@@ -281,11 +292,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should call onCancel when pressing Escape key', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
fireEvent.keyDown(document, { key: 'Escape' })
|
||||
|
||||
// Assert
|
||||
@@ -294,12 +301,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should update search value in provider tab', () => {
|
||||
// Arrange
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
fireEvent.click(screen.getByText('common.settings.provider'))
|
||||
renderAccountSetting({ initialTab: ACCOUNT_SETTING_TAB.PROVIDER })
|
||||
|
||||
// Act
|
||||
const input = screen.getByRole('textbox')
|
||||
@@ -312,11 +314,7 @@ describe('AccountSetting', () => {
|
||||
|
||||
it('should handle scroll event in panel', () => {
|
||||
// Act
|
||||
render(
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
<AccountSetting onCancel={mockOnCancel} />
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
renderAccountSetting()
|
||||
const scrollContainer = screen.getByRole('dialog').querySelector('.overflow-y-auto')
|
||||
|
||||
// Assert
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SearchInput from '@/app/components/base/search-input'
|
||||
import BillingPage from '@/app/components/billing/billing-page'
|
||||
@@ -20,15 +20,16 @@ import DataSourcePage from './data-source-page-new'
|
||||
import LanguagePage from './language-page'
|
||||
import MembersPage from './members-page'
|
||||
import ModelProviderPage from './model-provider-page'
|
||||
import { useResetModelProviderListExpanded } from './model-provider-page/atoms'
|
||||
|
||||
const iconClassName = `
|
||||
w-5 h-5 mr-2
|
||||
`
|
||||
|
||||
type IAccountSettingProps = {
|
||||
onCancel: () => void
|
||||
activeTab?: AccountSettingTab
|
||||
onTabChange?: (tab: AccountSettingTab) => void
|
||||
onCancelAction: () => void
|
||||
activeTab: AccountSettingTab
|
||||
onTabChangeAction: (tab: AccountSettingTab) => void
|
||||
}
|
||||
|
||||
type GroupItem = {
|
||||
@@ -40,14 +41,12 @@ type GroupItem = {
|
||||
}
|
||||
|
||||
export default function AccountSetting({
|
||||
onCancel,
|
||||
activeTab = ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
onTabChange,
|
||||
onCancelAction,
|
||||
activeTab,
|
||||
onTabChangeAction,
|
||||
}: IAccountSettingProps) {
|
||||
const [activeMenu, setActiveMenu] = useState<AccountSettingTab>(activeTab)
|
||||
useEffect(() => {
|
||||
setActiveMenu(activeTab)
|
||||
}, [activeTab])
|
||||
const resetModelProviderListExpanded = useResetModelProviderListExpanded()
|
||||
const activeMenu = activeTab
|
||||
const { t } = useTranslation()
|
||||
const { enableBilling, enableReplaceWebAppLogo } = useProviderContext()
|
||||
const { isCurrentWorkspaceDatasetOperator } = useAppContext()
|
||||
@@ -148,10 +147,22 @@ export default function AccountSetting({
|
||||
|
||||
const [searchValue, setSearchValue] = useState<string>('')
|
||||
|
||||
const handleTabChange = useCallback((tab: AccountSettingTab) => {
|
||||
if (tab === ACCOUNT_SETTING_TAB.PROVIDER)
|
||||
resetModelProviderListExpanded()
|
||||
|
||||
onTabChangeAction(tab)
|
||||
}, [onTabChangeAction, resetModelProviderListExpanded])
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
resetModelProviderListExpanded()
|
||||
onCancelAction()
|
||||
}, [onCancelAction, resetModelProviderListExpanded])
|
||||
|
||||
return (
|
||||
<MenuDialog
|
||||
show
|
||||
onClose={onCancel}
|
||||
onClose={handleClose}
|
||||
>
|
||||
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
|
||||
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
|
||||
@@ -166,21 +177,22 @@ export default function AccountSetting({
|
||||
<div>
|
||||
{
|
||||
menuItem.items.map(item => (
|
||||
<div
|
||||
<button
|
||||
type="button"
|
||||
key={item.key}
|
||||
className={cn(
|
||||
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
|
||||
'mb-0.5 flex h-[37px] w-full items-center rounded-lg p-1 pl-3 text-left text-sm',
|
||||
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
|
||||
)}
|
||||
aria-label={item.name}
|
||||
title={item.name}
|
||||
onClick={() => {
|
||||
setActiveMenu(item.key)
|
||||
onTabChange?.(item.key)
|
||||
handleTabChange(item.key)
|
||||
}}
|
||||
>
|
||||
{activeMenu === item.key ? item.activeIcon : item.icon}
|
||||
{!isMobile && <div className="truncate">{item.name}</div>}
|
||||
</div>
|
||||
</button>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
@@ -195,7 +207,8 @@ export default function AccountSetting({
|
||||
variant="tertiary"
|
||||
size="large"
|
||||
className="px-2"
|
||||
onClick={onCancel}
|
||||
aria-label={t('operation.close', { ns: 'common' })}
|
||||
onClick={handleClose}
|
||||
>
|
||||
<span className="i-ri-close-line h-5 w-5" />
|
||||
</Button>
|
||||
|
||||
@@ -97,7 +97,7 @@ const Operation = ({
|
||||
offset={{ mainAxis: 4 }}
|
||||
>
|
||||
<PortalToFollowElemTrigger asChild onClick={() => setOpen(prev => !prev)}>
|
||||
<div className={cn('system-sm-regular group flex h-full w-full cursor-pointer items-center justify-between px-3 text-text-secondary hover:bg-state-base-hover', open && 'bg-state-base-hover')}>
|
||||
<div className={cn('group flex h-full w-full cursor-pointer items-center justify-between px-3 text-text-secondary system-sm-regular hover:bg-state-base-hover', open && 'bg-state-base-hover')}>
|
||||
{RoleMap[member.role] || RoleMap.normal}
|
||||
<ChevronDownIcon className={cn('h-4 w-4 shrink-0 group-hover:block', open ? 'block' : 'hidden')} />
|
||||
</div>
|
||||
@@ -114,8 +114,8 @@ const Operation = ({
|
||||
: <div className="mr-1 mt-[2px] h-4 w-4 text-text-accent" />
|
||||
}
|
||||
<div>
|
||||
<div className="system-sm-semibold whitespace-nowrap text-text-secondary">{t(roleI18nKeyMap[role].label, { ns: 'common' })}</div>
|
||||
<div className="system-xs-regular whitespace-nowrap text-text-tertiary">{t(roleI18nKeyMap[role].tip, { ns: 'common' })}</div>
|
||||
<div className="whitespace-nowrap text-text-secondary system-sm-semibold">{t(roleI18nKeyMap[role].label, { ns: 'common' })}</div>
|
||||
<div className="whitespace-nowrap text-text-tertiary system-xs-regular">{t(roleI18nKeyMap[role].tip, { ns: 'common' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
@@ -125,8 +125,8 @@ const Operation = ({
|
||||
<div className="flex cursor-pointer rounded-lg px-3 py-2 hover:bg-state-base-hover" onClick={handleDeleteMemberOrCancelInvitation}>
|
||||
<div className="mr-1 mt-[2px] h-4 w-4 text-text-accent" />
|
||||
<div>
|
||||
<div className="system-sm-semibold whitespace-nowrap text-text-secondary">{t('members.removeFromTeam', { ns: 'common' })}</div>
|
||||
<div className="system-xs-regular whitespace-nowrap text-text-tertiary">{t('members.removeFromTeamTip', { ns: 'common' })}</div>
|
||||
<div className="whitespace-nowrap text-text-secondary system-sm-semibold">{t('members.removeFromTeam', { ns: 'common' })}</div>
|
||||
<div className="whitespace-nowrap text-text-tertiary system-xs-regular">{t('members.removeFromTeamTip', { ns: 'common' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -40,8 +40,7 @@ describe('MenuDialog', () => {
|
||||
)
|
||||
|
||||
// Assert
|
||||
const panel = screen.getByRole('dialog').querySelector('.custom-class')
|
||||
expect(panel).toBeInTheDocument()
|
||||
expect(screen.getByRole('dialog')).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { Dialog, DialogPanel, Transition, TransitionChild } from '@headlessui/react'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { Fragment, useCallback, useEffect } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { Dialog, DialogContent } from '@/app/components/base/ui/dialog'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type DialogProps = {
|
||||
@@ -19,42 +18,25 @@ const MenuDialog = ({
|
||||
}: DialogProps) => {
|
||||
const close = useCallback(() => onClose?.(), [onClose])
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.preventDefault()
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown)
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown)
|
||||
}
|
||||
}, [close])
|
||||
|
||||
return (
|
||||
<Transition appear show={show} as={Fragment}>
|
||||
<Dialog as="div" className="relative z-[60]" onClose={noop}>
|
||||
<div className="fixed inset-0">
|
||||
<div className="flex min-h-full flex-col items-center justify-center">
|
||||
<TransitionChild>
|
||||
<DialogPanel className={cn(
|
||||
'relative h-full w-full grow overflow-hidden bg-background-sidenav-bg p-0 text-left align-middle backdrop-blur-md transition-all',
|
||||
'duration-300 ease-in data-[closed]:scale-95 data-[closed]:opacity-0',
|
||||
'data-[enter]:scale-100 data-[enter]:opacity-100',
|
||||
'data-[enter]:scale-95 data-[leave]:opacity-0',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
|
||||
{children}
|
||||
</DialogPanel>
|
||||
</TransitionChild>
|
||||
</div>
|
||||
</div>
|
||||
</Dialog>
|
||||
</Transition>
|
||||
<Dialog
|
||||
open={show}
|
||||
onOpenChange={(open) => {
|
||||
if (!open)
|
||||
close()
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
overlayClassName="bg-transparent"
|
||||
className={cn(
|
||||
'left-0 top-0 h-full max-h-none w-full max-w-none translate-x-0 translate-y-0 overflow-hidden rounded-none border-none bg-background-sidenav-bg p-0 shadow-none backdrop-blur-md',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="absolute right-0 top-0 h-full w-1/2 bg-components-panel-bg" />
|
||||
{children}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { Provider } from 'jotai'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
import {
|
||||
useExpandModelProviderList,
|
||||
useModelProviderListExpanded,
|
||||
useResetModelProviderListExpanded,
|
||||
useSetModelProviderListExpanded,
|
||||
} from './atoms'
|
||||
|
||||
const createWrapper = () => {
|
||||
return ({ children }: { children: ReactNode }) => (
|
||||
<Provider>{children}</Provider>
|
||||
)
|
||||
}
|
||||
|
||||
describe('atoms', () => {
|
||||
let wrapper: ReturnType<typeof createWrapper>
|
||||
|
||||
beforeEach(() => {
|
||||
wrapper = createWrapper()
|
||||
})
|
||||
|
||||
// Read hook: returns whether a specific provider is expanded
|
||||
describe('useModelProviderListExpanded', () => {
|
||||
it('should return false when provider has not been expanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => useModelProviderListExpanded('openai'),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for any unknown provider name', () => {
|
||||
const { result } = renderHook(
|
||||
() => useModelProviderListExpanded('nonexistent-provider'),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when provider has been expanded via setter', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Setter hook: toggles expanded state for a specific provider
|
||||
describe('useSetModelProviderListExpanded', () => {
|
||||
it('should expand a provider when called with true', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
setExpanded: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should collapse a provider when called with false', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
setExpanded: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setExpanded(false)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should not affect other providers when setting one', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
expect(result.current.anthropicExpanded).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// Expand hook: expands any provider by name
|
||||
describe('useExpandModelProviderList', () => {
|
||||
it('should expand the specified provider', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('google'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('google')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should expand multiple providers independently', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
expect(result.current.anthropicExpanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should not collapse already expanded providers when expanding another', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Reset hook: clears all expanded state back to empty
|
||||
describe('useResetModelProviderListExpanded', () => {
|
||||
it('should reset all expanded providers to false', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openaiExpanded: useModelProviderListExpanded('openai'),
|
||||
anthropicExpanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.openaiExpanded).toBe(false)
|
||||
expect(result.current.anthropicExpanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should be safe to call when no providers are expanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should allow re-expanding providers after reset', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
expand: useExpandModelProviderList(),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Cross-hook interaction: verify hooks cooperate through the shared atom
|
||||
describe('Cross-hook interaction', () => {
|
||||
it('should reflect state set by useSetModelProviderListExpanded in useModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should reflect state set by useExpandModelProviderList in useModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('anthropic'),
|
||||
expand: useExpandModelProviderList(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('anthropic')
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(true)
|
||||
})
|
||||
|
||||
it('should allow useSetModelProviderListExpanded to collapse a provider expanded by useExpandModelProviderList', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
expand: useExpandModelProviderList(),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.expand('openai')
|
||||
})
|
||||
expect(result.current.expanded).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(false)
|
||||
})
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
|
||||
it('should reset state set by useSetModelProviderListExpanded via useResetModelProviderListExpanded', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
reset: useResetModelProviderListExpanded(),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setExpanded(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.reset()
|
||||
})
|
||||
|
||||
expect(result.current.expanded).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// selectAtom granularity: changing one provider should not affect unrelated reads
|
||||
describe('selectAtom granularity', () => {
|
||||
it('should not cause unrelated provider reads to change when one provider is toggled', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openai: useModelProviderListExpanded('openai'),
|
||||
anthropic: useModelProviderListExpanded('anthropic'),
|
||||
google: useModelProviderListExpanded('google'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
const anthropicBefore = result.current.anthropic
|
||||
const googleBefore = result.current.google
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
|
||||
expect(result.current.openai).toBe(true)
|
||||
expect(result.current.anthropic).toBe(anthropicBefore)
|
||||
expect(result.current.google).toBe(googleBefore)
|
||||
})
|
||||
|
||||
it('should keep individual provider states independent across multiple expansions and collapses', () => {
|
||||
const { result } = renderHook(
|
||||
() => ({
|
||||
openai: useModelProviderListExpanded('openai'),
|
||||
anthropic: useModelProviderListExpanded('anthropic'),
|
||||
setOpenai: useSetModelProviderListExpanded('openai'),
|
||||
setAnthropic: useSetModelProviderListExpanded('anthropic'),
|
||||
}),
|
||||
{ wrapper },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.setOpenai(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setAnthropic(true)
|
||||
})
|
||||
act(() => {
|
||||
result.current.setOpenai(false)
|
||||
})
|
||||
|
||||
expect(result.current.openai).toBe(false)
|
||||
expect(result.current.anthropic).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Isolation: separate Provider instances have independent state
|
||||
describe('Provider isolation', () => {
|
||||
it('should have independent state across different Provider instances', () => {
|
||||
const wrapper1 = createWrapper()
|
||||
const wrapper2 = createWrapper()
|
||||
|
||||
const { result: result1 } = renderHook(
|
||||
() => ({
|
||||
expanded: useModelProviderListExpanded('openai'),
|
||||
setExpanded: useSetModelProviderListExpanded('openai'),
|
||||
}),
|
||||
{ wrapper: wrapper1 },
|
||||
)
|
||||
|
||||
const { result: result2 } = renderHook(
|
||||
() => useModelProviderListExpanded('openai'),
|
||||
{ wrapper: wrapper2 },
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result1.current.setExpanded(true)
|
||||
})
|
||||
|
||||
expect(result1.current.expanded).toBe(true)
|
||||
expect(result2.current).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,35 @@
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { selectAtom } from 'jotai/utils'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
|
||||
const expandedAtom = atom<Record<string, boolean>>({})
|
||||
|
||||
export function useModelProviderListExpanded(providerName: string) {
|
||||
return useAtomValue(
|
||||
useMemo(
|
||||
() => selectAtom(expandedAtom, s => !!s[providerName]),
|
||||
[providerName],
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
export function useSetModelProviderListExpanded(providerName: string) {
|
||||
const set = useSetAtom(expandedAtom)
|
||||
return useCallback(
|
||||
(expanded: boolean) => set(prev => ({ ...prev, [providerName]: expanded })),
|
||||
[providerName, set],
|
||||
)
|
||||
}
|
||||
|
||||
export function useExpandModelProviderList() {
|
||||
const set = useSetAtom(expandedAtom)
|
||||
return useCallback(
|
||||
(providerName: string) => set(prev => ({ ...prev, [providerName]: true })),
|
||||
[set],
|
||||
)
|
||||
}
|
||||
|
||||
export function useResetModelProviderListExpanded() {
|
||||
const set = useSetAtom(expandedAtom)
|
||||
return useCallback(() => set({}), [set])
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import type {
|
||||
} from './declarations'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import { fetchDefaultModal, fetchModelList, fetchModelProviderCredentials } from '@/service/common'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
@@ -23,6 +24,7 @@ import {
|
||||
useAnthropicBuyQuota,
|
||||
useCurrentProviderAndModel,
|
||||
useDefaultModel,
|
||||
useInvalidateDefaultModel,
|
||||
useLanguage,
|
||||
useMarketplaceAllPlugins,
|
||||
useModelList,
|
||||
@@ -36,7 +38,6 @@ import {
|
||||
useUpdateModelList,
|
||||
useUpdateModelProviders,
|
||||
} from './hooks'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@tanstack/react-query', () => ({
|
||||
@@ -78,14 +79,6 @@ vi.mock('@/context/modal-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/event-emitter', () => ({
|
||||
useEventEmitterContextContext: vi.fn(() => ({
|
||||
eventEmitter: {
|
||||
emit: vi.fn(),
|
||||
},
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
useMarketplacePlugins: vi.fn(() => ({
|
||||
plugins: [],
|
||||
@@ -99,12 +92,16 @@ vi.mock('@/app/components/plugins/marketplace/hooks', () => ({
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.mock('./atoms', () => ({
|
||||
useExpandModelProviderList: vi.fn(() => vi.fn()),
|
||||
}))
|
||||
|
||||
const { useQuery, useQueryClient } = await import('@tanstack/react-query')
|
||||
const { getPayUrl } = await import('@/service/common')
|
||||
const { useProviderContext } = await import('@/context/provider-context')
|
||||
const { useModalContextSelector } = await import('@/context/modal-context')
|
||||
const { useEventEmitterContextContext } = await import('@/context/event-emitter')
|
||||
const { useMarketplacePlugins, useMarketplacePluginsByCollectionId } = await import('@/app/components/plugins/marketplace/hooks')
|
||||
const { useExpandModelProviderList } = await import('./atoms')
|
||||
|
||||
describe('hooks', () => {
|
||||
beforeEach(() => {
|
||||
@@ -913,6 +910,38 @@ describe('hooks', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('useInvalidateDefaultModel', () => {
|
||||
it('should invalidate default model queries', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const { result } = renderHook(() => useInvalidateDefaultModel())
|
||||
|
||||
act(() => {
|
||||
result.current(ModelTypeEnum.textGeneration)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: ['default-model', ModelTypeEnum.textGeneration],
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle multiple model types', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const { result } = renderHook(() => useInvalidateDefaultModel())
|
||||
|
||||
act(() => {
|
||||
result.current(ModelTypeEnum.textGeneration)
|
||||
result.current(ModelTypeEnum.textEmbedding)
|
||||
result.current(ModelTypeEnum.rerank)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useAnthropicBuyQuota', () => {
|
||||
beforeEach(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
@@ -1275,39 +1304,52 @@ describe('hooks', () => {
|
||||
|
||||
it('should refresh providers and model lists', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-providers'] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textEmbedding] })
|
||||
})
|
||||
|
||||
it('should emit event when refreshModelList is true and custom config is active', () => {
|
||||
it('should expand target provider list when refreshModelList is true and custom config is active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = createMockProvider()
|
||||
const customFields: CustomConfigurationModelFixedFields = {
|
||||
__model_name: 'gpt-4',
|
||||
__model_type: ModelTypeEnum.textGeneration,
|
||||
}
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1315,23 +1357,30 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, customFields, true)
|
||||
})
|
||||
|
||||
expect(emit).toHaveBeenCalledWith({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: 'openai',
|
||||
expect(expandModelProviderList).toHaveBeenCalledWith('openai')
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({ queryKey: ['model-list', ModelTypeEnum.textGeneration] })
|
||||
})
|
||||
|
||||
it('should not emit event when custom config is not active', () => {
|
||||
it('should not expand provider list when custom config is not active', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
const expandModelProviderList = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
; (useExpandModelProviderList as Mock).mockReturnValue(expandModelProviderList)
|
||||
|
||||
const provider = { ...createMockProvider(), custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure } }
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
@@ -1339,17 +1388,43 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(emit).not.toHaveBeenCalled()
|
||||
expect(expandModelProviderList).not.toHaveBeenCalled()
|
||||
expect(invalidateQueries).not.toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should emit event and invalidate all supported model types when __model_type is undefined', () => {
|
||||
it('should refetch active model provider list when custom refresh callback is absent', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
|
||||
const provider = createMockProvider()
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
const { result } = renderHook(() => useRefreshModel())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshModel(provider, undefined, true)
|
||||
})
|
||||
|
||||
expect(invalidateQueries).toHaveBeenCalledWith({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
})
|
||||
|
||||
it('should invalidate all supported model types when __model_type is undefined', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
const emit = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit },
|
||||
})
|
||||
|
||||
const provider = createMockProvider()
|
||||
const customFields = { __model_name: 'my-model', __model_type: undefined } as unknown as CustomConfigurationModelFixedFields
|
||||
@@ -1360,11 +1435,7 @@ describe('hooks', () => {
|
||||
result.current.handleRefreshModel(provider, customFields, true)
|
||||
})
|
||||
|
||||
expect(emit).toHaveBeenCalledWith({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: 'openai',
|
||||
})
|
||||
// When __model_type is undefined, all supported model types are invalidated
|
||||
// When __model_type is undefined, all supported model types are invalidated.
|
||||
const modelListCalls = invalidateQueries.mock.calls.filter(
|
||||
call => call[0]?.queryKey?.[0] === 'model-list',
|
||||
)
|
||||
@@ -1375,9 +1446,6 @@ describe('hooks', () => {
|
||||
const invalidateQueries = vi.fn()
|
||||
|
||||
; (useQueryClient as Mock).mockReturnValue({ invalidateQueries })
|
||||
; (useEventEmitterContextContext as Mock).mockReturnValue({
|
||||
eventEmitter: { emit: vi.fn() },
|
||||
})
|
||||
|
||||
const provider = {
|
||||
...createMockProvider(),
|
||||
|
||||
@@ -21,10 +21,10 @@ import {
|
||||
useMarketplacePluginsByCollectionId,
|
||||
} from '@/app/components/plugins/marketplace/hooks'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useModalContextSelector } from '@/context/modal-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
import {
|
||||
fetchDefaultModal,
|
||||
fetchModelList,
|
||||
@@ -32,12 +32,12 @@ import {
|
||||
getPayUrl,
|
||||
} from '@/service/common'
|
||||
import { commonQueryKeys } from '@/service/use-common'
|
||||
import { useExpandModelProviderList } from './atoms'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CustomConfigurationStatusEnum,
|
||||
ModelStatusEnum,
|
||||
} from './declarations'
|
||||
import { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
|
||||
|
||||
type UseDefaultModelAndModelList = (
|
||||
defaultModel: DefaultModelResponse | undefined,
|
||||
@@ -57,15 +57,21 @@ export const useSystemDefaultModelAndModelList: UseDefaultModelAndModelList = (
|
||||
|
||||
return currentDefaultModel
|
||||
}, [defaultModel, modelList])
|
||||
const currentDefaultModelKey = currentDefaultModel
|
||||
? `${currentDefaultModel.provider}:${currentDefaultModel.model}`
|
||||
: ''
|
||||
const [defaultModelState, setDefaultModelState] = useState<DefaultModel | undefined>(currentDefaultModel)
|
||||
const handleDefaultModelChange = useCallback((model: DefaultModel) => {
|
||||
setDefaultModelState(model)
|
||||
}, [])
|
||||
useEffect(() => {
|
||||
setDefaultModelState(currentDefaultModel)
|
||||
}, [currentDefaultModel])
|
||||
const [defaultModelSourceKey, setDefaultModelSourceKey] = useState(currentDefaultModelKey)
|
||||
const selectedDefaultModel = defaultModelSourceKey === currentDefaultModelKey
|
||||
? defaultModelState
|
||||
: currentDefaultModel
|
||||
|
||||
return [defaultModelState, handleDefaultModelChange]
|
||||
const handleDefaultModelChange = useCallback((model: DefaultModel) => {
|
||||
setDefaultModelSourceKey(currentDefaultModelKey)
|
||||
setDefaultModelState(model)
|
||||
}, [currentDefaultModelKey])
|
||||
|
||||
return [selectedDefaultModel, handleDefaultModelChange]
|
||||
}
|
||||
|
||||
export const useLanguage = () => {
|
||||
@@ -116,7 +122,7 @@ export const useProviderCredentialsAndLoadBalancing = (
|
||||
predefinedFormSchemasValue?.credentials,
|
||||
])
|
||||
|
||||
const mutate = useMemo(() => () => {
|
||||
const mutate = useCallback(() => {
|
||||
if (predefinedEnabled)
|
||||
queryClient.invalidateQueries({ queryKey: ['model-providers', 'credentials', provider, credentialId] })
|
||||
if (customEnabled)
|
||||
@@ -222,6 +228,14 @@ export const useUpdateModelList = () => {
|
||||
return updateModelList
|
||||
}
|
||||
|
||||
export const useInvalidateDefaultModel = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useCallback((type: ModelTypeEnum) => {
|
||||
queryClient.invalidateQueries({ queryKey: commonQueryKeys.defaultModel(type) })
|
||||
}, [queryClient])
|
||||
}
|
||||
|
||||
export const useAnthropicBuyQuota = () => {
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
@@ -314,7 +328,8 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText:
|
||||
}
|
||||
|
||||
export const useRefreshModel = () => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
const expandModelProviderList = useExpandModelProviderList()
|
||||
const queryClient = useQueryClient()
|
||||
const updateModelProviders = useUpdateModelProviders()
|
||||
const updateModelList = useUpdateModelList()
|
||||
const handleRefreshModel = useCallback((
|
||||
@@ -322,6 +337,19 @@ export const useRefreshModel = () => {
|
||||
CustomConfigurationModelFixedFields?: CustomConfigurationModelFixedFields,
|
||||
refreshModelList?: boolean,
|
||||
) => {
|
||||
const modelProviderModelListQueryKey = consoleQuery.modelProviders.models.queryKey({
|
||||
input: {
|
||||
params: {
|
||||
provider: provider.provider,
|
||||
},
|
||||
},
|
||||
})
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'none',
|
||||
})
|
||||
|
||||
updateModelProviders()
|
||||
|
||||
provider.supported_model_types.forEach((type) => {
|
||||
@@ -329,15 +357,17 @@ export const useRefreshModel = () => {
|
||||
})
|
||||
|
||||
if (refreshModelList && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
|
||||
eventEmitter?.emit({
|
||||
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
|
||||
payload: provider.provider,
|
||||
} as any)
|
||||
expandModelProviderList(provider.provider)
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: modelProviderModelListQueryKey,
|
||||
exact: true,
|
||||
refetchType: 'active',
|
||||
})
|
||||
|
||||
if (CustomConfigurationModelFixedFields?.__model_type)
|
||||
updateModelList(CustomConfigurationModelFixedFields.__model_type)
|
||||
}
|
||||
}, [eventEmitter, updateModelList, updateModelProviders])
|
||||
}, [expandModelProviderList, queryClient, updateModelList, updateModelProviders])
|
||||
|
||||
return {
|
||||
handleRefreshModel,
|
||||
|
||||
@@ -7,16 +7,7 @@ import {
|
||||
} from './declarations'
|
||||
import ModelProviderPage from './index'
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
mutateCurrentWorkspace: vi.fn(),
|
||||
isValidatingCurrentWorkspace: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockGlobalState = {
|
||||
systemFeatures: { enable_marketplace: true },
|
||||
}
|
||||
let mockEnableMarketplace = true
|
||||
|
||||
const mockQuotaConfig = {
|
||||
quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
@@ -28,7 +19,11 @@ const mockQuotaConfig = {
|
||||
}
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: (selector: (s: { systemFeatures: { enable_marketplace: boolean } }) => unknown) => selector(mockGlobalState),
|
||||
useSystemFeaturesQuery: () => ({
|
||||
data: {
|
||||
enable_marketplace: mockEnableMarketplace,
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockProviders = [
|
||||
@@ -60,21 +55,16 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
type MockDefaultModelData = {
|
||||
model: string
|
||||
provider?: { provider: string }
|
||||
} | null
|
||||
|
||||
const mockDefaultModelState: {
|
||||
data: MockDefaultModelData
|
||||
isLoading: boolean
|
||||
} = {
|
||||
data: null,
|
||||
isLoading: false,
|
||||
const mockDefaultModels: Record<string, { data: unknown, isLoading: boolean }> = {
|
||||
'llm': { data: null, isLoading: false },
|
||||
'text-embedding': { data: null, isLoading: false },
|
||||
'rerank': { data: null, isLoading: false },
|
||||
'speech2text': { data: null, isLoading: false },
|
||||
'tts': { data: null, isLoading: false },
|
||||
}
|
||||
|
||||
vi.mock('./hooks', () => ({
|
||||
useDefaultModel: () => mockDefaultModelState,
|
||||
useDefaultModel: (type: string) => mockDefaultModels[type] ?? { data: null, isLoading: false },
|
||||
}))
|
||||
|
||||
vi.mock('./install-from-marketplace', () => ({
|
||||
@@ -93,13 +83,18 @@ vi.mock('./system-model-selector', () => ({
|
||||
default: () => <div data-testid="system-model-selector" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useCheckInstalled: () => ({ data: undefined }),
|
||||
}))
|
||||
|
||||
describe('ModelProviderPage', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
mockGlobalState.systemFeatures.enable_marketplace = true
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = false
|
||||
mockEnableMarketplace = true
|
||||
Object.keys(mockDefaultModels).forEach((key) => {
|
||||
mockDefaultModels[key] = { data: null, isLoading: false }
|
||||
})
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'openai',
|
||||
label: { en_US: 'OpenAI' },
|
||||
@@ -157,13 +152,76 @@ describe('ModelProviderPage', () => {
|
||||
})
|
||||
|
||||
it('should hide marketplace section when marketplace feature is disabled', () => {
|
||||
mockGlobalState.systemFeatures.enable_marketplace = false
|
||||
mockEnableMarketplace = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByTestId('install-from-marketplace')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('system model config status', () => {
|
||||
it('should not show top warning when no configured providers exist (empty state card handles it)', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'anthropic',
|
||||
label: { en_US: 'Anthropic' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.emptyProviderTitle')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show none-configured warning when providers exist but no default models set', () => {
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.getByText('common.modelProvider.noneConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show partially-configured warning when some default models are set', () => {
|
||||
mockDefaultModels.llm = {
|
||||
data: { model: 'gpt-4', model_type: 'llm', provider: { provider: 'openai', icon_small: { en_US: '' } } },
|
||||
isLoading: false,
|
||||
}
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show warning when all default models are configured', () => {
|
||||
const makeModel = (model: string, type: string) => ({
|
||||
data: { model, model_type: type, provider: { provider: 'openai', icon_small: { en_US: '' } } },
|
||||
isLoading: false,
|
||||
})
|
||||
mockDefaultModels.llm = makeModel('gpt-4', 'llm')
|
||||
mockDefaultModels['text-embedding'] = makeModel('text-embedding-3', 'text-embedding')
|
||||
mockDefaultModels.rerank = makeModel('rerank-v3', 'rerank')
|
||||
mockDefaultModels.speech2text = makeModel('whisper-1', 'speech2text')
|
||||
mockDefaultModels.tts = makeModel('tts-1', 'tts')
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show warning while loading', () => {
|
||||
Object.keys(mockDefaultModels).forEach((key) => {
|
||||
mockDefaultModels[key] = { data: null, isLoading: true }
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
expect(screen.queryByText('common.modelProvider.noProviderInstalled')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.noneConfigured')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should prioritize fixed providers in visible order', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'zeta-provider',
|
||||
@@ -204,129 +262,4 @@ describe('ModelProviderPage', () => {
|
||||
])
|
||||
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show not configured alert when all default models are absent', () => {
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.notConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when default model is loading', () => {
|
||||
mockDefaultModelState.data = null
|
||||
mockDefaultModelState.isLoading = true
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should filter providers by label text', () => {
|
||||
render(<ModelProviderPage searchText="OpenAI" />)
|
||||
act(() => {
|
||||
vi.advanceTimersByTime(600)
|
||||
})
|
||||
expect(screen.getByText('openai')).toBeInTheDocument()
|
||||
expect(screen.queryByText('anthropic')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should classify system-enabled providers with matching quota as configured', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'sys-provider',
|
||||
label: { en_US: 'System Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: true,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('sys-provider')).toBeInTheDocument()
|
||||
expect(screen.queryByText('common.modelProvider.toBeConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should classify system-enabled provider with no matching quota as not configured', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'sys-no-quota',
|
||||
label: { en_US: 'System No Quota' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.noConfigure },
|
||||
system_configuration: {
|
||||
enabled: true,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.getByText('sys-no-quota')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.toBeConfigured')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should preserve order of two non-fixed providers (sort returns 0)', () => {
|
||||
mockProviders.splice(0, mockProviders.length, {
|
||||
provider: 'alpha-provider',
|
||||
label: { en_US: 'Alpha Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.active },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
}, {
|
||||
provider: 'beta-provider',
|
||||
label: { en_US: 'Beta Provider' },
|
||||
custom_configuration: { status: CustomConfigurationStatusEnum.active },
|
||||
system_configuration: {
|
||||
enabled: false,
|
||||
current_quota_type: CurrentSystemQuotaTypeEnum.free,
|
||||
quota_configurations: [mockQuotaConfig],
|
||||
},
|
||||
})
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
const renderedProviders = screen.getAllByTestId('provider-card').map(item => item.textContent)
|
||||
expect(renderedProviders).toEqual(['alpha-provider', 'beta-provider'])
|
||||
})
|
||||
|
||||
it('should not show not configured alert when shared default model mock has data', () => {
|
||||
mockDefaultModelState.data = { model: 'embed-model' }
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when rerankDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'rerank-model', provider: { provider: 'cohere' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when ttsDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'tts-model', provider: { provider: 'openai' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show not configured alert when speech2textDefaultModel has data', () => {
|
||||
mockDefaultModelState.data = { model: 'whisper', provider: { provider: 'openai' } }
|
||||
mockDefaultModelState.isLoading = false
|
||||
|
||||
render(<ModelProviderPage searchText="" />)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.notConfigured')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import type {
|
||||
ModelProvider,
|
||||
} from './declarations'
|
||||
import {
|
||||
RiAlertFill,
|
||||
RiBrainLine,
|
||||
} from '@remixicon/react'
|
||||
import type { PluginDetail } from '@/app/components/plugins/types'
|
||||
import { useDebounce } from 'ahooks'
|
||||
import { useEffect, useMemo } from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useSystemFeaturesQuery } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useCheckInstalled } from '@/service/use-plugins'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import {
|
||||
CustomConfigurationStatusEnum,
|
||||
@@ -24,6 +21,9 @@ import InstallFromMarketplace from './install-from-marketplace'
|
||||
import ProviderAddedCard from './provider-added-card'
|
||||
import QuotaPanel from './provider-added-card/quota-panel'
|
||||
import SystemModelSelector from './system-model-selector'
|
||||
import { providerToPluginId } from './utils'
|
||||
|
||||
type SystemModelConfigStatus = 'no-provider' | 'none-configured' | 'partially-configured' | 'fully-configured'
|
||||
|
||||
type Props = {
|
||||
searchText: string
|
||||
@@ -34,20 +34,35 @@ const FixedModelProvider = ['langgenius/openai/openai', 'langgenius/anthropic/an
|
||||
const ModelProviderPage = ({ searchText }: Props) => {
|
||||
const debouncedSearchText = useDebounce(searchText, { wait: 500 })
|
||||
const { t } = useTranslation()
|
||||
const { mutateCurrentWorkspace, isValidatingCurrentWorkspace } = useAppContext()
|
||||
const { data: textGenerationDefaultModel, isLoading: isTextGenerationDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textGeneration)
|
||||
const { data: embeddingsDefaultModel, isLoading: isEmbeddingsDefaultModelLoading } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
const { data: rerankDefaultModel, isLoading: isRerankDefaultModelLoading } = useDefaultModel(ModelTypeEnum.rerank)
|
||||
const { data: speech2textDefaultModel, isLoading: isSpeech2textDefaultModelLoading } = useDefaultModel(ModelTypeEnum.speech2text)
|
||||
const { data: ttsDefaultModel, isLoading: isTTSDefaultModelLoading } = useDefaultModel(ModelTypeEnum.tts)
|
||||
const { modelProviders: providers } = useProviderContext()
|
||||
const { enable_marketplace } = useGlobalPublicStore(s => s.systemFeatures)
|
||||
const { data: systemFeatures } = useSystemFeaturesQuery()
|
||||
|
||||
const allPluginIds = useMemo(() => {
|
||||
return [...new Set(providers.map(p => providerToPluginId(p.provider)).filter(Boolean))]
|
||||
}, [providers])
|
||||
const { data: installedPlugins } = useCheckInstalled({
|
||||
pluginIds: allPluginIds,
|
||||
enabled: allPluginIds.length > 0,
|
||||
})
|
||||
const pluginDetailMap = useMemo(() => {
|
||||
const map = new Map<string, PluginDetail>()
|
||||
if (installedPlugins?.plugins) {
|
||||
for (const plugin of installedPlugins.plugins)
|
||||
map.set(plugin.plugin_id, plugin)
|
||||
}
|
||||
return map
|
||||
}, [installedPlugins])
|
||||
const enableMarketplace = systemFeatures?.enable_marketplace ?? false
|
||||
const isDefaultModelLoading = isTextGenerationDefaultModelLoading
|
||||
|| isEmbeddingsDefaultModelLoading
|
||||
|| isRerankDefaultModelLoading
|
||||
|| isSpeech2textDefaultModelLoading
|
||||
|| isTTSDefaultModelLoading
|
||||
const defaultModelNotConfigured = !isDefaultModelLoading && !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel && !ttsDefaultModel
|
||||
const [configuredProviders, notConfiguredProviders] = useMemo(() => {
|
||||
const configuredProviders: ModelProvider[] = []
|
||||
const notConfiguredProviders: ModelProvider[] = []
|
||||
@@ -79,6 +94,26 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
|
||||
return [configuredProviders, notConfiguredProviders]
|
||||
}, [providers])
|
||||
|
||||
const systemModelConfigStatus: SystemModelConfigStatus = useMemo(() => {
|
||||
const defaultModels = [textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel]
|
||||
const configuredCount = defaultModels.filter(Boolean).length
|
||||
if (configuredCount === 0 && configuredProviders.length === 0)
|
||||
return 'no-provider'
|
||||
if (configuredCount === 0)
|
||||
return 'none-configured'
|
||||
if (configuredCount < defaultModels.length)
|
||||
return 'partially-configured'
|
||||
return 'fully-configured'
|
||||
}, [configuredProviders, textGenerationDefaultModel, embeddingsDefaultModel, rerankDefaultModel, speech2textDefaultModel, ttsDefaultModel])
|
||||
const warningTextKey
|
||||
= systemModelConfigStatus === 'none-configured'
|
||||
? 'modelProvider.noneConfigured'
|
||||
: systemModelConfigStatus === 'partially-configured'
|
||||
? 'modelProvider.notConfigured'
|
||||
: null
|
||||
const showWarning = !isDefaultModelLoading && !!warningTextKey
|
||||
|
||||
const [filteredConfiguredProviders, filteredNotConfiguredProviders] = useMemo(() => {
|
||||
const filteredConfiguredProviders = configuredProviders.filter(
|
||||
provider => provider.provider.toLowerCase().includes(debouncedSearchText.toLowerCase())
|
||||
@@ -92,28 +127,24 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
return [filteredConfiguredProviders, filteredNotConfiguredProviders]
|
||||
}, [configuredProviders, debouncedSearchText, notConfiguredProviders])
|
||||
|
||||
useEffect(() => {
|
||||
mutateCurrentWorkspace()
|
||||
}, [mutateCurrentWorkspace])
|
||||
|
||||
return (
|
||||
<div className="relative -mt-2 pt-1">
|
||||
<div className={cn('mb-2 flex items-center')}>
|
||||
<div className="grow text-text-primary system-md-semibold">{t('modelProvider.models', { ns: 'common' })}</div>
|
||||
<div className={cn(
|
||||
'relative flex shrink-0 items-center justify-end gap-2 rounded-lg border border-transparent p-px',
|
||||
defaultModelNotConfigured && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
|
||||
showWarning && 'border-components-panel-border bg-components-panel-bg-blur pl-2 shadow-xs',
|
||||
)}
|
||||
>
|
||||
{defaultModelNotConfigured && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
|
||||
{defaultModelNotConfigured && (
|
||||
{showWarning && <div className="absolute bottom-0 left-0 right-0 top-0 opacity-40" style={{ background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.25) 0%, rgba(255, 255, 255, 0.00) 100%)' }} />}
|
||||
{showWarning && (
|
||||
<div className="flex items-center gap-1 text-text-primary system-xs-medium">
|
||||
<RiAlertFill className="h-4 w-4 text-text-warning-secondary" />
|
||||
<span className="max-w-[460px] truncate" title={t('modelProvider.notConfigured', { ns: 'common' })}>{t('modelProvider.notConfigured', { ns: 'common' })}</span>
|
||||
<span className="i-ri-alert-fill h-4 w-4 text-text-warning-secondary" />
|
||||
<span className="max-w-[460px] truncate" title={t(warningTextKey, { ns: 'common' })}>{t(warningTextKey, { ns: 'common' })}</span>
|
||||
</div>
|
||||
)}
|
||||
<SystemModelSelector
|
||||
notConfigured={defaultModelNotConfigured}
|
||||
notConfigured={showWarning}
|
||||
textGenerationDefaultModel={textGenerationDefaultModel}
|
||||
embeddingsDefaultModel={embeddingsDefaultModel}
|
||||
rerankDefaultModel={rerankDefaultModel}
|
||||
@@ -123,11 +154,11 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} isLoading={isValidatingCurrentWorkspace} />}
|
||||
{IS_CLOUD_EDITION && <QuotaPanel providers={providers} />}
|
||||
{!filteredConfiguredProviders?.length && (
|
||||
<div className="mb-2 rounded-[10px] bg-workflow-process-bg p-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg shadow-lg backdrop-blur">
|
||||
<RiBrainLine className="h-5 w-5 text-text-primary" />
|
||||
<span className="i-ri-brain-line h-5 w-5 text-text-primary" />
|
||||
</div>
|
||||
<div className="mt-2 text-text-secondary system-sm-medium">{t('modelProvider.emptyProviderTitle', { ns: 'common' })}</div>
|
||||
<div className="mt-1 text-text-tertiary system-xs-regular">{t('modelProvider.emptyProviderTip', { ns: 'common' })}</div>
|
||||
@@ -139,6 +170,7 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
<ProviderAddedCard
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -152,13 +184,14 @@ const ModelProviderPage = ({ searchText }: Props) => {
|
||||
notConfigured
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
pluginDetail={pluginDetailMap.get(providerToPluginId(provider.provider))}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{
|
||||
enable_marketplace && (
|
||||
enableMarketplace && (
|
||||
<InstallFromMarketplace
|
||||
providers={providers}
|
||||
searchText={searchText}
|
||||
|
||||
@@ -2,10 +2,6 @@ import type {
|
||||
ModelProvider,
|
||||
} from './declarations'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import {
|
||||
RiArrowDownSLine,
|
||||
RiArrowRightUpLine,
|
||||
} from '@remixicon/react'
|
||||
import { useTheme } from 'next-themes'
|
||||
import Link from 'next/link'
|
||||
import { useCallback, useState } from 'react'
|
||||
@@ -47,15 +43,15 @@ const InstallFromMarketplace = ({
|
||||
<div className="mb-2">
|
||||
<Divider className="!mt-4 h-px" />
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="system-md-semibold flex cursor-pointer items-center gap-1 text-text-primary" onClick={() => setCollapse(!collapse)}>
|
||||
<RiArrowDownSLine className={cn('h-4 w-4', collapse && '-rotate-90')} />
|
||||
<div className="flex cursor-pointer items-center gap-1 text-text-primary system-md-semibold" onClick={() => setCollapse(!collapse)}>
|
||||
<span className={cn('i-ri-arrow-down-s-line h-4 w-4', collapse && '-rotate-90')} />
|
||||
{t('modelProvider.installProvider', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="mb-2 flex items-center pt-2">
|
||||
<span className="system-sm-regular pr-1 text-text-tertiary">{t('modelProvider.discoverMore', { ns: 'common' })}</span>
|
||||
<Link target="_blank" href={getMarketplaceUrl('', { theme })} className="system-sm-medium inline-flex items-center text-text-accent">
|
||||
<span className="pr-1 text-text-tertiary system-sm-regular">{t('modelProvider.discoverMore', { ns: 'common' })}</span>
|
||||
<Link target="_blank" href={getMarketplaceUrl('', { theme })} className="inline-flex items-center text-text-accent system-sm-medium">
|
||||
{t('marketplace.difyMarketplace', { ns: 'plugin' })}
|
||||
<RiArrowRightUpLine className="h-4 w-4" />
|
||||
<span className="i-ri-arrow-right-up-line h-4 w-4" />
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -2,12 +2,6 @@ import type { Credential } from '../../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import CredentialItem from './credential-item'
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiCheckLine: () => <div data-testid="check-icon" />,
|
||||
RiDeleteBinLine: () => <div data-testid="delete-icon" />,
|
||||
RiEqualizer2Line: () => <div data-testid="edit-icon" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/header/indicator', () => ({
|
||||
default: () => <div data-testid="indicator" />,
|
||||
}))
|
||||
@@ -61,8 +55,12 @@ describe('CredentialItem', () => {
|
||||
|
||||
render(<CredentialItem credential={credential} onEdit={onEdit} onDelete={onDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('edit-icon').closest('button') as HTMLButtonElement)
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
const buttons = screen.getAllByRole('button')
|
||||
const editButton = buttons.find(b => b.querySelector('.i-ri-equalizer-2-line'))!
|
||||
const deleteButton = buttons.find(b => b.querySelector('.i-ri-delete-bin-line'))!
|
||||
|
||||
fireEvent.click(editButton)
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onEdit).toHaveBeenCalledWith(credential)
|
||||
expect(onDelete).toHaveBeenCalledWith(credential)
|
||||
@@ -81,7 +79,10 @@ describe('CredentialItem', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
const deleteButton = screen.getAllByRole('button')
|
||||
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
|
||||
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onDelete).not.toHaveBeenCalled()
|
||||
})
|
||||
@@ -121,14 +122,16 @@ describe('CredentialItem', () => {
|
||||
|
||||
render(<CredentialItem credential={credential} disabled onDelete={onDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('delete-icon').closest('button') as HTMLButtonElement)
|
||||
const deleteButton = screen.getAllByRole('button')
|
||||
.find(b => b.querySelector('.i-ri-delete-bin-line'))!
|
||||
fireEvent.click(deleteButton)
|
||||
|
||||
expect(onDelete).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// showSelectedIcon=true: check icon area is always rendered; check icon only appears when IDs match
|
||||
it('should render check icon area when showSelectedIcon=true and selectedCredentialId matches', () => {
|
||||
render(
|
||||
const { container } = render(
|
||||
<CredentialItem
|
||||
credential={credential}
|
||||
showSelectedIcon
|
||||
@@ -136,7 +139,7 @@ describe('CredentialItem', () => {
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('check-icon')).toBeInTheDocument()
|
||||
expect(container.querySelector('.i-ri-check-line')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render check icon when showSelectedIcon=true but selectedCredentialId does not match', () => {
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
import type { Credential } from '../../declarations'
|
||||
import {
|
||||
RiCheckLine,
|
||||
RiDeleteBinLine,
|
||||
RiEqualizer2Line,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
memo,
|
||||
useMemo,
|
||||
@@ -11,7 +6,7 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@@ -56,7 +51,7 @@ const CredentialItem = ({
|
||||
key={credential.credential_id}
|
||||
className={cn(
|
||||
'group flex h-8 items-center rounded-lg p-1 hover:bg-state-base-hover',
|
||||
(disabled || credential.not_allowed_to_use) && 'cursor-not-allowed opacity-50',
|
||||
(disabled || credential.not_allowed_to_use) ? 'cursor-not-allowed opacity-50' : onItemClick && 'cursor-pointer',
|
||||
)}
|
||||
onClick={() => {
|
||||
if (disabled || credential.not_allowed_to_use)
|
||||
@@ -70,7 +65,7 @@ const CredentialItem = ({
|
||||
<div className="h-4 w-4">
|
||||
{
|
||||
selectedCredentialId === credential.credential_id && (
|
||||
<RiCheckLine className="h-4 w-4 text-text-accent" />
|
||||
<span className="i-ri-check-line h-4 w-4 text-text-accent" />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
@@ -78,7 +73,7 @@ const CredentialItem = ({
|
||||
}
|
||||
<Indicator className="ml-2 mr-1.5 shrink-0" />
|
||||
<div
|
||||
className="system-md-regular truncate text-text-secondary"
|
||||
className="truncate text-text-secondary system-md-regular"
|
||||
title={credential.credential_name}
|
||||
>
|
||||
{credential.credential_name}
|
||||
@@ -96,38 +91,50 @@ const CredentialItem = ({
|
||||
<div className="ml-2 hidden shrink-0 items-center group-hover:flex">
|
||||
{
|
||||
!disableEdit && !credential.not_allowed_to_use && (
|
||||
<Tooltip popupContent={t('operation.edit', { ns: 'common' })}>
|
||||
<ActionButton
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onEdit?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
|
||||
</ActionButton>
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
render={(
|
||||
<ActionButton
|
||||
disabled={disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
onEdit?.(credential)
|
||||
}}
|
||||
>
|
||||
<span className="i-ri-equalizer-2-line h-4 w-4 text-text-tertiary" />
|
||||
</ActionButton>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent>{t('operation.edit', { ns: 'common' })}</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
{
|
||||
!disableDelete && (
|
||||
<Tooltip popupContent={disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}>
|
||||
<ActionButton
|
||||
className="hover:bg-transparent"
|
||||
onClick={(e) => {
|
||||
if (disabled || disableDeleteWhenSelected)
|
||||
return
|
||||
e.stopPropagation()
|
||||
onDelete?.(credential)
|
||||
}}
|
||||
>
|
||||
<RiDeleteBinLine className={cn(
|
||||
'h-4 w-4 text-text-tertiary',
|
||||
!disableDeleteWhenSelected && 'hover:text-text-destructive',
|
||||
disableDeleteWhenSelected && 'opacity-50',
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
render={(
|
||||
<ActionButton
|
||||
className="hover:bg-transparent"
|
||||
onClick={(e) => {
|
||||
if (disabled || disableDeleteWhenSelected)
|
||||
return
|
||||
e.stopPropagation()
|
||||
onDelete?.(credential)
|
||||
}}
|
||||
>
|
||||
<span className={cn(
|
||||
'i-ri-delete-bin-line h-4 w-4 text-text-tertiary',
|
||||
!disableDeleteWhenSelected && 'hover:text-text-destructive',
|
||||
disableDeleteWhenSelected && 'opacity-50',
|
||||
)}
|
||||
/>
|
||||
</ActionButton>
|
||||
)}
|
||||
/>
|
||||
</ActionButton>
|
||||
/>
|
||||
<TooltipContent>
|
||||
{disableDeleteWhenSelected ? disableDeleteTip : t('operation.delete', { ns: 'common' })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
@@ -139,8 +146,9 @@ const CredentialItem = ({
|
||||
|
||||
if (credential.not_allowed_to_use) {
|
||||
return (
|
||||
<Tooltip popupContent={t('auth.customCredentialUnavailable', { ns: 'plugin' })}>
|
||||
{Item}
|
||||
<Tooltip>
|
||||
<TooltipTrigger render={Item} />
|
||||
<TooltipContent>{t('auth.customCredentialUnavailable', { ns: 'plugin' })}</TooltipContent>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -53,4 +53,14 @@ describe('useCredentialStatus', () => {
|
||||
expect(result.current.hasCredential).toBe(false)
|
||||
expect(result.current.available_credentials).toBeUndefined()
|
||||
})
|
||||
|
||||
it('handles undefined provider gracefully', () => {
|
||||
const { result } = renderHook(() => useCredentialStatus(undefined))
|
||||
expect(result.current.hasCredential).toBe(false)
|
||||
expect(result.current.authorized).toBeFalsy()
|
||||
expect(result.current.authRemoved).toBe(false)
|
||||
expect(result.current.available_credentials).toBeUndefined()
|
||||
expect(result.current.current_credential_id).toBeUndefined()
|
||||
expect(result.current.current_credential_name).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,12 +3,12 @@ import type {
|
||||
} from '../../declarations'
|
||||
import { useMemo } from 'react'
|
||||
|
||||
export const useCredentialStatus = (provider: ModelProvider) => {
|
||||
export const useCredentialStatus = (provider: ModelProvider | undefined) => {
|
||||
const {
|
||||
current_credential_id,
|
||||
current_credential_name,
|
||||
available_credentials,
|
||||
} = provider.custom_configuration
|
||||
} = provider?.custom_configuration ?? {}
|
||||
const hasCredential = !!available_credentials?.length
|
||||
const authorized = current_credential_id && current_credential_name
|
||||
const authRemoved = hasCredential && !current_credential_id && !current_credential_name
|
||||
|
||||
@@ -10,7 +10,7 @@ const ModelBadge: FC<ModelBadgeProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
return (
|
||||
<div className={cn('system-2xs-medium-uppercase flex h-[18px] cursor-default items-center rounded-[5px] border border-divider-deep px-1 text-text-tertiary', className)}>
|
||||
<div className={cn('inline-flex h-[18px] shrink-0 items-center justify-center whitespace-nowrap rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-[5px] text-text-tertiary system-2xs-medium-uppercase', className)}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import type { Credential, CredentialFormSchema, CustomModel, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import type { Credential, CredentialFormSchema, ModelProvider } from '../declarations'
|
||||
import { fireEvent, render, screen, waitFor, within } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
CurrentSystemQuotaTypeEnum,
|
||||
@@ -45,6 +43,15 @@ const mockHandlers = vi.hoisted(() => ({
|
||||
handleActiveCredential: vi.fn(),
|
||||
}))
|
||||
|
||||
type FormResponse = {
|
||||
isCheckValidated: boolean
|
||||
values: Record<string, unknown>
|
||||
}
|
||||
const mockFormState = vi.hoisted(() => ({
|
||||
responses: [] as FormResponse[],
|
||||
setFieldValue: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('../model-auth/hooks', () => ({
|
||||
useCredentialData: () => ({
|
||||
isLoading: mockState.isLoading,
|
||||
@@ -79,6 +86,36 @@ vi.mock('../hooks', () => ({
|
||||
useLanguage: () => 'en_US',
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/form/form-scenarios/auth', async () => {
|
||||
const React = await import('react')
|
||||
const AuthForm = React.forwardRef(({
|
||||
onChange,
|
||||
}: {
|
||||
onChange?: (field: string, value: string) => void
|
||||
}, ref: React.ForwardedRef<{ getFormValues: () => FormResponse, getForm: () => { setFieldValue: (field: string, value: string) => void } }>) => {
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
getFormValues: () => mockFormState.responses.shift() || { isCheckValidated: false, values: {} },
|
||||
getForm: () => ({ setFieldValue: mockFormState.setFieldValue }),
|
||||
}))
|
||||
return (
|
||||
<div>
|
||||
<button type="button" onClick={() => onChange?.('__model_name', 'updated-model')}>Model Name Change</button>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
return { default: AuthForm }
|
||||
})
|
||||
|
||||
vi.mock('../model-auth', () => ({
|
||||
CredentialSelector: ({ onSelect }: { onSelect: (credential: Credential & { addNewCredential?: boolean }) => void }) => (
|
||||
<div>
|
||||
<button type="button" onClick={() => onSelect({ credential_id: 'existing' })}>Choose Existing</button>
|
||||
<button type="button" onClick={() => onSelect({ credential_id: 'new', addNewCredential: true })}>Add New</button>
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createI18n = (text: string) => ({ en_US: text, zh_Hans: text })
|
||||
|
||||
const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
|
||||
@@ -121,7 +158,7 @@ const createProvider = (overrides?: Partial<ModelProvider>): ModelProvider => ({
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderModal = (overrides?: Partial<ComponentProps<typeof ModelModal>>) => {
|
||||
const renderModal = (overrides?: Partial<React.ComponentProps<typeof ModelModal>>) => {
|
||||
const provider = createProvider()
|
||||
const props = {
|
||||
provider,
|
||||
@@ -131,50 +168,13 @@ const renderModal = (overrides?: Partial<ComponentProps<typeof ModelModal>>) =>
|
||||
onRemove: vi.fn(),
|
||||
...overrides,
|
||||
}
|
||||
render(<ModelModal {...props} />)
|
||||
return props
|
||||
const view = render(<ModelModal {...props} />)
|
||||
return {
|
||||
...props,
|
||||
unmount: view.unmount,
|
||||
}
|
||||
}
|
||||
|
||||
const mockFormRef1 = {
|
||||
getFormValues: vi.fn(),
|
||||
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
|
||||
}
|
||||
|
||||
const mockFormRef2 = {
|
||||
getFormValues: vi.fn(),
|
||||
getForm: vi.fn(() => ({ setFieldValue: vi.fn() })),
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/base/form/form-scenarios/auth', () => ({
|
||||
default: React.forwardRef((props: { formSchemas: Record<string, unknown>[], onChange?: (f: string, v: string) => void }, ref: React.ForwardedRef<unknown>) => {
|
||||
React.useImperativeHandle(ref, () => {
|
||||
// Return the mock depending on schemas passed (hacky but works for refs)
|
||||
if (props.formSchemas.length > 0 && props.formSchemas[0].name === '__model_name')
|
||||
return mockFormRef1
|
||||
return mockFormRef2
|
||||
})
|
||||
return (
|
||||
<div data-testid="auth-form" onClick={() => props.onChange?.('test-field', 'val')}>
|
||||
AuthForm Mock (
|
||||
{props.formSchemas.length}
|
||||
{' '}
|
||||
fields)
|
||||
</div>
|
||||
)
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../model-auth', () => ({
|
||||
CredentialSelector: ({ onSelect }: { onSelect: (val: unknown) => void }) => (
|
||||
<button onClick={() => onSelect({ addNewCredential: true })} data-testid="credential-selector">
|
||||
Select Credential
|
||||
</button>
|
||||
),
|
||||
useAuth: vi.fn(),
|
||||
useCredentialData: vi.fn(),
|
||||
useModelFormSchemas: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('ModelModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -187,131 +187,168 @@ describe('ModelModal', () => {
|
||||
mockState.formValues = {}
|
||||
mockState.modelNameAndTypeFormSchemas = []
|
||||
mockState.modelNameAndTypeFormValues = {}
|
||||
|
||||
// reset form refs
|
||||
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __model_name: 'test', __model_type: ModelTypeEnum.textGeneration } })
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'test_auth', api_key: 'sk-test' } })
|
||||
mockFormState.responses = []
|
||||
})
|
||||
|
||||
it('should render title and loading state for predefined credential modal', () => {
|
||||
it('should show title, description, and loading state for predefined models', () => {
|
||||
mockState.isLoading = true
|
||||
renderModal()
|
||||
|
||||
const predefined = renderModal()
|
||||
|
||||
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.title')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.modelProvider.auth.apiKeyModal.desc')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByRole('status')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeDisabled()
|
||||
|
||||
it('should render model credential title when mode is configModelCredential', () => {
|
||||
renderModal({
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration },
|
||||
})
|
||||
predefined.unmount()
|
||||
const customizable = renderModal({ configurateMethod: ConfigurationMethodEnum.customizableModel })
|
||||
expect(screen.queryByText('common.modelProvider.auth.apiKeyModal.desc')).not.toBeInTheDocument()
|
||||
customizable.unmount()
|
||||
|
||||
mockState.credentialData = { credentials: {}, available_credentials: [] }
|
||||
renderModal({ mode: ModelModalModeEnum.configModelCredential, model: { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration } })
|
||||
expect(screen.getByText('common.modelProvider.auth.addModelCredential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render edit credential title when credential exists', () => {
|
||||
renderModal({
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
credential: { credential_id: '1' } as unknown as Credential,
|
||||
})
|
||||
expect(screen.getByText('common.modelProvider.auth.editModelCredential')).toBeInTheDocument()
|
||||
it('should reveal the credential label when adding a new credential', () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList })
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.auth.modelCredential')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('Add New'))
|
||||
|
||||
expect(screen.getByText('common.modelProvider.auth.modelCredential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should change title to Add Model when mode is configCustomModel', () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
expect(screen.getByText('common.modelProvider.auth.addModel')).toBeInTheDocument()
|
||||
it('should call onCancel when the cancel button is clicked', () => {
|
||||
const { onCancel } = renderModal()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should validate and fail save if form is invalid in configCustomModel mode', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
mockFormRef1.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
|
||||
renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should call onCancel when the escape key is pressed', () => {
|
||||
const { onCancel } = renderModal()
|
||||
|
||||
it('should validate and save new credential and model in configCustomModel mode', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text' } as unknown as CredentialFormSchema]
|
||||
const props = renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'sk-test' },
|
||||
name: 'test_auth',
|
||||
model: 'test',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
expect(props.onSave).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should save credential only in standard configProviderCredential mode', async () => {
|
||||
const { onSave } = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'sk-test' },
|
||||
name: 'test_auth',
|
||||
})
|
||||
expect(onSave).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should save active credential and cancel when picking existing credential in addCustomModelToModelList mode', async () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm1', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
|
||||
// By default selected is undefined so button clicks form
|
||||
// Let's not click credential selector, so it evaluates without it. If selectedCredential is undefined, form validation is checked.
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: false, values: {} })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleSaveCredential).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should save active credential when picking existing credential in addCustomModelToModelList mode', async () => {
|
||||
renderModal({ mode: ModelModalModeEnum.addCustomModelToModelList, model: { model: 'm2', model_type: ModelTypeEnum.textGeneration } as unknown as CustomModel })
|
||||
|
||||
// Select existing credential (addNewCredential: true simulates new but we can simulate false if we just hack the mocked state in the component, but it's internal.
|
||||
// The credential selector sets selectedCredential.
|
||||
fireEvent.click(screen.getByTestId('credential-selector')) // Sets addNewCredential = true internally, so it proceeds to form save
|
||||
|
||||
mockFormRef2.getFormValues.mockReturnValue({ isCheckValidated: true, values: { __authorization_name__: 'auth', api: 'key' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api: 'key' },
|
||||
name: 'auth',
|
||||
model: 'm2',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should open and confirm deletion of credential', () => {
|
||||
mockState.credentialData = { credentials: { api_key: '123' }, available_credentials: [] }
|
||||
mockState.formValues = { api_key: '123' } // To trigger isEditMode = true
|
||||
const credential = { credential_id: 'c1' } as unknown as Credential
|
||||
renderModal({ credential })
|
||||
|
||||
// Open Delete Confirm
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.remove' }))
|
||||
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith(credential, undefined)
|
||||
|
||||
// Simulate the dialog appearing and confirming
|
||||
mockState.deleteCredentialId = 'c1'
|
||||
renderModal({ credential }) // Re-render logic mock
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.confirm' })[0])
|
||||
|
||||
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should bind escape key to cancel', () => {
|
||||
const props = renderModal()
|
||||
fireEvent.keyDown(document, { key: 'Escape' })
|
||||
expect(props.onCancel).toHaveBeenCalled()
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should confirm deletion when a delete dialog is shown', () => {
|
||||
mockState.credentialData = { credentials: { api_key: 'secret' }, available_credentials: [] }
|
||||
mockState.deleteCredentialId = 'delete-id'
|
||||
|
||||
const credential: Credential = { credential_id: 'cred-1' }
|
||||
const { onCancel } = renderModal({ credential })
|
||||
|
||||
const alertDialog = screen.getByRole('alertdialog', { hidden: true })
|
||||
expect(alertDialog).toHaveTextContent('common.modelProvider.confirmDelete')
|
||||
|
||||
fireEvent.click(within(alertDialog).getByRole('button', { hidden: true, name: 'common.operation.confirm' }))
|
||||
|
||||
expect(mockHandlers.handleConfirmDelete).toHaveBeenCalledTimes(1)
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should handle save flows for different modal modes', async () => {
|
||||
mockState.modelNameAndTypeFormSchemas = [{ variable: '__model_name', type: 'text-input' } as unknown as CredentialFormSchema]
|
||||
mockState.formSchemas = [{ variable: 'api_key', type: 'secret-input' } as unknown as CredentialFormSchema]
|
||||
mockFormState.responses = [
|
||||
{ isCheckValidated: true, values: { __model_name: 'custom-model', __model_type: ModelTypeEnum.textGeneration } },
|
||||
{ isCheckValidated: true, values: { __authorization_name__: 'Auth Name', api_key: 'secret' } },
|
||||
]
|
||||
const configCustomModel = renderModal({ mode: ModelModalModeEnum.configCustomModel })
|
||||
fireEvent.click(screen.getAllByText('Model Name Change')[0])
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
expect(mockFormState.setFieldValue).toHaveBeenCalledWith('__model_name', 'updated-model')
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'secret' },
|
||||
name: 'Auth Name',
|
||||
model: 'custom-model',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
expect(configCustomModel.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Auth Name', api_key: 'secret' })
|
||||
configCustomModel.unmount()
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Model Auth', api_key: 'abc' } }]
|
||||
const model = { model: 'gpt-4', model_type: ModelTypeEnum.textGeneration }
|
||||
const configModelCredential = renderModal({
|
||||
mode: ModelModalModeEnum.configModelCredential,
|
||||
model,
|
||||
credential: { credential_id: 'cred-123' },
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: 'cred-123',
|
||||
credentials: { api_key: 'abc' },
|
||||
name: 'Model Auth',
|
||||
model: 'gpt-4',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
expect(configModelCredential.onSave).toHaveBeenCalledWith({ __authorization_name__: 'Model Auth', api_key: 'abc' })
|
||||
configModelCredential.unmount()
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'Provider Auth', api_key: 'provider-key' } }]
|
||||
const configProviderCredential = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'provider-key' },
|
||||
name: 'Provider Auth',
|
||||
})
|
||||
})
|
||||
configProviderCredential.unmount()
|
||||
|
||||
const addToModelList = renderModal({
|
||||
mode: ModelModalModeEnum.addCustomModelToModelList,
|
||||
model,
|
||||
})
|
||||
fireEvent.click(screen.getByText('Choose Existing'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
expect(mockHandlers.handleActiveCredential).toHaveBeenCalledWith({ credential_id: 'existing' }, model)
|
||||
expect(addToModelList.onCancel).toHaveBeenCalled()
|
||||
addToModelList.unmount()
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: true, values: { __authorization_name__: 'New Auth', api_key: 'new-key' } }]
|
||||
const addToModelListWithNew = renderModal({
|
||||
mode: ModelModalModeEnum.addCustomModelToModelList,
|
||||
model,
|
||||
})
|
||||
fireEvent.click(screen.getByText('Add New'))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledWith({
|
||||
credential_id: undefined,
|
||||
credentials: { api_key: 'new-key' },
|
||||
name: 'New Auth',
|
||||
model: 'gpt-4',
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
})
|
||||
})
|
||||
addToModelListWithNew.unmount()
|
||||
|
||||
mockFormState.responses = [{ isCheckValidated: false, values: {} }]
|
||||
const invalidSave = renderModal({ mode: ModelModalModeEnum.configProviderCredential })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await waitFor(() => {
|
||||
expect(mockHandlers.handleSaveCredential).toHaveBeenCalledTimes(4)
|
||||
})
|
||||
invalidSave.unmount()
|
||||
|
||||
mockState.credentialData = { credentials: { api_key: 'value' }, available_credentials: [] }
|
||||
mockState.formValues = { api_key: 'value' }
|
||||
const removable = renderModal({ credential: { credential_id: 'remove-1' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.remove' }))
|
||||
expect(mockHandlers.openConfirmDelete).toHaveBeenCalledWith({ credential_id: 'remove-1' }, undefined)
|
||||
removable.unmount()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,11 +9,9 @@ import type {
|
||||
FormRefObject,
|
||||
FormSchema,
|
||||
} from '@/app/components/base/form/types'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
@@ -21,15 +19,23 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import AuthForm from '@/app/components/base/form/form-scenarios/auth'
|
||||
import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import {
|
||||
useAuth,
|
||||
useCredentialData,
|
||||
@@ -197,7 +203,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="title-2xl-semi-bold text-text-primary">
|
||||
<div className="text-text-primary title-2xl-semi-bold">
|
||||
{label}
|
||||
</div>
|
||||
)
|
||||
@@ -206,7 +212,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
const modalDesc = useMemo(() => {
|
||||
if (providerFormSchemaPredefined) {
|
||||
return (
|
||||
<div className="system-xs-regular mt-1 text-text-tertiary">
|
||||
<div className="mt-1 text-text-tertiary system-xs-regular">
|
||||
{t('modelProvider.auth.apiKeyModal.desc', { ns: 'common' })}
|
||||
</div>
|
||||
)
|
||||
@@ -223,7 +229,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
className="mr-2 h-4 w-4 shrink-0"
|
||||
provider={provider}
|
||||
/>
|
||||
<div className="system-md-regular mr-1 text-text-secondary">{renderI18nObject(provider.label)}</div>
|
||||
<div className="mr-1 text-text-secondary system-md-regular">{renderI18nObject(provider.label)}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -235,7 +241,7 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
provider={provider}
|
||||
modelName={model.model}
|
||||
/>
|
||||
<div className="system-md-regular mr-1 text-text-secondary">{model.model}</div>
|
||||
<div className="mr-1 text-text-secondary system-md-regular">{model.model}</div>
|
||||
<Badge>{model.model_type}</Badge>
|
||||
</div>
|
||||
)
|
||||
@@ -275,174 +281,171 @@ const ModelModal: FC<ModelModalProps> = ({
|
||||
}, [])
|
||||
const notAllowCustomCredential = provider.allow_custom_token === false
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.key === 'Escape') {
|
||||
event.stopPropagation()
|
||||
onCancel()
|
||||
}
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', handleKeyDown, true)
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown, true)
|
||||
}
|
||||
const handleOpenChange = useCallback((open: boolean) => {
|
||||
if (!open)
|
||||
onCancel()
|
||||
}, [onCancel])
|
||||
|
||||
const handleConfirmOpenChange = useCallback((open: boolean) => {
|
||||
if (!open)
|
||||
closeConfirmDelete()
|
||||
}, [closeConfirmDelete])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem open>
|
||||
<PortalToFollowElemContent className="z-[60] h-full w-full">
|
||||
<div className="fixed inset-0 flex items-center justify-center bg-black/[.25]">
|
||||
<div className="relative w-[640px] rounded-2xl bg-components-panel-bg shadow-xl">
|
||||
<div
|
||||
className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
<div className="p-6 pb-3">
|
||||
{modalTitle}
|
||||
{modalDesc}
|
||||
{modalModel}
|
||||
</div>
|
||||
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
|
||||
{
|
||||
mode === ModelModalModeEnum.configCustomModel && (
|
||||
<AuthForm
|
||||
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={modelNameAndTypeFormValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef1}
|
||||
onChange={handleModelNameAndTypeChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
mode === ModelModalModeEnum.addCustomModelToModelList && (
|
||||
<CredentialSelector
|
||||
credentials={available_credentials || []}
|
||||
onSelect={setSelectedCredential}
|
||||
selectedCredential={selectedCredential}
|
||||
disabled={isLoading}
|
||||
notAllowAddNewCredential={notAllowCustomCredential}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
showCredentialLabel && (
|
||||
<div className="system-xs-medium-uppercase mb-3 mt-6 flex items-center text-text-tertiary">
|
||||
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
|
||||
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-3 flex items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading
|
||||
&& showCredentialForm
|
||||
&& (
|
||||
<AuthForm
|
||||
formSchemas={formSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
showRadioUI: formSchema.type === FormTypeEnum.radio,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={formValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef2}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className="flex justify-between p-6 pt-5">
|
||||
{
|
||||
(provider.help && (provider.help.title || provider.help.url))
|
||||
? (
|
||||
<a
|
||||
href={provider.help?.url[language] || provider.help?.url.en_US}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="system-xs-regular mt-2 inline-block align-middle text-text-accent"
|
||||
onClick={e => !provider.help.url && e.preventDefault()}
|
||||
>
|
||||
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
|
||||
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
|
||||
</a>
|
||||
)
|
||||
: <div />
|
||||
}
|
||||
<div className="ml-2 flex items-center justify-end space-x-2">
|
||||
{
|
||||
isEditMode && (
|
||||
<Button
|
||||
variant="warning"
|
||||
onClick={() => openConfirmDelete(credential, model)}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={isLoading || doingAction}
|
||||
>
|
||||
{saveButtonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
|
||||
<div className="border-t-[0.5px] border-t-divider-regular">
|
||||
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
|
||||
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
|
||||
{t('modelProvider.encrypted.front', { ns: 'common' })}
|
||||
<a
|
||||
className="mx-1 text-text-accent"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
|
||||
>
|
||||
PKCS1_OAEP
|
||||
</a>
|
||||
{t('modelProvider.encrypted.back', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<Dialog open onOpenChange={handleOpenChange}>
|
||||
<DialogContent
|
||||
backdropProps={{ forceRender: true }}
|
||||
className="w-[640px] max-w-[640px] overflow-hidden p-0"
|
||||
>
|
||||
<DialogCloseButton className="right-5 top-5 h-8 w-8" />
|
||||
<div className="p-6 pb-3">
|
||||
{modalTitle}
|
||||
{modalDesc}
|
||||
{modalModel}
|
||||
</div>
|
||||
<div className="max-h-[calc(100vh-320px)] overflow-y-auto px-6 py-3">
|
||||
{
|
||||
deleteCredentialId && (
|
||||
<Confirm
|
||||
isShow
|
||||
title={t('modelProvider.confirmDelete', { ns: 'common' })}
|
||||
isDisabled={doingAction}
|
||||
onCancel={closeConfirmDelete}
|
||||
onConfirm={handleDeleteCredential}
|
||||
mode === ModelModalModeEnum.configCustomModel && (
|
||||
<AuthForm
|
||||
formSchemas={modelNameAndTypeFormSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={modelNameAndTypeFormValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef1}
|
||||
onChange={handleModelNameAndTypeChange}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
mode === ModelModalModeEnum.addCustomModelToModelList && (
|
||||
<CredentialSelector
|
||||
credentials={available_credentials || []}
|
||||
onSelect={setSelectedCredential}
|
||||
selectedCredential={selectedCredential}
|
||||
disabled={isLoading}
|
||||
notAllowAddNewCredential={notAllowCustomCredential}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
showCredentialLabel && (
|
||||
<div className="mb-3 mt-6 flex items-center text-text-tertiary system-xs-medium-uppercase">
|
||||
{t('modelProvider.auth.modelCredential', { ns: 'common' })}
|
||||
<div className="ml-2 h-px grow bg-gradient-to-r from-divider-regular to-background-gradient-mask-transparent" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-3 flex items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading
|
||||
&& showCredentialForm
|
||||
&& (
|
||||
<AuthForm
|
||||
formSchemas={formSchemas.map((formSchema) => {
|
||||
return {
|
||||
...formSchema,
|
||||
name: formSchema.variable,
|
||||
showRadioUI: formSchema.type === FormTypeEnum.radio,
|
||||
}
|
||||
}) as FormSchema[]}
|
||||
defaultValues={formValues}
|
||||
inputClassName="justify-start"
|
||||
ref={formRef2}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
<div className="flex justify-between p-6 pt-5">
|
||||
{
|
||||
(provider.help && (provider.help.title || provider.help.url))
|
||||
? (
|
||||
<a
|
||||
href={provider.help?.url[language] || provider.help?.url.en_US}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mt-2 inline-block align-middle text-text-accent system-xs-regular"
|
||||
onClick={e => !provider.help.url && e.preventDefault()}
|
||||
>
|
||||
{provider.help.title?.[language] || provider.help.url[language] || provider.help.title?.en_US || provider.help.url.en_US}
|
||||
<LinkExternal02 className="ml-1 mt-[-2px] inline-block h-3 w-3" />
|
||||
</a>
|
||||
)
|
||||
: <div />
|
||||
}
|
||||
<div className="ml-2 flex items-center justify-end space-x-2">
|
||||
{
|
||||
isEditMode && (
|
||||
<Button
|
||||
variant="warning"
|
||||
onClick={() => openConfirmDelete(credential, model)}
|
||||
>
|
||||
{t('operation.remove', { ns: 'common' })}
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
onClick={handleSave}
|
||||
disabled={isLoading || doingAction}
|
||||
>
|
||||
{saveButtonText}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
(mode === ModelModalModeEnum.configCustomModel || mode === ModelModalModeEnum.configProviderCredential) && (
|
||||
<div className="border-t-[0.5px] border-t-divider-regular">
|
||||
<div className="flex items-center justify-center rounded-b-2xl bg-background-section-burn py-3 text-xs text-text-tertiary">
|
||||
<Lock01 className="mr-1 h-3 w-3 text-text-tertiary" />
|
||||
{t('modelProvider.encrypted.front', { ns: 'common' })}
|
||||
<a
|
||||
className="mx-1 text-text-accent"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html"
|
||||
>
|
||||
PKCS1_OAEP
|
||||
</a>
|
||||
{t('modelProvider.encrypted.back', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</DialogContent>
|
||||
<AlertDialog open={!!deleteCredentialId} onOpenChange={handleConfirmOpenChange}>
|
||||
<AlertDialogContent backdropProps={{ forceRender: true }}>
|
||||
<div className="flex flex-col gap-2 p-6 pb-4">
|
||||
<AlertDialogTitle className="text-text-primary title-2xl-semi-bold">
|
||||
{t('modelProvider.confirmDelete', { ns: 'common' })}
|
||||
</AlertDialogTitle>
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton>{t('operation.cancel', { ns: 'common' })}</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton
|
||||
disabled={doingAction}
|
||||
onClick={handleDeleteCredential}
|
||||
>
|
||||
{t('operation.confirm', { ns: 'common' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,10 @@ import { useTranslation } from 'react-i18next'
|
||||
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/app/components/base/ui/popover'
|
||||
import { PROVIDER_WITH_PRESET_TONE, STOP_PARAMETER_RULE, TONE_LIST } from '@/config'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useModelParameterRules } from '@/service/use-common'
|
||||
@@ -129,117 +129,118 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
<Popover
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement={isInWorkflow ? 'left' : 'bottom-end'}
|
||||
offset={4}
|
||||
onOpenChange={(newOpen) => {
|
||||
if (readonly)
|
||||
return
|
||||
setOpen(newOpen)
|
||||
}}
|
||||
>
|
||||
<div className="relative">
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={() => {
|
||||
if (readonly)
|
||||
return
|
||||
setOpen(v => !v)
|
||||
}}
|
||||
className="block"
|
||||
>
|
||||
{
|
||||
renderTrigger
|
||||
? renderTrigger({
|
||||
open,
|
||||
disabled,
|
||||
modelDisabled,
|
||||
hasDeprecated,
|
||||
currentProvider,
|
||||
currentModel,
|
||||
providerName: provider,
|
||||
modelId,
|
||||
})
|
||||
: (
|
||||
<Trigger
|
||||
disabled={disabled}
|
||||
isInWorkflow={isInWorkflow}
|
||||
modelDisabled={modelDisabled}
|
||||
hasDeprecated={hasDeprecated}
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
providerName={provider}
|
||||
modelId={modelId}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className={cn('z-[60]', portalToFollowElemContentClassName)}>
|
||||
<div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
|
||||
<div className={cn('max-h-[420px] overflow-y-auto p-4 pt-3')}>
|
||||
<div className="relative">
|
||||
<div className={cn('system-sm-semibold mb-1 flex h-6 items-center text-text-secondary')}>
|
||||
{t('modelProvider.model', { ns: 'common' }).toLocaleUpperCase()}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
|
||||
modelList={activeTextGenerationModelList}
|
||||
onSelect={handleChangeModel}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
!!parameterRules.length && (
|
||||
<div className="my-3 h-px bg-divider-subtle" />
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-5"><Loading /></div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading && !!parameterRules.length && (
|
||||
<div className="mb-2 flex items-center justify-between">
|
||||
<div className={cn('system-sm-semibold flex h-6 items-center text-text-secondary')}>{t('modelProvider.parameters', { ns: 'common' })}</div>
|
||||
{
|
||||
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
|
||||
<PresetsParameter onSelect={handleSelectPresetParameter} />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading && !!parameterRules.length && (
|
||||
[
|
||||
...parameterRules,
|
||||
...(isAdvancedMode ? [STOP_PARAMETER_RULE] : []),
|
||||
].map(parameter => (
|
||||
<ParameterItem
|
||||
key={`${modelId}-${parameter.name}`}
|
||||
parameterRule={parameter}
|
||||
value={completionParams?.[parameter.name]}
|
||||
onChange={v => handleParamChange(parameter.name, v)}
|
||||
onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)}
|
||||
<PopoverTrigger
|
||||
render={(
|
||||
<div className="block">
|
||||
{
|
||||
renderTrigger
|
||||
? renderTrigger({
|
||||
open,
|
||||
disabled,
|
||||
modelDisabled,
|
||||
hasDeprecated,
|
||||
currentProvider,
|
||||
currentModel,
|
||||
providerName: provider,
|
||||
modelId,
|
||||
})
|
||||
: (
|
||||
<Trigger
|
||||
disabled={disabled}
|
||||
isInWorkflow={isInWorkflow}
|
||||
modelDisabled={modelDisabled}
|
||||
hasDeprecated={hasDeprecated}
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
providerName={provider}
|
||||
modelId={modelId}
|
||||
/>
|
||||
))
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{!hideDebugWithMultipleModel && (
|
||||
<div
|
||||
className="bg-components-section-burn system-sm-regular flex h-[50px] cursor-pointer items-center justify-between rounded-b-xl border-t border-t-divider-subtle px-4 text-text-accent"
|
||||
onClick={() => onDebugWithMultipleModelChange?.()}
|
||||
>
|
||||
{
|
||||
debugWithMultipleModel
|
||||
? t('debugAsSingleModel', { ns: 'appDebug' })
|
||||
: t('debugAsMultipleModel', { ns: 'appDebug' })
|
||||
}
|
||||
<ArrowNarrowLeft className="h-3 w-3 rotate-180" />
|
||||
</div>
|
||||
)}
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</div>
|
||||
</PortalToFollowElem>
|
||||
)}
|
||||
/>
|
||||
<PopoverContent
|
||||
placement={isInWorkflow ? 'left' : 'bottom-end'}
|
||||
sideOffset={4}
|
||||
className={portalToFollowElemContentClassName}
|
||||
popupClassName={cn(popupClassName, 'w-[389px] rounded-2xl')}
|
||||
>
|
||||
<div className="max-h-[420px] overflow-y-auto p-4 pt-3">
|
||||
<div className="relative">
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-semibold">
|
||||
{t('modelProvider.model', { ns: 'common' }).toLocaleUpperCase()}
|
||||
</div>
|
||||
<ModelSelector
|
||||
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
|
||||
modelList={activeTextGenerationModelList}
|
||||
onSelect={handleChangeModel}
|
||||
onHide={() => setOpen(false)}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
!!parameterRules.length && (
|
||||
<div className="my-3 h-px bg-divider-subtle" />
|
||||
)
|
||||
}
|
||||
{
|
||||
isLoading && (
|
||||
<div className="mt-5"><Loading /></div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading && !!parameterRules.length && (
|
||||
<div className="mb-2 flex items-center justify-between">
|
||||
<div className="flex h-6 items-center text-text-secondary system-sm-semibold">{t('modelProvider.parameters', { ns: 'common' })}</div>
|
||||
{
|
||||
PROVIDER_WITH_PRESET_TONE.includes(provider) && (
|
||||
<PresetsParameter onSelect={handleSelectPresetParameter} />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!isLoading && !!parameterRules.length && (
|
||||
[
|
||||
...parameterRules,
|
||||
...(isAdvancedMode ? [STOP_PARAMETER_RULE] : []),
|
||||
].map(parameter => (
|
||||
<ParameterItem
|
||||
key={`${modelId}-${parameter.name}`}
|
||||
parameterRule={parameter}
|
||||
value={completionParams?.[parameter.name]}
|
||||
onChange={v => handleParamChange(parameter.name, v)}
|
||||
onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)}
|
||||
isInWorkflow={isInWorkflow}
|
||||
/>
|
||||
))
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{!hideDebugWithMultipleModel && (
|
||||
<div
|
||||
className="flex h-[50px] cursor-pointer items-center justify-between rounded-b-xl border-t border-t-divider-subtle px-4 text-text-accent system-sm-regular"
|
||||
onClick={() => onDebugWithMultipleModelChange?.()}
|
||||
>
|
||||
{
|
||||
debugWithMultipleModel
|
||||
? t('debugAsSingleModel', { ns: 'appDebug' })
|
||||
: t('debugAsMultipleModel', { ns: 'appDebug' })
|
||||
}
|
||||
<ArrowNarrowLeft className="h-3 w-3 rotate-180" />
|
||||
</div>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import DeprecatedModelTrigger from './deprecated-model-trigger'
|
||||
|
||||
vi.mock('../model-icon', () => ({
|
||||
default: ({ modelName }: { modelName: string }) => <span>{modelName}</span>,
|
||||
}))
|
||||
|
||||
const mockUseProviderContext = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: mockUseProviderContext,
|
||||
}))
|
||||
|
||||
describe('DeprecatedModelTrigger', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
modelProviders: [{ provider: 'someone-else' }, { provider: 'openai' }],
|
||||
})
|
||||
})
|
||||
|
||||
it('should render model name', () => {
|
||||
render(<DeprecatedModelTrigger modelName="gpt-deprecated" providerName="openai" />)
|
||||
expect(screen.getAllByText('gpt-deprecated').length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should show deprecated tooltip when warn icon is hovered', async () => {
|
||||
const { container } = render(
|
||||
<DeprecatedModelTrigger
|
||||
modelName="gpt-deprecated"
|
||||
providerName="openai"
|
||||
showWarnIcon
|
||||
/>,
|
||||
)
|
||||
|
||||
const tooltipTrigger = container.querySelector('[data-state]') as HTMLElement
|
||||
fireEvent.mouseEnter(tooltipTrigger)
|
||||
|
||||
expect(await screen.findByText('common.modelProvider.deprecated')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render when provider is not found', () => {
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
modelProviders: [{ provider: 'someone-else' }],
|
||||
})
|
||||
|
||||
render(<DeprecatedModelTrigger modelName="gpt-deprecated" providerName="openai" />)
|
||||
expect(screen.getAllByText('gpt-deprecated').length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
it('should not show deprecated tooltip when warn icon is disabled', async () => {
|
||||
render(
|
||||
<DeprecatedModelTrigger
|
||||
modelName="gpt-deprecated"
|
||||
providerName="openai"
|
||||
showWarnIcon={false}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.deprecated')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,54 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ModelIcon from '../model-icon'
|
||||
|
||||
type ModelTriggerProps = {
|
||||
modelName: string
|
||||
providerName: string
|
||||
className?: string
|
||||
showWarnIcon?: boolean
|
||||
contentClassName?: string
|
||||
}
|
||||
const ModelTrigger: FC<ModelTriggerProps> = ({
|
||||
modelName,
|
||||
providerName,
|
||||
className,
|
||||
showWarnIcon,
|
||||
contentClassName,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { modelProviders } = useProviderContext()
|
||||
const currentProvider = modelProviders.find(provider => provider.provider === providerName)
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn('group box-content flex h-8 grow cursor-pointer items-center gap-1 rounded-lg bg-components-input-bg-disabled p-[3px] pl-1', className)}
|
||||
>
|
||||
<div className={cn('flex w-full items-center', contentClassName)}>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-1 py-[1px]">
|
||||
<ModelIcon
|
||||
className="h-4 w-4"
|
||||
provider={currentProvider}
|
||||
modelName={modelName}
|
||||
/>
|
||||
<div className="system-sm-regular truncate text-components-input-text-filled">
|
||||
{modelName}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center justify-center">
|
||||
{showWarnIcon && (
|
||||
<Tooltip popupContent={t('modelProvider.deprecated', { ns: 'common' })}>
|
||||
<AlertTriangle className="h-4 w-4 text-text-warning-secondary" />
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelTrigger
|
||||
@@ -1,31 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import EmptyTrigger from './empty-trigger'
|
||||
|
||||
describe('EmptyTrigger', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render configure model text', () => {
|
||||
render(<EmptyTrigger open={false} />)
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// open=true: hover bg class present
|
||||
it('should apply hover background class when open is true', () => {
|
||||
// Act
|
||||
const { container } = render(<EmptyTrigger open={true} />)
|
||||
|
||||
// Assert
|
||||
expect(container.firstChild).toHaveClass('bg-components-input-bg-hover')
|
||||
})
|
||||
|
||||
// className prop truthy: custom className appears on root
|
||||
it('should apply custom className when provided', () => {
|
||||
// Act
|
||||
const { container } = render(<EmptyTrigger open={false} className="custom-class" />)
|
||||
|
||||
// Assert
|
||||
expect(container.firstChild).toHaveClass('custom-class')
|
||||
})
|
||||
})
|
||||
@@ -1,42 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import { RiEqualizer2Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type ModelTriggerProps = {
|
||||
open: boolean
|
||||
className?: string
|
||||
}
|
||||
const ModelTrigger: FC<ModelTriggerProps> = ({
|
||||
open,
|
||||
className,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex cursor-pointer items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1 hover:bg-components-input-bg-hover',
|
||||
open && 'bg-components-input-bg-hover',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="flex grow items-center">
|
||||
<div className="mr-1.5 flex h-4 w-4 items-center justify-center rounded-[5px] border border-dashed border-divider-regular">
|
||||
<CubeOutline className="h-3 w-3 text-text-quaternary" />
|
||||
</div>
|
||||
<div
|
||||
className="truncate text-[13px] text-text-tertiary"
|
||||
title="Configure model"
|
||||
>
|
||||
{t('detailPanel.configureModel', { ns: 'plugin' })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
<RiEqualizer2Line className="h-3.5 w-3.5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelTrigger
|
||||
@@ -7,15 +7,13 @@ import type {
|
||||
} from '../declarations'
|
||||
import { useState } from 'react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@/app/components/base/ui/popover'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useCurrentProviderAndModel } from '../hooks'
|
||||
import DeprecatedModelTrigger from './deprecated-model-trigger'
|
||||
import EmptyTrigger from './empty-trigger'
|
||||
import ModelTrigger from './model-trigger'
|
||||
import ModelSelectorTrigger from './model-selector-trigger'
|
||||
import Popup from './popup'
|
||||
|
||||
type ModelSelectorProps = {
|
||||
@@ -24,6 +22,7 @@ type ModelSelectorProps = {
|
||||
triggerClassName?: string
|
||||
popupClassName?: string
|
||||
onSelect?: (model: DefaultModel) => void
|
||||
onHide?: () => void
|
||||
readonly?: boolean
|
||||
scopeFeatures?: ModelFeatureEnum[]
|
||||
deprecatedClassName?: string
|
||||
@@ -35,10 +34,11 @@ const ModelSelector: FC<ModelSelectorProps> = ({
|
||||
triggerClassName,
|
||||
popupClassName,
|
||||
onSelect,
|
||||
onHide,
|
||||
readonly,
|
||||
scopeFeatures = [],
|
||||
deprecatedClassName,
|
||||
showDeprecatedWarnIcon = false,
|
||||
showDeprecatedWarnIcon = true,
|
||||
}) => {
|
||||
const [open, setOpen] = useState(false)
|
||||
const {
|
||||
@@ -56,67 +56,60 @@ const ModelSelector: FC<ModelSelectorProps> = ({
|
||||
onSelect({ provider, model: model.model })
|
||||
}
|
||||
|
||||
const handleToggle = () => {
|
||||
if (readonly)
|
||||
return
|
||||
|
||||
setOpen(v => !v)
|
||||
}
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
<Popover
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement="bottom-start"
|
||||
offset={4}
|
||||
onOpenChange={(newOpen) => {
|
||||
if (readonly)
|
||||
return
|
||||
setOpen(newOpen)
|
||||
}}
|
||||
>
|
||||
<div className={cn('relative')}>
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={handleToggle}
|
||||
className="block"
|
||||
>
|
||||
{
|
||||
currentModel && currentProvider && (
|
||||
<ModelTrigger
|
||||
open={open}
|
||||
provider={currentProvider}
|
||||
model={currentModel}
|
||||
className={triggerClassName}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!currentModel && defaultModel && (
|
||||
<DeprecatedModelTrigger
|
||||
modelName={defaultModel?.model || ''}
|
||||
providerName={defaultModel?.provider || ''}
|
||||
className={triggerClassName}
|
||||
showWarnIcon={showDeprecatedWarnIcon}
|
||||
contentClassName={deprecatedClassName}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!defaultModel && (
|
||||
<EmptyTrigger
|
||||
open={open}
|
||||
className={triggerClassName}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className={`z-[1002] ${popupClassName}`}>
|
||||
<Popup
|
||||
defaultModel={defaultModel}
|
||||
modelList={modelList}
|
||||
onSelect={handleSelect}
|
||||
scopeFeatures={scopeFeatures}
|
||||
onHide={() => setOpen(false)}
|
||||
/>
|
||||
</PortalToFollowElemContent>
|
||||
</div>
|
||||
</PortalToFollowElem>
|
||||
<PopoverTrigger
|
||||
render={(
|
||||
<button
|
||||
type="button"
|
||||
className="block w-full border-0 bg-transparent p-0 text-left"
|
||||
disabled={readonly}
|
||||
>
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
defaultModel={defaultModel}
|
||||
open={open}
|
||||
readonly={readonly}
|
||||
className={triggerClassName}
|
||||
deprecatedClassName={deprecatedClassName}
|
||||
showDeprecatedWarnIcon={showDeprecatedWarnIcon}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
/>
|
||||
{/*
|
||||
* TODO(overlay-migration): temporary layering hack.
|
||||
* Some callers still render ModelSelector inside legacy high-z modals
|
||||
* (e.g. code/automatic generators at z-[1000]). Keep this selector above
|
||||
* them until those call sites are fully migrated to unified base/ui overlays.
|
||||
*/}
|
||||
<PopoverContent
|
||||
placement="bottom-start"
|
||||
sideOffset={4}
|
||||
className={cn('z-[1002]', popupClassName)}
|
||||
popupClassName="overflow-hidden rounded-lg"
|
||||
popupProps={{ style: { minWidth: '320px', width: 'var(--anchor-width, auto)' } }}
|
||||
>
|
||||
<Popup
|
||||
defaultModel={defaultModel}
|
||||
modelList={modelList}
|
||||
onSelect={handleSelect}
|
||||
scopeFeatures={scopeFeatures}
|
||||
onHide={() => {
|
||||
setOpen(false)
|
||||
onHide?.()
|
||||
}}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
import type { Model, ModelItem } from '../declarations'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelFeatureEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '../declarations'
|
||||
import ModelSelectorTrigger from './model-selector-trigger'
|
||||
|
||||
const mockUseProviderContext = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: mockUseProviderContext,
|
||||
}))
|
||||
|
||||
const createModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
|
||||
model: 'gpt-4',
|
||||
label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' },
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
features: [ModelFeatureEnum.vision],
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status: ModelStatusEnum.active,
|
||||
model_properties: { mode: 'chat', context_size: 4096 },
|
||||
load_balancing_enabled: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
provider: 'openai',
|
||||
icon_small: {
|
||||
en_US: 'https://example.com/openai-light.png',
|
||||
zh_Hans: 'https://example.com/openai-light.png',
|
||||
},
|
||||
icon_small_dark: {
|
||||
en_US: 'https://example.com/openai-dark.png',
|
||||
zh_Hans: 'https://example.com/openai-dark.png',
|
||||
},
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [createModelItem()],
|
||||
status: ModelStatusEnum.active,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('ModelSelectorTrigger', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
modelProviders: [createModel()],
|
||||
})
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render empty state when no model is selected', () => {
|
||||
const { container } = render(<ModelSelectorTrigger />)
|
||||
|
||||
expect(screen.getByText('plugin.detailPanel.configureModel')).toBeInTheDocument()
|
||||
expect(container.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
|
||||
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-normal')
|
||||
})
|
||||
|
||||
it('should render selected model details when model is active', () => {
|
||||
const currentProvider = createModel()
|
||||
const currentModel = createModelItem()
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={currentProvider}
|
||||
currentModel={currentModel}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('GPT-4')).toBeInTheDocument()
|
||||
expect(screen.getByText('CHAT')).toBeInTheDocument()
|
||||
expect(container.querySelector('.i-ri-arrow-down-s-line')).toBeInTheDocument()
|
||||
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-normal')
|
||||
})
|
||||
|
||||
it('should render deprecated default model and disabled style when selection is missing', () => {
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
defaultModel={{ provider: 'openai', model: 'legacy-model' }}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('legacy-model')).toBeInTheDocument()
|
||||
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-disabled')
|
||||
expect(container.querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Props', () => {
|
||||
it('should apply custom className to root element', () => {
|
||||
const { container } = render(<ModelSelectorTrigger className="custom-trigger" />)
|
||||
|
||||
expect(container.firstElementChild).toHaveClass('custom-trigger')
|
||||
})
|
||||
|
||||
it('should apply open background style when open is true and model is active', () => {
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={createModel()}
|
||||
currentModel={createModelItem()}
|
||||
open
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.firstElementChild).toHaveClass('bg-components-input-bg-hover')
|
||||
})
|
||||
|
||||
it('should hide the expand arrow when readonly is true', () => {
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={createModel()}
|
||||
currentModel={createModelItem()}
|
||||
readonly
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Status Handling', () => {
|
||||
it('should show status badge when selected model is not active and not readonly', () => {
|
||||
render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={createModel()}
|
||||
currentModel={createModelItem({ status: ModelStatusEnum.noConfigure })}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('common.modelProvider.selector.configureRequired')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show status badge when selected model is readonly', () => {
|
||||
render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={createModel()}
|
||||
currentModel={createModelItem({ status: ModelStatusEnum.noConfigure })}
|
||||
readonly
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('common.modelProvider.selector.configureRequired')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show incompatible tooltip when hovering no-permission status badge', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(
|
||||
<ModelSelectorTrigger
|
||||
currentProvider={createModel()}
|
||||
currentModel={createModelItem({ status: ModelStatusEnum.noPermission })}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.hover(screen.getByText('common.modelProvider.selector.incompatible'))
|
||||
|
||||
expect(await screen.findByText('common.modelProvider.selector.incompatibleTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should show deprecated tooltip when hovering warn icon', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
defaultModel={{ provider: 'openai', model: 'legacy-model' }}
|
||||
/>,
|
||||
)
|
||||
|
||||
const warnIcon = container.querySelector('.i-ri-alert-line')
|
||||
expect(warnIcon).toBeInTheDocument()
|
||||
|
||||
await user.hover(warnIcon as HTMLElement)
|
||||
|
||||
expect(await screen.findByText('common.modelProvider.deprecated')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render fallback icon when deprecated provider is not found', () => {
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
modelProviders: [],
|
||||
})
|
||||
const { container } = render(
|
||||
<ModelSelectorTrigger
|
||||
defaultModel={{ provider: 'unknown-provider', model: 'legacy-model' }}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('img[alt="model-icon"]')).not.toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,146 @@
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
DefaultModel,
|
||||
Model,
|
||||
ModelItem,
|
||||
} from '../declarations'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { ModelStatusEnum } from '../declarations'
|
||||
import ModelIcon from '../model-icon'
|
||||
import ModelName from '../model-name'
|
||||
|
||||
const STATUS_I18N_KEY: Partial<Record<ModelStatusEnum, string>> = {
|
||||
[ModelStatusEnum.quotaExceeded]: 'modelProvider.selector.creditsExhausted',
|
||||
[ModelStatusEnum.noConfigure]: 'modelProvider.selector.configureRequired',
|
||||
[ModelStatusEnum.noPermission]: 'modelProvider.selector.incompatible',
|
||||
[ModelStatusEnum.disabled]: 'modelProvider.selector.disabled',
|
||||
[ModelStatusEnum.credentialRemoved]: 'modelProvider.selector.apiKeyUnavailable',
|
||||
}
|
||||
|
||||
type ModelSelectorTriggerProps = {
|
||||
currentProvider?: Model
|
||||
currentModel?: ModelItem
|
||||
defaultModel?: DefaultModel
|
||||
open?: boolean
|
||||
readonly?: boolean
|
||||
className?: string
|
||||
deprecatedClassName?: string
|
||||
showDeprecatedWarnIcon?: boolean
|
||||
}
|
||||
|
||||
const ModelSelectorTrigger: FC<ModelSelectorTriggerProps> = ({
|
||||
currentProvider,
|
||||
currentModel,
|
||||
defaultModel,
|
||||
open,
|
||||
readonly,
|
||||
className,
|
||||
deprecatedClassName,
|
||||
showDeprecatedWarnIcon = true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { modelProviders } = useProviderContext()
|
||||
|
||||
const isSelected = !!currentProvider && !!currentModel
|
||||
const isDeprecated = !isSelected && !!defaultModel
|
||||
const isEmpty = !isSelected && !defaultModel
|
||||
|
||||
const isActive = isSelected && currentModel.status === ModelStatusEnum.active
|
||||
const isDisabled = isDeprecated || (isSelected && !isActive)
|
||||
const statusI18nKey = isSelected ? STATUS_I18N_KEY[currentModel.status] : undefined
|
||||
|
||||
const deprecatedProvider = isDeprecated
|
||||
? modelProviders.find(p => p.provider === defaultModel.provider)
|
||||
: undefined
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'group flex h-8 items-center gap-0.5 rounded-lg p-1',
|
||||
isDisabled
|
||||
? 'bg-components-input-bg-disabled'
|
||||
: 'bg-components-input-bg-normal',
|
||||
!readonly && !isDisabled && 'cursor-pointer hover:bg-components-input-bg-hover',
|
||||
open && !isDisabled && 'bg-components-input-bg-hover',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{isEmpty
|
||||
? (
|
||||
<div className="flex h-6 w-6 items-center justify-center">
|
||||
<div className="flex h-5 w-5 items-center justify-center rounded-md border-[0.5px] border-components-panel-border-subtle bg-background-default-subtle">
|
||||
<span className="i-ri-brain-2-line h-3.5 w-3.5 text-text-quaternary" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
: (
|
||||
<ModelIcon
|
||||
className="p-0.5"
|
||||
provider={isSelected ? currentProvider : deprecatedProvider}
|
||||
modelName={isSelected ? currentModel.model : defaultModel?.model}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className={cn('flex grow items-center gap-1 truncate px-1 py-[3px]', isDeprecated && deprecatedClassName)}>
|
||||
{isSelected && (
|
||||
<ModelName
|
||||
className="grow"
|
||||
modelItem={currentModel}
|
||||
showMode
|
||||
showFeatures
|
||||
/>
|
||||
)}
|
||||
{isDeprecated && (
|
||||
<div className="grow truncate text-components-input-text-filled system-sm-regular">
|
||||
{defaultModel.model}
|
||||
</div>
|
||||
)}
|
||||
{isEmpty && (
|
||||
<div className="grow truncate text-[13px] text-text-quaternary">
|
||||
{t('detailPanel.configureModel', { ns: 'plugin' })}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isSelected && !readonly && !isActive && statusI18nKey && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger
|
||||
disabled={currentModel.status !== ModelStatusEnum.noPermission}
|
||||
render={(
|
||||
<div className="flex shrink-0 items-center gap-[3px] rounded-md border border-text-warning px-[5px] py-0.5">
|
||||
<span className="i-ri-alert-fill h-3 w-3 text-text-warning" />
|
||||
<span className="whitespace-nowrap text-text-warning system-xs-medium">
|
||||
{t(statusI18nKey as 'modelProvider.selector.creditsExhausted', { ns: 'common' })}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<TooltipContent placement="top">
|
||||
{t('modelProvider.selector.incompatibleTip', { ns: 'common' })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{isDeprecated && showDeprecatedWarnIcon && (
|
||||
<Tooltip>
|
||||
<TooltipTrigger render={(
|
||||
<span className="i-ri-alert-line h-4 w-4 shrink-0 text-text-warning-secondary" />
|
||||
)}
|
||||
/>
|
||||
<TooltipContent placement="top">
|
||||
{t('modelProvider.deprecated', { ns: 'common' })}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
)}
|
||||
|
||||
{!readonly && (isActive || isEmpty) && (
|
||||
<span className="i-ri-arrow-down-s-line h-3.5 w-3.5 shrink-0 text-text-tertiary" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelSelectorTrigger
|
||||
@@ -1,91 +0,0 @@
|
||||
import type { Model, ModelItem } from '../declarations'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import {
|
||||
ConfigurationMethodEnum,
|
||||
ModelStatusEnum,
|
||||
ModelTypeEnum,
|
||||
} from '../declarations'
|
||||
import ModelTrigger from './model-trigger'
|
||||
|
||||
vi.mock('../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
|
||||
return {
|
||||
...actual,
|
||||
useLanguage: () => 'en_US',
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../model-icon', () => ({
|
||||
default: ({ modelName }: { modelName: string }) => <span>{modelName}</span>,
|
||||
}))
|
||||
|
||||
vi.mock('../model-name', () => ({
|
||||
default: ({ modelItem }: { modelItem: ModelItem }) => <span>{modelItem.label.en_US}</span>,
|
||||
}))
|
||||
|
||||
const makeModelItem = (overrides: Partial<ModelItem> = {}): ModelItem => ({
|
||||
model: 'gpt-4',
|
||||
label: { en_US: 'GPT-4', zh_Hans: 'GPT-4' },
|
||||
model_type: ModelTypeEnum.textGeneration,
|
||||
fetch_from: ConfigurationMethodEnum.predefinedModel,
|
||||
status: ModelStatusEnum.active,
|
||||
model_properties: {},
|
||||
load_balancing_enabled: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const makeModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
provider: 'openai',
|
||||
icon_small: { en_US: '', zh_Hans: '' },
|
||||
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
|
||||
models: [makeModelItem()],
|
||||
status: ModelStatusEnum.active,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('ModelTrigger', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show model name', () => {
|
||||
render(
|
||||
<ModelTrigger
|
||||
open
|
||||
provider={makeModel()}
|
||||
model={makeModelItem()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('GPT-4')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show status tooltip content when model is not active', async () => {
|
||||
const { container } = render(
|
||||
<ModelTrigger
|
||||
open={false}
|
||||
provider={makeModel()}
|
||||
model={makeModelItem({ status: ModelStatusEnum.noConfigure })}
|
||||
/>,
|
||||
)
|
||||
|
||||
const tooltipTrigger = container.querySelector('[data-state]') as HTMLElement
|
||||
fireEvent.mouseEnter(tooltipTrigger)
|
||||
|
||||
expect(await screen.findByText('No Configure')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show status icon when readonly', () => {
|
||||
render(
|
||||
<ModelTrigger
|
||||
open={false}
|
||||
provider={makeModel()}
|
||||
model={makeModelItem({ status: ModelStatusEnum.noConfigure })}
|
||||
readonly
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('GPT-4')).toBeInTheDocument()
|
||||
expect(screen.queryByText('No Configure')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,78 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
Model,
|
||||
ModelItem,
|
||||
} from '../declarations'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import {
|
||||
MODEL_STATUS_TEXT,
|
||||
ModelStatusEnum,
|
||||
} from '../declarations'
|
||||
import { useLanguage } from '../hooks'
|
||||
import ModelIcon from '../model-icon'
|
||||
import ModelName from '../model-name'
|
||||
|
||||
type ModelTriggerProps = {
|
||||
open: boolean
|
||||
provider: Model
|
||||
model: ModelItem
|
||||
className?: string
|
||||
readonly?: boolean
|
||||
}
|
||||
const ModelTrigger: FC<ModelTriggerProps> = ({
|
||||
open,
|
||||
provider,
|
||||
model,
|
||||
className,
|
||||
readonly,
|
||||
}) => {
|
||||
const language = useLanguage()
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'group flex h-8 items-center gap-0.5 rounded-lg bg-components-input-bg-normal p-1',
|
||||
!readonly && 'cursor-pointer hover:bg-components-input-bg-hover',
|
||||
open && 'bg-components-input-bg-hover',
|
||||
model.status !== ModelStatusEnum.active && 'bg-components-input-bg-disabled hover:bg-components-input-bg-disabled',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<ModelIcon
|
||||
className="p-0.5"
|
||||
provider={provider}
|
||||
modelName={model.model}
|
||||
/>
|
||||
<div className="flex grow items-center gap-1 truncate px-1 py-[3px]">
|
||||
<ModelName
|
||||
className="grow"
|
||||
modelItem={model}
|
||||
showMode
|
||||
showFeatures
|
||||
/>
|
||||
{!readonly && (
|
||||
<div className="flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
{
|
||||
model.status !== ModelStatusEnum.active
|
||||
? (
|
||||
<Tooltip popupContent={MODEL_STATUS_TEXT[model.status][language]}>
|
||||
<AlertTriangle className="h-4 w-4 text-text-warning-secondary" />
|
||||
</Tooltip>
|
||||
)
|
||||
: (
|
||||
<RiArrowDownSLine
|
||||
className="h-3.5 w-3.5 text-text-tertiary"
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelTrigger
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user