mirror of
https://github.com/langgenius/dify.git
synced 2026-01-27 18:44:39 +00:00
Compare commits
2 Commits
refactor/d
...
yanli/api-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ab15c9ba1 | ||
|
|
e48419937b |
@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
@@ -203,6 +215,9 @@ class AppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
@@ -210,21 +225,41 @@ class AppRunner:
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
||||
def _handle_invoke_result_direct(
|
||||
self,
|
||||
invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish(
|
||||
@@ -235,13 +270,22 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
self,
|
||||
invoke_result: Generator[LLMResultChunk, None, None],
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
model: str = ""
|
||||
@@ -259,12 +303,26 @@ class AppRunner:
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, str):
|
||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
||||
_logger.warning("received multimodal output, type=%s", type(content))
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content # failback to str
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -289,6 +347,101 @@ class AppRunner:
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_multimodal_image_content(
|
||||
self,
|
||||
content: ImagePromptMessageContent,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle multimodal image content from LLM response.
|
||||
Save the image and create a MessageFile record.
|
||||
|
||||
:param content: ImagePromptMessageContent instance
|
||||
:param message_id: message id
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
_logger.info("Handling multimodal image content for message %s", message_id)
|
||||
|
||||
image_url = content.url
|
||||
base64_data = content.base64_data
|
||||
|
||||
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
|
||||
|
||||
if not image_url and not base64_data:
|
||||
_logger.warning("Image content has neither URL nor base64 data")
|
||||
return
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
# Save the image file
|
||||
try:
|
||||
if image_url:
|
||||
# Download image from URL
|
||||
_logger.info("Downloading image from URL: %s", image_url)
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
elif base64_data:
|
||||
if base64_data.startswith("data:"):
|
||||
base64_data = base64_data.split(",", 1)[1]
|
||||
|
||||
image_binary = base64.b64decode(base64_data)
|
||||
mimetype = content.mime_type or "image/png"
|
||||
extension = guess_extension(mimetype) or ".png"
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=image_binary,
|
||||
mimetype=mimetype,
|
||||
filename=f"generated_image{extension}",
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
else:
|
||||
return
|
||||
except Exception:
|
||||
_logger.exception("Failed to save image file")
|
||||
return
|
||||
|
||||
# Create MessageFile record
|
||||
message_file = MessageFile(
|
||||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
# Publish QueueMessageFileEvent
|
||||
queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamEvent,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
||||
@@ -5,7 +5,7 @@ from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
@@ -57,13 +58,15 @@ class MessageCycleManager:
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
# Fast path: cached determination from prior QueueMessageFileEvent
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||
with session_factory.create_session() as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||
|
||||
if has_file:
|
||||
if message_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
@@ -199,6 +202,8 @@ class MessageCycleManager:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
self._message_has_file.add(message_file.message_id)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk(
|
||||
Build a single `LLMResult` from the first returned chunk.
|
||||
|
||||
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
|
||||
|
||||
Note:
|
||||
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
|
||||
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
|
||||
"""
|
||||
content = ""
|
||||
content_list: list[PromptMessageContentUnionTypes] = []
|
||||
@@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk(
|
||||
system_fingerprint: str | None = None
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
try:
|
||||
first_chunk = next(chunks, None)
|
||||
if first_chunk is not None:
|
||||
if isinstance(first_chunk.delta.message.content, str):
|
||||
content += first_chunk.delta.message.content
|
||||
elif isinstance(first_chunk.delta.message.content, list):
|
||||
content_list.extend(first_chunk.delta.message.content)
|
||||
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
if first_chunk.delta.message.tool_calls:
|
||||
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
|
||||
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
|
||||
system_fingerprint = first_chunk.system_fingerprint
|
||||
finally:
|
||||
try:
|
||||
for _ in chunks:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
|
||||
|
||||
return LLMResult(
|
||||
model=model,
|
||||
|
||||
@@ -0,0 +1,454 @@
|
||||
"""Test multimodal image output handling in BaseAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
class TestBaseAppRunnerMultimodal:
|
||||
"""Test that BaseAppRunner correctly handles multimodal image content."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_id(self):
|
||||
"""Mock tenant ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_id(self):
|
||||
"""Mock message ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = MagicMock()
|
||||
manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_file(self):
|
||||
"""Create a mock tool file."""
|
||||
tool_file = MagicMock()
|
||||
tool_file.id = str(uuid4())
|
||||
return tool_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file(self):
|
||||
"""Create a mock message file."""
|
||||
message_file = MagicMock()
|
||||
message_file.id = str(uuid4())
|
||||
return message_file
|
||||
|
||||
def test_handle_multimodal_image_content_with_url(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from URL."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from URL
|
||||
mock_mgr.create_file_by_url.assert_called_once_with(
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
# Verify message file was created with correct parameters
|
||||
mock_msg_file_class.assert_called_once()
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["message_id"] == mock_message_id
|
||||
assert call_kwargs["type"] == FileType.IMAGE
|
||||
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
|
||||
assert call_kwargs["belongs_to"] == "assistant"
|
||||
assert call_kwargs["created_by"] == mock_user_id
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once_with(mock_message_file)
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once_with(mock_message_file)
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
publish_call = mock_queue_manager.publish.call_args
|
||||
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
|
||||
assert publish_call[0][0].message_file_id == mock_message_file.id
|
||||
# publish_from might be passed as positional or keyword argument
|
||||
assert (
|
||||
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
|
||||
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data."""
|
||||
# Arrange
|
||||
import base64
|
||||
|
||||
# Create a small test image (1x1 PNG)
|
||||
test_image_data = base64.b64encode(
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
|
||||
).decode()
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=test_image_data,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from base64
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
assert call_kwargs["user_id"] == mock_user_id
|
||||
assert call_kwargs["tenant_id"] == mock_tenant_id
|
||||
assert call_kwargs["conversation_id"] is None
|
||||
assert "file_binary" in call_kwargs
|
||||
assert call_kwargs["mimetype"] == "image/png"
|
||||
assert call_kwargs["filename"].startswith("generated_image")
|
||||
assert call_kwargs["filename"].endswith(".png")
|
||||
|
||||
# Verify message file was created
|
||||
mock_msg_file_class.assert_called_once()
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64_data_uri(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data with URI prefix."""
|
||||
# Arrange
|
||||
# Data URI format: data:image/png;base64,<base64_data>
|
||||
test_image_data = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
)
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=f"data:image/png;base64,{test_image_data}",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify that base64 data was extracted correctly (without prefix)
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
# The base64 data should be decoded, so we check the binary was passed
|
||||
assert "file_binary" in call_kwargs
|
||||
|
||||
def test_handle_multimodal_image_content_without_url_or_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content without URL or base64 data."""
|
||||
# Arrange
|
||||
content = ImagePromptMessageContent(
|
||||
url="",
|
||||
base64_data="",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create any files or publish events
|
||||
mock_mgr_class.assert_not_called()
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_with_error(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content when an error occurs."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock to raise exception
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
# Should not raise exception, just log it
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create message file or publish event on error
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_debugger_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that debugger mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is ACCOUNT for debugger mode
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_handle_multimodal_image_content_service_api_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that service API mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is END_USER for service API
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
@@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
@@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open a session when event_type is provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
@@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
# Should open session once
|
||||
mock_session_factory.create_session.assert_called_once()
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
@@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open session again when event_type provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
|
||||
@@ -101,3 +101,26 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
|
||||
assert result.message.tool_calls == []
|
||||
assert result.usage == LLMUsage.empty_usage()
|
||||
assert result.system_fingerprint is None
|
||||
|
||||
|
||||
def test__normalize_non_stream_plugin_result__closes_chunk_iterator():
|
||||
prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage())
|
||||
closed: list[bool] = []
|
||||
|
||||
def _chunk_iter():
|
||||
try:
|
||||
yield chunk
|
||||
yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage())
|
||||
finally:
|
||||
closed.append(True)
|
||||
|
||||
result = _normalize_non_stream_plugin_result(
|
||||
model="test-model",
|
||||
prompt_messages=prompt_messages,
|
||||
result=_chunk_iter(),
|
||||
)
|
||||
|
||||
assert result.message.content == "hello"
|
||||
assert closed == [True]
|
||||
|
||||
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Tests for multimodal image file handling in chat hooks.
|
||||
* Tests the file object conversion logic without full hook integration.
|
||||
*/
|
||||
|
||||
describe('Multimodal File Handling', () => {
|
||||
describe('File type to MIME type mapping', () => {
|
||||
it('should map image to image/png', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedMime = 'image/png'
|
||||
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map video to video/mp4', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedMime = 'video/mp4'
|
||||
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map audio to audio/mpeg', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedMime = 'audio/mpeg'
|
||||
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map unknown to application/octet-stream', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const expectedMime = 'application/octet-stream'
|
||||
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
})
|
||||
|
||||
describe('TransferMethod selection', () => {
|
||||
it('should select remote_url for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('remote_url')
|
||||
})
|
||||
|
||||
it('should select local_file for non-images', () => {
|
||||
const fileType: string = 'video'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('local_file')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File extension mapping', () => {
|
||||
it('should use .png extension for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedExtension = '.png'
|
||||
const extension = fileType === 'image' ? 'png' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp4 extension for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedExtension = '.mp4'
|
||||
const extension = fileType === 'video' ? 'mp4' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp3 extension for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedExtension = '.mp3'
|
||||
const extension = fileType === 'audio' ? 'mp3' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
})
|
||||
|
||||
describe('File name generation', () => {
|
||||
it('should generate correct file name for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedName = 'generated_image.png'
|
||||
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedName = 'generated_video.mp4'
|
||||
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedName = 'generated_audio.mp3'
|
||||
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
})
|
||||
|
||||
describe('SupportFileType mapping', () => {
|
||||
it('should map image type to image supportFileType', () => {
|
||||
const fileType: string = 'image'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('image')
|
||||
})
|
||||
|
||||
it('should map video type to video supportFileType', () => {
|
||||
const fileType: string = 'video'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('video')
|
||||
})
|
||||
|
||||
it('should map audio type to audio supportFileType', () => {
|
||||
const fileType: string = 'audio'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('audio')
|
||||
})
|
||||
|
||||
it('should map unknown type to document supportFileType', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('document')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File conversion logic', () => {
|
||||
it('should detect existing transferMethod', () => {
|
||||
const fileWithTransferMethod = {
|
||||
id: 'file-123',
|
||||
transferMethod: 'remote_url' as const,
|
||||
type: 'image/png',
|
||||
name: 'test.png',
|
||||
size: 1024,
|
||||
supportFileType: 'image',
|
||||
progress: 100,
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
|
||||
expect(hasTransferMethod).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect missing transferMethod', () => {
|
||||
const fileWithoutTransferMethod = {
|
||||
id: 'file-456',
|
||||
type: 'image',
|
||||
url: 'http://example.com/image.png',
|
||||
belongs_to: 'assistant',
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
|
||||
expect(hasTransferMethod).toBe(false)
|
||||
})
|
||||
|
||||
it('should create file with size 0 for generated files', () => {
|
||||
const expectedSize = 0
|
||||
expect(expectedSize).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Agent vs Non-Agent mode logic', () => {
|
||||
it('should check for agent_thoughts to determine mode', () => {
|
||||
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [{}],
|
||||
}
|
||||
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is empty', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [],
|
||||
}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is undefined', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBeFalsy()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -419,9 +419,40 @@ export const useChat = (
|
||||
}
|
||||
},
|
||||
onFile(file) {
|
||||
// Convert simple file type to MIME type for non-agent mode
|
||||
// Backend sends: { id, type: "image", belongs_to, url }
|
||||
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
|
||||
|
||||
// Determine file type for MIME conversion
|
||||
const fileType = (file as { type?: string }).type || 'image'
|
||||
|
||||
// If file already has transferMethod, use it as base and ensure all required fields exist
|
||||
// Otherwise, create a new complete file object
|
||||
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
|
||||
|
||||
const convertedFile: FileEntity = {
|
||||
id: baseFile?.id || (file as { id: string }).id,
|
||||
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
|
||||
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
|
||||
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
|
||||
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
|
||||
progress: baseFile?.progress ?? 100,
|
||||
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
|
||||
url: baseFile?.url || (file as { url?: string }).url,
|
||||
size: baseFile?.size ?? 0, // Generated files don't have a known size
|
||||
}
|
||||
|
||||
// For agent mode, add files to the last thought
|
||||
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
|
||||
if (lastThought)
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
|
||||
if (lastThought) {
|
||||
const thought = lastThought as { message_files?: FileEntity[] }
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
|
||||
}
|
||||
// For non-agent mode, add files directly to responseItem.message_files
|
||||
else {
|
||||
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
|
||||
responseItem.message_files = [...currentFiles, convertedFile]
|
||||
}
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
|
||||
@@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Find submit button by class
|
||||
@@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Component should still be functional - check for the main container
|
||||
@@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
isLoading: false,
|
||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox to be rendered with timeout for CI environment
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Submit
|
||||
@@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
// Verify the component is still rendered after submission
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
// Wait for the mutation to complete
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
|
||||
it('should render ResultItem components for non-external results', async () => {
|
||||
@@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
isLoading: false,
|
||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for component to be fully rendered with longer timeout
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Submit a query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
// Verify component is rendered after submission
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
// Wait for mutation to complete with longer timeout
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
|
||||
it('should render external results when dataset is external', async () => {
|
||||
@@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
// Component should render
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type in textarea to verify component is functional
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
})
|
||||
// Verify component is still functional after submission
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Enter query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||
|
||||
// Submit
|
||||
@@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Enter query and submit
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
// Wait for state updates
|
||||
await waitFor(() => {
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
}, { timeout: 2000 })
|
||||
}, { timeout: 3000 })
|
||||
|
||||
// Verify mutation was called
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
@@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Submit a query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
// Verify the component renders
|
||||
await waitFor(() => {
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
}, { timeout: 3000 })
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({
|
||||
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
|
||||
}))
|
||||
|
||||
// Mock marketplace client used by marketplace utils
|
||||
vi.mock('@/service/client', () => ({
|
||||
marketplaceClient: {
|
||||
collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
collections: [
|
||||
{
|
||||
name: 'collection-1',
|
||||
label: { 'en-US': 'Collection 1' },
|
||||
description: { 'en-US': 'Desc' },
|
||||
rule: '',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-01',
|
||||
searchable: true,
|
||||
search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' },
|
||||
},
|
||||
],
|
||||
},
|
||||
})),
|
||||
collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
plugins: [
|
||||
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||
],
|
||||
},
|
||||
})),
|
||||
// Some utils paths may call searchAdvanced; provide a minimal stub
|
||||
searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
plugins: [
|
||||
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||
],
|
||||
total: 1,
|
||||
},
|
||||
})),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock context/query-client
|
||||
vi.mock('@/context/query-client', () => ({
|
||||
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
|
||||
@@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => {
|
||||
// ================================
|
||||
// Async Utils Tests
|
||||
// ================================
|
||||
|
||||
// Narrow mock surface and avoid any in tests
|
||||
// Types are local to this spec to keep scope minimal
|
||||
|
||||
type FnMock = ReturnType<typeof vi.fn>
|
||||
|
||||
type MarketplaceClientMock = {
|
||||
collectionPlugins: FnMock
|
||||
collections: FnMock
|
||||
}
|
||||
|
||||
describe('Async Utils', () => {
|
||||
let marketplaceClientMock: MarketplaceClientMock
|
||||
|
||||
beforeAll(async () => {
|
||||
const mod = await import('@/service/client')
|
||||
marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock
|
||||
})
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
@@ -1490,12 +1545,10 @@ describe('Async Utils', () => {
|
||||
{ type: 'plugin', org: 'test', name: 'plugin2' },
|
||||
]
|
||||
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Adjusted to our mocked marketplaceClient instead of fetch
|
||||
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||
data: { plugins: mockPlugins },
|
||||
})
|
||||
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
const result = await getMarketplacePluginsByCollectionId('test-collection', {
|
||||
@@ -1504,12 +1557,13 @@ describe('Async Utils', () => {
|
||||
type: 'plugin',
|
||||
})
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalled()
|
||||
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||
expect(result).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should handle fetch error and return empty array', async () => {
|
||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
||||
// Simulate error from client
|
||||
marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
const result = await getMarketplacePluginsByCollectionId('test-collection')
|
||||
@@ -1519,25 +1573,18 @@ describe('Async Utils', () => {
|
||||
|
||||
it('should pass abort signal when provided', async () => {
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Our client mock receives the signal as second arg
|
||||
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||
data: { plugins: mockPlugins },
|
||||
})
|
||||
|
||||
const controller = new AbortController()
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
|
||||
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
expect.any(Request),
|
||||
expect.any(Object),
|
||||
)
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('test-collection')
|
||||
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||
const call = marketplaceClientMock.collectionPlugins.mock.calls[0]
|
||||
expect(call[1]).toMatchObject({ signal: controller.signal })
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1548,23 +1595,17 @@ describe('Async Utils', () => {
|
||||
]
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
|
||||
let callCount = 0
|
||||
globalThis.fetch = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Simulate two-step client calls: collections then collectionPlugins
|
||||
let stage = 0
|
||||
marketplaceClientMock.collections.mockImplementationOnce(async () => {
|
||||
stage = 1
|
||||
return { data: { collections: mockCollections } }
|
||||
})
|
||||
marketplaceClientMock.collectionPlugins.mockImplementation(async () => {
|
||||
if (stage === 1) {
|
||||
return { data: { plugins: mockPlugins } }
|
||||
}
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
return { data: { plugins: [] } }
|
||||
})
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
@@ -1578,7 +1619,8 @@ describe('Async Utils', () => {
|
||||
})
|
||||
|
||||
it('should handle fetch error and return empty data', async () => {
|
||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
||||
// Simulate client error
|
||||
marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
const result = await getMarketplaceCollectionsAndPlugins()
|
||||
@@ -1588,24 +1630,16 @@ describe('Async Utils', () => {
|
||||
})
|
||||
|
||||
it('should append condition and type to URL when provided', async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { collections: [] } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
// Assert that the client was called with query containing condition/type
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
await getMarketplaceCollectionsAndPlugins({
|
||||
condition: 'category=tool',
|
||||
type: 'bundle',
|
||||
})
|
||||
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalled()
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('condition=category%3Dtool')
|
||||
expect(marketplaceClientMock.collections).toHaveBeenCalled()
|
||||
const call = marketplaceClientMock.collections.mock.calls[0]
|
||||
expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) })
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
useWorkflowReadOnly,
|
||||
} from '../hooks'
|
||||
import { useStore, useWorkflowStore } from '../store'
|
||||
import { BlockEnum, ControlMode, WorkflowRunningStatus } from '../types'
|
||||
import { BlockEnum, ControlMode } from '../types'
|
||||
import {
|
||||
getLayoutByDagre,
|
||||
getLayoutForChildNodes,
|
||||
@@ -36,17 +36,12 @@ export const useWorkflowInteractions = () => {
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
|
||||
const handleCancelDebugAndPreviewPanel = useCallback(() => {
|
||||
const { workflowRunningData } = workflowStore.getState()
|
||||
const runningStatus = workflowRunningData?.result?.status
|
||||
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
|
||||
workflowStore.setState({
|
||||
showDebugAndPreviewPanel: false,
|
||||
...(isActiveRun ? {} : { workflowRunningData: undefined }),
|
||||
workflowRunningData: undefined,
|
||||
})
|
||||
if (!isActiveRun) {
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}, [workflowStore, handleNodeCancelRunningStatus, handleEdgeCancelRunningStatus])
|
||||
|
||||
return {
|
||||
|
||||
@@ -21,7 +21,7 @@ import {
|
||||
import { BlockEnum } from '../../types'
|
||||
import ConversationVariableModal from './conversation-variable-modal'
|
||||
import Empty from './empty'
|
||||
import { useChat } from './hooks/use-chat'
|
||||
import { useChat } from './hooks'
|
||||
import UserInput from './user-input'
|
||||
|
||||
type ChatWrapperProps = {
|
||||
|
||||
516
web/app/components/workflow/panel/debug-and-preview/hooks.ts
Normal file
516
web/app/components/workflow/panel/debug-and-preview/hooks.ts
Normal file
@@ -0,0 +1,516 @@
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type {
|
||||
ChatItem,
|
||||
ChatItemInTree,
|
||||
Inputs,
|
||||
} from '@/app/components/base/chat/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { uniqBy } from 'es-toolkit/compat'
|
||||
import { produce, setAutoFreeze } from 'immer'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
getProcessedInputs,
|
||||
processOpeningStatement,
|
||||
} from '@/app/components/base/chat/chat/utils'
|
||||
import { getThreadMessages } from '@/app/components/base/chat/utils'
|
||||
import {
|
||||
getProcessedFiles,
|
||||
getProcessedFilesFromResponse,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { useInvalidAllLastRun } from '@/service/use-workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../constants'
|
||||
import {
|
||||
useSetWorkflowVarsWithValue,
|
||||
useWorkflowRun,
|
||||
} from '../../hooks'
|
||||
import { useHooksStore } from '../../hooks-store'
|
||||
import { useWorkflowStore } from '../../store'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '../../types'
|
||||
|
||||
type GetAbortController = (abortController: AbortController) => void
|
||||
type SendCallback = {
|
||||
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<any>
|
||||
}
|
||||
export const useChat = (
|
||||
config: any,
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
},
|
||||
prevChatTree?: ChatItemInTree[],
|
||||
stopChat?: (taskId: string) => void,
|
||||
) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const hasStopResponded = useRef(false)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const conversationId = useRef('')
|
||||
const taskIdRef = useRef('')
|
||||
const [isResponding, setIsResponding] = useState(false)
|
||||
const isRespondingRef = useRef(false)
|
||||
const configsMap = useHooksStore(s => s.configsMap)
|
||||
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
|
||||
const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([])
|
||||
const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null)
|
||||
const {
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
} = workflowStore.getState()
|
||||
|
||||
const handleResponding = useCallback((isResponding: boolean) => {
|
||||
setIsResponding(isResponding)
|
||||
isRespondingRef.current = isResponding
|
||||
}, [])
|
||||
|
||||
const [chatTree, setChatTree] = useState<ChatItemInTree[]>(prevChatTree || [])
|
||||
const chatTreeRef = useRef<ChatItemInTree[]>(chatTree)
|
||||
const [targetMessageId, setTargetMessageId] = useState<string>()
|
||||
const threadMessages = useMemo(() => getThreadMessages(chatTree, targetMessageId), [chatTree, targetMessageId])
|
||||
|
||||
const getIntroduction = useCallback((str: string) => {
|
||||
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
|
||||
}, [formSettings?.inputs, formSettings?.inputsForm])
|
||||
|
||||
/** Final chat list that will be rendered */
|
||||
const chatList = useMemo(() => {
|
||||
const ret = [...threadMessages]
|
||||
if (config?.opening_statement) {
|
||||
const index = threadMessages.findIndex(item => item.isOpeningStatement)
|
||||
|
||||
if (index > -1) {
|
||||
ret[index] = {
|
||||
...ret[index],
|
||||
content: getIntroduction(config.opening_statement),
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
}
|
||||
}
|
||||
else {
|
||||
ret.unshift({
|
||||
id: `${Date.now()}`,
|
||||
content: getIntroduction(config.opening_statement),
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
|
||||
|
||||
useEffect(() => {
|
||||
setAutoFreeze(false)
|
||||
return () => {
|
||||
setAutoFreeze(true)
|
||||
}
|
||||
}, [])
|
||||
|
||||
/** Find the target node by bfs and then operate on it */
|
||||
const produceChatTreeNode = useCallback((targetId: string, operation: (node: ChatItemInTree) => void) => {
|
||||
return produce(chatTreeRef.current, (draft) => {
|
||||
const queue: ChatItemInTree[] = [...draft]
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!
|
||||
if (current.id === targetId) {
|
||||
operation(current)
|
||||
break
|
||||
}
|
||||
if (current.children)
|
||||
queue.push(...current.children)
|
||||
}
|
||||
})
|
||||
}, [])
|
||||
|
||||
const handleStop = useCallback(() => {
|
||||
hasStopResponded.current = true
|
||||
handleResponding(false)
|
||||
if (stopChat && taskIdRef.current)
|
||||
stopChat(taskIdRef.current)
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
if (suggestedQuestionsAbortControllerRef.current)
|
||||
suggestedQuestionsAbortControllerRef.current.abort()
|
||||
}, [handleResponding, setIterTimes, setLoopTimes, stopChat])
|
||||
|
||||
const handleRestart = useCallback(() => {
|
||||
conversationId.current = ''
|
||||
taskIdRef.current = ''
|
||||
handleStop()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
setChatTree([])
|
||||
setSuggestQuestions([])
|
||||
}, [
|
||||
handleStop,
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
])
|
||||
|
||||
const updateCurrentQAOnTree = useCallback(({
|
||||
parentId,
|
||||
responseItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
}: {
|
||||
parentId?: string
|
||||
responseItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
questionItem: ChatItem
|
||||
}) => {
|
||||
let nextState: ChatItemInTree[]
|
||||
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] }
|
||||
if (!parentId && !chatTree.some(item => [placeholderQuestionId, questionItem.id].includes(item.id))) {
|
||||
// QA whose parent is not provided is considered as a first message of the conversation,
|
||||
// and it should be a root node of the chat tree
|
||||
nextState = produce(chatTree, (draft) => {
|
||||
draft.push(currentQA)
|
||||
})
|
||||
}
|
||||
else {
|
||||
// find the target QA in the tree and update it; if not found, insert it to its parent node
|
||||
nextState = produceChatTreeNode(parentId!, (parentNode) => {
|
||||
const questionNodeIndex = parentNode.children!.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
if (questionNodeIndex === -1)
|
||||
parentNode.children!.push(currentQA)
|
||||
else
|
||||
parentNode.children![questionNodeIndex] = currentQA
|
||||
})
|
||||
}
|
||||
setChatTree(nextState)
|
||||
chatTreeRef.current = nextState
|
||||
}, [chatTree, produceChatTreeNode])
|
||||
|
||||
const handleSend = useCallback((
|
||||
params: {
|
||||
query: string
|
||||
files?: FileEntity[]
|
||||
parent_message_id?: string
|
||||
[key: string]: any
|
||||
},
|
||||
{
|
||||
onGetSuggestedQuestions,
|
||||
}: SendCallback,
|
||||
) => {
|
||||
if (isRespondingRef.current) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
|
||||
|
||||
const placeholderQuestionId = `question-${Date.now()}`
|
||||
const questionItem = {
|
||||
id: placeholderQuestionId,
|
||||
content: params.query,
|
||||
isAnswer: false,
|
||||
message_files: params.files,
|
||||
parentMessageId: params.parent_message_id,
|
||||
}
|
||||
|
||||
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
|
||||
const placeholderAnswerItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
|
||||
}
|
||||
|
||||
setTargetMessageId(parentMessage?.id)
|
||||
updateCurrentQAOnTree({
|
||||
parentId: params.parent_message_id,
|
||||
responseItem: placeholderAnswerItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
})
|
||||
|
||||
// answer
|
||||
const responseItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
agent_thoughts: [],
|
||||
message_files: [],
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
|
||||
}
|
||||
|
||||
handleResponding(true)
|
||||
|
||||
const { files, inputs, ...restParams } = params
|
||||
const bodyParams = {
|
||||
files: getProcessedFiles(files || []),
|
||||
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
|
||||
...restParams,
|
||||
}
|
||||
if (bodyParams?.files?.length) {
|
||||
bodyParams.files = bodyParams.files.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
let hasSetResponseId = false
|
||||
|
||||
handleRun(
|
||||
bodyParams,
|
||||
{
|
||||
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
|
||||
responseItem.content = responseItem.content + message
|
||||
|
||||
if (messageId && !hasSetResponseId) {
|
||||
questionItem.id = `question-${messageId}`
|
||||
responseItem.id = messageId
|
||||
responseItem.parentMessageId = questionItem.id
|
||||
hasSetResponseId = true
|
||||
}
|
||||
|
||||
if (isFirstMessage && newConversationId)
|
||||
conversationId.current = newConversationId
|
||||
|
||||
taskIdRef.current = taskId
|
||||
if (messageId)
|
||||
responseItem.id = messageId
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
async onCompleted(hasError?: boolean, errorMessage?: string) {
|
||||
handleResponding(false)
|
||||
fetchInspectVars({})
|
||||
invalidAllLastRun()
|
||||
|
||||
if (hasError) {
|
||||
if (errorMessage) {
|
||||
responseItem.content = errorMessage
|
||||
responseItem.isError = true
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) {
|
||||
try {
|
||||
const { data }: any = await onGetSuggestedQuestions(
|
||||
responseItem.id,
|
||||
newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController,
|
||||
)
|
||||
setSuggestQuestions(data)
|
||||
}
|
||||
// eslint-disable-next-line unused-imports/no-unused-vars
|
||||
catch (error) {
|
||||
setSuggestQuestions([])
|
||||
}
|
||||
}
|
||||
},
|
||||
onMessageEnd: (messageEnd) => {
|
||||
responseItem.citation = messageEnd.metadata?.retriever_resources || []
|
||||
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
|
||||
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
responseItem.content = messageReplace.answer
|
||||
},
|
||||
onError() {
|
||||
handleResponding(false)
|
||||
},
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
taskIdRef.current = task_id
|
||||
responseItem.workflow_run_id = workflow_run_id
|
||||
responseItem.workflowProcess = {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onWorkflowFinished: ({ data }) => {
|
||||
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onIterationStart: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
})
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onIterationFinish: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onLoopStart: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
})
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onLoopFinish: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onNodeStarted: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as any)
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onNodeRetry: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push(data)
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onNodeFinished: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onAgentLog: ({ data }) => {
|
||||
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentNodeIndex > -1) {
|
||||
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
|
||||
|
||||
if (current.execution_metadata) {
|
||||
if (current.execution_metadata.agent_log) {
|
||||
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
|
||||
if (currentLogIndex > -1) {
|
||||
current.execution_metadata.agent_log[currentLogIndex] = {
|
||||
...current.execution_metadata.agent_log[currentLogIndex],
|
||||
...data,
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log.push(data)
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log = [data]
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata = {
|
||||
agent_log: [data],
|
||||
} as any
|
||||
}
|
||||
|
||||
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
|
||||
...current,
|
||||
}
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [threadMessages, chatTree.length, updateCurrentQAOnTree, handleResponding, formSettings?.inputsForm, handleRun, notify, t, config?.suggested_questions_after_answer?.enabled, fetchInspectVars, invalidAllLastRun])
|
||||
|
||||
return {
|
||||
conversationId: conversationId.current,
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
handleSend,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
isResponding,
|
||||
suggestedQuestions,
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { ChatItem, ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
|
||||
export type ChatConfig = {
|
||||
opening_statement?: string
|
||||
suggested_questions?: string[]
|
||||
suggested_questions_after_answer?: {
|
||||
enabled?: boolean
|
||||
}
|
||||
text_to_speech?: unknown
|
||||
speech_to_text?: unknown
|
||||
retriever_resource?: unknown
|
||||
sensitive_word_avoidance?: unknown
|
||||
file_upload?: unknown
|
||||
}
|
||||
|
||||
export type GetAbortController = (abortController: AbortController) => void
|
||||
|
||||
export type SendCallback = {
|
||||
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<unknown>
|
||||
}
|
||||
|
||||
export type SendParams = {
|
||||
query: string
|
||||
files?: FileEntity[]
|
||||
parent_message_id?: string
|
||||
inputs?: Record<string, unknown>
|
||||
conversation_id?: string
|
||||
}
|
||||
|
||||
export type UpdateCurrentQAParams = {
|
||||
parentId?: string
|
||||
responseItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
questionItem: ChatItem
|
||||
}
|
||||
|
||||
export type ChatTreeUpdater = (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void
|
||||
@@ -1,90 +0,0 @@
|
||||
import { useCallback } from 'react'
|
||||
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../../constants'
|
||||
import { useEdgesInteractionsWithoutSync } from '../../../hooks/use-edges-interactions-without-sync'
|
||||
import { useNodesInteractionsWithoutSync } from '../../../hooks/use-nodes-interactions-without-sync'
|
||||
import { useStore, useWorkflowStore } from '../../../store'
|
||||
import { WorkflowRunningStatus } from '../../../types'
|
||||
|
||||
type UseChatFlowControlParams = {
|
||||
stopChat?: (taskId: string) => void
|
||||
}
|
||||
|
||||
export function useChatFlowControl({
|
||||
stopChat,
|
||||
}: UseChatFlowControlParams) {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const setIsResponding = useStore(s => s.setIsResponding)
|
||||
const resetChatPreview = useStore(s => s.resetChatPreview)
|
||||
const setActiveTaskId = useStore(s => s.setActiveTaskId)
|
||||
const setHasStopResponded = useStore(s => s.setHasStopResponded)
|
||||
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
|
||||
const invalidateRun = useStore(s => s.invalidateRun)
|
||||
const { handleNodeCancelRunningStatus } = useNodesInteractionsWithoutSync()
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
|
||||
const { setIterTimes, setLoopTimes } = workflowStore.getState()
|
||||
|
||||
const handleResponding = useCallback((responding: boolean) => {
|
||||
setIsResponding(responding)
|
||||
}, [setIsResponding])
|
||||
|
||||
const handleStop = useCallback(() => {
|
||||
const {
|
||||
activeTaskId,
|
||||
suggestedQuestionsAbortController,
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
} = workflowStore.getState()
|
||||
const runningStatus = workflowRunningData?.result?.status
|
||||
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
|
||||
setHasStopResponded(true)
|
||||
handleResponding(false)
|
||||
if (stopChat && activeTaskId)
|
||||
stopChat(activeTaskId)
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
if (suggestedQuestionsAbortController)
|
||||
suggestedQuestionsAbortController.abort()
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
setActiveTaskId('')
|
||||
invalidateRun()
|
||||
if (isActiveRun && workflowRunningData) {
|
||||
setWorkflowRunningData({
|
||||
...workflowRunningData,
|
||||
result: {
|
||||
...workflowRunningData.result,
|
||||
status: WorkflowRunningStatus.Stopped,
|
||||
},
|
||||
})
|
||||
}
|
||||
if (isActiveRun) {
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}
|
||||
}, [
|
||||
handleResponding,
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
stopChat,
|
||||
workflowStore,
|
||||
setHasStopResponded,
|
||||
setSuggestedQuestionsAbortController,
|
||||
setActiveTaskId,
|
||||
invalidateRun,
|
||||
handleNodeCancelRunningStatus,
|
||||
handleEdgeCancelRunningStatus,
|
||||
])
|
||||
|
||||
const handleRestart = useCallback(() => {
|
||||
handleStop()
|
||||
resetChatPreview()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
}, [handleStop, setIterTimes, setLoopTimes, resetChatPreview])
|
||||
|
||||
return {
|
||||
handleResponding,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { processOpeningStatement } from '@/app/components/base/chat/chat/utils'
|
||||
import { getThreadMessages } from '@/app/components/base/chat/utils'
|
||||
|
||||
type UseChatListParams = {
|
||||
chatTree: ChatItemInTree[]
|
||||
targetMessageId: string | undefined
|
||||
config: {
|
||||
opening_statement?: string
|
||||
suggested_questions?: string[]
|
||||
} | undefined
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
}
|
||||
}
|
||||
|
||||
export function useChatList({
|
||||
chatTree,
|
||||
targetMessageId,
|
||||
config,
|
||||
formSettings,
|
||||
}: UseChatListParams) {
|
||||
const threadMessages = useMemo(
|
||||
() => getThreadMessages(chatTree, targetMessageId),
|
||||
[chatTree, targetMessageId],
|
||||
)
|
||||
|
||||
const getIntroduction = useCallback((str: string) => {
|
||||
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
|
||||
}, [formSettings?.inputs, formSettings?.inputsForm])
|
||||
|
||||
const chatList = useMemo(() => {
|
||||
const ret = [...threadMessages]
|
||||
if (config?.opening_statement) {
|
||||
const index = threadMessages.findIndex(item => item.isOpeningStatement)
|
||||
|
||||
if (index > -1) {
|
||||
ret[index] = {
|
||||
...ret[index],
|
||||
content: getIntroduction(config.opening_statement),
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
}
|
||||
}
|
||||
else {
|
||||
ret.unshift({
|
||||
id: `${Date.now()}`,
|
||||
content: getIntroduction(config.opening_statement),
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
|
||||
|
||||
return {
|
||||
threadMessages,
|
||||
chatList,
|
||||
getIntroduction,
|
||||
}
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
import type { SendCallback, SendParams, UpdateCurrentQAParams } from './types'
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItem, ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { uniqBy } from 'es-toolkit/compat'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getProcessedInputs } from '@/app/components/base/chat/chat/utils'
|
||||
import { getProcessedFiles, getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { useInvalidAllLastRun } from '@/service/use-workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { useSetWorkflowVarsWithValue, useWorkflowRun } from '../../../hooks'
|
||||
import { useHooksStore } from '../../../hooks-store'
|
||||
import { useStore, useWorkflowStore } from '../../../store'
|
||||
import { createWorkflowEventHandlers } from './use-workflow-event-handlers'
|
||||
|
||||
type UseChatMessageSenderParams = {
|
||||
threadMessages: ChatItemInTree[]
|
||||
config?: {
|
||||
suggested_questions_after_answer?: {
|
||||
enabled?: boolean
|
||||
}
|
||||
}
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
}
|
||||
handleResponding: (responding: boolean) => void
|
||||
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
|
||||
}
|
||||
|
||||
export function useChatMessageSender({
|
||||
threadMessages,
|
||||
config,
|
||||
formSettings,
|
||||
handleResponding,
|
||||
updateCurrentQAOnTree,
|
||||
}: UseChatMessageSenderParams) {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const workflowStore = useWorkflowStore()
|
||||
|
||||
const configsMap = useHooksStore(s => s.configsMap)
|
||||
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
|
||||
const setConversationId = useStore(s => s.setConversationId)
|
||||
const setTargetMessageId = useStore(s => s.setTargetMessageId)
|
||||
const setSuggestedQuestions = useStore(s => s.setSuggestedQuestions)
|
||||
const setActiveTaskId = useStore(s => s.setActiveTaskId)
|
||||
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
|
||||
const startRun = useStore(s => s.startRun)
|
||||
|
||||
const handleSend = useCallback((
|
||||
params: SendParams,
|
||||
{ onGetSuggestedQuestions }: SendCallback,
|
||||
) => {
|
||||
if (workflowStore.getState().isResponding) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const { suggestedQuestionsAbortController } = workflowStore.getState()
|
||||
if (suggestedQuestionsAbortController)
|
||||
suggestedQuestionsAbortController.abort()
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
|
||||
const runId = startRun()
|
||||
const isCurrentRun = () => runId === workflowStore.getState().activeRunId
|
||||
|
||||
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
|
||||
|
||||
const placeholderQuestionId = `question-${Date.now()}`
|
||||
const questionItem: ChatItem = {
|
||||
id: placeholderQuestionId,
|
||||
content: params.query,
|
||||
isAnswer: false,
|
||||
message_files: params.files,
|
||||
parentMessageId: params.parent_message_id,
|
||||
}
|
||||
|
||||
const siblingIndex = parentMessage?.children?.length ?? workflowStore.getState().chatTree.length
|
||||
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
|
||||
const placeholderAnswerItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex,
|
||||
}
|
||||
|
||||
setTargetMessageId(parentMessage?.id)
|
||||
updateCurrentQAOnTree({
|
||||
parentId: params.parent_message_id,
|
||||
responseItem: placeholderAnswerItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
})
|
||||
|
||||
const responseItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
agent_thoughts: [],
|
||||
message_files: [],
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex,
|
||||
}
|
||||
|
||||
handleResponding(true)
|
||||
|
||||
const { files, inputs, ...restParams } = params
|
||||
const bodyParams = {
|
||||
files: getProcessedFiles(files || []),
|
||||
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
|
||||
...restParams,
|
||||
}
|
||||
if (bodyParams?.files?.length) {
|
||||
bodyParams.files = bodyParams.files.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
let hasSetResponseId = false
|
||||
|
||||
const workflowHandlers = createWorkflowEventHandlers({
|
||||
responseItem,
|
||||
questionItem,
|
||||
placeholderQuestionId,
|
||||
parentMessageId: params.parent_message_id,
|
||||
updateCurrentQAOnTree,
|
||||
})
|
||||
|
||||
handleRun(
|
||||
bodyParams,
|
||||
{
|
||||
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.content = responseItem.content + message
|
||||
|
||||
if (messageId && !hasSetResponseId) {
|
||||
questionItem.id = `question-${messageId}`
|
||||
responseItem.id = messageId
|
||||
responseItem.parentMessageId = questionItem.id
|
||||
hasSetResponseId = true
|
||||
}
|
||||
|
||||
if (isFirstMessage && newConversationId)
|
||||
setConversationId(newConversationId)
|
||||
|
||||
if (taskId)
|
||||
setActiveTaskId(taskId)
|
||||
if (messageId)
|
||||
responseItem.id = messageId
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
async onCompleted(hasError?: boolean, errorMessage?: string) {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
handleResponding(false)
|
||||
fetchInspectVars({})
|
||||
invalidAllLastRun()
|
||||
|
||||
if (hasError) {
|
||||
if (errorMessage) {
|
||||
responseItem.content = errorMessage
|
||||
responseItem.isError = true
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (config?.suggested_questions_after_answer?.enabled && !workflowStore.getState().hasStopResponded && onGetSuggestedQuestions) {
|
||||
try {
|
||||
const result = await onGetSuggestedQuestions(
|
||||
responseItem.id,
|
||||
newAbortController => setSuggestedQuestionsAbortController(newAbortController),
|
||||
) as { data: string[] }
|
||||
setSuggestedQuestions(result.data)
|
||||
}
|
||||
catch {
|
||||
setSuggestedQuestions([])
|
||||
}
|
||||
finally {
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
}
|
||||
}
|
||||
},
|
||||
onMessageEnd: (messageEnd) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.citation = messageEnd.metadata?.retriever_resources || []
|
||||
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
|
||||
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.content = messageReplace.answer
|
||||
},
|
||||
onError() {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
handleResponding(false)
|
||||
},
|
||||
onWorkflowStarted: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
const taskId = workflowHandlers.onWorkflowStarted(event)
|
||||
if (taskId)
|
||||
setActiveTaskId(taskId)
|
||||
},
|
||||
onWorkflowFinished: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onWorkflowFinished(event)
|
||||
},
|
||||
onIterationStart: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onIterationStart(event)
|
||||
},
|
||||
onIterationFinish: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onIterationFinish(event)
|
||||
},
|
||||
onLoopStart: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onLoopStart(event)
|
||||
},
|
||||
onLoopFinish: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onLoopFinish(event)
|
||||
},
|
||||
onNodeStarted: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeStarted(event)
|
||||
},
|
||||
onNodeRetry: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeRetry(event)
|
||||
},
|
||||
onNodeFinished: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeFinished(event)
|
||||
},
|
||||
onAgentLog: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onAgentLog(event)
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [
|
||||
threadMessages,
|
||||
updateCurrentQAOnTree,
|
||||
handleResponding,
|
||||
formSettings?.inputsForm,
|
||||
handleRun,
|
||||
notify,
|
||||
t,
|
||||
config?.suggested_questions_after_answer?.enabled,
|
||||
setTargetMessageId,
|
||||
setConversationId,
|
||||
setSuggestedQuestions,
|
||||
setActiveTaskId,
|
||||
setSuggestedQuestionsAbortController,
|
||||
startRun,
|
||||
fetchInspectVars,
|
||||
invalidAllLastRun,
|
||||
workflowStore,
|
||||
])
|
||||
|
||||
return { handleSend }
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import type { ChatTreeUpdater, UpdateCurrentQAParams } from './types'
|
||||
import type { ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
import { produce } from 'immer'
|
||||
import { useCallback } from 'react'
|
||||
|
||||
export function useChatTreeOperations(updateChatTree: ChatTreeUpdater) {
|
||||
const produceChatTreeNode = useCallback(
|
||||
(tree: ChatItemInTree[], targetId: string, operation: (node: ChatItemInTree) => void) => {
|
||||
return produce(tree, (draft) => {
|
||||
const queue: ChatItemInTree[] = [...draft]
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!
|
||||
if (current.id === targetId) {
|
||||
operation(current)
|
||||
break
|
||||
}
|
||||
if (current.children)
|
||||
queue.push(...current.children)
|
||||
}
|
||||
})
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
const updateCurrentQAOnTree = useCallback(({
|
||||
parentId,
|
||||
responseItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
}: UpdateCurrentQAParams) => {
|
||||
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] } as ChatItemInTree
|
||||
updateChatTree((currentChatTree) => {
|
||||
if (!parentId) {
|
||||
const questionIndex = currentChatTree.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
return produce(currentChatTree, (draft) => {
|
||||
if (questionIndex === -1)
|
||||
draft.push(currentQA)
|
||||
else
|
||||
draft[questionIndex] = currentQA
|
||||
})
|
||||
}
|
||||
|
||||
return produceChatTreeNode(currentChatTree, parentId, (parentNode) => {
|
||||
if (!parentNode.children)
|
||||
parentNode.children = []
|
||||
const questionNodeIndex = parentNode.children.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
if (questionNodeIndex === -1)
|
||||
parentNode.children.push(currentQA)
|
||||
else
|
||||
parentNode.children[questionNodeIndex] = currentQA
|
||||
})
|
||||
})
|
||||
}, [produceChatTreeNode, updateChatTree])
|
||||
|
||||
return {
|
||||
produceChatTreeNode,
|
||||
updateCurrentQAOnTree,
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
import type { ChatConfig } from './types'
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useStore } from '../../../store'
|
||||
import { useChatFlowControl } from './use-chat-flow-control'
|
||||
import { useChatList } from './use-chat-list'
|
||||
import { useChatMessageSender } from './use-chat-message-sender'
|
||||
import { useChatTreeOperations } from './use-chat-tree-operations'
|
||||
|
||||
export function useChat(
|
||||
config: ChatConfig | undefined,
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
},
|
||||
prevChatTree?: ChatItemInTree[],
|
||||
stopChat?: (taskId: string) => void,
|
||||
) {
|
||||
const chatTree = useStore(s => s.chatTree)
|
||||
const conversationId = useStore(s => s.conversationId)
|
||||
const isResponding = useStore(s => s.isResponding)
|
||||
const suggestedQuestions = useStore(s => s.suggestedQuestions)
|
||||
const targetMessageId = useStore(s => s.targetMessageId)
|
||||
const updateChatTree = useStore(s => s.updateChatTree)
|
||||
const setTargetMessageId = useStore(s => s.setTargetMessageId)
|
||||
|
||||
const initialChatTreeRef = useRef(prevChatTree)
|
||||
useEffect(() => {
|
||||
const initialChatTree = initialChatTreeRef.current
|
||||
if (!initialChatTree || initialChatTree.length === 0)
|
||||
return
|
||||
updateChatTree(currentChatTree => (currentChatTree.length === 0 ? initialChatTree : currentChatTree))
|
||||
}, [updateChatTree])
|
||||
|
||||
const { updateCurrentQAOnTree } = useChatTreeOperations(updateChatTree)
|
||||
|
||||
const {
|
||||
handleResponding,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
} = useChatFlowControl({
|
||||
stopChat,
|
||||
})
|
||||
|
||||
const {
|
||||
threadMessages,
|
||||
chatList,
|
||||
} = useChatList({
|
||||
chatTree,
|
||||
targetMessageId,
|
||||
config,
|
||||
formSettings,
|
||||
})
|
||||
|
||||
const { handleSend } = useChatMessageSender({
|
||||
threadMessages,
|
||||
config,
|
||||
formSettings,
|
||||
handleResponding,
|
||||
updateCurrentQAOnTree,
|
||||
})
|
||||
|
||||
return {
|
||||
conversationId,
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
handleSend,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
isResponding,
|
||||
suggestedQuestions,
|
||||
}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
import type { UpdateCurrentQAParams } from './types'
|
||||
import type { ChatItem } from '@/app/components/base/chat/types'
|
||||
import type { AgentLogItem, NodeTracing } from '@/types/workflow'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '../../../types'
|
||||
|
||||
type WorkflowEventHandlersContext = {
|
||||
responseItem: ChatItem
|
||||
questionItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
parentMessageId?: string
|
||||
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
|
||||
}
|
||||
|
||||
type TracingData = Partial<NodeTracing> & { id: string }
|
||||
type AgentLogData = Partial<AgentLogItem> & { node_id: string, message_id: string }
|
||||
|
||||
export function createWorkflowEventHandlers(ctx: WorkflowEventHandlersContext) {
|
||||
const { responseItem, questionItem, placeholderQuestionId, parentMessageId, updateCurrentQAOnTree } = ctx
|
||||
|
||||
const updateTree = () => {
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: parentMessageId,
|
||||
})
|
||||
}
|
||||
|
||||
const updateTracingItem = (data: TracingData) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateTree()
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }: { workflow_run_id: string, task_id: string }) => {
|
||||
responseItem.workflow_run_id = workflow_run_id
|
||||
responseItem.workflowProcess = {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
}
|
||||
updateTree()
|
||||
return task_id
|
||||
},
|
||||
|
||||
onWorkflowFinished: ({ data }: { data: { status: string } }) => {
|
||||
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onIterationStart: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onIterationFinish: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onLoopStart: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onLoopFinish: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onNodeStarted: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onNodeRetry: ({ data }: { data: NodeTracing }) => {
|
||||
responseItem.workflowProcess!.tracing!.push(data)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onNodeFinished: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onAgentLog: ({ data }: { data: AgentLogData }) => {
|
||||
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentNodeIndex > -1) {
|
||||
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
|
||||
|
||||
if (current.execution_metadata) {
|
||||
if (current.execution_metadata.agent_log) {
|
||||
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
|
||||
if (currentLogIndex > -1) {
|
||||
current.execution_metadata.agent_log[currentLogIndex] = {
|
||||
...current.execution_metadata.agent_log[currentLogIndex],
|
||||
...data,
|
||||
} as AgentLogItem
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log.push(data as AgentLogItem)
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log = [data as AgentLogItem]
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata = {
|
||||
agent_log: [data as AgentLogItem],
|
||||
} as NodeTracing['execution_metadata']
|
||||
}
|
||||
|
||||
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
|
||||
...current,
|
||||
}
|
||||
|
||||
updateTree()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
import type { StateCreator } from 'zustand'
|
||||
import type { ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
|
||||
type ChatPreviewState = {
|
||||
chatTree: ChatItemInTree[]
|
||||
targetMessageId: string | undefined
|
||||
suggestedQuestions: string[]
|
||||
conversationId: string
|
||||
isResponding: boolean
|
||||
activeRunId: number
|
||||
activeTaskId: string
|
||||
hasStopResponded: boolean
|
||||
suggestedQuestionsAbortController: AbortController | null
|
||||
}
|
||||
|
||||
type ChatPreviewActions = {
|
||||
setChatTree: (chatTree: ChatItemInTree[]) => void
|
||||
updateChatTree: (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void
|
||||
setTargetMessageId: (messageId: string | undefined) => void
|
||||
setSuggestedQuestions: (questions: string[]) => void
|
||||
setConversationId: (conversationId: string) => void
|
||||
setIsResponding: (isResponding: boolean) => void
|
||||
setActiveTaskId: (taskId: string) => void
|
||||
setHasStopResponded: (hasStopResponded: boolean) => void
|
||||
setSuggestedQuestionsAbortController: (controller: AbortController | null) => void
|
||||
startRun: () => number
|
||||
invalidateRun: () => number
|
||||
resetChatPreview: () => void
|
||||
}
|
||||
|
||||
export type ChatPreviewSliceShape = ChatPreviewState & ChatPreviewActions
|
||||
|
||||
const initialState: ChatPreviewState = {
|
||||
chatTree: [],
|
||||
targetMessageId: undefined,
|
||||
suggestedQuestions: [],
|
||||
conversationId: '',
|
||||
isResponding: false,
|
||||
activeRunId: 0,
|
||||
activeTaskId: '',
|
||||
hasStopResponded: false,
|
||||
suggestedQuestionsAbortController: null,
|
||||
}
|
||||
|
||||
export const createChatPreviewSlice: StateCreator<ChatPreviewSliceShape> = (set, get) => ({
|
||||
...initialState,
|
||||
|
||||
setChatTree: chatTree => set({ chatTree }),
|
||||
|
||||
updateChatTree: updater => set((state) => {
|
||||
const nextChatTree = updater(state.chatTree)
|
||||
if (nextChatTree === state.chatTree)
|
||||
return state
|
||||
return { chatTree: nextChatTree }
|
||||
}),
|
||||
|
||||
setTargetMessageId: targetMessageId => set({ targetMessageId }),
|
||||
|
||||
setSuggestedQuestions: suggestedQuestions => set({ suggestedQuestions }),
|
||||
|
||||
setConversationId: conversationId => set({ conversationId }),
|
||||
|
||||
setIsResponding: isResponding => set({ isResponding }),
|
||||
|
||||
setActiveTaskId: activeTaskId => set({ activeTaskId }),
|
||||
|
||||
setHasStopResponded: hasStopResponded => set({ hasStopResponded }),
|
||||
|
||||
setSuggestedQuestionsAbortController: suggestedQuestionsAbortController => set({ suggestedQuestionsAbortController }),
|
||||
|
||||
startRun: () => {
|
||||
const activeRunId = get().activeRunId + 1
|
||||
set({
|
||||
activeRunId,
|
||||
activeTaskId: '',
|
||||
hasStopResponded: false,
|
||||
suggestedQuestionsAbortController: null,
|
||||
})
|
||||
return activeRunId
|
||||
},
|
||||
|
||||
invalidateRun: () => {
|
||||
const activeRunId = get().activeRunId + 1
|
||||
set({
|
||||
activeRunId,
|
||||
activeTaskId: '',
|
||||
suggestedQuestionsAbortController: null,
|
||||
})
|
||||
return activeRunId
|
||||
},
|
||||
|
||||
resetChatPreview: () => set(state => ({
|
||||
...initialState,
|
||||
activeRunId: state.activeRunId + 1,
|
||||
})),
|
||||
})
|
||||
@@ -1,7 +1,6 @@
|
||||
import type {
|
||||
StateCreator,
|
||||
} from 'zustand'
|
||||
import type { ChatPreviewSliceShape } from './chat-preview-slice'
|
||||
import type { ChatVariableSliceShape } from './chat-variable-slice'
|
||||
import type { InspectVarsSliceShape } from './debug/inspect-vars-slice'
|
||||
import type { EnvVariableSliceShape } from './env-variable-slice'
|
||||
@@ -23,7 +22,6 @@ import {
|
||||
} from 'zustand'
|
||||
import { createStore } from 'zustand/vanilla'
|
||||
import { WorkflowContext } from '@/app/components/workflow/context'
|
||||
import { createChatPreviewSlice } from './chat-preview-slice'
|
||||
import { createChatVariableSlice } from './chat-variable-slice'
|
||||
import { createInspectVarsSlice } from './debug/inspect-vars-slice'
|
||||
import { createEnvVariableSlice } from './env-variable-slice'
|
||||
@@ -44,8 +42,7 @@ export type SliceFromInjection
|
||||
& Partial<RagPipelineSliceShape>
|
||||
|
||||
export type Shape
|
||||
= ChatPreviewSliceShape
|
||||
& ChatVariableSliceShape
|
||||
= ChatVariableSliceShape
|
||||
& EnvVariableSliceShape
|
||||
& FormSliceShape
|
||||
& HelpLineSliceShape
|
||||
@@ -70,7 +67,6 @@ export const createWorkflowStore = (params: CreateWorkflowStoreParams) => {
|
||||
const { injectWorkflowStoreSliceFn } = params || {}
|
||||
|
||||
return createStore<Shape>((...args) => ({
|
||||
...createChatPreviewSlice(...args),
|
||||
...createChatVariableSlice(...args),
|
||||
...createEnvVariableSlice(...args),
|
||||
...createFormSlice(...args),
|
||||
|
||||
@@ -822,7 +822,7 @@
|
||||
"count": 2
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 15
|
||||
"count": 14
|
||||
}
|
||||
},
|
||||
"app/components/base/chat/chat/index.tsx": {
|
||||
@@ -3769,6 +3769,11 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/workflow/panel/debug-and-preview/hooks.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 7
|
||||
}
|
||||
},
|
||||
"app/components/workflow/panel/env-panel/variable-modal.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 4
|
||||
|
||||
@@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => {
|
||||
: `${formatted}${units[unitIndex].symbol}`
|
||||
}
|
||||
}
|
||||
// Fallback: if no threshold matched, return the number string
|
||||
return num.toString()
|
||||
}
|
||||
|
||||
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {
|
||||
|
||||
Reference in New Issue
Block a user