Compare commits

...

2 Commits

Author SHA1 Message Date
yyh
6c9304d7ef feat: implement parent message validation for service API
Added a new validation method to check parent message IDs in the MessageBasedAppGenerator class, ensuring proper handling of UUID_NIL and conversation existence. Updated related app generators and added unit tests for comprehensive coverage.
2025-12-25 15:36:18 +08:00
yyh
df09acb74b feat: chat messages api support parent message id 2025-12-25 15:21:44 +08:00
15 changed files with 371 additions and 26 deletions

View File

@@ -4,10 +4,11 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from constants import UUID_NIL
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -33,8 +34,11 @@ from libs import helper
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.conversation_service import ConversationService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
@@ -53,14 +57,18 @@ class ChatRequestPayload(BaseModel):
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: str | None = Field(default=None, description="Conversation UUID")
parent_message_id: str | None = Field(default=None, description="Parent message UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
@field_validator("conversation_id", mode="before")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_conversation_id(cls, value: str | UUID | None) -> str | None:
"""Allow missing or blank conversation IDs; enforce UUID format when provided."""
def normalize_uuid_fields(cls, value: str | UUID | None, info: ValidationInfo) -> str | None:
"""Allow missing or blank UUID fields; enforce UUID format when provided."""
if isinstance(value, UUID):
return str(value)
if isinstance(value, str):
value = value.strip()
@@ -70,7 +78,36 @@ class ChatRequestPayload(BaseModel):
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
raise ValueError(f"{info.field_name} must be a valid UUID") from exc
def _validate_parent_message_request(
*,
app_model: App,
end_user: EndUser,
conversation_id: str | None,
parent_message_id: str | None,
) -> None:
if not parent_message_id or parent_message_id == UUID_NIL:
return
if not conversation_id:
raise BadRequest("conversation_id is required when parent_message_id is provided.")
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=end_user
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
try:
parent_message = MessageService.get_message(app_model=app_model, user=end_user, message_id=parent_message_id)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
if parent_message.conversation_id != conversation.id:
raise BadRequest("parent_message_id does not belong to the conversation.")
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
@@ -205,6 +242,13 @@ class ChatApi(Resource):
streaming = payload.response_mode == "streaming"
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id=args.get("conversation_id"),
parent_message_id=args.get("parent_message_id"),
)
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@@ -127,6 +126,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
@@ -168,7 +175,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -9,7 +9,6 @@ from flask import Flask, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
@@ -105,6 +104,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -163,7 +170,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@@ -8,7 +8,6 @@ from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -97,6 +96,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
self._validate_parent_message_for_service_api(
app_model=app_model,
user=user,
conversation=conversation,
parent_message_id=args.get("parent_message_id"),
invoke_from=invoke_from,
)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
@@ -156,7 +163,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=self._resolve_parent_message_id(args, invoke_from),
user_id=user.id,
invoke_from=invoke_from,
extras=extras,

View File

@@ -1,11 +1,12 @@
import json
import logging
from collections.abc import Generator
from typing import Union, cast
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -34,6 +35,7 @@ from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Me
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
@@ -84,6 +86,34 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception("Failed to handle response, conversation_id: %s", conversation.id)
raise e
def _resolve_parent_message_id(self, args: Mapping[str, Any], invoke_from: InvokeFrom) -> str | None:
parent_message_id = args.get("parent_message_id")
if invoke_from == InvokeFrom.SERVICE_API and not parent_message_id:
return UUID_NIL
return parent_message_id
def _validate_parent_message_for_service_api(
self,
*,
app_model: App,
user: Union[Account, EndUser],
conversation: Conversation | None,
parent_message_id: str | None,
invoke_from: InvokeFrom,
) -> None:
if invoke_from != InvokeFrom.SERVICE_API:
return
if not parent_message_id or parent_message_id == UUID_NIL:
return
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
parent_message = MessageService.get_message(app_model=app_model, user=user, message_id=parent_message_id)
if parent_message.conversation_id != conversation.id:
raise MessageNotExistsError("Message not exists")
def _get_app_model_config(self, app_model: App, conversation: Conversation | None = None) -> AppModelConfig:
if conversation:
stmt = select(AppModelConfig).where(

View File

@@ -2,9 +2,8 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from pydantic import BaseModel, ConfigDict, Field
from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig
@@ -158,20 +157,12 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
parent_message_id: str | None = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
"For service API, we need to ensure its forward compatibility, "
"so passing in the parent_message_id as request arg is not supported for now. "
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
"Starting from v0.9.0, parent_message_id is used to support message regeneration "
"and branching in chat APIs."
"For service API, when it is omitted, the system treats it as UUID_NIL to preserve legacy linear history."
),
)
@field_validator("parent_message_id")
@classmethod
def validate_parent_message_id(cls, v, info: ValidationInfo):
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
raise ValueError("parent_message_id should be UUID_NIL for service API")
return v
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
"""

View File

@@ -0,0 +1,130 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from constants import UUID_NIL
from controllers.service_api.app.completion import _validate_parent_message_request
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
def test_validate_parent_message_skips_when_missing():
app_model = object()
end_user = object()
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id=None
)
get_conversation.assert_not_called()
get_message.assert_not_called()
def test_validate_parent_message_skips_uuid_nil():
app_model = object()
end_user = object()
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation") as get_conversation,
patch("controllers.service_api.app.completion.MessageService.get_message") as get_message,
):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id=None,
parent_message_id=UUID_NIL,
)
get_conversation.assert_not_called()
get_message.assert_not_called()
def test_validate_parent_message_requires_conversation_id():
app_model = object()
end_user = object()
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model, end_user=end_user, conversation_id=None, parent_message_id="parent-id"
)
def test_validate_parent_message_missing_conversation_raises_not_found():
app_model = object()
end_user = object()
with patch(
"controllers.service_api.app.completion.ConversationService.get_conversation",
side_effect=ConversationNotExistsError(),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_missing_message_raises_not_found():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch(
"controllers.service_api.app.completion.MessageService.get_message",
side_effect=MessageNotExistsError(),
),
):
with pytest.raises(NotFound):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_mismatch_conversation_raises_bad_request():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="different-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
with pytest.raises(BadRequest):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)
def test_validate_parent_message_matches_conversation():
app_model = object()
end_user = object()
conversation = SimpleNamespace(id="conversation-id")
message = SimpleNamespace(conversation_id="conversation-id")
with (
patch("controllers.service_api.app.completion.ConversationService.get_conversation", return_value=conversation),
patch("controllers.service_api.app.completion.MessageService.get_message", return_value=message),
):
_validate_parent_message_request(
app_model=app_model,
end_user=end_user,
conversation_id="conversation-id",
parent_message_id="parent-id",
)

View File

@@ -23,3 +23,24 @@ def test_chat_request_payload_validates_uuid():
def test_chat_request_payload_rejects_invalid_uuid():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "conversation_id": "invalid"})
def test_chat_request_payload_accepts_blank_parent_message_id():
payload = ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": ""})
assert payload.parent_message_id is None
def test_chat_request_payload_validates_parent_message_id_uuid():
parent_message_id = str(uuid.uuid4())
payload = ChatRequestPayload.model_validate(
{"inputs": {}, "query": "hello", "parent_message_id": parent_message_id}
)
assert payload.parent_message_id == parent_message_id
def test_chat_request_payload_rejects_invalid_parent_message_id():
with pytest.raises(ValidationError):
ChatRequestPayload.model_validate({"inputs": {}, "query": "hello", "parent_message_id": "invalid"})

View File

@@ -0,0 +1,90 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from constants import UUID_NIL
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
def test_validate_parent_message_service_api_skips_missing():
generator = MessageBasedAppGenerator()
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id=None,
invoke_from=InvokeFrom.SERVICE_API,
)
get_message.assert_not_called()
def test_validate_parent_message_service_api_skips_uuid_nil():
generator = MessageBasedAppGenerator()
with patch("core.app.apps.message_based_app_generator.MessageService.get_message") as get_message:
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id=UUID_NIL,
invoke_from=InvokeFrom.SERVICE_API,
)
get_message.assert_not_called()
def test_validate_parent_message_service_api_requires_conversation():
generator = MessageBasedAppGenerator()
with pytest.raises(ConversationNotExistsError):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=None,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)
def test_validate_parent_message_service_api_mismatch_conversation():
generator = MessageBasedAppGenerator()
conversation = SimpleNamespace(id="conversation-id")
parent_message = SimpleNamespace(conversation_id="different-id")
with patch(
"core.app.apps.message_based_app_generator.MessageService.get_message",
return_value=parent_message,
):
with pytest.raises(MessageNotExistsError):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=conversation,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)
def test_validate_parent_message_service_api_matches_conversation():
generator = MessageBasedAppGenerator()
conversation = SimpleNamespace(id="conversation-id")
parent_message = SimpleNamespace(conversation_id="conversation-id")
with patch(
"core.app.apps.message_based_app_generator.MessageService.get_message",
return_value=parent_message,
):
generator._validate_parent_message_for_service_api(
app_model=object(),
user=object(),
conversation=conversation,
parent_message_id="parent-id",
invoke_from=InvokeFrom.SERVICE_API,
)

View File

@@ -56,6 +56,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='conversation_id' type='string' key='conversation_id'>
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
</Property>
<Property name='files' type='array[object]' key='files'>
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
- `type` (string) Supported type:

View File

@@ -56,6 +56,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='conversation_id' type='string' key='conversation_id'>
会話ID、以前のチャット記録に基づいて会話を続けるには、以前のメッセージのconversation_idを渡す必要があります。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
</Property>
<Property name='files' type='array[object]' key='files'>
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
- `type` (string) サポートされるタイプ:

View File

@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='conversation_id' type='string' key='conversation_id'>
(选填)会话 ID需要基于之前的聊天记录继续对话必须传之前消息的 conversation_id。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
</Property>
<Property name='files' type='array[object]' key='files'>
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
- `type` (string) 支持类型:

View File

@@ -55,6 +55,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='conversation_id' type='string' key='conversation_id'>
Conversation ID, to continue the conversation based on previous chat records, it is necessary to pass the previous message's conversation_id.
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
Parent message ID to continue from a specific message or regenerate. Requires `conversation_id` and must belong to that conversation.
</Property>
<Property name='files' type='array[object]' key='files'>
File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision/Video capability.
- `type` (string) Supported type:

View File

@@ -55,6 +55,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='conversation_id' type='string' key='conversation_id'>
会話ID、以前のチャット記録に基づいて会話を続けるには、前のメッセージのconversation_idを渡す必要があります。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
特定のメッセージから続けたり再生成するための親メッセージID。`conversation_id` が必須で、その会話に属している必要があります。
</Property>
<Property name='files' type='array[object]' key='files'>
ファイルリスト、モデルが Vision/Video 機能をサポートしている場合に限り、ファイルをテキスト理解および質問応答に組み合わせて入力するのに適しています。
- `type` (string) サポートされるタイプ:

View File

@@ -54,6 +54,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='conversation_id' type='string' key='conversation_id'>
(选填)会话 ID需要基于之前的聊天记录继续对话必须传之前消息的 conversation_id。
</Property>
<Property name='parent_message_id' type='string' key='parent_message_id'>
用于从特定消息继续或重新生成的父消息 ID。需要提供 `conversation_id`,且必须属于该对话。
</Property>
<Property name='files' type='array[object]' key='files'>
文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision/Video 能力时可用。
- `type` (string) 支持类型: