mirror of
https://github.com/langgenius/dify.git
synced 2026-01-04 13:37:22 +00:00
Compare commits
2 Commits
release/e-
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c9304d7ef | ||
|
|
df09acb74b |
@@ -4,10 +4,11 @@ from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from constants import UUID_NIL
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@@ -33,8 +34,11 @@ from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,14 +57,18 @@ class ChatRequestPayload(BaseModel):
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: str | None = Field(default=None, description="Conversation UUID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message UUID")
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@field_validator("conversation_id", "parent_message_id", mode="before")
|
||||
@classmethod
|
||||
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
|
||||
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
|
||||
def normalize_uuid_fields(cls, value: str | UUID | None, info: ValidationInfo) -> str | None:
|
||||
"""Allow missing or blank UUID fields; enforce UUID format when provided."""
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
|
||||
@@ -70,7 +78,36 @@ class ChatRequestPayload(BaseModel):
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
raise ValueError(f"{info.field_name} must be a valid UUID") from exc
|
||||
|
||||
|
||||
def _validate_parent_message_request(
|
||||
*,
|
||||
app_model: App,
|
||||
end_user: EndUser,
|
||||
conversation_id: str | None,
|
||||
parent_message_id: str | None,
|
||||
) -> None:
|
||||
if not parent_message_id or parent_message_id == UUID_NIL:
|
||||
return
|
||||
|
||||
if not conversation_id:
|
||||
raise BadRequest("conversation_id is required when parent_message_id is provided.")
|
||||
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=end_user
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
try:
|
||||
parent_message = MessageService.get_message(app_model=app_model, user=end_user, message_id=parent_message_id)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
if parent_message.conversation_id != conversation.id:
|
||||
raise BadRequest("parent_message_id does not belong to the conversation.")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
@@ -205,6 +242,13 @@ class ChatApi(Resource):
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id=args.get("conversation_id"),
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@@ -127,6 +126,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
self._validate_parent_message_for_service_api(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
@@ -168,7 +175,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -9,7 +9,6 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
@@ -105,6 +104,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
self._validate_parent_message_for_service_api(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
@@ -163,7 +170,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -8,7 +8,6 @@ from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -97,6 +96,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
self._validate_parent_message_for_service_api(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation=conversation,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
invoke_from=invoke_from,
|
||||
)
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
@@ -156,7 +163,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Union, cast
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@@ -34,6 +35,7 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -84,6 +86,34 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
def _resolve_parent_message_id(self, args: Mapping[str, Any], invoke_from: InvokeFrom) -> str | None:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if invoke_from == InvokeFrom.SERVICE_API and not parent_message_id:
|
||||
return UUID_NIL
|
||||
return parent_message_id
|
||||
|
||||
def _validate_parent_message_for_service_api(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
conversation: Conversation | None,
|
||||
parent_message_id: str | None,
|
||||
invoke_from: InvokeFrom,
|
||||
) -> None:
|
||||
if invoke_from != InvokeFrom.SERVICE_API:
|
||||
return
|
||||
|
||||
if not parent_message_id or parent_message_id == UUID_NIL:
|
||||
return
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
||||
parent_message = MessageService.get_message(app_model=app_model, user=user, message_id=parent_message_id)
|
||||
if parent_message.conversation_id != conversation.id:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
stmt = select(AppModelConfig).where(
|
||||
|
||||
@@ -2,9 +2,8 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file import File, FileUploadConfig
|
||||
@@ -158,20 +157,12 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
"For service API, we need to ensure its forward compatibility, "
|
||||
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration "
|
||||
"and branching in chat APIs."
|
||||
"For service API, when it is omitted, the system treats it as UUID_NIL to preserve legacy linear history."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("parent_message_id")
|
||||
@classmethod
|
||||
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||
return v
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.1"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -3458,7 +3458,7 @@ class SegmentService:
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
|
||||
|
||||
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
|
||||
query = query.order_by(DocumentSegment.position.asc())
|
||||
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
return paginated_segments.items, paginated_segments.total
|
||||
|
||||
@@ -110,5 +110,5 @@ class EnterpriseService:
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
body = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||
|
||||
@@ -286,12 +286,12 @@ class BuiltinToolManageService:
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from constants import UUID_NIL
|
||||
from controllers.service_api.app.completion import _validate_parent_message_request
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def test_validate_parent_message_skips_when_missing():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
|
||||
):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id=None
|
||||
)
|
||||
|
||||
get_conversation.assert_not_called()
|
||||
get_message.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parent_message_skips_uuid_nil():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
|
||||
):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id=None,
|
||||
parent_message_id=UUID_NIL,
|
||||
)
|
||||
|
||||
get_conversation.assert_not_called()
|
||||
get_message.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parent_message_requires_conversation_id():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id="parent-id"
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_missing_conversation_raises_not_found():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
|
||||
with patch(
|
||||
"controllers.service_api.app.completion.ConversationService.get_conversation",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_missing_message_raises_not_found():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch(
|
||||
"controllers.service_api.app.completion.MessageService.get_message",
|
||||
side_effect=MessageNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_mismatch_conversation_raises_bad_request():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
message = SimpleNamespace(conversation_id="different-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_matches_conversation():
|
||||
app_model = object()
|
||||
end_user = object()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
message = SimpleNamespace(conversation_id="conversation-id")
|
||||
|
||||
with (
|
||||
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
|
||||
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
|
||||
):
|
||||
_validate_parent_message_request(
|
||||
app_model=app_model,
|
||||
end_user=end_user,
|
||||
conversation_id="conversation-id",
|
||||
parent_message_id="parent-id",
|
||||
)
|
||||
@@ -23,3 +23,24 @@ def test_chat_request_payload_validates_uuid():
|
||||
def test_chat_request_payload_rejects_invalid_uuid():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
|
||||
|
||||
|
||||
def test_chat_request_payload_accepts_blank_parent_message_id():
|
||||
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": ""})
|
||||
|
||||
assert payload.parent_message_id is None
|
||||
|
||||
|
||||
def test_chat_request_payload_validates_parent_message_id_uuid():
|
||||
parent_message_id = str(uuid.uuid4())
|
||||
|
||||
payload = ChatRequestPayload.model_validate(
|
||||
{"inputs": {}, "query": "hello", "parent_message_id": parent_message_id}
|
||||
)
|
||||
|
||||
assert payload.parent_message_id == parent_message_id
|
||||
|
||||
|
||||
def test_chat_request_payload_rejects_invalid_parent_message_id():
|
||||
with pytest.raises(ValidationError):
|
||||
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": "invalid"})
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def test_validate_parent_message_service_api_skips_missing():
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
|
||||
generator._validate_parent_message_for_service_api(
|
||||
app_model=object(),
|
||||
user=object(),
|
||||
conversation=None,
|
||||
parent_message_id=None,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
get_message.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parent_message_service_api_skips_uuid_nil():
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
|
||||
generator._validate_parent_message_for_service_api(
|
||||
app_model=object(),
|
||||
user=object(),
|
||||
conversation=None,
|
||||
parent_message_id=UUID_NIL,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
get_message.assert_not_called()
|
||||
|
||||
|
||||
def test_validate_parent_message_service_api_requires_conversation():
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
with pytest.raises(ConversationNotExistsError):
|
||||
generator._validate_parent_message_for_service_api(
|
||||
app_model=object(),
|
||||
user=object(),
|
||||
conversation=None,
|
||||
parent_message_id="parent-id",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_service_api_mismatch_conversation():
|
||||
generator = MessageBasedAppGenerator()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
parent_message = SimpleNamespace(conversation_id="different-id")
|
||||
|
||||
with patch(
|
||||
"core.app.apps.message_based_app_generator.MessageService.get_message",
|
||||
return_value=parent_message,
|
||||
):
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
generator._validate_parent_message_for_service_api(
|
||||
app_model=object(),
|
||||
user=object(),
|
||||
conversation=conversation,
|
||||
parent_message_id="parent-id",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_parent_message_service_api_matches_conversation():
|
||||
generator = MessageBasedAppGenerator()
|
||||
conversation = SimpleNamespace(id="conversation-id")
|
||||
parent_message = SimpleNamespace(conversation_id="conversation-id")
|
||||
|
||||
with patch(
|
||||
"core.app.apps.message_based_app_generator.MessageService.get_message",
|
||||
return_value=parent_message,
|
||||
):
|
||||
generator._validate_parent_message_for_service_api(
|
||||
app_model=object(),
|
||||
user=object(),
|
||||
conversation=conversation,
|
||||
parent_message_id="parent-id",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
@@ -1,472 +0,0 @@
|
||||
"""
|
||||
Unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests the retrieval of document segments with pagination and filtering:
|
||||
- Basic pagination (page, limit)
|
||||
- Status filtering
|
||||
- Keyword search
|
||||
- Ordering by position and id (to avoid duplicate data)
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
|
||||
class SegmentServiceTestDataFactory:
|
||||
"""
|
||||
Factory class for creating test data and mock objects for segment tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_segment_mock(
|
||||
segment_id: str = "segment-123",
|
||||
document_id: str = "doc-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
dataset_id: str = "dataset-123",
|
||||
position: int = 1,
|
||||
content: str = "Test content",
|
||||
status: str = "completed",
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock document segment.
|
||||
|
||||
Args:
|
||||
segment_id: Unique identifier for the segment
|
||||
document_id: Parent document ID
|
||||
tenant_id: Tenant ID the segment belongs to
|
||||
dataset_id: Parent dataset ID
|
||||
position: Position within the document
|
||||
content: Segment text content
|
||||
status: Indexing status
|
||||
**kwargs: Additional attributes
|
||||
|
||||
Returns:
|
||||
Mock: DocumentSegment mock object
|
||||
"""
|
||||
segment = create_autospec(DocumentSegment, instance=True)
|
||||
segment.id = segment_id
|
||||
segment.document_id = document_id
|
||||
segment.tenant_id = tenant_id
|
||||
segment.dataset_id = dataset_id
|
||||
segment.position = position
|
||||
segment.content = content
|
||||
segment.status = status
|
||||
for key, value in kwargs.items():
|
||||
setattr(segment, key, value)
|
||||
return segment
|
||||
|
||||
|
||||
class TestSegmentServiceGetSegments:
|
||||
"""
|
||||
Comprehensive unit tests for SegmentService.get_segments method.
|
||||
|
||||
Tests cover:
|
||||
- Basic pagination functionality
|
||||
- Status list filtering
|
||||
- Keyword search filtering
|
||||
- Ordering (position + id for uniqueness)
|
||||
- Empty results
|
||||
- Combined filters
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_segment_service_dependencies(self):
|
||||
"""
|
||||
Common mock setup for segment service dependencies.
|
||||
|
||||
Patches:
|
||||
- db: Database operations and pagination
|
||||
- select: SQLAlchemy query builder
|
||||
"""
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.select") as mock_select,
|
||||
):
|
||||
yield {
|
||||
"db": mock_db,
|
||||
"select": mock_select,
|
||||
}
|
||||
|
||||
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test basic pagination functionality.
|
||||
|
||||
Verifies:
|
||||
- Query is built with document_id and tenant_id filters
|
||||
- Pagination uses correct page and limit parameters
|
||||
- Returns segments and total count
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
page = 1
|
||||
limit = 20
|
||||
|
||||
# Create mock segments
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="First segment"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=2, content="Second segment"
|
||||
)
|
||||
|
||||
# Mock pagination result
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
# Mock select builder
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
assert items[0].id == "seg-1"
|
||||
assert items[1].id == "seg-2"
|
||||
mock_segment_service_dependencies["db"].paginate.assert_called_once()
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
assert call_kwargs["error_out"] is False
|
||||
|
||||
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test filtering by status list.
|
||||
|
||||
Verifies:
|
||||
- Status list filter is applied to query
|
||||
- Only segments with matching status are returned
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed", "indexing"]
|
||||
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2]
|
||||
mock_paginated.total = 2
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
# Verify where was called multiple times (base filters + status filter)
|
||||
assert mock_query.where.call_count >= 2
|
||||
|
||||
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with empty status list.
|
||||
|
||||
Verifies:
|
||||
- Empty status list is handled correctly
|
||||
- No status filter is applied to avoid WHERE false condition
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = []
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id, tenant_id=tenant_id, status_list=status_list
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test keyword search functionality.
|
||||
|
||||
Verifies:
|
||||
- Keyword filter uses ilike for case-insensitive search
|
||||
- Search pattern includes wildcards (%keyword%)
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
keyword = "search term"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", content="This contains search term"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify where was called for base filters + keyword filter
|
||||
assert mock_query.where.call_count == 2
|
||||
|
||||
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test ordering by position and id.
|
||||
|
||||
Verifies:
|
||||
- Results are ordered by position ASC
|
||||
- Results are secondarily ordered by id ASC to ensure uniqueness
|
||||
- This prevents duplicate data across pages when positions are not unique
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
# Create segments with same position but different ids
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1", position=1, content="Content 1"
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-2", position=1, content="Content 2"
|
||||
)
|
||||
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-3", position=2, content="Content 3"
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment1, segment2, segment3]
|
||||
mock_paginated.total = 3
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 3
|
||||
assert total == 3
|
||||
mock_query.order_by.assert_called_once()
|
||||
|
||||
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test when no segments match the criteria.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned for items
|
||||
- Total count is 0
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "non-existent-doc"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert items == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with multiple filters combined.
|
||||
|
||||
Verifies:
|
||||
- All filters work together correctly
|
||||
- Status list and keyword search both applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
status_list = ["completed"]
|
||||
keyword = "important"
|
||||
page = 2
|
||||
limit = 10
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(
|
||||
segment_id="seg-1",
|
||||
status="completed",
|
||||
content="This is important information",
|
||||
)
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=status_list,
|
||||
keyword=keyword,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Verify filters: base + status + keyword
|
||||
assert mock_query.where.call_count == 3
|
||||
# Verify pagination parameters
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["page"] == page
|
||||
assert call_kwargs["per_page"] == limit
|
||||
|
||||
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test with None status list.
|
||||
|
||||
Verifies:
|
||||
- None status list is handled correctly
|
||||
- No status filter is applied
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
|
||||
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = [segment]
|
||||
mock_paginated.total = 1
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
items, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
status_list=None,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(items) == 1
|
||||
assert total == 1
|
||||
# Should only be called once (base filters only, no status filter)
|
||||
assert mock_query.where.call_count == 1
|
||||
|
||||
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
|
||||
"""
|
||||
Test that max_per_page is correctly set to 100.
|
||||
|
||||
Verifies:
|
||||
- max_per_page parameter is set to 100
|
||||
- This prevents excessive page sizes
|
||||
"""
|
||||
# Arrange
|
||||
document_id = "doc-123"
|
||||
tenant_id = "tenant-123"
|
||||
limit = 200 # Request more than max_per_page
|
||||
|
||||
mock_paginated = Mock()
|
||||
mock_paginated.items = []
|
||||
mock_paginated.total = 0
|
||||
|
||||
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
|
||||
|
||||
mock_query = Mock()
|
||||
mock_segment_service_dependencies["select"].return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
|
||||
# Act
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=tenant_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Assert
|
||||
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
|
||||
assert call_kwargs["max_per_page"] == 100
|
||||
2
api/uv.lock
generated
2
api/uv.lock
generated
@@ -1368,7 +1368,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.11.2"
|
||||
version = "1.11.1"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -692,7 +692,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -734,7 +734,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -773,7 +773,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.11.2
|
||||
image: langgenius/dify-api:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -803,7 +803,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.11.2
|
||||
image: langgenius/dify-web:1.11.1
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Billing from './index'
|
||||
|
||||
let currentBillingUrl: string | null = 'https://billing'
|
||||
let fetching = false
|
||||
let isManager = true
|
||||
let enableBilling = true
|
||||
|
||||
const refetchMock = vi.fn()
|
||||
const openAsyncWindowMock = vi.fn()
|
||||
|
||||
vi.mock('@/service/use-billing', () => ({
|
||||
useBillingUrl: () => ({
|
||||
data: currentBillingUrl,
|
||||
isFetching: fetching,
|
||||
refetch: refetchMock,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => openAsyncWindowMock,
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceManager: isManager,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
enableBilling,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../plan', () => ({
|
||||
__esModule: true,
|
||||
default: ({ loc }: { loc: string }) => <div data-testid="plan-component" data-loc={loc} />,
|
||||
}))
|
||||
|
||||
describe('Billing', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
currentBillingUrl = 'https://billing'
|
||||
fetching = false
|
||||
isManager = true
|
||||
enableBilling = true
|
||||
refetchMock.mockResolvedValue({ data: 'https://billing' })
|
||||
})
|
||||
|
||||
it('hides the billing action when user is not manager or billing is disabled', () => {
|
||||
isManager = false
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
|
||||
vi.clearAllMocks()
|
||||
isManager = true
|
||||
enableBilling = false
|
||||
render(<Billing />)
|
||||
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens the billing window with the immediate url when the button is clicked', async () => {
|
||||
render(<Billing />)
|
||||
|
||||
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
|
||||
const [, options] = openAsyncWindowMock.mock.calls[0]
|
||||
expect(options).toMatchObject({
|
||||
immediateUrl: currentBillingUrl,
|
||||
features: 'noopener,noreferrer',
|
||||
})
|
||||
})
|
||||
|
||||
it('disables the button while billing url is fetching', () => {
|
||||
fetching = true
|
||||
render(<Billing />)
|
||||
|
||||
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
|
||||
expect(actionButton).toBeDisabled()
|
||||
})
|
||||
})
|
||||
@@ -1,92 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { Plan } from '../type'
|
||||
import HeaderBillingBtn from './index'
|
||||
|
||||
type HeaderGlobal = typeof globalThis & {
|
||||
__mockProviderContext?: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
function getHeaderGlobal(): HeaderGlobal {
|
||||
return globalThis as HeaderGlobal
|
||||
}
|
||||
|
||||
const ensureProviderContextMock = () => {
|
||||
const globals = getHeaderGlobal()
|
||||
if (!globals.__mockProviderContext)
|
||||
throw new Error('Provider context mock not set')
|
||||
return globals.__mockProviderContext
|
||||
}
|
||||
|
||||
vi.mock('@/context/provider-context', () => {
|
||||
const mock = vi.fn()
|
||||
const globals = getHeaderGlobal()
|
||||
globals.__mockProviderContext = mock
|
||||
return {
|
||||
useProviderContext: () => mock(),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../upgrade-btn', () => ({
|
||||
__esModule: true,
|
||||
default: () => <button data-testid="upgrade-btn" type="button">Upgrade</button>,
|
||||
}))
|
||||
|
||||
describe('HeaderBillingBtn', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
ensureProviderContextMock().mockReturnValue({
|
||||
plan: {
|
||||
type: Plan.professional,
|
||||
},
|
||||
enableBilling: true,
|
||||
isFetchedPlan: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders nothing when billing is disabled or plan is not fetched', () => {
|
||||
ensureProviderContextMock().mockReturnValueOnce({
|
||||
plan: {
|
||||
type: Plan.professional,
|
||||
},
|
||||
enableBilling: false,
|
||||
isFetchedPlan: true,
|
||||
})
|
||||
|
||||
const { container } = render(<HeaderBillingBtn />)
|
||||
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('renders upgrade button for sandbox plan', () => {
|
||||
ensureProviderContextMock().mockReturnValueOnce({
|
||||
plan: {
|
||||
type: Plan.sandbox,
|
||||
},
|
||||
enableBilling: true,
|
||||
isFetchedPlan: true,
|
||||
})
|
||||
|
||||
render(<HeaderBillingBtn />)
|
||||
|
||||
expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders plan badge and forwards clicks when not display-only', () => {
|
||||
const onClick = vi.fn()
|
||||
|
||||
const { rerender } = render(<HeaderBillingBtn onClick={onClick} />)
|
||||
|
||||
const badge = screen.getByText('pro').closest('div')
|
||||
|
||||
expect(badge).toHaveClass('cursor-pointer')
|
||||
|
||||
fireEvent.click(badge!)
|
||||
expect(onClick).toHaveBeenCalledTimes(1)
|
||||
|
||||
rerender(<HeaderBillingBtn onClick={onClick} isDisplayOnly />)
|
||||
expect(screen.getByText('pro').closest('div')).toHaveClass('cursor-default')
|
||||
|
||||
fireEvent.click(screen.getByText('pro').closest('div')!)
|
||||
expect(onClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,44 +0,0 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import PartnerStack from './index'
|
||||
|
||||
let isCloudEdition = true
|
||||
|
||||
const saveOrUpdate = vi.fn()
|
||||
const bind = vi.fn()
|
||||
|
||||
vi.mock('@/config', () => ({
|
||||
get IS_CLOUD_EDITION() {
|
||||
return isCloudEdition
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('./use-ps-info', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
saveOrUpdate,
|
||||
bind,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('PartnerStack', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isCloudEdition = true
|
||||
})
|
||||
|
||||
it('does not call partner stack helpers when not in cloud edition', () => {
|
||||
isCloudEdition = false
|
||||
|
||||
render(<PartnerStack />)
|
||||
|
||||
expect(saveOrUpdate).not.toHaveBeenCalled()
|
||||
expect(bind).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('calls saveOrUpdate and bind once when running in cloud edition', () => {
|
||||
render(<PartnerStack />)
|
||||
|
||||
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
|
||||
expect(bind).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,197 +0,0 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { PARTNER_STACK_CONFIG } from '@/config'
|
||||
import usePSInfo from './use-ps-info'
|
||||
|
||||
let searchParamsValues: Record<string, string | null> = {}
|
||||
const setSearchParams = (values: Record<string, string | null>) => {
|
||||
searchParamsValues = values
|
||||
}
|
||||
|
||||
type PartnerStackGlobal = typeof globalThis & {
|
||||
__partnerStackCookieMocks?: {
|
||||
get: ReturnType<typeof vi.fn>
|
||||
set: ReturnType<typeof vi.fn>
|
||||
remove: ReturnType<typeof vi.fn>
|
||||
}
|
||||
__partnerStackMutateAsync?: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
function getPartnerStackGlobal(): PartnerStackGlobal {
|
||||
return globalThis as PartnerStackGlobal
|
||||
}
|
||||
|
||||
const ensureCookieMocks = () => {
|
||||
const globals = getPartnerStackGlobal()
|
||||
if (!globals.__partnerStackCookieMocks)
|
||||
throw new Error('Cookie mocks not initialized')
|
||||
return globals.__partnerStackCookieMocks
|
||||
}
|
||||
|
||||
const ensureMutateAsync = () => {
|
||||
const globals = getPartnerStackGlobal()
|
||||
if (!globals.__partnerStackMutateAsync)
|
||||
throw new Error('Mutate mock not initialized')
|
||||
return globals.__partnerStackMutateAsync
|
||||
}
|
||||
|
||||
vi.mock('js-cookie', () => {
|
||||
const get = vi.fn()
|
||||
const set = vi.fn()
|
||||
const remove = vi.fn()
|
||||
const globals = getPartnerStackGlobal()
|
||||
globals.__partnerStackCookieMocks = { get, set, remove }
|
||||
const cookieApi = { get, set, remove }
|
||||
return {
|
||||
__esModule: true,
|
||||
default: cookieApi,
|
||||
get,
|
||||
set,
|
||||
remove,
|
||||
}
|
||||
})
|
||||
vi.mock('next/navigation', () => ({
|
||||
useSearchParams: () => ({
|
||||
get: (key: string) => searchParamsValues[key] ?? null,
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/service/use-billing', () => {
|
||||
const mutateAsync = vi.fn()
|
||||
const globals = getPartnerStackGlobal()
|
||||
globals.__partnerStackMutateAsync = mutateAsync
|
||||
return {
|
||||
useBindPartnerStackInfo: () => ({
|
||||
mutateAsync,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
describe('usePSInfo', () => {
|
||||
const originalLocationDescriptor = Object.getOwnPropertyDescriptor(globalThis, 'location')
|
||||
|
||||
beforeAll(() => {
|
||||
Object.defineProperty(globalThis, 'location', {
|
||||
value: { hostname: 'cloud.dify.ai' },
|
||||
configurable: true,
|
||||
})
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
setSearchParams({})
|
||||
const { get, set, remove } = ensureCookieMocks()
|
||||
get.mockReset()
|
||||
set.mockReset()
|
||||
remove.mockReset()
|
||||
const mutate = ensureMutateAsync()
|
||||
mutate.mockReset()
|
||||
mutate.mockResolvedValue(undefined)
|
||||
get.mockReturnValue('{}')
|
||||
})
|
||||
|
||||
afterAll(() => {
|
||||
if (originalLocationDescriptor)
|
||||
Object.defineProperty(globalThis, 'location', originalLocationDescriptor)
|
||||
})
|
||||
|
||||
it('saves partner info when query params change', () => {
|
||||
const { get, set } = ensureCookieMocks()
|
||||
get.mockReturnValue(JSON.stringify({ partnerKey: 'old', clickId: 'old-click' }))
|
||||
setSearchParams({
|
||||
ps_partner_key: 'new-partner',
|
||||
ps_xid: 'new-click',
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => usePSInfo())
|
||||
|
||||
expect(result.current.psPartnerKey).toBe('new-partner')
|
||||
expect(result.current.psClickId).toBe('new-click')
|
||||
|
||||
act(() => {
|
||||
result.current.saveOrUpdate()
|
||||
})
|
||||
|
||||
expect(set).toHaveBeenCalledWith(
|
||||
PARTNER_STACK_CONFIG.cookieName,
|
||||
JSON.stringify({
|
||||
partnerKey: 'new-partner',
|
||||
clickId: 'new-click',
|
||||
}),
|
||||
{
|
||||
expires: PARTNER_STACK_CONFIG.saveCookieDays,
|
||||
path: '/',
|
||||
domain: '.dify.ai',
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
it('does not overwrite cookie when params do not change', () => {
|
||||
setSearchParams({
|
||||
ps_partner_key: 'existing',
|
||||
ps_xid: 'existing-click',
|
||||
})
|
||||
const { get } = ensureCookieMocks()
|
||||
get.mockReturnValue(JSON.stringify({
|
||||
partnerKey: 'existing',
|
||||
clickId: 'existing-click',
|
||||
}))
|
||||
|
||||
const { result } = renderHook(() => usePSInfo())
|
||||
|
||||
act(() => {
|
||||
result.current.saveOrUpdate()
|
||||
})
|
||||
|
||||
const { set } = ensureCookieMocks()
|
||||
expect(set).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('binds partner info and clears cookie once', async () => {
|
||||
setSearchParams({
|
||||
ps_partner_key: 'bind-partner',
|
||||
ps_xid: 'bind-click',
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => usePSInfo())
|
||||
|
||||
const mutate = ensureMutateAsync()
|
||||
const { remove } = ensureCookieMocks()
|
||||
await act(async () => {
|
||||
await result.current.bind()
|
||||
})
|
||||
|
||||
expect(mutate).toHaveBeenCalledWith({
|
||||
partnerKey: 'bind-partner',
|
||||
clickId: 'bind-click',
|
||||
})
|
||||
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
|
||||
path: '/',
|
||||
domain: '.dify.ai',
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.bind()
|
||||
})
|
||||
|
||||
expect(mutate).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('still removes cookie when bind fails with status 400', async () => {
|
||||
const mutate = ensureMutateAsync()
|
||||
mutate.mockRejectedValueOnce({ status: 400 })
|
||||
setSearchParams({
|
||||
ps_partner_key: 'bind-partner',
|
||||
ps_xid: 'bind-click',
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => usePSInfo())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.bind()
|
||||
})
|
||||
|
||||
const { remove } = ensureCookieMocks()
|
||||
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
|
||||
path: '/',
|
||||
domain: '.dify.ai',
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,130 +0,0 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants'
|
||||
import { Plan } from '../type'
|
||||
import PlanComp from './index'
|
||||
|
||||
let currentPath = '/billing'
|
||||
|
||||
const push = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push }),
|
||||
usePathname: () => currentPath,
|
||||
}))
|
||||
|
||||
const setShowAccountSettingModalMock = vi.fn()
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
useModalContextSelector: (selector: any) => selector({
|
||||
setShowAccountSettingModal: setShowAccountSettingModalMock,
|
||||
}),
|
||||
}))
|
||||
|
||||
const providerContextMock = vi.fn()
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => providerContextMock(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
userProfile: { email: 'user@example.com' },
|
||||
isCurrentWorkspaceManager: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
const mutateAsyncMock = vi.fn()
|
||||
let isPending = false
|
||||
vi.mock('@/service/use-education', () => ({
|
||||
useEducationVerify: () => ({
|
||||
mutateAsync: mutateAsyncMock,
|
||||
isPending,
|
||||
}),
|
||||
}))
|
||||
|
||||
const verifyStateModalMock = vi.fn(props => (
|
||||
<div data-testid="verify-modal" data-is-show={props.isShow ? 'true' : 'false'}>
|
||||
{props.isShow ? 'visible' : 'hidden'}
|
||||
</div>
|
||||
))
|
||||
vi.mock('@/app/education-apply/verify-state-modal', () => ({
|
||||
__esModule: true,
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
default: (props: any) => verifyStateModalMock(props),
|
||||
}))
|
||||
|
||||
vi.mock('../upgrade-btn', () => ({
|
||||
__esModule: true,
|
||||
default: () => <button data-testid="plan-upgrade-btn" type="button">Upgrade</button>,
|
||||
}))
|
||||
|
||||
describe('PlanComp', () => {
|
||||
const planMock = {
|
||||
type: Plan.professional,
|
||||
usage: {
|
||||
teamMembers: 4,
|
||||
documentsUploadQuota: 3,
|
||||
vectorSpace: 8,
|
||||
annotatedResponse: 5,
|
||||
triggerEvents: 60,
|
||||
apiRateLimit: 100,
|
||||
},
|
||||
total: {
|
||||
teamMembers: 10,
|
||||
documentsUploadQuota: 20,
|
||||
vectorSpace: 10,
|
||||
annotatedResponse: 500,
|
||||
triggerEvents: 100,
|
||||
apiRateLimit: 200,
|
||||
},
|
||||
reset: {
|
||||
triggerEvents: 2,
|
||||
apiRateLimit: 1,
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
currentPath = '/billing'
|
||||
isPending = false
|
||||
providerContextMock.mockReturnValue({
|
||||
plan: planMock,
|
||||
enableEducationPlan: true,
|
||||
allowRefreshEducationVerify: false,
|
||||
isEducationAccount: false,
|
||||
})
|
||||
mutateAsyncMock.mockReset()
|
||||
mutateAsyncMock.mockResolvedValue({ token: 'token' })
|
||||
})
|
||||
|
||||
it('renders plan info and handles education verify success', async () => {
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
expect(screen.getByText('billing.plans.professional.name')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('plan-upgrade-btn')).toBeInTheDocument()
|
||||
|
||||
const verifyBtn = screen.getByText('education.toVerified')
|
||||
fireEvent.click(verifyBtn)
|
||||
|
||||
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
|
||||
await waitFor(() => expect(push).toHaveBeenCalledWith('/education-apply?token=token'))
|
||||
expect(localStorage.removeItem).toHaveBeenCalledWith(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
|
||||
})
|
||||
|
||||
it('shows modal when education verify fails', async () => {
|
||||
mutateAsyncMock.mockRejectedValueOnce(new Error('boom'))
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
const verifyBtn = screen.getByText('education.toVerified')
|
||||
fireEvent.click(verifyBtn)
|
||||
|
||||
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
|
||||
await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true'))
|
||||
})
|
||||
|
||||
it('resets modal context when on education apply path', () => {
|
||||
currentPath = '/education-apply/setup'
|
||||
render(<PlanComp loc="billing-page" />)
|
||||
|
||||
expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null)
|
||||
})
|
||||
})
|
||||
@@ -1,25 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import ProgressBar from './index'
|
||||
|
||||
describe('ProgressBar', () => {
|
||||
it('renders with provided percent and color', () => {
|
||||
render(<ProgressBar percent={42} color="bg-test-color" />)
|
||||
|
||||
const bar = screen.getByTestId('billing-progress-bar')
|
||||
expect(bar).toHaveClass('bg-test-color')
|
||||
expect(bar.getAttribute('style')).toContain('width: 42%')
|
||||
})
|
||||
|
||||
it('caps width at 100% when percent exceeds max', () => {
|
||||
render(<ProgressBar percent={150} color="bg-test-color" />)
|
||||
|
||||
const bar = screen.getByTestId('billing-progress-bar')
|
||||
expect(bar.getAttribute('style')).toContain('width: 100%')
|
||||
})
|
||||
|
||||
it('uses the default color when no color prop is provided', () => {
|
||||
render(<ProgressBar percent={20} color={undefined as unknown as string} />)
|
||||
|
||||
expect(screen.getByTestId('billing-progress-bar')).toHaveClass('#2970FF')
|
||||
})
|
||||
})
|
||||
@@ -1,70 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import TriggerEventsLimitModal from './index'
|
||||
|
||||
const mockOnClose = vi.fn()
|
||||
const mockOnUpgrade = vi.fn()
|
||||
|
||||
const planUpgradeModalMock = vi.fn((props: { show: boolean, title: string, description: string, extraInfo?: React.ReactNode, onClose: () => void, onUpgrade: () => void }) => (
|
||||
<div
|
||||
data-testid="plan-upgrade-modal"
|
||||
data-show={props.show}
|
||||
data-title={props.title}
|
||||
data-description={props.description}
|
||||
>
|
||||
{props.extraInfo}
|
||||
</div>
|
||||
))
|
||||
|
||||
vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({
|
||||
__esModule: true,
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
default: (props: any) => planUpgradeModalMock(props),
|
||||
}))
|
||||
|
||||
describe('TriggerEventsLimitModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('passes the trigger usage props to the upgrade modal', () => {
|
||||
render(
|
||||
<TriggerEventsLimitModal
|
||||
show
|
||||
onClose={mockOnClose}
|
||||
onUpgrade={mockOnUpgrade}
|
||||
usage={12}
|
||||
total={20}
|
||||
resetInDays={5}
|
||||
/>,
|
||||
)
|
||||
|
||||
const modal = screen.getByTestId('plan-upgrade-modal')
|
||||
expect(modal.getAttribute('data-show')).toBe('true')
|
||||
expect(modal.getAttribute('data-title')).toContain('billing.triggerLimitModal.title')
|
||||
expect(modal.getAttribute('data-description')).toContain('billing.triggerLimitModal.description')
|
||||
expect(planUpgradeModalMock).toHaveBeenCalled()
|
||||
|
||||
const passedProps = planUpgradeModalMock.mock.calls[0][0]
|
||||
expect(passedProps.onClose).toBe(mockOnClose)
|
||||
expect(passedProps.onUpgrade).toBe(mockOnUpgrade)
|
||||
|
||||
expect(screen.getByText('billing.triggerLimitModal.usageTitle')).toBeInTheDocument()
|
||||
expect(screen.getByText('12')).toBeInTheDocument()
|
||||
expect(screen.getByText('20')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders even when trigger modal is hidden', () => {
|
||||
render(
|
||||
<TriggerEventsLimitModal
|
||||
show={false}
|
||||
onClose={mockOnClose}
|
||||
onUpgrade={mockOnUpgrade}
|
||||
usage={0}
|
||||
total={0}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(planUpgradeModalMock).toHaveBeenCalled()
|
||||
expect(screen.getByTestId('plan-upgrade-modal').getAttribute('data-show')).toBe('false')
|
||||
})
|
||||
})
|
||||
@@ -1,35 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { defaultPlan } from '../config'
|
||||
import AppsInfo from './apps-info'
|
||||
|
||||
const appsUsage = 7
|
||||
const appsTotal = 15
|
||||
|
||||
const mockPlan = {
|
||||
...defaultPlan,
|
||||
usage: {
|
||||
...defaultPlan.usage,
|
||||
buildApps: appsUsage,
|
||||
},
|
||||
total: {
|
||||
...defaultPlan.total,
|
||||
buildApps: appsTotal,
|
||||
},
|
||||
}
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
plan: mockPlan,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('AppsInfo', () => {
|
||||
it('renders build apps usage information with context data', () => {
|
||||
render(<AppsInfo className="apps-info-class" />)
|
||||
|
||||
expect(screen.getByText('billing.usagePage.buildApps')).toBeInTheDocument()
|
||||
expect(screen.getByText(`${appsUsage}`)).toBeInTheDocument()
|
||||
expect(screen.getByText(`${appsTotal}`)).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.usagePage.buildApps').closest('.apps-info-class')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,114 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { NUM_INFINITE } from '../config'
|
||||
import UsageInfo from './index'
|
||||
|
||||
const TestIcon = () => <span data-testid="usage-icon" />
|
||||
|
||||
describe('UsageInfo', () => {
|
||||
it('renders the metric with a suffix unit and tooltip text', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Apps"
|
||||
usage={30}
|
||||
total={100}
|
||||
unit="GB"
|
||||
tooltip="tooltip text"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('usage-icon')).toBeInTheDocument()
|
||||
expect(screen.getByText('Apps')).toBeInTheDocument()
|
||||
expect(screen.getByText('30')).toBeInTheDocument()
|
||||
expect(screen.getByText('100')).toBeInTheDocument()
|
||||
expect(screen.getByText('GB')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders inline unit when unitPosition is inline', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={20}
|
||||
total={100}
|
||||
unit="GB"
|
||||
unitPosition="inline"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('100GB')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows reset hint text instead of the unit when resetHint is provided', () => {
|
||||
const resetHint = 'Resets in 3 days'
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={20}
|
||||
total={100}
|
||||
unit="GB"
|
||||
resetHint={resetHint}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText(resetHint)).toBeInTheDocument()
|
||||
expect(screen.queryByText('GB')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('displays unlimited text when total is infinite', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={10}
|
||||
total={NUM_INFINITE}
|
||||
unit="GB"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('applies warning color when usage is close to the limit', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={85}
|
||||
total={100}
|
||||
/>,
|
||||
)
|
||||
|
||||
const progressBar = screen.getByTestId('billing-progress-bar')
|
||||
expect(progressBar).toHaveClass('bg-components-progress-warning-progress')
|
||||
})
|
||||
|
||||
it('applies error color when usage exceeds the limit', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={120}
|
||||
total={100}
|
||||
/>,
|
||||
)
|
||||
|
||||
const progressBar = screen.getByTestId('billing-progress-bar')
|
||||
expect(progressBar).toHaveClass('bg-components-progress-error-progress')
|
||||
})
|
||||
|
||||
it('does not render the icon when hideIcon is true', () => {
|
||||
render(
|
||||
<UsageInfo
|
||||
Icon={TestIcon}
|
||||
name="Storage"
|
||||
usage={5}
|
||||
total={100}
|
||||
hideIcon
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,58 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import VectorSpaceFull from './index'
|
||||
|
||||
type VectorProviderGlobal = typeof globalThis & {
|
||||
__vectorProviderContext?: ReturnType<typeof vi.fn>
|
||||
}
|
||||
|
||||
function getVectorGlobal(): VectorProviderGlobal {
|
||||
return globalThis as VectorProviderGlobal
|
||||
}
|
||||
|
||||
vi.mock('@/context/provider-context', () => {
|
||||
const mock = vi.fn()
|
||||
getVectorGlobal().__vectorProviderContext = mock
|
||||
return {
|
||||
useProviderContext: () => mock(),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../upgrade-btn', () => ({
|
||||
__esModule: true,
|
||||
default: () => <button data-testid="vector-upgrade-btn" type="button">Upgrade</button>,
|
||||
}))
|
||||
|
||||
describe('VectorSpaceFull', () => {
|
||||
const planMock = {
|
||||
type: 'team',
|
||||
usage: {
|
||||
vectorSpace: 8,
|
||||
},
|
||||
total: {
|
||||
vectorSpace: 10,
|
||||
},
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
const globals = getVectorGlobal()
|
||||
globals.__vectorProviderContext?.mockReturnValue({
|
||||
plan: planMock,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders tip text and upgrade button', () => {
|
||||
render(<VectorSpaceFull />)
|
||||
|
||||
expect(screen.getByText('billing.vectorSpace.fullTip')).toBeInTheDocument()
|
||||
expect(screen.getByText('billing.vectorSpace.fullSolution')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('vector-upgrade-btn')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows vector usage and total', () => {
|
||||
render(<VectorSpaceFull />)
|
||||
|
||||
expect(screen.getByText('8')).toBeInTheDocument()
|
||||
expect(screen.getByText('10MB')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -56,6 +56,9 @@ Chat applications support session persistence, allowing previous chat history to
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
|
||||
- `type` (string) Supported type:
|
||||
|
||||
@@ -56,6 +56,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
会話ID、以前のチャット記録に基づいて会話を続けるには、以前のメッセージのconversation_idを渡す必要があります。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
|
||||
- `type` (string) サポートされるタイプ:
|
||||
|
||||
@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
(选填)会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
|
||||
- `type` (string) 支持类型:
|
||||
|
||||
@@ -55,6 +55,9 @@ Chat applications support session persistence, allowing previous chat history to
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
|
||||
- `type` (string) Supported type:
|
||||
|
||||
@@ -55,6 +55,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
会話ID、以前のチャット記録に基づいて会話を続けるには、前のメッセージのconversation_idを渡す必要があります。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
|
||||
- `type` (string) サポートされるタイプ:
|
||||
|
||||
@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
<Property name='conversation_id' type='string' key='conversation_id'>
|
||||
(选填)会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。
|
||||
</Property>
|
||||
<Property name='parent_message_id' type='string' key='parent_message_id'>
|
||||
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
|
||||
</Property>
|
||||
<Property name='files' type='array[object]' key='files'>
|
||||
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
|
||||
- `type` (string) 支持类型:
|
||||
|
||||
@@ -63,11 +63,6 @@ export const useShortcuts = (): void => {
|
||||
return !isEventTargetInputArea(e.target as HTMLElement)
|
||||
}, [])
|
||||
|
||||
const shouldHandleCopy = useCallback(() => {
|
||||
const selection = document.getSelection()
|
||||
return !selection || selection.isCollapsed
|
||||
}, [])
|
||||
|
||||
useKeyPress(['delete', 'backspace'], (e) => {
|
||||
if (shouldHandleShortcut(e)) {
|
||||
e.preventDefault()
|
||||
@@ -78,7 +73,7 @@ export const useShortcuts = (): void => {
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => {
|
||||
const { showDebugAndPreviewPanel } = workflowStore.getState()
|
||||
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel) {
|
||||
if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel) {
|
||||
e.preventDefault()
|
||||
handleNodesCopy()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.11.2",
|
||||
"version": "1.11.1",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.26.1+sha512.664074abc367d2c9324fdc18037097ce0a8f126034160f709928e9e9f95d98714347044e5c3164d65bd5da6c59c6be362b107546292a8eecb7999196e5ce58fa",
|
||||
"engines": {
|
||||
|
||||
@@ -371,8 +371,6 @@ export const useTriggerPluginDynamicOptions = (payload: {
|
||||
},
|
||||
enabled: enabled && !!payload.plugin_id && !!payload.provider && !!payload.action && !!payload.parameter && !!payload.credential_id,
|
||||
retry: 0,
|
||||
staleTime: 0,
|
||||
gcTime: 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user