Compare commits

..

2 Commits

Author SHA1 Message Date
yyh
6c9304d7ef feat: implement parent message validation for service API
Added a new validation method to check parent message IDs in the MessageBasedAppGenerator class, ensuring proper handling of UUID_NIL and conversation existence. Updated related app generators and added unit tests for comprehensive coverage.
2025-12-25 15:36:18 +08:00
yyh
df09acb74b feat: chat messages api support parent message id 2025-12-25 15:21:44 +08:00
36 changed files with 389 additions and 1372 deletions

View File

@@ -4,10 +4,11 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from constants import UUID_NIL
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -33,8 +34,11 @@ from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.conversation_service import ConversationService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
@@ -53,14 +57,18 @@ class ChatRequestPayload(BaseModel):
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
parent_message_id: str | None = Field(default=None, description="Parent message UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
def normalize_uuid_fields(cls, value: str | UUID | None, info: ValidationInfo) -> str | None:
"""Allow missing or blank UUID fields; enforce UUID format when provided."""
if isinstance(value, UUID):
return str(value)
if isinstance(value, str):
value = value.strip()
@@ -70,7 +78,36 @@ class ChatRequestPayload(BaseModel):
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
raise ValueError(f"{info.field_name} must be a valid UUID") from exc
def _validate_parent_message_request(
*,
app_model: App,
end_user: EndUser,
conversation_id: str | None,
parent_message_id: str | None,
) -> None:
if not parent_message_id or parent_message_id == UUID_NIL:
return
if not conversation_id:
raise BadRequest("conversation_id is required when parent_message_id is provided.")
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=end_user
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
try:
parent_message = MessageService.get_message(app_model=app_model, user=end_user, message_id=parent_message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
if parent_message.conversation_id != conversation.id:
raise BadRequest("parent_message_id does not belong to the conversation.")
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@@ -205,6 +242,13 @@ class ChatApi(Resource):
streaming = payload.response_mode == "streaming"
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id=args.get("conversation_id"),
parent_message_id=args.get("parent_message_id"),
)
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@@ -127,6 +126,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
@@ -168,7 +175,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -9,7 +9,6 @@ from flask import Flask, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
@@ -105,6 +104,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -163,7 +170,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -8,7 +8,6 @@ from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -97,6 +96,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -156,7 +163,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
invoke_from=invoke_from,
extras=extras,

View File

@@ -1,11 +1,12 @@
import json
import logging
from collections.abc import Generator
from typing import Union, cast
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -34,6 +35,7 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
@@ -84,6 +86,34 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _resolve_parent_message_id(self, args: Mapping[str, Any], invoke_from: InvokeFrom) -> str | None:
parent_message_id = args.get("parent_message_id")
if invoke_from == InvokeFrom.SERVICE_API and not parent_message_id:
return UUID_NIL
return parent_message_id
def _validate_parent_message_for_service_api(
self,
*,
app_model: App,
user: Union[Account, EndUser],
conversation: Conversation | None,
parent_message_id: str | None,
invoke_from: InvokeFrom,
) -> None:
if invoke_from != InvokeFrom.SERVICE_API:
return
if not parent_message_id or parent_message_id == UUID_NIL:
return
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
parent_message = MessageService.get_message(app_model=app_model, user=user, message_id=parent_message_id)
if parent_message.conversation_id != conversation.id:
raise MessageNotExistsError("Message not exists")
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
stmt = select(AppModelConfig).where(

View File

@@ -2,9 +2,8 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import BaseModel, ConfigDict, Field
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
@@ -158,20 +157,12 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
"For service API, we need to ensure its forward compatibility, "
"so passing in the parent_message_id as request arg is not supported for now. "
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
"Starting from v0.9.0, parent_message_id is used to support message regeneration "
"and branching in chat APIs."
"For service API, when it is omitted, the system treats it as UUID_NIL to preserve legacy linear history."
),
)
@field_validator("parent_message_id")
@classmethod
def validate_parent_message_id(cls, v, info: ValidationInfo):
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
raise ValueError("parent_message_id should be UUID_NIL for service API")
return v
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
"""

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.11.2"
version = "1.11.1"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@@ -3458,7 +3458,7 @@ class SegmentService:
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
return paginated_segments.items, paginated_segments.total

View File

@@ -110,5 +110,5 @@ class EnterpriseService:
if not app_id:
raise ValueError("app_id must be provided.")
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
body = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)

View File

@@ -286,12 +286,12 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except Exception as e:
session.rollback()
raise ValueError(str(e))
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id, "builtin")
return {"result": "success"}
@staticmethod

View File

@@ -0,0 +1,130 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from constants import UUID_NIL
from controllers.service_api.app.completion import _validate_parent_message_request
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
def test_validate_parent_message_skips_when_missing():
app_model = object()
end_user = object()
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id=None
)
get_conversation.assert_not_called()
get_message.assert_not_called()
def test_validate_parent_message_skips_uuid_nil():
app_model = object()
end_user = object()
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id=None,
parent_message_id=UUID_NIL,
)
get_conversation.assert_not_called()
get_message.assert_not_called()
def test_validate_parent_message_requires_conversation_id():
app_model = object()
end_user = object()
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id="parent-id"
)
def test_validate_parent_message_missing_conversation_raises_not_found():
app_model = object()
end_user = object()
with patch(
"controllers.service_api.app.completion.ConversationService.get_conversation",
side_effect=ConversationNotExistsError(),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_missing_message_raises_not_found():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch(
"controllers.service_api.app.completion.MessageService.get_message",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_mismatch_conversation_raises_bad_request():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="different-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_matches_conversation():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)

View File

@@ -23,3 +23,24 @@ def test_chat_request_payload_validates_uuid():
def test_chat_request_payload_rejects_invalid_uuid():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
def test_chat_request_payload_accepts_blank_parent_message_id():
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": ""})
assert payload.parent_message_id is None
def test_chat_request_payload_validates_parent_message_id_uuid():
parent_message_id = str(uuid.uuid4())
payload = ChatRequestPayload.model_validate(
{"inputs": {}, "query": "hello", "parent_message_id": parent_message_id}
)
assert payload.parent_message_id == parent_message_id
def test_chat_request_payload_rejects_invalid_parent_message_id():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": "invalid"})

View File

@@ -0,0 +1,90 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from constants import UUID_NIL
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
def test_validate_parent_message_service_api_skips_missing():
generator = MessageBasedAppGenerator()
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id=None,
invoke_from=InvokeFrom.SERVICE_API,
)
get_message.assert_not_called()
def test_validate_parent_message_service_api_skips_uuid_nil():
generator = MessageBasedAppGenerator()
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id=UUID_NIL,
invoke_from=InvokeFrom.SERVICE_API,
)
get_message.assert_not_called()
def test_validate_parent_message_service_api_requires_conversation():
generator = MessageBasedAppGenerator()
with pytest.raises(ConversationNotExistsError):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)
def test_validate_parent_message_service_api_mismatch_conversation():
generator = MessageBasedAppGenerator()
conversation = SimpleNamespace(id="conversation-id")
parent_message = SimpleNamespace(conversation_id="different-id")
with patch(
"core.app.apps.message_based_app_generator.MessageService.get_message",
return_value=parent_message,
):
with pytest.raises(MessageNotExistsError):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=conversation,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)
def test_validate_parent_message_service_api_matches_conversation():
generator = MessageBasedAppGenerator()
conversation = SimpleNamespace(id="conversation-id")
parent_message = SimpleNamespace(conversation_id="conversation-id")
with patch(
"core.app.apps.message_based_app_generator.MessageService.get_message",
return_value=parent_message,
):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=conversation,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)

View File

@@ -1,472 +0,0 @@
"""
Unit tests for SegmentService.get_segments method.
Tests the retrieval of document segments with pagination and filtering:
- Basic pagination (page, limit)
- Status filtering
- Keyword search
- Ordering by position and id (to avoid duplicate data)
"""
from unittest.mock import Mock, create_autospec, patch
import pytest
from models.dataset import DocumentSegment
class SegmentServiceTestDataFactory:
"""
Factory class for creating test data and mock objects for segment tests.
"""
@staticmethod
def create_segment_mock(
segment_id: str = "segment-123",
document_id: str = "doc-123",
tenant_id: str = "tenant-123",
dataset_id: str = "dataset-123",
position: int = 1,
content: str = "Test content",
status: str = "completed",
**kwargs,
) -> Mock:
"""
Create a mock document segment.
Args:
segment_id: Unique identifier for the segment
document_id: Parent document ID
tenant_id: Tenant ID the segment belongs to
dataset_id: Parent dataset ID
position: Position within the document
content: Segment text content
status: Indexing status
**kwargs: Additional attributes
Returns:
Mock: DocumentSegment mock object
"""
segment = create_autospec(DocumentSegment, instance=True)
segment.id = segment_id
segment.document_id = document_id
segment.tenant_id = tenant_id
segment.dataset_id = dataset_id
segment.position = position
segment.content = content
segment.status = status
for key, value in kwargs.items():
setattr(segment, key, value)
return segment
class TestSegmentServiceGetSegments:
"""
Comprehensive unit tests for SegmentService.get_segments method.
Tests cover:
- Basic pagination functionality
- Status list filtering
- Keyword search filtering
- Ordering (position + id for uniqueness)
- Empty results
- Combined filters
"""
@pytest.fixture
def mock_segment_service_dependencies(self):
"""
Common mock setup for segment service dependencies.
Patches:
- db: Database operations and pagination
- select: SQLAlchemy query builder
"""
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.select") as mock_select,
):
yield {
"db": mock_db,
"select": mock_select,
}
def test_get_segments_basic_pagination(self, mock_segment_service_dependencies):
"""
Test basic pagination functionality.
Verifies:
- Query is built with document_id and tenant_id filters
- Pagination uses correct page and limit parameters
- Returns segments and total count
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
page = 1
limit = 20
# Create mock segments
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="First segment"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=2, content="Second segment"
)
# Mock pagination result
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
# Mock select builder
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, page=page, limit=limit)
# Assert
assert len(items) == 2
assert total == 2
assert items[0].id == "seg-1"
assert items[1].id == "seg-2"
mock_segment_service_dependencies["db"].paginate.assert_called_once()
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
assert call_kwargs["max_per_page"] == 100
assert call_kwargs["error_out"] is False
def test_get_segments_with_status_filter(self, mock_segment_service_dependencies):
"""
Test filtering by status list.
Verifies:
- Status list filter is applied to query
- Only segments with matching status are returned
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed", "indexing"]
segment1 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1", status="completed")
segment2 = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-2", status="indexing")
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2]
mock_paginated.total = 2
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 2
assert total == 2
# Verify where was called multiple times (base filters + status filter)
assert mock_query.where.call_count >= 2
def test_get_segments_with_empty_status_list(self, mock_segment_service_dependencies):
"""
Test with empty status list.
Verifies:
- Empty status list is handled correctly
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = []
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id, tenant_id=tenant_id, status_list=status_list
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_with_keyword_search(self, mock_segment_service_dependencies):
"""
Test keyword search functionality.
Verifies:
- Keyword filter uses ilike for case-insensitive search
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
keyword = "search term"
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", content="This contains search term"
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id, keyword=keyword)
# Assert
assert len(items) == 1
assert total == 1
# Verify where was called for base filters + keyword filter
assert mock_query.where.call_count == 2
def test_get_segments_ordering_by_position_and_id(self, mock_segment_service_dependencies):
"""
Test ordering by position and id.
Verifies:
- Results are ordered by position ASC
- Results are secondarily ordered by id ASC to ensure uniqueness
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
# Create segments with same position but different ids
segment1 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1", position=1, content="Content 1"
)
segment2 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-2", position=1, content="Content 2"
)
segment3 = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-3", position=2, content="Content 3"
)
mock_paginated = Mock()
mock_paginated.items = [segment1, segment2, segment3]
mock_paginated.total = 3
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert len(items) == 3
assert total == 3
mock_query.order_by.assert_called_once()
def test_get_segments_empty_results(self, mock_segment_service_dependencies):
"""
Test when no segments match the criteria.
Verifies:
- Empty list is returned for items
- Total count is 0
"""
# Arrange
document_id = "non-existent-doc"
tenant_id = "tenant-123"
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(document_id=document_id, tenant_id=tenant_id)
# Assert
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, mock_segment_service_dependencies):
"""
Test with multiple filters combined.
Verifies:
- All filters work together correctly
- Status list and keyword search both applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
status_list = ["completed"]
keyword = "important"
page = 2
limit = 10
segment = SegmentServiceTestDataFactory.create_segment_mock(
segment_id="seg-1",
status="completed",
content="This is important information",
)
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=status_list,
keyword=keyword,
page=page,
limit=limit,
)
# Assert
assert len(items) == 1
assert total == 1
# Verify filters: base + status + keyword
assert mock_query.where.call_count == 3
# Verify pagination parameters
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["page"] == page
assert call_kwargs["per_page"] == limit
def test_get_segments_with_none_status_list(self, mock_segment_service_dependencies):
"""
Test with None status list.
Verifies:
- None status list is handled correctly
- No status filter is applied
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
segment = SegmentServiceTestDataFactory.create_segment_mock(segment_id="seg-1")
mock_paginated = Mock()
mock_paginated.items = [segment]
mock_paginated.total = 1
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
items, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
status_list=None,
)
# Assert
assert len(items) == 1
assert total == 1
# Should only be called once (base filters only, no status filter)
assert mock_query.where.call_count == 1
def test_get_segments_pagination_max_per_page_limit(self, mock_segment_service_dependencies):
"""
Test that max_per_page is correctly set to 100.
Verifies:
- max_per_page parameter is set to 100
- This prevents excessive page sizes
"""
# Arrange
document_id = "doc-123"
tenant_id = "tenant-123"
limit = 200 # Request more than max_per_page
mock_paginated = Mock()
mock_paginated.items = []
mock_paginated.total = 0
mock_segment_service_dependencies["db"].paginate.return_value = mock_paginated
mock_query = Mock()
mock_segment_service_dependencies["select"].return_value = mock_query
mock_query.where.return_value = mock_query
mock_query.order_by.return_value = mock_query
# Act
from services.dataset_service import SegmentService
SegmentService.get_segments(
document_id=document_id,
tenant_id=tenant_id,
limit=limit,
)
# Assert
call_kwargs = mock_segment_service_dependencies["db"].paginate.call_args[1]
assert call_kwargs["max_per_page"] == 100

2
api/uv.lock generated
View File

@@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.11.2"
version = "1.11.1"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },

View File

@@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -692,7 +692,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -734,7 +734,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -773,7 +773,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.11.2
image: langgenius/dify-api:1.11.1
restart: always
environment:
# Use the shared environment variables.
@@ -803,7 +803,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.11.2
image: langgenius/dify-web:1.11.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -1,84 +0,0 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Billing from './index'
let currentBillingUrl: string | null = 'https://billing'
let fetching = false
let isManager = true
let enableBilling = true
const refetchMock = vi.fn()
const openAsyncWindowMock = vi.fn()
vi.mock('@/service/use-billing', () => ({
useBillingUrl: () => ({
data: currentBillingUrl,
isFetching: fetching,
refetch: refetchMock,
}),
}))
vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: () => openAsyncWindowMock,
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
isCurrentWorkspaceManager: isManager,
}),
}))
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
enableBilling,
}),
}))
vi.mock('../plan', () => ({
__esModule: true,
default: ({ loc }: { loc: string }) => <div data-testid="plan-component" data-loc={loc} />,
}))
describe('Billing', () => {
beforeEach(() => {
vi.clearAllMocks()
currentBillingUrl = 'https://billing'
fetching = false
isManager = true
enableBilling = true
refetchMock.mockResolvedValue({ data: 'https://billing' })
})
it('hides the billing action when user is not manager or billing is disabled', () => {
isManager = false
render(<Billing />)
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
vi.clearAllMocks()
isManager = true
enableBilling = false
render(<Billing />)
expect(screen.queryByRole('button', { name: /billing\.viewBillingTitle/ })).not.toBeInTheDocument()
})
it('opens the billing window with the immediate url when the button is clicked', async () => {
render(<Billing />)
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
fireEvent.click(actionButton)
await waitFor(() => expect(openAsyncWindowMock).toHaveBeenCalled())
const [, options] = openAsyncWindowMock.mock.calls[0]
expect(options).toMatchObject({
immediateUrl: currentBillingUrl,
features: 'noopener,noreferrer',
})
})
it('disables the button while billing url is fetching', () => {
fetching = true
render(<Billing />)
const actionButton = screen.getByRole('button', { name: /billing\.viewBillingTitle/ })
expect(actionButton).toBeDisabled()
})
})

View File

@@ -1,92 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { Plan } from '../type'
import HeaderBillingBtn from './index'
type HeaderGlobal = typeof globalThis & {
__mockProviderContext?: ReturnType<typeof vi.fn>
}
function getHeaderGlobal(): HeaderGlobal {
return globalThis as HeaderGlobal
}
const ensureProviderContextMock = () => {
const globals = getHeaderGlobal()
if (!globals.__mockProviderContext)
throw new Error('Provider context mock not set')
return globals.__mockProviderContext
}
vi.mock('@/context/provider-context', () => {
const mock = vi.fn()
const globals = getHeaderGlobal()
globals.__mockProviderContext = mock
return {
useProviderContext: () => mock(),
}
})
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="upgrade-btn" type="button">Upgrade</button>,
}))
describe('HeaderBillingBtn', () => {
beforeEach(() => {
vi.clearAllMocks()
ensureProviderContextMock().mockReturnValue({
plan: {
type: Plan.professional,
},
enableBilling: true,
isFetchedPlan: true,
})
})
it('renders nothing when billing is disabled or plan is not fetched', () => {
ensureProviderContextMock().mockReturnValueOnce({
plan: {
type: Plan.professional,
},
enableBilling: false,
isFetchedPlan: true,
})
const { container } = render(<HeaderBillingBtn />)
expect(container.firstChild).toBeNull()
})
it('renders upgrade button for sandbox plan', () => {
ensureProviderContextMock().mockReturnValueOnce({
plan: {
type: Plan.sandbox,
},
enableBilling: true,
isFetchedPlan: true,
})
render(<HeaderBillingBtn />)
expect(screen.getByTestId('upgrade-btn')).toBeInTheDocument()
})
it('renders plan badge and forwards clicks when not display-only', () => {
const onClick = vi.fn()
const { rerender } = render(<HeaderBillingBtn onClick={onClick} />)
const badge = screen.getByText('pro').closest('div')
expect(badge).toHaveClass('cursor-pointer')
fireEvent.click(badge!)
expect(onClick).toHaveBeenCalledTimes(1)
rerender(<HeaderBillingBtn onClick={onClick} isDisplayOnly />)
expect(screen.getByText('pro').closest('div')).toHaveClass('cursor-default')
fireEvent.click(screen.getByText('pro').closest('div')!)
expect(onClick).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,44 +0,0 @@
import { render } from '@testing-library/react'
import PartnerStack from './index'
let isCloudEdition = true
const saveOrUpdate = vi.fn()
const bind = vi.fn()
vi.mock('@/config', () => ({
get IS_CLOUD_EDITION() {
return isCloudEdition
},
}))
vi.mock('./use-ps-info', () => ({
__esModule: true,
default: () => ({
saveOrUpdate,
bind,
}),
}))
describe('PartnerStack', () => {
beforeEach(() => {
vi.clearAllMocks()
isCloudEdition = true
})
it('does not call partner stack helpers when not in cloud edition', () => {
isCloudEdition = false
render(<PartnerStack />)
expect(saveOrUpdate).not.toHaveBeenCalled()
expect(bind).not.toHaveBeenCalled()
})
it('calls saveOrUpdate and bind once when running in cloud edition', () => {
render(<PartnerStack />)
expect(saveOrUpdate).toHaveBeenCalledTimes(1)
expect(bind).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,197 +0,0 @@
import { act, renderHook } from '@testing-library/react'
import { PARTNER_STACK_CONFIG } from '@/config'
import usePSInfo from './use-ps-info'
let searchParamsValues: Record<string, string | null> = {}
const setSearchParams = (values: Record<string, string | null>) => {
searchParamsValues = values
}
type PartnerStackGlobal = typeof globalThis & {
__partnerStackCookieMocks?: {
get: ReturnType<typeof vi.fn>
set: ReturnType<typeof vi.fn>
remove: ReturnType<typeof vi.fn>
}
__partnerStackMutateAsync?: ReturnType<typeof vi.fn>
}
function getPartnerStackGlobal(): PartnerStackGlobal {
return globalThis as PartnerStackGlobal
}
const ensureCookieMocks = () => {
const globals = getPartnerStackGlobal()
if (!globals.__partnerStackCookieMocks)
throw new Error('Cookie mocks not initialized')
return globals.__partnerStackCookieMocks
}
const ensureMutateAsync = () => {
const globals = getPartnerStackGlobal()
if (!globals.__partnerStackMutateAsync)
throw new Error('Mutate mock not initialized')
return globals.__partnerStackMutateAsync
}
vi.mock('js-cookie', () => {
const get = vi.fn()
const set = vi.fn()
const remove = vi.fn()
const globals = getPartnerStackGlobal()
globals.__partnerStackCookieMocks = { get, set, remove }
const cookieApi = { get, set, remove }
return {
__esModule: true,
default: cookieApi,
get,
set,
remove,
}
})
vi.mock('next/navigation', () => ({
useSearchParams: () => ({
get: (key: string) => searchParamsValues[key] ?? null,
}),
}))
vi.mock('@/service/use-billing', () => {
const mutateAsync = vi.fn()
const globals = getPartnerStackGlobal()
globals.__partnerStackMutateAsync = mutateAsync
return {
useBindPartnerStackInfo: () => ({
mutateAsync,
}),
}
})
describe('usePSInfo', () => {
const originalLocationDescriptor = Object.getOwnPropertyDescriptor(globalThis, 'location')
beforeAll(() => {
Object.defineProperty(globalThis, 'location', {
value: { hostname: 'cloud.dify.ai' },
configurable: true,
})
})
beforeEach(() => {
setSearchParams({})
const { get, set, remove } = ensureCookieMocks()
get.mockReset()
set.mockReset()
remove.mockReset()
const mutate = ensureMutateAsync()
mutate.mockReset()
mutate.mockResolvedValue(undefined)
get.mockReturnValue('{}')
})
afterAll(() => {
if (originalLocationDescriptor)
Object.defineProperty(globalThis, 'location', originalLocationDescriptor)
})
it('saves partner info when query params change', () => {
const { get, set } = ensureCookieMocks()
get.mockReturnValue(JSON.stringify({ partnerKey: 'old', clickId: 'old-click' }))
setSearchParams({
ps_partner_key: 'new-partner',
ps_xid: 'new-click',
})
const { result } = renderHook(() => usePSInfo())
expect(result.current.psPartnerKey).toBe('new-partner')
expect(result.current.psClickId).toBe('new-click')
act(() => {
result.current.saveOrUpdate()
})
expect(set).toHaveBeenCalledWith(
PARTNER_STACK_CONFIG.cookieName,
JSON.stringify({
partnerKey: 'new-partner',
clickId: 'new-click',
}),
{
expires: PARTNER_STACK_CONFIG.saveCookieDays,
path: '/',
domain: '.dify.ai',
},
)
})
it('does not overwrite cookie when params do not change', () => {
setSearchParams({
ps_partner_key: 'existing',
ps_xid: 'existing-click',
})
const { get } = ensureCookieMocks()
get.mockReturnValue(JSON.stringify({
partnerKey: 'existing',
clickId: 'existing-click',
}))
const { result } = renderHook(() => usePSInfo())
act(() => {
result.current.saveOrUpdate()
})
const { set } = ensureCookieMocks()
expect(set).not.toHaveBeenCalled()
})
it('binds partner info and clears cookie once', async () => {
setSearchParams({
ps_partner_key: 'bind-partner',
ps_xid: 'bind-click',
})
const { result } = renderHook(() => usePSInfo())
const mutate = ensureMutateAsync()
const { remove } = ensureCookieMocks()
await act(async () => {
await result.current.bind()
})
expect(mutate).toHaveBeenCalledWith({
partnerKey: 'bind-partner',
clickId: 'bind-click',
})
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
path: '/',
domain: '.dify.ai',
})
await act(async () => {
await result.current.bind()
})
expect(mutate).toHaveBeenCalledTimes(1)
})
it('still removes cookie when bind fails with status 400', async () => {
const mutate = ensureMutateAsync()
mutate.mockRejectedValueOnce({ status: 400 })
setSearchParams({
ps_partner_key: 'bind-partner',
ps_xid: 'bind-click',
})
const { result } = renderHook(() => usePSInfo())
await act(async () => {
await result.current.bind()
})
const { remove } = ensureCookieMocks()
expect(remove).toHaveBeenCalledWith(PARTNER_STACK_CONFIG.cookieName, {
path: '/',
domain: '.dify.ai',
})
})
})

View File

@@ -1,130 +0,0 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants'
import { Plan } from '../type'
import PlanComp from './index'
let currentPath = '/billing'
const push = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({ push }),
usePathname: () => currentPath,
}))
const setShowAccountSettingModalMock = vi.fn()
vi.mock('@/context/modal-context', () => ({
// eslint-disable-next-line ts/no-explicit-any
useModalContextSelector: (selector: any) => selector({
setShowAccountSettingModal: setShowAccountSettingModalMock,
}),
}))
const providerContextMock = vi.fn()
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => providerContextMock(),
}))
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
userProfile: { email: 'user@example.com' },
isCurrentWorkspaceManager: true,
}),
}))
const mutateAsyncMock = vi.fn()
let isPending = false
vi.mock('@/service/use-education', () => ({
useEducationVerify: () => ({
mutateAsync: mutateAsyncMock,
isPending,
}),
}))
const verifyStateModalMock = vi.fn(props => (
<div data-testid="verify-modal" data-is-show={props.isShow ? 'true' : 'false'}>
{props.isShow ? 'visible' : 'hidden'}
</div>
))
vi.mock('@/app/education-apply/verify-state-modal', () => ({
__esModule: true,
// eslint-disable-next-line ts/no-explicit-any
default: (props: any) => verifyStateModalMock(props),
}))
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="plan-upgrade-btn" type="button">Upgrade</button>,
}))
describe('PlanComp', () => {
const planMock = {
type: Plan.professional,
usage: {
teamMembers: 4,
documentsUploadQuota: 3,
vectorSpace: 8,
annotatedResponse: 5,
triggerEvents: 60,
apiRateLimit: 100,
},
total: {
teamMembers: 10,
documentsUploadQuota: 20,
vectorSpace: 10,
annotatedResponse: 500,
triggerEvents: 100,
apiRateLimit: 200,
},
reset: {
triggerEvents: 2,
apiRateLimit: 1,
},
}
beforeEach(() => {
vi.clearAllMocks()
currentPath = '/billing'
isPending = false
providerContextMock.mockReturnValue({
plan: planMock,
enableEducationPlan: true,
allowRefreshEducationVerify: false,
isEducationAccount: false,
})
mutateAsyncMock.mockReset()
mutateAsyncMock.mockResolvedValue({ token: 'token' })
})
it('renders plan info and handles education verify success', async () => {
render(<PlanComp loc="billing-page" />)
expect(screen.getByText('billing.plans.professional.name')).toBeInTheDocument()
expect(screen.getByTestId('plan-upgrade-btn')).toBeInTheDocument()
const verifyBtn = screen.getByText('education.toVerified')
fireEvent.click(verifyBtn)
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
await waitFor(() => expect(push).toHaveBeenCalledWith('/education-apply?token=token'))
expect(localStorage.removeItem).toHaveBeenCalledWith(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM)
})
it('shows modal when education verify fails', async () => {
mutateAsyncMock.mockRejectedValueOnce(new Error('boom'))
render(<PlanComp loc="billing-page" />)
const verifyBtn = screen.getByText('education.toVerified')
fireEvent.click(verifyBtn)
await waitFor(() => expect(mutateAsyncMock).toHaveBeenCalled())
await waitFor(() => expect(screen.getByTestId('verify-modal').getAttribute('data-is-show')).toBe('true'))
})
it('resets modal context when on education apply path', () => {
currentPath = '/education-apply/setup'
render(<PlanComp loc="billing-page" />)
expect(setShowAccountSettingModalMock).toHaveBeenCalledWith(null)
})
})

View File

@@ -1,25 +0,0 @@
import { render, screen } from '@testing-library/react'
import ProgressBar from './index'
describe('ProgressBar', () => {
it('renders with provided percent and color', () => {
render(<ProgressBar percent={42} color="bg-test-color" />)
const bar = screen.getByTestId('billing-progress-bar')
expect(bar).toHaveClass('bg-test-color')
expect(bar.getAttribute('style')).toContain('width: 42%')
})
it('caps width at 100% when percent exceeds max', () => {
render(<ProgressBar percent={150} color="bg-test-color" />)
const bar = screen.getByTestId('billing-progress-bar')
expect(bar.getAttribute('style')).toContain('width: 100%')
})
it('uses the default color when no color prop is provided', () => {
render(<ProgressBar percent={20} color={undefined as unknown as string} />)
expect(screen.getByTestId('billing-progress-bar')).toHaveClass('#2970FF')
})
})

View File

@@ -1,70 +0,0 @@
import { render, screen } from '@testing-library/react'
import TriggerEventsLimitModal from './index'
const mockOnClose = vi.fn()
const mockOnUpgrade = vi.fn()
const planUpgradeModalMock = vi.fn((props: { show: boolean, title: string, description: string, extraInfo?: React.ReactNode, onClose: () => void, onUpgrade: () => void }) => (
<div
data-testid="plan-upgrade-modal"
data-show={props.show}
data-title={props.title}
data-description={props.description}
>
{props.extraInfo}
</div>
))
vi.mock('@/app/components/billing/plan-upgrade-modal', () => ({
__esModule: true,
// eslint-disable-next-line ts/no-explicit-any
default: (props: any) => planUpgradeModalMock(props),
}))
describe('TriggerEventsLimitModal', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('passes the trigger usage props to the upgrade modal', () => {
render(
<TriggerEventsLimitModal
show
onClose={mockOnClose}
onUpgrade={mockOnUpgrade}
usage={12}
total={20}
resetInDays={5}
/>,
)
const modal = screen.getByTestId('plan-upgrade-modal')
expect(modal.getAttribute('data-show')).toBe('true')
expect(modal.getAttribute('data-title')).toContain('billing.triggerLimitModal.title')
expect(modal.getAttribute('data-description')).toContain('billing.triggerLimitModal.description')
expect(planUpgradeModalMock).toHaveBeenCalled()
const passedProps = planUpgradeModalMock.mock.calls[0][0]
expect(passedProps.onClose).toBe(mockOnClose)
expect(passedProps.onUpgrade).toBe(mockOnUpgrade)
expect(screen.getByText('billing.triggerLimitModal.usageTitle')).toBeInTheDocument()
expect(screen.getByText('12')).toBeInTheDocument()
expect(screen.getByText('20')).toBeInTheDocument()
})
it('renders even when trigger modal is hidden', () => {
render(
<TriggerEventsLimitModal
show={false}
onClose={mockOnClose}
onUpgrade={mockOnUpgrade}
usage={0}
total={0}
/>,
)
expect(planUpgradeModalMock).toHaveBeenCalled()
expect(screen.getByTestId('plan-upgrade-modal').getAttribute('data-show')).toBe('false')
})
})

View File

@@ -1,35 +0,0 @@
import { render, screen } from '@testing-library/react'
import { defaultPlan } from '../config'
import AppsInfo from './apps-info'
const appsUsage = 7
const appsTotal = 15
const mockPlan = {
...defaultPlan,
usage: {
...defaultPlan.usage,
buildApps: appsUsage,
},
total: {
...defaultPlan.total,
buildApps: appsTotal,
},
}
vi.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
plan: mockPlan,
}),
}))
describe('AppsInfo', () => {
it('renders build apps usage information with context data', () => {
render(<AppsInfo className="apps-info-class" />)
expect(screen.getByText('billing.usagePage.buildApps')).toBeInTheDocument()
expect(screen.getByText(`${appsUsage}`)).toBeInTheDocument()
expect(screen.getByText(`${appsTotal}`)).toBeInTheDocument()
expect(screen.getByText('billing.usagePage.buildApps').closest('.apps-info-class')).toBeInTheDocument()
})
})

View File

@@ -1,114 +0,0 @@
import { render, screen } from '@testing-library/react'
import { NUM_INFINITE } from '../config'
import UsageInfo from './index'
const TestIcon = () => <span data-testid="usage-icon" />
describe('UsageInfo', () => {
it('renders the metric with a suffix unit and tooltip text', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Apps"
usage={30}
total={100}
unit="GB"
tooltip="tooltip text"
/>,
)
expect(screen.getByTestId('usage-icon')).toBeInTheDocument()
expect(screen.getByText('Apps')).toBeInTheDocument()
expect(screen.getByText('30')).toBeInTheDocument()
expect(screen.getByText('100')).toBeInTheDocument()
expect(screen.getByText('GB')).toBeInTheDocument()
})
it('renders inline unit when unitPosition is inline', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={20}
total={100}
unit="GB"
unitPosition="inline"
/>,
)
expect(screen.getByText('100GB')).toBeInTheDocument()
})
it('shows reset hint text instead of the unit when resetHint is provided', () => {
const resetHint = 'Resets in 3 days'
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={20}
total={100}
unit="GB"
resetHint={resetHint}
/>,
)
expect(screen.getByText(resetHint)).toBeInTheDocument()
expect(screen.queryByText('GB')).not.toBeInTheDocument()
})
it('displays unlimited text when total is infinite', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={10}
total={NUM_INFINITE}
unit="GB"
/>,
)
expect(screen.getByText('billing.plansCommon.unlimited')).toBeInTheDocument()
})
it('applies warning color when usage is close to the limit', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={85}
total={100}
/>,
)
const progressBar = screen.getByTestId('billing-progress-bar')
expect(progressBar).toHaveClass('bg-components-progress-warning-progress')
})
it('applies error color when usage exceeds the limit', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={120}
total={100}
/>,
)
const progressBar = screen.getByTestId('billing-progress-bar')
expect(progressBar).toHaveClass('bg-components-progress-error-progress')
})
it('does not render the icon when hideIcon is true', () => {
render(
<UsageInfo
Icon={TestIcon}
name="Storage"
usage={5}
total={100}
hideIcon
/>,
)
expect(screen.queryByTestId('usage-icon')).not.toBeInTheDocument()
})
})

View File

@@ -1,58 +0,0 @@
import { render, screen } from '@testing-library/react'
import VectorSpaceFull from './index'
type VectorProviderGlobal = typeof globalThis & {
__vectorProviderContext?: ReturnType<typeof vi.fn>
}
function getVectorGlobal(): VectorProviderGlobal {
return globalThis as VectorProviderGlobal
}
vi.mock('@/context/provider-context', () => {
const mock = vi.fn()
getVectorGlobal().__vectorProviderContext = mock
return {
useProviderContext: () => mock(),
}
})
vi.mock('../upgrade-btn', () => ({
__esModule: true,
default: () => <button data-testid="vector-upgrade-btn" type="button">Upgrade</button>,
}))
describe('VectorSpaceFull', () => {
const planMock = {
type: 'team',
usage: {
vectorSpace: 8,
},
total: {
vectorSpace: 10,
},
}
beforeEach(() => {
vi.clearAllMocks()
const globals = getVectorGlobal()
globals.__vectorProviderContext?.mockReturnValue({
plan: planMock,
})
})
it('renders tip text and upgrade button', () => {
render(<VectorSpaceFull />)
expect(screen.getByText('billing.vectorSpace.fullTip')).toBeInTheDocument()
expect(screen.getByText('billing.vectorSpace.fullSolution')).toBeInTheDocument()
expect(screen.getByTestId('vector-upgrade-btn')).toBeInTheDocument()
})
it('shows vector usage and total', () => {
render(<VectorSpaceFull />)
expect(screen.getByText('8')).toBeInTheDocument()
expect(screen.getByText('10MB')).toBeInTheDocument()
})
})

View File

@@ -56,6 +56,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='conversation_id' type='string' key='conversation_id'>
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
</Property>
<Property name='files' type='array[object]' key='files'>
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
- `type` (string) Supported type:

View File

@@ -56,6 +56,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='conversation_id' type='string' key='conversation_id'>
会話ID、以前のチャット記録に基づいて会話を続けるには、以前のメッセージのconversation_idを渡す必要があります。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
</Property>
<Property name='files' type='array[object]' key='files'>
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
- `type` (string) サポートされるタイプ:

View File

@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='conversation_id' type='string' key='conversation_id'>
(选填)会话 ID需要基于之前的聊天记录继续对话必须传之前消息的 conversation_id。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
</Property>
<Property name='files' type='array[object]' key='files'>
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
- `type` (string) 支持类型:

View File

@@ -55,6 +55,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='conversation_id' type='string' key='conversation_id'>
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
</Property>
<Property name='files' type='array[object]' key='files'>
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
- `type` (string) Supported type:

View File

@@ -55,6 +55,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='conversation_id' type='string' key='conversation_id'>
会話ID、以前のチャット記録に基づいて会話を続けるには、前のメッセージのconversation_idを渡す必要があります。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
</Property>
<Property name='files' type='array[object]' key='files'>
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
- `type` (string) サポートされるタイプ:

View File

@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='conversation_id' type='string' key='conversation_id'>
(选填)会话 ID需要基于之前的聊天记录继续对话必须传之前消息的 conversation_id。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
</Property>
<Property name='files' type='array[object]' key='files'>
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
- `type` (string) 支持类型:

View File

@@ -63,11 +63,6 @@ export const useShortcuts = (): void => {
return !isEventTargetInputArea(e.target as HTMLElement)
}, [])
const shouldHandleCopy = useCallback(() => {
const selection = document.getSelection()
return !selection || selection.isCollapsed
}, [])
useKeyPress(['delete', 'backspace'], (e) => {
if (shouldHandleShortcut(e)) {
e.preventDefault()
@@ -78,7 +73,7 @@ export const useShortcuts = (): void => {
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => {
const { showDebugAndPreviewPanel } = workflowStore.getState()
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel) {
if (shouldHandleShortcut(e) && !showDebugAndPreviewPanel) {
e.preventDefault()
handleNodesCopy()
}

View File

@@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.11.2",
"version": "1.11.1",
"private": true,
"packageManager": "pnpm@10.26.1+sha512.664074abc367d2c9324fdc18037097ce0a8f126034160f709928e9e9f95d98714347044e5c3164d65bd5da6c59c6be362b107546292a8eecb7999196e5ce58fa",
"engines": {

View File

@@ -371,8 +371,6 @@ export const useTriggerPluginDynamicOptions = (payload: {
},
enabled: enabled && !!payload.plugin_id && !!payload.provider && !!payload.action && !!payload.parameter && !!payload.credential_id,
retry: 0,
staleTime: 0,
gcTime: 0,
})
}