Compare commits

..

2 Commits

Author SHA1 Message Date
Yanli 盐粒
8ab15c9ba1 fix(api): release non-stream plugin chunk iterator 2026-01-27 00:56:40 +08:00
wangxiaolei
e48419937b feat: chatflow support multimodal (#31293)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-27 00:24:48 +08:00
28 changed files with 1622 additions and 1019 deletions

View File

@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@@ -1,6 +1,8 @@
import base64
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from mimetypes import guess_extension
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
from core.file.models import File
@@ -203,6 +215,9 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
):
"""
Handle invoke result
@@ -210,21 +225,41 @@ class AppRunner:
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
def _handle_invoke_result_direct(
self,
invoke_result: LLMResult,
queue_manager: AppQueueManager,
):
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
queue_manager.publish(
@@ -235,13 +270,22 @@ class AppRunner:
)
def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
self,
invoke_result: Generator[LLMResultChunk, None, None],
queue_manager: AppQueueManager,
agent: bool,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
model: str = ""
@@ -259,12 +303,26 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, str):
# TODO(QuantumGhost): Add multimodal output support for easy ui.
_logger.warning("received multimodal output, type=%s", type(content))
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content # failback to str
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
@@ -289,6 +347,101 @@ class AppRunner:
PublishFrom.APPLICATION_MANAGER,
)
def _handle_multimodal_image_content(
self,
content: ImagePromptMessageContent,
message_id: str,
user_id: str,
tenant_id: str,
queue_manager: AppQueueManager,
):
"""
Handle multimodal image content from LLM response.
Save the image and create a MessageFile record.
:param content: ImagePromptMessageContent instance
:param message_id: message id
:param user_id: user id
:param tenant_id: tenant id
:param queue_manager: queue manager
:return:
"""
_logger.info("Handling multimodal image content for message %s", message_id)
image_url = content.url
base64_data = content.base64_data
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
if not image_url and not base64_data:
_logger.warning("Image content has neither URL nor base64 data")
return
tool_file_manager = ToolFileManager()
# Save the image file
try:
if image_url:
# Download image from URL
_logger.info("Downloading image from URL: %s", image_url)
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=image_url,
conversation_id=None,
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
elif base64_data:
if base64_data.startswith("data:"):
base64_data = base64_data.split(",", 1)[1]
image_binary = base64.b64decode(base64_data)
mimetype = content.mime_type or "image/png"
extension = guess_extension(mimetype) or ".png"
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=image_binary,
mimetype=mimetype,
filename=f"generated_image{extension}",
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
else:
return
except Exception:
_logger.exception("Failed to save image file")
return
# Create MessageFile record
message_file = MessageFile(
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(
CreatorUserRole.ACCOUNT
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
else CreatorUserRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
# Publish QueueMessageFileEvent
queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file.id),
PublishFrom.APPLICATION_MANAGER,
)
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
def moderation_for_inputs(
self,
*,

View File

@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamEvent,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
event_type=self._precomputed_event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import exists, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
StreamEvent,
WorkflowTaskState,
)
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
@@ -57,13 +58,15 @@ class MessageCycleManager:
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
# Fast path: cached determination from prior QueueMessageFileEvent
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
# Use SQLAlchemy 2.x style session.scalar(select(...))
with session_factory.create_session() as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
if has_file:
if message_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
@@ -199,6 +202,8 @@ class MessageCycleManager:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
self._message_has_file.add(message_file.message_id)
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension

View File

@@ -92,6 +92,10 @@ def _build_llm_result_from_first_chunk(
Build a single `LLMResult` from the first returned chunk.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
@@ -99,18 +103,25 @@ def _build_llm_result_from_first_chunk(
system_fingerprint: str | None = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
try:
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
finally:
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
return LLMResult(
model=model,

View File

@@ -0,0 +1,454 @@
"""Test multimodal image output handling in BaseAppRunner."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueMessageFileEvent
from core.file.enums import FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from models.enums import CreatorUserRole
class TestBaseAppRunnerMultimodal:
"""Test that BaseAppRunner correctly handles multimodal image content."""
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return str(uuid4())
@pytest.fixture
def mock_tenant_id(self):
"""Mock tenant ID."""
return str(uuid4())
@pytest.fixture
def mock_message_id(self):
"""Mock message ID."""
return str(uuid4())
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = MagicMock()
manager.invoke_from = InvokeFrom.SERVICE_API
return manager
@pytest.fixture
def mock_tool_file(self):
"""Create a mock tool file."""
tool_file = MagicMock()
tool_file.id = str(uuid4())
return tool_file
@pytest.fixture
def mock_message_file(self):
"""Create a mock message file."""
message_file = MagicMock()
message_file.id = str(uuid4())
return message_file
def test_handle_multimodal_image_content_with_url(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from URL."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from URL
mock_mgr.create_file_by_url.assert_called_once_with(
user_id=mock_user_id,
tenant_id=mock_tenant_id,
file_url=image_url,
conversation_id=None,
)
# Verify message file was created with correct parameters
mock_msg_file_class.assert_called_once()
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["message_id"] == mock_message_id
assert call_kwargs["type"] == FileType.IMAGE
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
assert call_kwargs["belongs_to"] == "assistant"
assert call_kwargs["created_by"] == mock_user_id
# Verify database operations
mock_session.add.assert_called_once_with(mock_message_file)
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once_with(mock_message_file)
# Verify event was published
mock_queue_manager.publish.assert_called_once()
publish_call = mock_queue_manager.publish.call_args
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
assert publish_call[0][0].message_file_id == mock_message_file.id
# publish_from might be passed as positional or keyword argument
assert (
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
)
def test_handle_multimodal_image_content_with_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data."""
# Arrange
import base64
# Create a small test image (1x1 PNG)
test_image_data = base64.b64encode(
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
).decode()
content = ImagePromptMessageContent(
base64_data=test_image_data,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from base64
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
assert call_kwargs["user_id"] == mock_user_id
assert call_kwargs["tenant_id"] == mock_tenant_id
assert call_kwargs["conversation_id"] is None
assert "file_binary" in call_kwargs
assert call_kwargs["mimetype"] == "image/png"
assert call_kwargs["filename"].startswith("generated_image")
assert call_kwargs["filename"].endswith(".png")
# Verify message file was created
mock_msg_file_class.assert_called_once()
# Verify database operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify event was published
mock_queue_manager.publish.assert_called_once()
def test_handle_multimodal_image_content_with_base64_data_uri(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data with URI prefix."""
# Arrange
# Data URI format: data:image/png;base64,<base64_data>
test_image_data = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
)
content = ImagePromptMessageContent(
base64_data=f"data:image/png;base64,{test_image_data}",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify that base64 data was extracted correctly (without prefix)
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
# The base64 data should be decoded, so we check the binary was passed
assert "file_binary" in call_kwargs
def test_handle_multimodal_image_content_without_url_or_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content without URL or base64 data."""
# Arrange
content = ImagePromptMessageContent(
url="",
base64_data="",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create any files or publish events
mock_mgr_class.assert_not_called()
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_with_error(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content when an error occurs."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock to raise exception
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
# Should not raise exception, just log it
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create message file or publish event on error
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_debugger_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that debugger mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is ACCOUNT for debugger mode
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
def test_handle_multimodal_image_content_service_api_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that service API mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is END_USER for service API
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER

View File

@@ -1,7 +1,6 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
from unittest.mock import Mock, patch
import pytest
from flask import current_app
@@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
@@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = None
# Execute
with current_app.app_context():
@@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
@@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
@@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
# Should not open a session when event_type is provided
mock_session_factory.create_session.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
@@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
# Should open session once
mock_session_factory.create_session.assert_called_once()
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
@@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
# Should not open session again when event_type provided
mock_session_factory.create_session.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE

View File

@@ -101,3 +101,26 @@ def test__normalize_non_stream_plugin_result__empty_iterator_defaults():
assert result.message.tool_calls == []
assert result.usage == LLMUsage.empty_usage()
assert result.system_fingerprint is None
def test__normalize_non_stream_plugin_result__closes_chunk_iterator():
prompt_messages = [UserPromptMessage(content="hi")]
chunk = _make_chunk(content="hello", usage=LLMUsage.empty_usage())
closed: list[bool] = []
def _chunk_iter():
try:
yield chunk
yield _make_chunk(content="ignored", usage=LLMUsage.empty_usage())
finally:
closed.append(True)
result = _normalize_non_stream_plugin_result(
model="test-model",
prompt_messages=prompt_messages,
result=_chunk_iter(),
)
assert result.message.content == "hello"
assert closed == [True]

View File

@@ -0,0 +1,178 @@
/**
* Tests for multimodal image file handling in chat hooks.
* Tests the file object conversion logic without full hook integration.
*/
describe('Multimodal File Handling', () => {
describe('File type to MIME type mapping', () => {
it('should map image to image/png', () => {
const fileType: string = 'image'
const expectedMime = 'image/png'
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map video to video/mp4', () => {
const fileType: string = 'video'
const expectedMime = 'video/mp4'
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map audio to audio/mpeg', () => {
const fileType: string = 'audio'
const expectedMime = 'audio/mpeg'
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map unknown to application/octet-stream', () => {
const fileType: string = 'unknown'
const expectedMime = 'application/octet-stream'
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
})
describe('TransferMethod selection', () => {
it('should select remote_url for images', () => {
const fileType: string = 'image'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('remote_url')
})
it('should select local_file for non-images', () => {
const fileType: string = 'video'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('local_file')
})
})
describe('File extension mapping', () => {
it('should use .png extension for images', () => {
const fileType: string = 'image'
const expectedExtension = '.png'
const extension = fileType === 'image' ? 'png' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp4 extension for videos', () => {
const fileType: string = 'video'
const expectedExtension = '.mp4'
const extension = fileType === 'video' ? 'mp4' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp3 extension for audio', () => {
const fileType: string = 'audio'
const expectedExtension = '.mp3'
const extension = fileType === 'audio' ? 'mp3' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
})
describe('File name generation', () => {
it('should generate correct file name for images', () => {
const fileType: string = 'image'
const expectedName = 'generated_image.png'
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for videos', () => {
const fileType: string = 'video'
const expectedName = 'generated_video.mp4'
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for audio', () => {
const fileType: string = 'audio'
const expectedName = 'generated_audio.mp3'
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
expect(fileName).toBe(expectedName)
})
})
describe('SupportFileType mapping', () => {
it('should map image type to image supportFileType', () => {
const fileType: string = 'image'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('image')
})
it('should map video type to video supportFileType', () => {
const fileType: string = 'video'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('video')
})
it('should map audio type to audio supportFileType', () => {
const fileType: string = 'audio'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('audio')
})
it('should map unknown type to document supportFileType', () => {
const fileType: string = 'unknown'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('document')
})
})
describe('File conversion logic', () => {
it('should detect existing transferMethod', () => {
const fileWithTransferMethod = {
id: 'file-123',
transferMethod: 'remote_url' as const,
type: 'image/png',
name: 'test.png',
size: 1024,
supportFileType: 'image',
progress: 100,
}
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
expect(hasTransferMethod).toBe(true)
})
it('should detect missing transferMethod', () => {
const fileWithoutTransferMethod = {
id: 'file-456',
type: 'image',
url: 'http://example.com/image.png',
belongs_to: 'assistant',
}
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
expect(hasTransferMethod).toBe(false)
})
it('should create file with size 0 for generated files', () => {
const expectedSize = 0
expect(expectedSize).toBe(0)
})
})
describe('Agent vs Non-Agent mode logic', () => {
it('should check for agent_thoughts to determine mode', () => {
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [{}],
}
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(true)
})
it('should detect non-agent mode when agent_thoughts is empty', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [],
}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(false)
})
it('should detect non-agent mode when agent_thoughts is undefined', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBeFalsy()
})
})
})

View File

@@ -419,9 +419,40 @@ export const useChat = (
}
},
onFile(file) {
// Convert simple file type to MIME type for non-agent mode
// Backend sends: { id, type: "image", belongs_to, url }
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
// Determine file type for MIME conversion
const fileType = (file as { type?: string }).type || 'image'
// If file already has transferMethod, use it as base and ensure all required fields exist
// Otherwise, create a new complete file object
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
const convertedFile: FileEntity = {
id: baseFile?.id || (file as { id: string }).id,
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
progress: baseFile?.progress ?? 100,
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
url: baseFile?.url || (file as { url?: string }).url,
size: baseFile?.size ?? 0, // Generated files don't have a known size
}
// For agent mode, add files to the last thought
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
if (lastThought)
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
if (lastThought) {
const thought = lastThought as { message_files?: FileEntity[] }
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
}
// For non-agent mode, add files directly to responseItem.message_files
else {
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
responseItem.message_files = [...currentFiles, convertedFile]
}
updateCurrentQAOnTree({
placeholderQuestionId,

View File

@@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => {
renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } })
// Find submit button by class
@@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } })
// Component should still be functional - check for the main container
@@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => {
isLoading: false,
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox to be rendered with timeout for CI environment
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } })
// Submit
@@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton)
fireEvent.click(submitButton)
// Verify the component is still rendered after submission
expect(container.firstChild).toBeInTheDocument()
// Wait for the mutation to complete
await waitFor(
() => {
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
},
{ timeout: 3000 },
)
})
it('should render ResultItem components for non-external results', async () => {
@@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => {
isLoading: false,
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for component to be fully rendered with longer timeout
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Submit a query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } })
const buttons = screen.getAllByRole('button')
@@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton)
fireEvent.click(submitButton)
// Verify component is rendered after submission
expect(container.firstChild).toBeInTheDocument()
// Wait for mutation to complete with longer timeout
await waitFor(
() => {
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
},
{ timeout: 3000 },
)
})
it('should render external results when dataset is external', async () => {
@@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => {
// Component should render
expect(container.firstChild).toBeInTheDocument()
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Type in textarea to verify component is functional
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'Test query' } })
const buttons = screen.getAllByRole('button')
@@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => {
if (submitButton)
fireEvent.click(submitButton)
await waitFor(() => {
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
// Verify component is still functional after submission
await waitFor(
() => {
expect(screen.getByRole('textbox')).toBeInTheDocument()
},
{ timeout: 3000 },
)
})
})
@@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Enter query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test query' } })
// Submit
@@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Enter query and submit
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test query' } })
const buttons = screen.getAllByRole('button')
@@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
// Wait for state updates
await waitFor(() => {
expect(container.firstChild).toBeInTheDocument()
}, { timeout: 2000 })
}, { timeout: 3000 })
// Verify mutation was called
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
@@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
// Wait for textbox with timeout for CI
const textarea = await waitFor(
() => screen.getByRole('textbox'),
{ timeout: 3000 },
)
// Submit a query
const textarea = screen.getByRole('textbox')
fireEvent.change(textarea, { target: { value: 'test' } })
const buttons = screen.getAllByRole('button')
@@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
// Verify the component renders
await waitFor(() => {
expect(container.firstChild).toBeInTheDocument()
})
}, { timeout: 3000 })
})
})

View File

@@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
}))
// Mock marketplace client used by marketplace utils
vi.mock('@/service/client', () => ({
marketplaceClient: {
collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
collections: [
{
name: 'collection-1',
label: { 'en-US': 'Collection 1' },
description: { 'en-US': 'Desc' },
rule: '',
created_at: '2024-01-01',
updated_at: '2024-01-01',
searchable: true,
search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' },
},
],
},
})),
collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
plugins: [
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
],
},
})),
// Some utils paths may call searchAdvanced; provide a minimal stub
searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
data: {
plugins: [
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
],
total: 1,
},
})),
},
}))
// Mock context/query-client
vi.mock('@/context/query-client', () => ({
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
@@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => {
// ================================
// Async Utils Tests
// ================================
// Narrow mock surface and avoid any in tests
// Types are local to this spec to keep scope minimal
type FnMock = ReturnType<typeof vi.fn>
type MarketplaceClientMock = {
collectionPlugins: FnMock
collections: FnMock
}
describe('Async Utils', () => {
let marketplaceClientMock: MarketplaceClientMock
beforeAll(async () => {
const mod = await import('@/service/client')
marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock
})
beforeEach(() => {
vi.clearAllMocks()
})
@@ -1490,12 +1545,10 @@ describe('Async Utils', () => {
{ type: 'plugin', org: 'test', name: 'plugin2' },
]
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
// Adjusted to our mocked marketplaceClient instead of fetch
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
data: { plugins: mockPlugins },
})
const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection', {
@@ -1504,12 +1557,13 @@ describe('Async Utils', () => {
type: 'plugin',
})
expect(globalThis.fetch).toHaveBeenCalled()
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
expect(result).toHaveLength(2)
})
it('should handle fetch error and return empty array', async () => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
// Simulate error from client
marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error'))
const { getMarketplacePluginsByCollectionId } = await import('./utils')
const result = await getMarketplacePluginsByCollectionId('test-collection')
@@ -1519,25 +1573,18 @@ describe('Async Utils', () => {
it('should pass abort signal when provided', async () => {
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
// Our client mock receives the signal as second arg
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
data: { plugins: mockPlugins },
})
const controller = new AbortController()
const { getMarketplacePluginsByCollectionId } = await import('./utils')
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalledWith(
expect.any(Request),
expect.any(Object),
)
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('test-collection')
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
const call = marketplaceClientMock.collectionPlugins.mock.calls[0]
expect(call[1]).toMatchObject({ signal: controller.signal })
})
})
@@ -1548,23 +1595,17 @@ describe('Async Utils', () => {
]
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
let callCount = 0
globalThis.fetch = vi.fn().mockImplementation(() => {
callCount++
if (callCount === 1) {
return Promise.resolve(
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
// Simulate two-step client calls: collections then collectionPlugins
let stage = 0
marketplaceClientMock.collections.mockImplementationOnce(async () => {
stage = 1
return { data: { collections: mockCollections } }
})
marketplaceClientMock.collectionPlugins.mockImplementation(async () => {
if (stage === 1) {
return { data: { plugins: mockPlugins } }
}
return Promise.resolve(
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
return { data: { plugins: [] } }
})
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
@@ -1578,7 +1619,8 @@ describe('Async Utils', () => {
})
it('should handle fetch error and return empty data', async () => {
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
// Simulate client error
marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error'))
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
const result = await getMarketplaceCollectionsAndPlugins()
@@ -1588,24 +1630,16 @@ describe('Async Utils', () => {
})
it('should append condition and type to URL when provided', async () => {
globalThis.fetch = vi.fn().mockResolvedValue(
new Response(JSON.stringify({ data: { collections: [] } }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
// Assert that the client was called with query containing condition/type
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
await getMarketplaceCollectionsAndPlugins({
condition: 'category=tool',
type: 'bundle',
})
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
expect(globalThis.fetch).toHaveBeenCalled()
const call = vi.mocked(globalThis.fetch).mock.calls[0]
const request = call[0] as Request
expect(request.url).toContain('condition=category%3Dtool')
expect(marketplaceClientMock.collections).toHaveBeenCalled()
const call = marketplaceClientMock.collections.mock.calls[0]
expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) })
})
})
})

View File

@@ -18,7 +18,7 @@ import {
useWorkflowReadOnly,
} from '../hooks'
import { useStore, useWorkflowStore } from '../store'
import { BlockEnum, ControlMode, WorkflowRunningStatus } from '../types'
import { BlockEnum, ControlMode } from '../types'
import {
getLayoutByDagre,
getLayoutForChildNodes,
@@ -36,17 +36,12 @@ export const useWorkflowInteractions = () => {
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
const handleCancelDebugAndPreviewPanel = useCallback(() => {
const { workflowRunningData } = workflowStore.getState()
const runningStatus = workflowRunningData?.result?.status
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
workflowStore.setState({
showDebugAndPreviewPanel: false,
...(isActiveRun ? {} : { workflowRunningData: undefined }),
workflowRunningData: undefined,
})
if (!isActiveRun) {
handleNodeCancelRunningStatus()
handleEdgeCancelRunningStatus()
}
handleNodeCancelRunningStatus()
handleEdgeCancelRunningStatus()
}, [workflowStore, handleNodeCancelRunningStatus, handleEdgeCancelRunningStatus])
return {

View File

@@ -21,7 +21,7 @@ import {
import { BlockEnum } from '../../types'
import ConversationVariableModal from './conversation-variable-modal'
import Empty from './empty'
import { useChat } from './hooks/use-chat'
import { useChat } from './hooks'
import UserInput from './user-input'
type ChatWrapperProps = {

View File

@@ -0,0 +1,516 @@
import type { InputForm } from '@/app/components/base/chat/chat/type'
import type {
ChatItem,
ChatItemInTree,
Inputs,
} from '@/app/components/base/chat/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import { uniqBy } from 'es-toolkit/compat'
import { produce, setAutoFreeze } from 'immer'
import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import {
getProcessedInputs,
processOpeningStatement,
} from '@/app/components/base/chat/chat/utils'
import { getThreadMessages } from '@/app/components/base/chat/utils'
import {
getProcessedFiles,
getProcessedFilesFromResponse,
} from '@/app/components/base/file-uploader/utils'
import { useToastContext } from '@/app/components/base/toast'
import { useInvalidAllLastRun } from '@/service/use-workflow'
import { TransferMethod } from '@/types/app'
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../constants'
import {
useSetWorkflowVarsWithValue,
useWorkflowRun,
} from '../../hooks'
import { useHooksStore } from '../../hooks-store'
import { useWorkflowStore } from '../../store'
import { NodeRunningStatus, WorkflowRunningStatus } from '../../types'
type GetAbortController = (abortController: AbortController) => void
type SendCallback = {
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<any>
}
export const useChat = (
config: any,
formSettings?: {
inputs: Inputs
inputsForm: InputForm[]
},
prevChatTree?: ChatItemInTree[],
stopChat?: (taskId: string) => void,
) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const { handleRun } = useWorkflowRun()
const hasStopResponded = useRef(false)
const workflowStore = useWorkflowStore()
const conversationId = useRef('')
const taskIdRef = useRef('')
const [isResponding, setIsResponding] = useState(false)
const isRespondingRef = useRef(false)
const configsMap = useHooksStore(s => s.configsMap)
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([])
const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null)
const {
setIterTimes,
setLoopTimes,
} = workflowStore.getState()
const handleResponding = useCallback((isResponding: boolean) => {
setIsResponding(isResponding)
isRespondingRef.current = isResponding
}, [])
const [chatTree, setChatTree] = useState<ChatItemInTree[]>(prevChatTree || [])
const chatTreeRef = useRef<ChatItemInTree[]>(chatTree)
const [targetMessageId, setTargetMessageId] = useState<string>()
const threadMessages = useMemo(() => getThreadMessages(chatTree, targetMessageId), [chatTree, targetMessageId])
const getIntroduction = useCallback((str: string) => {
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
}, [formSettings?.inputs, formSettings?.inputsForm])
/** Final chat list that will be rendered */
const chatList = useMemo(() => {
const ret = [...threadMessages]
if (config?.opening_statement) {
const index = threadMessages.findIndex(item => item.isOpeningStatement)
if (index > -1) {
ret[index] = {
...ret[index],
content: getIntroduction(config.opening_statement),
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
}
}
else {
ret.unshift({
id: `${Date.now()}`,
content: getIntroduction(config.opening_statement),
isAnswer: true,
isOpeningStatement: true,
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
})
}
}
return ret
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
useEffect(() => {
setAutoFreeze(false)
return () => {
setAutoFreeze(true)
}
}, [])
/** Find the target node by bfs and then operate on it */
const produceChatTreeNode = useCallback((targetId: string, operation: (node: ChatItemInTree) => void) => {
return produce(chatTreeRef.current, (draft) => {
const queue: ChatItemInTree[] = [...draft]
while (queue.length > 0) {
const current = queue.shift()!
if (current.id === targetId) {
operation(current)
break
}
if (current.children)
queue.push(...current.children)
}
})
}, [])
const handleStop = useCallback(() => {
hasStopResponded.current = true
handleResponding(false)
if (stopChat && taskIdRef.current)
stopChat(taskIdRef.current)
setIterTimes(DEFAULT_ITER_TIMES)
setLoopTimes(DEFAULT_LOOP_TIMES)
if (suggestedQuestionsAbortControllerRef.current)
suggestedQuestionsAbortControllerRef.current.abort()
}, [handleResponding, setIterTimes, setLoopTimes, stopChat])
const handleRestart = useCallback(() => {
conversationId.current = ''
taskIdRef.current = ''
handleStop()
setIterTimes(DEFAULT_ITER_TIMES)
setLoopTimes(DEFAULT_LOOP_TIMES)
setChatTree([])
setSuggestQuestions([])
}, [
handleStop,
setIterTimes,
setLoopTimes,
])
const updateCurrentQAOnTree = useCallback(({
parentId,
responseItem,
placeholderQuestionId,
questionItem,
}: {
parentId?: string
responseItem: ChatItem
placeholderQuestionId: string
questionItem: ChatItem
}) => {
let nextState: ChatItemInTree[]
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] }
if (!parentId && !chatTree.some(item => [placeholderQuestionId, questionItem.id].includes(item.id))) {
// QA whose parent is not provided is considered as a first message of the conversation,
// and it should be a root node of the chat tree
nextState = produce(chatTree, (draft) => {
draft.push(currentQA)
})
}
else {
// find the target QA in the tree and update it; if not found, insert it to its parent node
nextState = produceChatTreeNode(parentId!, (parentNode) => {
const questionNodeIndex = parentNode.children!.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
if (questionNodeIndex === -1)
parentNode.children!.push(currentQA)
else
parentNode.children![questionNodeIndex] = currentQA
})
}
setChatTree(nextState)
chatTreeRef.current = nextState
}, [chatTree, produceChatTreeNode])
const handleSend = useCallback((
params: {
query: string
files?: FileEntity[]
parent_message_id?: string
[key: string]: any
},
{
onGetSuggestedQuestions,
}: SendCallback,
) => {
if (isRespondingRef.current) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return false
}
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
const placeholderQuestionId = `question-${Date.now()}`
const questionItem = {
id: placeholderQuestionId,
content: params.query,
isAnswer: false,
message_files: params.files,
parentMessageId: params.parent_message_id,
}
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
const placeholderAnswerItem = {
id: placeholderAnswerId,
content: '',
isAnswer: true,
parentMessageId: questionItem.id,
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
}
setTargetMessageId(parentMessage?.id)
updateCurrentQAOnTree({
parentId: params.parent_message_id,
responseItem: placeholderAnswerItem,
placeholderQuestionId,
questionItem,
})
// answer
const responseItem: ChatItem = {
id: placeholderAnswerId,
content: '',
agent_thoughts: [],
message_files: [],
isAnswer: true,
parentMessageId: questionItem.id,
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
}
handleResponding(true)
const { files, inputs, ...restParams } = params
const bodyParams = {
files: getProcessedFiles(files || []),
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
...restParams,
}
if (bodyParams?.files?.length) {
bodyParams.files = bodyParams.files.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
let hasSetResponseId = false
handleRun(
bodyParams,
{
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
responseItem.content = responseItem.content + message
if (messageId && !hasSetResponseId) {
questionItem.id = `question-${messageId}`
responseItem.id = messageId
responseItem.parentMessageId = questionItem.id
hasSetResponseId = true
}
if (isFirstMessage && newConversationId)
conversationId.current = newConversationId
taskIdRef.current = taskId
if (messageId)
responseItem.id = messageId
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
async onCompleted(hasError?: boolean, errorMessage?: string) {
handleResponding(false)
fetchInspectVars({})
invalidAllLastRun()
if (hasError) {
if (errorMessage) {
responseItem.content = errorMessage
responseItem.isError = true
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
return
}
if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) {
try {
const { data }: any = await onGetSuggestedQuestions(
responseItem.id,
newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController,
)
setSuggestQuestions(data)
}
// eslint-disable-next-line unused-imports/no-unused-vars
catch (error) {
setSuggestQuestions([])
}
}
},
onMessageEnd: (messageEnd) => {
responseItem.citation = messageEnd.metadata?.retriever_resources || []
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onMessageReplace: (messageReplace) => {
responseItem.content = messageReplace.answer
},
onError() {
handleResponding(false)
},
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
taskIdRef.current = task_id
responseItem.workflow_run_id = workflow_run_id
responseItem.workflowProcess = {
status: WorkflowRunningStatus.Running,
tracing: [],
}
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onWorkflowFinished: ({ data }) => {
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onIterationStart: ({ data }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
})
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onIterationFinish: ({ data }) => {
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
if (currentTracingIndex > -1) {
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
...responseItem.workflowProcess!.tracing[currentTracingIndex],
...data,
}
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
},
onLoopStart: ({ data }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
})
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onLoopFinish: ({ data }) => {
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
if (currentTracingIndex > -1) {
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
...responseItem.workflowProcess!.tracing[currentTracingIndex],
...data,
}
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
},
onNodeStarted: ({ data }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
} as any)
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onNodeRetry: ({ data }) => {
responseItem.workflowProcess!.tracing!.push(data)
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onNodeFinished: ({ data }) => {
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
if (currentTracingIndex > -1) {
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
...responseItem.workflowProcess!.tracing[currentTracingIndex],
...data,
}
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
},
onAgentLog: ({ data }) => {
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
if (currentNodeIndex > -1) {
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
if (current.execution_metadata) {
if (current.execution_metadata.agent_log) {
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
if (currentLogIndex > -1) {
current.execution_metadata.agent_log[currentLogIndex] = {
...current.execution_metadata.agent_log[currentLogIndex],
...data,
}
}
else {
current.execution_metadata.agent_log.push(data)
}
}
else {
current.execution_metadata.agent_log = [data]
}
}
else {
current.execution_metadata = {
agent_log: [data],
} as any
}
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
...current,
}
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
},
},
)
}, [threadMessages, chatTree.length, updateCurrentQAOnTree, handleResponding, formSettings?.inputsForm, handleRun, notify, t, config?.suggested_questions_after_answer?.enabled, fetchInspectVars, invalidAllLastRun])
return {
conversationId: conversationId.current,
chatList,
setTargetMessageId,
handleSend,
handleStop,
handleRestart,
isResponding,
suggestedQuestions,
}
}

View File

@@ -1,38 +0,0 @@
import type { ChatItem, ChatItemInTree } from '@/app/components/base/chat/types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
export type ChatConfig = {
opening_statement?: string
suggested_questions?: string[]
suggested_questions_after_answer?: {
enabled?: boolean
}
text_to_speech?: unknown
speech_to_text?: unknown
retriever_resource?: unknown
sensitive_word_avoidance?: unknown
file_upload?: unknown
}
export type GetAbortController = (abortController: AbortController) => void
export type SendCallback = {
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<unknown>
}
export type SendParams = {
query: string
files?: FileEntity[]
parent_message_id?: string
inputs?: Record<string, unknown>
conversation_id?: string
}
export type UpdateCurrentQAParams = {
parentId?: string
responseItem: ChatItem
placeholderQuestionId: string
questionItem: ChatItem
}
export type ChatTreeUpdater = (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void

View File

@@ -1,90 +0,0 @@
import { useCallback } from 'react'
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../../constants'
import { useEdgesInteractionsWithoutSync } from '../../../hooks/use-edges-interactions-without-sync'
import { useNodesInteractionsWithoutSync } from '../../../hooks/use-nodes-interactions-without-sync'
import { useStore, useWorkflowStore } from '../../../store'
import { WorkflowRunningStatus } from '../../../types'
type UseChatFlowControlParams = {
stopChat?: (taskId: string) => void
}
export function useChatFlowControl({
stopChat,
}: UseChatFlowControlParams) {
const workflowStore = useWorkflowStore()
const setIsResponding = useStore(s => s.setIsResponding)
const resetChatPreview = useStore(s => s.resetChatPreview)
const setActiveTaskId = useStore(s => s.setActiveTaskId)
const setHasStopResponded = useStore(s => s.setHasStopResponded)
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
const invalidateRun = useStore(s => s.invalidateRun)
const { handleNodeCancelRunningStatus } = useNodesInteractionsWithoutSync()
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
const { setIterTimes, setLoopTimes } = workflowStore.getState()
const handleResponding = useCallback((responding: boolean) => {
setIsResponding(responding)
}, [setIsResponding])
const handleStop = useCallback(() => {
const {
activeTaskId,
suggestedQuestionsAbortController,
workflowRunningData,
setWorkflowRunningData,
} = workflowStore.getState()
const runningStatus = workflowRunningData?.result?.status
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
setHasStopResponded(true)
handleResponding(false)
if (stopChat && activeTaskId)
stopChat(activeTaskId)
setIterTimes(DEFAULT_ITER_TIMES)
setLoopTimes(DEFAULT_LOOP_TIMES)
if (suggestedQuestionsAbortController)
suggestedQuestionsAbortController.abort()
setSuggestedQuestionsAbortController(null)
setActiveTaskId('')
invalidateRun()
if (isActiveRun && workflowRunningData) {
setWorkflowRunningData({
...workflowRunningData,
result: {
...workflowRunningData.result,
status: WorkflowRunningStatus.Stopped,
},
})
}
if (isActiveRun) {
handleNodeCancelRunningStatus()
handleEdgeCancelRunningStatus()
}
}, [
handleResponding,
setIterTimes,
setLoopTimes,
stopChat,
workflowStore,
setHasStopResponded,
setSuggestedQuestionsAbortController,
setActiveTaskId,
invalidateRun,
handleNodeCancelRunningStatus,
handleEdgeCancelRunningStatus,
])
const handleRestart = useCallback(() => {
handleStop()
resetChatPreview()
setIterTimes(DEFAULT_ITER_TIMES)
setLoopTimes(DEFAULT_LOOP_TIMES)
}, [handleStop, setIterTimes, setLoopTimes, resetChatPreview])
return {
handleResponding,
handleStop,
handleRestart,
}
}

View File

@@ -1,65 +0,0 @@
import type { InputForm } from '@/app/components/base/chat/chat/type'
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
import { useCallback, useMemo } from 'react'
import { processOpeningStatement } from '@/app/components/base/chat/chat/utils'
import { getThreadMessages } from '@/app/components/base/chat/utils'
type UseChatListParams = {
chatTree: ChatItemInTree[]
targetMessageId: string | undefined
config: {
opening_statement?: string
suggested_questions?: string[]
} | undefined
formSettings?: {
inputs: Inputs
inputsForm: InputForm[]
}
}
export function useChatList({
chatTree,
targetMessageId,
config,
formSettings,
}: UseChatListParams) {
const threadMessages = useMemo(
() => getThreadMessages(chatTree, targetMessageId),
[chatTree, targetMessageId],
)
const getIntroduction = useCallback((str: string) => {
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
}, [formSettings?.inputs, formSettings?.inputsForm])
const chatList = useMemo(() => {
const ret = [...threadMessages]
if (config?.opening_statement) {
const index = threadMessages.findIndex(item => item.isOpeningStatement)
if (index > -1) {
ret[index] = {
...ret[index],
content: getIntroduction(config.opening_statement),
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
}
}
else {
ret.unshift({
id: `${Date.now()}`,
content: getIntroduction(config.opening_statement),
isAnswer: true,
isOpeningStatement: true,
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
})
}
}
return ret
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
return {
threadMessages,
chatList,
getIntroduction,
}
}

View File

@@ -1,306 +0,0 @@
import type { SendCallback, SendParams, UpdateCurrentQAParams } from './types'
import type { InputForm } from '@/app/components/base/chat/chat/type'
import type { ChatItem, ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
import { uniqBy } from 'es-toolkit/compat'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { getProcessedInputs } from '@/app/components/base/chat/chat/utils'
import { getProcessedFiles, getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import { useToastContext } from '@/app/components/base/toast'
import { useInvalidAllLastRun } from '@/service/use-workflow'
import { TransferMethod } from '@/types/app'
import { useSetWorkflowVarsWithValue, useWorkflowRun } from '../../../hooks'
import { useHooksStore } from '../../../hooks-store'
import { useStore, useWorkflowStore } from '../../../store'
import { createWorkflowEventHandlers } from './use-workflow-event-handlers'
type UseChatMessageSenderParams = {
threadMessages: ChatItemInTree[]
config?: {
suggested_questions_after_answer?: {
enabled?: boolean
}
}
formSettings?: {
inputs: Inputs
inputsForm: InputForm[]
}
handleResponding: (responding: boolean) => void
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
}
export function useChatMessageSender({
threadMessages,
config,
formSettings,
handleResponding,
updateCurrentQAOnTree,
}: UseChatMessageSenderParams) {
const { t } = useTranslation()
const { notify } = useToastContext()
const { handleRun } = useWorkflowRun()
const workflowStore = useWorkflowStore()
const configsMap = useHooksStore(s => s.configsMap)
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
const setConversationId = useStore(s => s.setConversationId)
const setTargetMessageId = useStore(s => s.setTargetMessageId)
const setSuggestedQuestions = useStore(s => s.setSuggestedQuestions)
const setActiveTaskId = useStore(s => s.setActiveTaskId)
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
const startRun = useStore(s => s.startRun)
const handleSend = useCallback((
params: SendParams,
{ onGetSuggestedQuestions }: SendCallback,
) => {
if (workflowStore.getState().isResponding) {
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
return false
}
const { suggestedQuestionsAbortController } = workflowStore.getState()
if (suggestedQuestionsAbortController)
suggestedQuestionsAbortController.abort()
setSuggestedQuestionsAbortController(null)
const runId = startRun()
const isCurrentRun = () => runId === workflowStore.getState().activeRunId
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
const placeholderQuestionId = `question-${Date.now()}`
const questionItem: ChatItem = {
id: placeholderQuestionId,
content: params.query,
isAnswer: false,
message_files: params.files,
parentMessageId: params.parent_message_id,
}
const siblingIndex = parentMessage?.children?.length ?? workflowStore.getState().chatTree.length
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
const placeholderAnswerItem: ChatItem = {
id: placeholderAnswerId,
content: '',
isAnswer: true,
parentMessageId: questionItem.id,
siblingIndex,
}
setTargetMessageId(parentMessage?.id)
updateCurrentQAOnTree({
parentId: params.parent_message_id,
responseItem: placeholderAnswerItem,
placeholderQuestionId,
questionItem,
})
const responseItem: ChatItem = {
id: placeholderAnswerId,
content: '',
agent_thoughts: [],
message_files: [],
isAnswer: true,
parentMessageId: questionItem.id,
siblingIndex,
}
handleResponding(true)
const { files, inputs, ...restParams } = params
const bodyParams = {
files: getProcessedFiles(files || []),
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
...restParams,
}
if (bodyParams?.files?.length) {
bodyParams.files = bodyParams.files.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
let hasSetResponseId = false
const workflowHandlers = createWorkflowEventHandlers({
responseItem,
questionItem,
placeholderQuestionId,
parentMessageId: params.parent_message_id,
updateCurrentQAOnTree,
})
handleRun(
bodyParams,
{
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }) => {
if (!isCurrentRun())
return
responseItem.content = responseItem.content + message
if (messageId && !hasSetResponseId) {
questionItem.id = `question-${messageId}`
responseItem.id = messageId
responseItem.parentMessageId = questionItem.id
hasSetResponseId = true
}
if (isFirstMessage && newConversationId)
setConversationId(newConversationId)
if (taskId)
setActiveTaskId(taskId)
if (messageId)
responseItem.id = messageId
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
async onCompleted(hasError?: boolean, errorMessage?: string) {
if (!isCurrentRun())
return
handleResponding(false)
fetchInspectVars({})
invalidAllLastRun()
if (hasError) {
if (errorMessage) {
responseItem.content = errorMessage
responseItem.isError = true
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
}
return
}
if (config?.suggested_questions_after_answer?.enabled && !workflowStore.getState().hasStopResponded && onGetSuggestedQuestions) {
try {
const result = await onGetSuggestedQuestions(
responseItem.id,
newAbortController => setSuggestedQuestionsAbortController(newAbortController),
) as { data: string[] }
setSuggestedQuestions(result.data)
}
catch {
setSuggestedQuestions([])
}
finally {
setSuggestedQuestionsAbortController(null)
}
}
},
onMessageEnd: (messageEnd) => {
if (!isCurrentRun())
return
responseItem.citation = messageEnd.metadata?.retriever_resources || []
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: params.parent_message_id,
})
},
onMessageReplace: (messageReplace) => {
if (!isCurrentRun())
return
responseItem.content = messageReplace.answer
},
onError() {
if (!isCurrentRun())
return
handleResponding(false)
},
onWorkflowStarted: (event) => {
if (!isCurrentRun())
return
const taskId = workflowHandlers.onWorkflowStarted(event)
if (taskId)
setActiveTaskId(taskId)
},
onWorkflowFinished: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onWorkflowFinished(event)
},
onIterationStart: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onIterationStart(event)
},
onIterationFinish: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onIterationFinish(event)
},
onLoopStart: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onLoopStart(event)
},
onLoopFinish: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onLoopFinish(event)
},
onNodeStarted: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onNodeStarted(event)
},
onNodeRetry: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onNodeRetry(event)
},
onNodeFinished: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onNodeFinished(event)
},
onAgentLog: (event) => {
if (!isCurrentRun())
return
workflowHandlers.onAgentLog(event)
},
},
)
}, [
threadMessages,
updateCurrentQAOnTree,
handleResponding,
formSettings?.inputsForm,
handleRun,
notify,
t,
config?.suggested_questions_after_answer?.enabled,
setTargetMessageId,
setConversationId,
setSuggestedQuestions,
setActiveTaskId,
setSuggestedQuestionsAbortController,
startRun,
fetchInspectVars,
invalidAllLastRun,
workflowStore,
])
return { handleSend }
}

View File

@@ -1,59 +0,0 @@
import type { ChatTreeUpdater, UpdateCurrentQAParams } from './types'
import type { ChatItemInTree } from '@/app/components/base/chat/types'
import { produce } from 'immer'
import { useCallback } from 'react'
export function useChatTreeOperations(updateChatTree: ChatTreeUpdater) {
const produceChatTreeNode = useCallback(
(tree: ChatItemInTree[], targetId: string, operation: (node: ChatItemInTree) => void) => {
return produce(tree, (draft) => {
const queue: ChatItemInTree[] = [...draft]
while (queue.length > 0) {
const current = queue.shift()!
if (current.id === targetId) {
operation(current)
break
}
if (current.children)
queue.push(...current.children)
}
})
},
[],
)
const updateCurrentQAOnTree = useCallback(({
parentId,
responseItem,
placeholderQuestionId,
questionItem,
}: UpdateCurrentQAParams) => {
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] } as ChatItemInTree
updateChatTree((currentChatTree) => {
if (!parentId) {
const questionIndex = currentChatTree.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
return produce(currentChatTree, (draft) => {
if (questionIndex === -1)
draft.push(currentQA)
else
draft[questionIndex] = currentQA
})
}
return produceChatTreeNode(currentChatTree, parentId, (parentNode) => {
if (!parentNode.children)
parentNode.children = []
const questionNodeIndex = parentNode.children.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
if (questionNodeIndex === -1)
parentNode.children.push(currentQA)
else
parentNode.children[questionNodeIndex] = currentQA
})
})
}, [produceChatTreeNode, updateChatTree])
return {
produceChatTreeNode,
updateCurrentQAOnTree,
}
}

View File

@@ -1,74 +0,0 @@
import type { ChatConfig } from './types'
import type { InputForm } from '@/app/components/base/chat/chat/type'
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
import { useEffect, useRef } from 'react'
import { useStore } from '../../../store'
import { useChatFlowControl } from './use-chat-flow-control'
import { useChatList } from './use-chat-list'
import { useChatMessageSender } from './use-chat-message-sender'
import { useChatTreeOperations } from './use-chat-tree-operations'
export function useChat(
config: ChatConfig | undefined,
formSettings?: {
inputs: Inputs
inputsForm: InputForm[]
},
prevChatTree?: ChatItemInTree[],
stopChat?: (taskId: string) => void,
) {
const chatTree = useStore(s => s.chatTree)
const conversationId = useStore(s => s.conversationId)
const isResponding = useStore(s => s.isResponding)
const suggestedQuestions = useStore(s => s.suggestedQuestions)
const targetMessageId = useStore(s => s.targetMessageId)
const updateChatTree = useStore(s => s.updateChatTree)
const setTargetMessageId = useStore(s => s.setTargetMessageId)
const initialChatTreeRef = useRef(prevChatTree)
useEffect(() => {
const initialChatTree = initialChatTreeRef.current
if (!initialChatTree || initialChatTree.length === 0)
return
updateChatTree(currentChatTree => (currentChatTree.length === 0 ? initialChatTree : currentChatTree))
}, [updateChatTree])
const { updateCurrentQAOnTree } = useChatTreeOperations(updateChatTree)
const {
handleResponding,
handleStop,
handleRestart,
} = useChatFlowControl({
stopChat,
})
const {
threadMessages,
chatList,
} = useChatList({
chatTree,
targetMessageId,
config,
formSettings,
})
const { handleSend } = useChatMessageSender({
threadMessages,
config,
formSettings,
handleResponding,
updateCurrentQAOnTree,
})
return {
conversationId,
chatList,
setTargetMessageId,
handleSend,
handleStop,
handleRestart,
isResponding,
suggestedQuestions,
}
}

View File

@@ -1,133 +0,0 @@
import type { UpdateCurrentQAParams } from './types'
import type { ChatItem } from '@/app/components/base/chat/types'
import type { AgentLogItem, NodeTracing } from '@/types/workflow'
import { NodeRunningStatus, WorkflowRunningStatus } from '../../../types'
type WorkflowEventHandlersContext = {
responseItem: ChatItem
questionItem: ChatItem
placeholderQuestionId: string
parentMessageId?: string
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
}
type TracingData = Partial<NodeTracing> & { id: string }
type AgentLogData = Partial<AgentLogItem> & { node_id: string, message_id: string }
export function createWorkflowEventHandlers(ctx: WorkflowEventHandlersContext) {
const { responseItem, questionItem, placeholderQuestionId, parentMessageId, updateCurrentQAOnTree } = ctx
const updateTree = () => {
updateCurrentQAOnTree({
placeholderQuestionId,
questionItem,
responseItem,
parentId: parentMessageId,
})
}
const updateTracingItem = (data: TracingData) => {
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
if (currentTracingIndex > -1) {
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
...responseItem.workflowProcess!.tracing[currentTracingIndex],
...data,
}
updateTree()
}
}
return {
onWorkflowStarted: ({ workflow_run_id, task_id }: { workflow_run_id: string, task_id: string }) => {
responseItem.workflow_run_id = workflow_run_id
responseItem.workflowProcess = {
status: WorkflowRunningStatus.Running,
tracing: [],
}
updateTree()
return task_id
},
onWorkflowFinished: ({ data }: { data: { status: string } }) => {
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
updateTree()
},
onIterationStart: ({ data }: { data: Partial<NodeTracing> }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
} as NodeTracing)
updateTree()
},
onIterationFinish: ({ data }: { data: TracingData }) => {
updateTracingItem(data)
},
onLoopStart: ({ data }: { data: Partial<NodeTracing> }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
} as NodeTracing)
updateTree()
},
onLoopFinish: ({ data }: { data: TracingData }) => {
updateTracingItem(data)
},
onNodeStarted: ({ data }: { data: Partial<NodeTracing> }) => {
responseItem.workflowProcess!.tracing!.push({
...data,
status: NodeRunningStatus.Running,
} as NodeTracing)
updateTree()
},
onNodeRetry: ({ data }: { data: NodeTracing }) => {
responseItem.workflowProcess!.tracing!.push(data)
updateTree()
},
onNodeFinished: ({ data }: { data: TracingData }) => {
updateTracingItem(data)
},
onAgentLog: ({ data }: { data: AgentLogData }) => {
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
if (currentNodeIndex > -1) {
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
if (current.execution_metadata) {
if (current.execution_metadata.agent_log) {
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
if (currentLogIndex > -1) {
current.execution_metadata.agent_log[currentLogIndex] = {
...current.execution_metadata.agent_log[currentLogIndex],
...data,
} as AgentLogItem
}
else {
current.execution_metadata.agent_log.push(data as AgentLogItem)
}
}
else {
current.execution_metadata.agent_log = [data as AgentLogItem]
}
}
else {
current.execution_metadata = {
agent_log: [data as AgentLogItem],
} as NodeTracing['execution_metadata']
}
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
...current,
}
updateTree()
}
},
}
}

View File

@@ -1,96 +0,0 @@
import type { StateCreator } from 'zustand'
import type { ChatItemInTree } from '@/app/components/base/chat/types'
type ChatPreviewState = {
chatTree: ChatItemInTree[]
targetMessageId: string | undefined
suggestedQuestions: string[]
conversationId: string
isResponding: boolean
activeRunId: number
activeTaskId: string
hasStopResponded: boolean
suggestedQuestionsAbortController: AbortController | null
}
type ChatPreviewActions = {
setChatTree: (chatTree: ChatItemInTree[]) => void
updateChatTree: (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void
setTargetMessageId: (messageId: string | undefined) => void
setSuggestedQuestions: (questions: string[]) => void
setConversationId: (conversationId: string) => void
setIsResponding: (isResponding: boolean) => void
setActiveTaskId: (taskId: string) => void
setHasStopResponded: (hasStopResponded: boolean) => void
setSuggestedQuestionsAbortController: (controller: AbortController | null) => void
startRun: () => number
invalidateRun: () => number
resetChatPreview: () => void
}
export type ChatPreviewSliceShape = ChatPreviewState & ChatPreviewActions
const initialState: ChatPreviewState = {
chatTree: [],
targetMessageId: undefined,
suggestedQuestions: [],
conversationId: '',
isResponding: false,
activeRunId: 0,
activeTaskId: '',
hasStopResponded: false,
suggestedQuestionsAbortController: null,
}
export const createChatPreviewSlice: StateCreator<ChatPreviewSliceShape> = (set, get) => ({
...initialState,
setChatTree: chatTree => set({ chatTree }),
updateChatTree: updater => set((state) => {
const nextChatTree = updater(state.chatTree)
if (nextChatTree === state.chatTree)
return state
return { chatTree: nextChatTree }
}),
setTargetMessageId: targetMessageId => set({ targetMessageId }),
setSuggestedQuestions: suggestedQuestions => set({ suggestedQuestions }),
setConversationId: conversationId => set({ conversationId }),
setIsResponding: isResponding => set({ isResponding }),
setActiveTaskId: activeTaskId => set({ activeTaskId }),
setHasStopResponded: hasStopResponded => set({ hasStopResponded }),
setSuggestedQuestionsAbortController: suggestedQuestionsAbortController => set({ suggestedQuestionsAbortController }),
startRun: () => {
const activeRunId = get().activeRunId + 1
set({
activeRunId,
activeTaskId: '',
hasStopResponded: false,
suggestedQuestionsAbortController: null,
})
return activeRunId
},
invalidateRun: () => {
const activeRunId = get().activeRunId + 1
set({
activeRunId,
activeTaskId: '',
suggestedQuestionsAbortController: null,
})
return activeRunId
},
resetChatPreview: () => set(state => ({
...initialState,
activeRunId: state.activeRunId + 1,
})),
})

View File

@@ -1,7 +1,6 @@
import type {
StateCreator,
} from 'zustand'
import type { ChatPreviewSliceShape } from './chat-preview-slice'
import type { ChatVariableSliceShape } from './chat-variable-slice'
import type { InspectVarsSliceShape } from './debug/inspect-vars-slice'
import type { EnvVariableSliceShape } from './env-variable-slice'
@@ -23,7 +22,6 @@ import {
} from 'zustand'
import { createStore } from 'zustand/vanilla'
import { WorkflowContext } from '@/app/components/workflow/context'
import { createChatPreviewSlice } from './chat-preview-slice'
import { createChatVariableSlice } from './chat-variable-slice'
import { createInspectVarsSlice } from './debug/inspect-vars-slice'
import { createEnvVariableSlice } from './env-variable-slice'
@@ -44,8 +42,7 @@ export type SliceFromInjection
& Partial<RagPipelineSliceShape>
export type Shape
= ChatPreviewSliceShape
& ChatVariableSliceShape
= ChatVariableSliceShape
& EnvVariableSliceShape
& FormSliceShape
& HelpLineSliceShape
@@ -70,7 +67,6 @@ export const createWorkflowStore = (params: CreateWorkflowStoreParams) => {
const { injectWorkflowStoreSliceFn } = params || {}
return createStore<Shape>((...args) => ({
...createChatPreviewSlice(...args),
...createChatVariableSlice(...args),
...createEnvVariableSlice(...args),
...createFormSlice(...args),

View File

@@ -822,7 +822,7 @@
"count": 2
},
"ts/no-explicit-any": {
"count": 15
"count": 14
}
},
"app/components/base/chat/chat/index.tsx": {
@@ -3769,6 +3769,11 @@
"count": 2
}
},
"app/components/workflow/panel/debug-and-preview/hooks.ts": {
"ts/no-explicit-any": {
"count": 7
}
},
"app/components/workflow/panel/env-panel/variable-modal.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 4

View File

@@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => {
: `${formatted}${units[unitIndex].symbol}`
}
}
// Fallback: if no threshold matched, return the number string
return num.toString()
}
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {