mirror of
https://github.com/langgenius/dify.git
synced 2026-01-28 19:14:46 +00:00
Compare commits
5 Commits
refactor/d
...
fix/workfl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4275aa729f | ||
|
|
8519b16cfc | ||
|
|
f00d823f9f | ||
|
|
e48419937b | ||
|
|
0ed0a31ed6 |
24
AGENTS.md
24
AGENTS.md
@@ -25,6 +25,30 @@ pnpm type-check:tsgo
|
||||
pnpm test
|
||||
```
|
||||
|
||||
### Frontend Linting
|
||||
|
||||
ESLint is used for frontend code quality. Available commands:
|
||||
|
||||
```bash
|
||||
# Lint all files (report only)
|
||||
pnpm lint
|
||||
|
||||
# Lint and auto-fix issues
|
||||
pnpm lint:fix
|
||||
|
||||
# Lint specific files or directories
|
||||
pnpm lint:fix app/components/base/button/
|
||||
pnpm lint:fix app/components/base/button/index.tsx
|
||||
|
||||
# Lint quietly (errors only, no warnings)
|
||||
pnpm lint:quiet
|
||||
|
||||
# Check code complexity
|
||||
pnpm lint:complexity
|
||||
```
|
||||
|
||||
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
|
||||
|
||||
## Testing & Quality Practices
|
||||
|
||||
- Follow TDD: red → green → refactor.
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Notes: `large_language_model.py`
|
||||
|
||||
## Purpose
|
||||
|
||||
Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to
|
||||
bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`).
|
||||
|
||||
## Key behaviors / invariants
|
||||
|
||||
- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from
|
||||
the first yielded `LLMResultChunk`.
|
||||
- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by
|
||||
`_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`.
|
||||
- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking
|
||||
patterns (IDs anchored to first chunk, every chunk, or missing entirely).
|
||||
- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to
|
||||
surface invalid delta sequences explicitly.
|
||||
- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging.
|
||||
- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must
|
||||
be re-attached in this layer before callbacks/consumers use them.
|
||||
- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation
|
||||
unless `callback.raise_error` is true.
|
||||
|
||||
## Test focus
|
||||
|
||||
- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and
|
||||
patches `_gen_tool_call_id` for deterministic IDs.
|
||||
120
api/AGENTS.md
120
api/AGENTS.md
@@ -1,97 +1,47 @@
|
||||
# API Agent Guide
|
||||
|
||||
## Agent Notes (must-check)
|
||||
## Notes for Agent (must-check)
|
||||
|
||||
Before you start work on any backend file under `api/`, you MUST check whether a related note exists under:
|
||||
Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec.
|
||||
|
||||
- `agent-notes/<same-relative-path-as-target-file>.md`
|
||||
Look for:
|
||||
|
||||
Rules:
|
||||
- The module (file) docstring at the top of a source code file
|
||||
- Docstrings on classes and functions/methods
|
||||
- Paragraph/block comments for non-obvious logic
|
||||
|
||||
- **Path mapping**: for a target file `<path>/<name>.py`, the note must be `agent-notes/<path>/<name>.py.md` (same folder structure, same filename, plus `.md`).
|
||||
- **Before working**:
|
||||
- If the note exists, read it first and follow any constraints/decisions recorded there.
|
||||
- If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality.
|
||||
- If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases.
|
||||
- **During working**:
|
||||
- Keep the note in sync as you discover constraints, make decisions, or change approach.
|
||||
- If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note).
|
||||
- Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end).
|
||||
- Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog.
|
||||
- **When finishing work**:
|
||||
- Update the related note(s) to reflect what changed, why, and any new edge cases/tests.
|
||||
- If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance.
|
||||
- Keep notes concise and accurate; they are meant to prevent repeated rediscovery.
|
||||
### What to write where
|
||||
|
||||
## Skill Index
|
||||
- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse.
|
||||
- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing.
|
||||
- Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard.
|
||||
- Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes.
|
||||
- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used).
|
||||
- If the class is intentionally stateful, note what state exists and what methods mutate it.
|
||||
- If concurrency/async assumptions matter, state them explicitly.
|
||||
- **Function/method docstring**: behavioural contract.
|
||||
- Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions.
|
||||
- Add examples only when they prevent misuse.
|
||||
- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states.
|
||||
- Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality.
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
### Rules (must follow)
|
||||
|
||||
### Platform Foundations
|
||||
In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments.
|
||||
|
||||
#### [Infrastructure Overview](agent_skills/infra.md)
|
||||
|
||||
- **When to read this**
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.
|
||||
- **What it covers**
|
||||
- Configuration stack (`configs/app_config.py`, remote settings)
|
||||
- Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`)
|
||||
- Redis conventions (`extensions/ext_redis.py`)
|
||||
- Plugin runtime topology
|
||||
- Vector-store factory (`core/rag/datasource/vdb/*`)
|
||||
- Observability hooks
|
||||
- SSRF proxy usage
|
||||
- Core CLI commands
|
||||
|
||||
### Plugin & Extension Development
|
||||
|
||||
#### [Plugin Systems](agent_skills/plugin.md)
|
||||
|
||||
- **When to read this**
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.
|
||||
- **What it covers**
|
||||
- Plugin manifests (`core/plugin/entities/plugin.py`)
|
||||
- Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands)
|
||||
- Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent)
|
||||
- Daemon coordination (`core/plugin/entities/plugin_daemon.py`)
|
||||
- How provider registries surface capabilities to the rest of the platform
|
||||
|
||||
#### [Plugin OAuth](agent_skills/plugin_oauth.md)
|
||||
|
||||
- **When to read this**
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.
|
||||
- **Topics**
|
||||
- Credential storage
|
||||
- Encryption helpers (`core/helper/provider_encryption.py`)
|
||||
- OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`)
|
||||
- How console/API layers expose the flows
|
||||
|
||||
### Workflow Entry & Execution
|
||||
|
||||
#### [Trigger Concepts](agent_skills/trigger.md)
|
||||
|
||||
- **When to read this**
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.
|
||||
- **Details**
|
||||
- Start-node taxonomy
|
||||
- Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`)
|
||||
- Async orchestration (`services/async_workflow_service.py`, Celery queues)
|
||||
- Debug event bus
|
||||
- Storage/logging interactions
|
||||
|
||||
## General Reminders
|
||||
|
||||
- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
- **Before working**
|
||||
- Read the notes in the area you’ll touch; treat them as part of the spec.
|
||||
- If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality.
|
||||
- If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour).
|
||||
- **During working**
|
||||
- Keep the notes in sync as you discover constraints, make decisions, or change approach.
|
||||
- If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants.
|
||||
- Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct.
|
||||
- Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions.
|
||||
- **When finishing**
|
||||
- Update the notes to reflect what changed, why, and any new edge cases/tests.
|
||||
- Remove or rewrite any comments that could be mistaken as current guidance but no longer apply.
|
||||
- Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery.
|
||||
|
||||
## Coding Style
|
||||
|
||||
@@ -226,7 +176,7 @@ Before opening a PR / submitting:
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
|
||||
### Miscellaneous
|
||||
|
||||
|
||||
@@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from mimetypes import guess_extension
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
@@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
)
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
@@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
@@ -203,6 +215,9 @@ class AppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
@@ -210,21 +225,41 @@ class AppRunner:
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
if not stream and isinstance(invoke_result, LLMResult):
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
elif stream and isinstance(invoke_result, Generator):
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
|
||||
def _handle_invoke_result_direct(
|
||||
self,
|
||||
invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish(
|
||||
@@ -235,13 +270,22 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
self,
|
||||
invoke_result: Generator[LLMResultChunk, None, None],
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool,
|
||||
message_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param agent: agent
|
||||
:param message_id: message id for multimodal output
|
||||
:param user_id: user id for multimodal output
|
||||
:param tenant_id: tenant id for multimodal output
|
||||
:return:
|
||||
"""
|
||||
model: str = ""
|
||||
@@ -259,12 +303,26 @@ class AppRunner:
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, str):
|
||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
||||
_logger.warning("received multimodal output, type=%s", type(content))
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content # failback to str
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -289,6 +347,101 @@ class AppRunner:
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_multimodal_image_content(
|
||||
self,
|
||||
content: ImagePromptMessageContent,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
queue_manager: AppQueueManager,
|
||||
):
|
||||
"""
|
||||
Handle multimodal image content from LLM response.
|
||||
Save the image and create a MessageFile record.
|
||||
|
||||
:param content: ImagePromptMessageContent instance
|
||||
:param message_id: message id
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
_logger.info("Handling multimodal image content for message %s", message_id)
|
||||
|
||||
image_url = content.url
|
||||
base64_data = content.base64_data
|
||||
|
||||
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
|
||||
|
||||
if not image_url and not base64_data:
|
||||
_logger.warning("Image content has neither URL nor base64 data")
|
||||
return
|
||||
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
# Save the image file
|
||||
try:
|
||||
if image_url:
|
||||
# Download image from URL
|
||||
_logger.info("Downloading image from URL: %s", image_url)
|
||||
tool_file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
elif base64_data:
|
||||
if base64_data.startswith("data:"):
|
||||
base64_data = base64_data.split(",", 1)[1]
|
||||
|
||||
image_binary = base64.b64decode(base64_data)
|
||||
mimetype = content.mime_type or "image/png"
|
||||
extension = guess_extension(mimetype) or ".png"
|
||||
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=image_binary,
|
||||
mimetype=mimetype,
|
||||
filename=f"generated_image{extension}",
|
||||
)
|
||||
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
|
||||
else:
|
||||
return
|
||||
except Exception:
|
||||
_logger.exception("Failed to save image file")
|
||||
return
|
||||
|
||||
# Create MessageFile record
|
||||
message_file = MessageFile(
|
||||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
CreatorUserRole.ACCOUNT
|
||||
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
|
||||
else CreatorUserRole.END_USER
|
||||
),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
# Publish QueueMessageFileEvent
|
||||
queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file.id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
@@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamEvent,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
|
||||
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
|
||||
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
|
||||
message_id=self._message_id
|
||||
)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
event_type=self._precomputed_event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
||||
@@ -5,7 +5,7 @@ from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
|
||||
StreamEvent,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
@@ -57,13 +58,15 @@ class MessageCycleManager:
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
# Fast path: cached determination from prior QueueMessageFileEvent
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||
with session_factory.create_session() as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
||||
|
||||
if has_file:
|
||||
if message_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
@@ -199,6 +202,8 @@ class MessageCycleManager:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
self._message_has_file.add(message_file.message_id)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
|
||||
@@ -0,0 +1,454 @@
|
||||
"""Test multimodal image output handling in BaseAppRunner."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
class TestBaseAppRunnerMultimodal:
|
||||
"""Test that BaseAppRunner correctly handles multimodal image content."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tenant_id(self):
|
||||
"""Mock tenant ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_id(self):
|
||||
"""Mock message ID."""
|
||||
return str(uuid4())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = MagicMock()
|
||||
manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_file(self):
|
||||
"""Create a mock tool file."""
|
||||
tool_file = MagicMock()
|
||||
tool_file.id = str(uuid4())
|
||||
return tool_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file(self):
|
||||
"""Create a mock message file."""
|
||||
message_file = MagicMock()
|
||||
message_file.id = str(uuid4())
|
||||
return message_file
|
||||
|
||||
def test_handle_multimodal_image_content_with_url(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from URL."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from URL
|
||||
mock_mgr.create_file_by_url.assert_called_once_with(
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
file_url=image_url,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
# Verify message file was created with correct parameters
|
||||
mock_msg_file_class.assert_called_once()
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["message_id"] == mock_message_id
|
||||
assert call_kwargs["type"] == FileType.IMAGE
|
||||
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
|
||||
assert call_kwargs["belongs_to"] == "assistant"
|
||||
assert call_kwargs["created_by"] == mock_user_id
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once_with(mock_message_file)
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once_with(mock_message_file)
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
publish_call = mock_queue_manager.publish.call_args
|
||||
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
|
||||
assert publish_call[0][0].message_file_id == mock_message_file.id
|
||||
# publish_from might be passed as positional or keyword argument
|
||||
assert (
|
||||
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
|
||||
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data."""
|
||||
# Arrange
|
||||
import base64
|
||||
|
||||
# Create a small test image (1x1 PNG)
|
||||
test_image_data = base64.b64encode(
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
|
||||
).decode()
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=test_image_data,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Verify tool file was created from base64
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
assert call_kwargs["user_id"] == mock_user_id
|
||||
assert call_kwargs["tenant_id"] == mock_tenant_id
|
||||
assert call_kwargs["conversation_id"] is None
|
||||
assert "file_binary" in call_kwargs
|
||||
assert call_kwargs["mimetype"] == "image/png"
|
||||
assert call_kwargs["filename"].startswith("generated_image")
|
||||
assert call_kwargs["filename"].endswith(".png")
|
||||
|
||||
# Verify message file was created
|
||||
mock_msg_file_class.assert_called_once()
|
||||
|
||||
# Verify database operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
# Verify event was published
|
||||
mock_queue_manager.publish.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_with_base64_data_uri(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test handling image from base64 data with URI prefix."""
|
||||
# Arrange
|
||||
# Data URI format: data:image/png;base64,<base64_data>
|
||||
test_image_data = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
)
|
||||
content = ImagePromptMessageContent(
|
||||
base64_data=f"data:image/png;base64,{test_image_data}",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify that base64 data was extracted correctly (without prefix)
|
||||
mock_mgr.create_file_by_raw.assert_called_once()
|
||||
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
|
||||
# The base64 data should be decoded, so we check the binary was passed
|
||||
assert "file_binary" in call_kwargs
|
||||
|
||||
def test_handle_multimodal_image_content_without_url_or_base64(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content without URL or base64 data."""
|
||||
# Arrange
|
||||
content = ImagePromptMessageContent(
|
||||
url="",
|
||||
base64_data="",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create any files or publish events
|
||||
mock_mgr_class.assert_not_called()
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_with_error(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
):
|
||||
"""Test handling image content when an error occurs."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock to raise exception
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
# Should not raise exception, just log it
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - should not create message file or publish event on error
|
||||
mock_msg_file_class.assert_not_called()
|
||||
mock_session.add.assert_not_called()
|
||||
mock_queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_handle_multimodal_image_content_debugger_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that debugger mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is ACCOUNT for debugger mode
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
|
||||
|
||||
def test_handle_multimodal_image_content_service_api_mode(
|
||||
self,
|
||||
mock_user_id,
|
||||
mock_tenant_id,
|
||||
mock_message_id,
|
||||
mock_queue_manager,
|
||||
mock_tool_file,
|
||||
mock_message_file,
|
||||
):
|
||||
"""Test that service API mode sets correct created_by_role."""
|
||||
# Arrange
|
||||
image_url = "http://example.com/image.png"
|
||||
content = ImagePromptMessageContent(
|
||||
url=image_url,
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
method = AppRunner._handle_multimodal_image_content
|
||||
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=mock_message_id,
|
||||
user_id=mock_user_id,
|
||||
tenant_id=mock_tenant_id,
|
||||
queue_manager=mock_queue_manager,
|
||||
)
|
||||
|
||||
# Assert - verify created_by_role is END_USER for service API
|
||||
call_kwargs = mock_msg_file_class.call_args[1]
|
||||
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
@@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
@@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
@@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
@@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open a session when event_type is provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
@@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.scalar(select(...))
|
||||
mock_session.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
# Should open session once
|
||||
mock_session_factory.create_session.assert_called_once()
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
@@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
# Should not open session again when event_type provided
|
||||
mock_session_factory.create_session.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
|
||||
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
178
web/app/components/base/chat/chat/hooks.multimodal.spec.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Tests for multimodal image file handling in chat hooks.
|
||||
* Tests the file object conversion logic without full hook integration.
|
||||
*/
|
||||
|
||||
describe('Multimodal File Handling', () => {
|
||||
describe('File type to MIME type mapping', () => {
|
||||
it('should map image to image/png', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedMime = 'image/png'
|
||||
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map video to video/mp4', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedMime = 'video/mp4'
|
||||
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map audio to audio/mpeg', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedMime = 'audio/mpeg'
|
||||
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
|
||||
it('should map unknown to application/octet-stream', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const expectedMime = 'application/octet-stream'
|
||||
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
|
||||
expect(mimeType).toBe(expectedMime)
|
||||
})
|
||||
})
|
||||
|
||||
describe('TransferMethod selection', () => {
|
||||
it('should select remote_url for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('remote_url')
|
||||
})
|
||||
|
||||
it('should select local_file for non-images', () => {
|
||||
const fileType: string = 'video'
|
||||
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
|
||||
expect(transferMethod).toBe('local_file')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File extension mapping', () => {
|
||||
it('should use .png extension for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedExtension = '.png'
|
||||
const extension = fileType === 'image' ? 'png' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp4 extension for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedExtension = '.mp4'
|
||||
const extension = fileType === 'video' ? 'mp4' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
|
||||
it('should use .mp3 extension for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedExtension = '.mp3'
|
||||
const extension = fileType === 'audio' ? 'mp3' : 'bin'
|
||||
expect(extension).toBe(expectedExtension.replace('.', ''))
|
||||
})
|
||||
})
|
||||
|
||||
describe('File name generation', () => {
|
||||
it('should generate correct file name for images', () => {
|
||||
const fileType: string = 'image'
|
||||
const expectedName = 'generated_image.png'
|
||||
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for videos', () => {
|
||||
const fileType: string = 'video'
|
||||
const expectedName = 'generated_video.mp4'
|
||||
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
|
||||
it('should generate correct file name for audio', () => {
|
||||
const fileType: string = 'audio'
|
||||
const expectedName = 'generated_audio.mp3'
|
||||
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
|
||||
expect(fileName).toBe(expectedName)
|
||||
})
|
||||
})
|
||||
|
||||
describe('SupportFileType mapping', () => {
|
||||
it('should map image type to image supportFileType', () => {
|
||||
const fileType: string = 'image'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('image')
|
||||
})
|
||||
|
||||
it('should map video type to video supportFileType', () => {
|
||||
const fileType: string = 'video'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('video')
|
||||
})
|
||||
|
||||
it('should map audio type to audio supportFileType', () => {
|
||||
const fileType: string = 'audio'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('audio')
|
||||
})
|
||||
|
||||
it('should map unknown type to document supportFileType', () => {
|
||||
const fileType: string = 'unknown'
|
||||
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
|
||||
expect(supportFileType).toBe('document')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File conversion logic', () => {
|
||||
it('should detect existing transferMethod', () => {
|
||||
const fileWithTransferMethod = {
|
||||
id: 'file-123',
|
||||
transferMethod: 'remote_url' as const,
|
||||
type: 'image/png',
|
||||
name: 'test.png',
|
||||
size: 1024,
|
||||
supportFileType: 'image',
|
||||
progress: 100,
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
|
||||
expect(hasTransferMethod).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect missing transferMethod', () => {
|
||||
const fileWithoutTransferMethod = {
|
||||
id: 'file-456',
|
||||
type: 'image',
|
||||
url: 'http://example.com/image.png',
|
||||
belongs_to: 'assistant',
|
||||
}
|
||||
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
|
||||
expect(hasTransferMethod).toBe(false)
|
||||
})
|
||||
|
||||
it('should create file with size 0 for generated files', () => {
|
||||
const expectedSize = 0
|
||||
expect(expectedSize).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Agent vs Non-Agent mode logic', () => {
|
||||
it('should check for agent_thoughts to determine mode', () => {
|
||||
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [{}],
|
||||
}
|
||||
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(true)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is empty', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
|
||||
agent_thoughts: [],
|
||||
}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBe(false)
|
||||
})
|
||||
|
||||
it('should detect non-agent mode when agent_thoughts is undefined', () => {
|
||||
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
|
||||
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
|
||||
expect(isAgentMode).toBeFalsy()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -419,9 +419,40 @@ export const useChat = (
|
||||
}
|
||||
},
|
||||
onFile(file) {
|
||||
// Convert simple file type to MIME type for non-agent mode
|
||||
// Backend sends: { id, type: "image", belongs_to, url }
|
||||
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
|
||||
|
||||
// Determine file type for MIME conversion
|
||||
const fileType = (file as { type?: string }).type || 'image'
|
||||
|
||||
// If file already has transferMethod, use it as base and ensure all required fields exist
|
||||
// Otherwise, create a new complete file object
|
||||
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
|
||||
|
||||
const convertedFile: FileEntity = {
|
||||
id: baseFile?.id || (file as { id: string }).id,
|
||||
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
|
||||
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
|
||||
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
|
||||
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
|
||||
progress: baseFile?.progress ?? 100,
|
||||
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
|
||||
url: baseFile?.url || (file as { url?: string }).url,
|
||||
size: baseFile?.size ?? 0, // Generated files don't have a known size
|
||||
}
|
||||
|
||||
// For agent mode, add files to the last thought
|
||||
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
|
||||
if (lastThought)
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
|
||||
if (lastThought) {
|
||||
const thought = lastThought as { message_files?: FileEntity[] }
|
||||
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
|
||||
}
|
||||
// For non-agent mode, add files directly to responseItem.message_files
|
||||
else {
|
||||
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
|
||||
responseItem.message_files = [...currentFiles, convertedFile]
|
||||
}
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
|
||||
@@ -2039,8 +2039,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Find submit button by class
|
||||
@@ -2054,8 +2059,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Component should still be functional - check for the main container
|
||||
@@ -2089,10 +2099,15 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
isLoading: false,
|
||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox to be rendered with timeout for CI environment
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
// Submit
|
||||
@@ -2101,8 +2116,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
// Verify the component is still rendered after submission
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
// Wait for the mutation to complete
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
|
||||
it('should render ResultItem components for non-external results', async () => {
|
||||
@@ -2127,10 +2147,15 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
isLoading: false,
|
||||
} as unknown as ReturnType<typeof useDatasetTestingRecords>)
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
const { container: _container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for component to be fully rendered with longer timeout
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Submit a query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2138,8 +2163,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
// Verify component is rendered after submission
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
// Wait for mutation to complete with longer timeout
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
|
||||
it('should render external results when dataset is external', async () => {
|
||||
@@ -2165,8 +2195,14 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
|
||||
// Component should render
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Type in textarea to verify component is functional
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'Test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2174,9 +2210,13 @@ describe('Integration: Hit Testing Flow', () => {
|
||||
if (submitButton)
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
})
|
||||
// Verify component is still functional after submission
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
},
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2260,8 +2300,13 @@ describe('renderHitResults Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Enter query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||
|
||||
// Submit
|
||||
@@ -2386,8 +2431,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Enter query and submit
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test query' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2400,7 +2450,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
// Wait for state updates
|
||||
await waitFor(() => {
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
}, { timeout: 2000 })
|
||||
}, { timeout: 3000 })
|
||||
|
||||
// Verify mutation was called
|
||||
expect(mockHitTestingMutateAsync).toHaveBeenCalled()
|
||||
@@ -2445,8 +2495,13 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
|
||||
const { container } = renderWithProviders(<HitTestingPage datasetId="dataset-1" />)
|
||||
|
||||
// Wait for textbox with timeout for CI
|
||||
const textarea = await waitFor(
|
||||
() => screen.getByRole('textbox'),
|
||||
{ timeout: 3000 },
|
||||
)
|
||||
|
||||
// Submit a query
|
||||
const textarea = screen.getByRole('textbox')
|
||||
fireEvent.change(textarea, { target: { value: 'test' } })
|
||||
|
||||
const buttons = screen.getAllByRole('button')
|
||||
@@ -2458,7 +2513,7 @@ describe('HitTestingPage Internal Functions Coverage', () => {
|
||||
// Verify the component renders
|
||||
await waitFor(() => {
|
||||
expect(container.firstChild).toBeInTheDocument()
|
||||
})
|
||||
}, { timeout: 3000 })
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -162,6 +162,44 @@ vi.mock('@/utils/var', () => ({
|
||||
getMarketplaceUrl: (path: string, _params?: Record<string, string | undefined>) => `https://marketplace.dify.ai${path}`,
|
||||
}))
|
||||
|
||||
// Mock marketplace client used by marketplace utils
|
||||
vi.mock('@/service/client', () => ({
|
||||
marketplaceClient: {
|
||||
collections: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
collections: [
|
||||
{
|
||||
name: 'collection-1',
|
||||
label: { 'en-US': 'Collection 1' },
|
||||
description: { 'en-US': 'Desc' },
|
||||
rule: '',
|
||||
created_at: '2024-01-01',
|
||||
updated_at: '2024-01-01',
|
||||
searchable: true,
|
||||
search_params: { query: '', sort_by: 'install_count', sort_order: 'DESC' },
|
||||
},
|
||||
],
|
||||
},
|
||||
})),
|
||||
collectionPlugins: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
plugins: [
|
||||
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||
],
|
||||
},
|
||||
})),
|
||||
// Some utils paths may call searchAdvanced; provide a minimal stub
|
||||
searchAdvanced: vi.fn(async (_args?: unknown, _opts?: { signal?: AbortSignal }) => ({
|
||||
data: {
|
||||
plugins: [
|
||||
{ type: 'plugin', org: 'test', name: 'plugin1', tags: [] },
|
||||
],
|
||||
total: 1,
|
||||
},
|
||||
})),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock context/query-client
|
||||
vi.mock('@/context/query-client', () => ({
|
||||
TanstackQueryInitializer: ({ children }: { children: React.ReactNode }) => <div data-testid="query-initializer">{children}</div>,
|
||||
@@ -1474,7 +1512,24 @@ describe('flatMap Coverage', () => {
|
||||
// ================================
|
||||
// Async Utils Tests
|
||||
// ================================
|
||||
|
||||
// Narrow mock surface and avoid any in tests
|
||||
// Types are local to this spec to keep scope minimal
|
||||
|
||||
type FnMock = ReturnType<typeof vi.fn>
|
||||
|
||||
type MarketplaceClientMock = {
|
||||
collectionPlugins: FnMock
|
||||
collections: FnMock
|
||||
}
|
||||
|
||||
describe('Async Utils', () => {
|
||||
let marketplaceClientMock: MarketplaceClientMock
|
||||
|
||||
beforeAll(async () => {
|
||||
const mod = await import('@/service/client')
|
||||
marketplaceClientMock = mod.marketplaceClient as unknown as MarketplaceClientMock
|
||||
})
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
@@ -1490,12 +1545,10 @@ describe('Async Utils', () => {
|
||||
{ type: 'plugin', org: 'test', name: 'plugin2' },
|
||||
]
|
||||
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Adjusted to our mocked marketplaceClient instead of fetch
|
||||
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||
data: { plugins: mockPlugins },
|
||||
})
|
||||
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
const result = await getMarketplacePluginsByCollectionId('test-collection', {
|
||||
@@ -1504,12 +1557,13 @@ describe('Async Utils', () => {
|
||||
type: 'plugin',
|
||||
})
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalled()
|
||||
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||
expect(result).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('should handle fetch error and return empty array', async () => {
|
||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
||||
// Simulate error from client
|
||||
marketplaceClientMock.collectionPlugins.mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
const result = await getMarketplacePluginsByCollectionId('test-collection')
|
||||
@@ -1519,25 +1573,18 @@ describe('Async Utils', () => {
|
||||
|
||||
it('should pass abort signal when provided', async () => {
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Our client mock receives the signal as second arg
|
||||
marketplaceClientMock.collectionPlugins.mockResolvedValueOnce({
|
||||
data: { plugins: mockPlugins },
|
||||
})
|
||||
|
||||
const controller = new AbortController()
|
||||
const { getMarketplacePluginsByCollectionId } = await import('./utils')
|
||||
await getMarketplacePluginsByCollectionId('test-collection', {}, { signal: controller.signal })
|
||||
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
expect.any(Request),
|
||||
expect.any(Object),
|
||||
)
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('test-collection')
|
||||
expect(marketplaceClientMock.collectionPlugins).toHaveBeenCalled()
|
||||
const call = marketplaceClientMock.collectionPlugins.mock.calls[0]
|
||||
expect(call[1]).toMatchObject({ signal: controller.signal })
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1548,23 +1595,17 @@ describe('Async Utils', () => {
|
||||
]
|
||||
const mockPlugins = [{ type: 'plugins', org: 'test', name: 'plugin1' }]
|
||||
|
||||
let callCount = 0
|
||||
globalThis.fetch = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { collections: mockCollections } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
// Simulate two-step client calls: collections then collectionPlugins
|
||||
let stage = 0
|
||||
marketplaceClientMock.collections.mockImplementationOnce(async () => {
|
||||
stage = 1
|
||||
return { data: { collections: mockCollections } }
|
||||
})
|
||||
marketplaceClientMock.collectionPlugins.mockImplementation(async () => {
|
||||
if (stage === 1) {
|
||||
return { data: { plugins: mockPlugins } }
|
||||
}
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ data: { plugins: mockPlugins } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
return { data: { plugins: [] } }
|
||||
})
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
@@ -1578,7 +1619,8 @@ describe('Async Utils', () => {
|
||||
})
|
||||
|
||||
it('should handle fetch error and return empty data', async () => {
|
||||
globalThis.fetch = vi.fn().mockRejectedValue(new Error('Network error'))
|
||||
// Simulate client error
|
||||
marketplaceClientMock.collections.mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
const result = await getMarketplaceCollectionsAndPlugins()
|
||||
@@ -1588,24 +1630,16 @@ describe('Async Utils', () => {
|
||||
})
|
||||
|
||||
it('should append condition and type to URL when provided', async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue(
|
||||
new Response(JSON.stringify({ data: { collections: [] } }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
// Assert that the client was called with query containing condition/type
|
||||
const { getMarketplaceCollectionsAndPlugins } = await import('./utils')
|
||||
await getMarketplaceCollectionsAndPlugins({
|
||||
condition: 'category=tool',
|
||||
type: 'bundle',
|
||||
})
|
||||
|
||||
// oRPC uses Request objects, so check that fetch was called with a Request containing the right URL
|
||||
expect(globalThis.fetch).toHaveBeenCalled()
|
||||
const call = vi.mocked(globalThis.fetch).mock.calls[0]
|
||||
const request = call[0] as Request
|
||||
expect(request.url).toContain('condition=category%3Dtool')
|
||||
expect(marketplaceClientMock.collections).toHaveBeenCalled()
|
||||
const call = marketplaceClientMock.collections.mock.calls[0]
|
||||
expect(call[0]).toMatchObject({ query: expect.objectContaining({ condition: 'category=tool', type: 'bundle' }) })
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,705 @@
|
||||
/**
|
||||
* Test Suite for useNodesSyncDraft Hook
|
||||
*
|
||||
* PURPOSE:
|
||||
* This hook handles syncing workflow draft to the server. The key fix being tested
|
||||
* is the error handling behavior when `draft_workflow_not_sync` error occurs.
|
||||
*
|
||||
* MULTI-TAB PROBLEM SCENARIO:
|
||||
* 1. User opens the same workflow in Tab A and Tab B (both have hash: v1)
|
||||
* 2. Tab A saves successfully, server returns new hash: v2
|
||||
* 3. Tab B tries to save with old hash: v1, server returns 400 error with code
|
||||
* 'draft_workflow_not_sync'
|
||||
* 4. BEFORE FIX: handleRefreshWorkflowDraft() was called without args, which fetched
|
||||
* draft AND overwrote canvas - user lost unsaved changes in Tab B
|
||||
* 5. AFTER FIX: handleRefreshWorkflowDraft(true) is called, which fetches draft but
|
||||
* only updates hash (notUpdateCanvas=true), preserving user's canvas changes
|
||||
*
|
||||
* TESTING STRATEGY:
|
||||
* We don't simulate actual tab switching UI behavior. Instead, we mock the API to
|
||||
* return `draft_workflow_not_sync` error and verify:
|
||||
* - The hook calls handleRefreshWorkflowDraft(true) - not handleRefreshWorkflowDraft()
|
||||
* - This ensures canvas data is preserved while hash is updated for retry
|
||||
*
|
||||
* This is behavior-driven testing - we verify "what the code does when receiving
|
||||
* specific API errors" rather than simulating complete user interaction flows.
|
||||
* True multi-tab integration testing would require E2E frameworks like Playwright.
|
||||
*/
|
||||
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
|
||||
// Mock reactflow store
|
||||
const mockGetNodes = vi.fn()
|
||||
|
||||
type MockEdge = {
|
||||
id: string
|
||||
source: string
|
||||
target: string
|
||||
data: Record<string, unknown>
|
||||
}
|
||||
|
||||
const mockStoreState: {
|
||||
getNodes: ReturnType<typeof vi.fn>
|
||||
edges: MockEdge[]
|
||||
transform: number[]
|
||||
} = {
|
||||
getNodes: mockGetNodes,
|
||||
edges: [],
|
||||
transform: [0, 0, 1],
|
||||
}
|
||||
vi.mock('reactflow', () => ({
|
||||
useStoreApi: () => ({
|
||||
getState: () => mockStoreState,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock features store
|
||||
const mockFeaturesState = {
|
||||
features: {
|
||||
opening: { enabled: false, opening_statement: '', suggested_questions: [] },
|
||||
suggested: {},
|
||||
text2speech: {},
|
||||
speech2text: {},
|
||||
citation: {},
|
||||
moderation: {},
|
||||
file: {},
|
||||
},
|
||||
}
|
||||
vi.mock('@/app/components/base/features/hooks', () => ({
|
||||
useFeaturesStore: () => ({
|
||||
getState: () => mockFeaturesState,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock workflow service
|
||||
const mockSyncWorkflowDraft = vi.fn()
|
||||
vi.mock('@/service/workflow', () => ({
|
||||
syncWorkflowDraft: (...args: unknown[]) => mockSyncWorkflowDraft(...args),
|
||||
}))
|
||||
|
||||
// Mock useNodesReadOnly
|
||||
const mockGetNodesReadOnly = vi.fn()
|
||||
vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({
|
||||
useNodesReadOnly: () => ({
|
||||
getNodesReadOnly: mockGetNodesReadOnly,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock useSerialAsyncCallback - pass through the callback
|
||||
vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({
|
||||
useSerialAsyncCallback: (callback: (...args: unknown[]) => unknown) => callback,
|
||||
}))
|
||||
|
||||
// Mock workflow store
|
||||
const mockSetSyncWorkflowDraftHash = vi.fn()
|
||||
const mockSetDraftUpdatedAt = vi.fn()
|
||||
|
||||
const createMockWorkflowStoreState = (overrides = {}) => ({
|
||||
appId: 'test-app-id',
|
||||
conversationVariables: [],
|
||||
environmentVariables: [],
|
||||
syncWorkflowDraftHash: 'current-hash-123',
|
||||
isWorkflowDataLoaded: true,
|
||||
setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash,
|
||||
setDraftUpdatedAt: mockSetDraftUpdatedAt,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const mockWorkflowStoreGetState = vi.fn()
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useWorkflowStore: () => ({
|
||||
getState: mockWorkflowStoreGetState,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock useWorkflowRefreshDraft (THE KEY DEPENDENCY FOR THIS TEST)
|
||||
const mockHandleRefreshWorkflowDraft = vi.fn()
|
||||
vi.mock('.', () => ({
|
||||
useWorkflowRefreshDraft: () => ({
|
||||
handleRefreshWorkflowDraft: mockHandleRefreshWorkflowDraft,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock API_PREFIX
|
||||
vi.mock('@/config', () => ({
|
||||
API_PREFIX: '/api',
|
||||
}))
|
||||
|
||||
// Create a mock error response that mimics the actual API error
|
||||
const createMockErrorResponse = (code: string) => {
|
||||
const errorBody = { code, message: 'Draft not in sync' }
|
||||
let bodyUsed = false
|
||||
|
||||
return {
|
||||
json: vi.fn().mockImplementation(() => {
|
||||
bodyUsed = true
|
||||
return Promise.resolve(errorBody)
|
||||
}),
|
||||
get bodyUsed() {
|
||||
return bodyUsed
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
describe('useNodesSyncDraft', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockGetNodesReadOnly.mockReturnValue(false)
|
||||
mockGetNodes.mockReturnValue([
|
||||
{ id: 'node-1', type: 'start', data: { type: 'start' } },
|
||||
{ id: 'node-2', type: 'llm', data: { type: 'llm' } },
|
||||
])
|
||||
mockStoreState.edges = [
|
||||
{ id: 'edge-1', source: 'node-1', target: 'node-2', data: {} },
|
||||
]
|
||||
mockWorkflowStoreGetState.mockReturnValue(createMockWorkflowStoreState())
|
||||
mockSyncWorkflowDraft.mockResolvedValue({
|
||||
hash: 'new-hash-456',
|
||||
updated_at: Date.now(),
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('doSyncWorkflowDraft function', () => {
|
||||
it('should return doSyncWorkflowDraft function', () => {
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
expect(result.current.doSyncWorkflowDraft).toBeDefined()
|
||||
expect(typeof result.current.doSyncWorkflowDraft).toBe('function')
|
||||
})
|
||||
|
||||
it('should return syncWorkflowDraftWhenPageClose function', () => {
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
expect(result.current.syncWorkflowDraftWhenPageClose).toBeDefined()
|
||||
expect(typeof result.current.syncWorkflowDraftWhenPageClose).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('successful sync', () => {
|
||||
it('should call syncWorkflowDraft service on successful sync', async () => {
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).toHaveBeenCalledWith({
|
||||
url: '/apps/test-app-id/workflows/draft',
|
||||
params: expect.objectContaining({
|
||||
hash: 'current-hash-123',
|
||||
graph: expect.objectContaining({
|
||||
nodes: expect.any(Array),
|
||||
edges: expect.any(Array),
|
||||
viewport: expect.any(Object),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
it('should update syncWorkflowDraftHash on success', async () => {
|
||||
mockSyncWorkflowDraft.mockResolvedValue({
|
||||
hash: 'new-hash-789',
|
||||
updated_at: 1234567890,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-789')
|
||||
})
|
||||
|
||||
it('should update draftUpdatedAt on success', async () => {
|
||||
const updatedAt = 1234567890
|
||||
mockSyncWorkflowDraft.mockResolvedValue({
|
||||
hash: 'new-hash',
|
||||
updated_at: updatedAt,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSetDraftUpdatedAt).toHaveBeenCalledWith(updatedAt)
|
||||
})
|
||||
|
||||
it('should call onSuccess callback on success', async () => {
|
||||
const onSuccess = vi.fn()
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSuccess })
|
||||
})
|
||||
|
||||
expect(onSuccess).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onSettled callback after success', async () => {
|
||||
const onSettled = vi.fn()
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSettled })
|
||||
})
|
||||
|
||||
expect(onSettled).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('sync error handling - draft_workflow_not_sync (THE KEY FIX)', () => {
|
||||
/**
|
||||
* This is THE KEY TEST for the bug fix.
|
||||
*
|
||||
* SCENARIO: Multi-tab editing
|
||||
* 1. User opens workflow in Tab A and Tab B
|
||||
* 2. Tab A saves draft successfully, gets new hash
|
||||
* 3. Tab B tries to save with old hash
|
||||
* 4. Server returns 400 with code 'draft_workflow_not_sync'
|
||||
*
|
||||
* BEFORE FIX:
|
||||
* - handleRefreshWorkflowDraft() was called without arguments
|
||||
* - This would fetch draft AND overwrite the canvas
|
||||
* - User loses their unsaved changes in Tab B
|
||||
*
|
||||
* AFTER FIX:
|
||||
* - handleRefreshWorkflowDraft(true) is called
|
||||
* - This fetches draft but DOES NOT overwrite canvas
|
||||
* - Only hash is updated for the next sync attempt
|
||||
* - User's unsaved changes are preserved
|
||||
*/
|
||||
it('should call handleRefreshWorkflowDraft with notUpdateCanvas=true when draft_workflow_not_sync error occurs', async () => {
|
||||
const mockError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// THE KEY ASSERTION: handleRefreshWorkflowDraft must be called with true
|
||||
await waitFor(() => {
|
||||
expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('should NOT call handleRefreshWorkflowDraft when notRefreshWhenSyncError is true', async () => {
|
||||
const mockError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
// First parameter is notRefreshWhenSyncError
|
||||
await result.current.doSyncWorkflowDraft(true)
|
||||
})
|
||||
|
||||
// Wait a bit for async operations
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
|
||||
expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onError callback when draft_workflow_not_sync error occurs', async () => {
|
||||
const mockError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
const onError = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onError })
|
||||
})
|
||||
|
||||
expect(onError).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onSettled callback after error', async () => {
|
||||
const mockError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
const onSettled = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSettled })
|
||||
})
|
||||
|
||||
expect(onSettled).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('other error handling', () => {
|
||||
it('should NOT call handleRefreshWorkflowDraft for non-draft_workflow_not_sync errors', async () => {
|
||||
const mockError = createMockErrorResponse('some_other_error')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// Wait a bit for async operations
|
||||
await new Promise(resolve => setTimeout(resolve, 100))
|
||||
|
||||
expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle error without json method', async () => {
|
||||
const mockError = new Error('Network error')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
const onError = vi.fn()
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onError })
|
||||
})
|
||||
|
||||
expect(onError).toHaveBeenCalled()
|
||||
expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle error with bodyUsed already true', async () => {
|
||||
const mockError = {
|
||||
json: vi.fn(),
|
||||
bodyUsed: true,
|
||||
}
|
||||
mockSyncWorkflowDraft.mockRejectedValue(mockError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// Should not call json() when bodyUsed is true
|
||||
expect(mockError.json).not.toHaveBeenCalled()
|
||||
expect(mockHandleRefreshWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('read-only mode', () => {
|
||||
it('should not sync when nodes are read-only', async () => {
|
||||
mockGetNodesReadOnly.mockReturnValue(true)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not sync on page close when nodes are read-only', () => {
|
||||
mockGetNodesReadOnly.mockReturnValue(true)
|
||||
|
||||
// Mock sendBeacon
|
||||
const mockSendBeacon = vi.fn()
|
||||
Object.defineProperty(navigator, 'sendBeacon', {
|
||||
value: mockSendBeacon,
|
||||
writable: true,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.syncWorkflowDraftWhenPageClose()
|
||||
})
|
||||
|
||||
expect(mockSendBeacon).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('workflow data not loaded', () => {
|
||||
it('should not sync when workflow data is not loaded', async () => {
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockWorkflowStoreState({ isWorkflowDataLoaded: false }),
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('no appId', () => {
|
||||
it('should not sync when appId is not set', async () => {
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockWorkflowStoreState({ appId: null }),
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('node filtering', () => {
|
||||
it('should filter out temp nodes', async () => {
|
||||
mockGetNodes.mockReturnValue([
|
||||
{ id: 'node-1', type: 'start', data: { type: 'start' } },
|
||||
{ id: 'node-temp', type: 'custom', data: { type: 'custom', _isTempNode: true } },
|
||||
{ id: 'node-2', type: 'llm', data: { type: 'llm' } },
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
params: expect.objectContaining({
|
||||
graph: expect.objectContaining({
|
||||
nodes: expect.not.arrayContaining([
|
||||
expect.objectContaining({ id: 'node-temp' }),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
it('should remove internal underscore properties from nodes', async () => {
|
||||
mockGetNodes.mockReturnValue([
|
||||
{
|
||||
id: 'node-1',
|
||||
type: 'start',
|
||||
data: {
|
||||
type: 'start',
|
||||
_internalProp: 'should be removed',
|
||||
_anotherInternal: true,
|
||||
publicProp: 'should remain',
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
const callArgs = mockSyncWorkflowDraft.mock.calls[0][0]
|
||||
const sentNode = callArgs.params.graph.nodes[0]
|
||||
|
||||
expect(sentNode.data).not.toHaveProperty('_internalProp')
|
||||
expect(sentNode.data).not.toHaveProperty('_anotherInternal')
|
||||
expect(sentNode.data).toHaveProperty('publicProp', 'should remain')
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge filtering', () => {
|
||||
it('should filter out temp edges', async () => {
|
||||
mockStoreState.edges = [
|
||||
{ id: 'edge-1', source: 'node-1', target: 'node-2', data: {} },
|
||||
{ id: 'edge-temp', source: 'node-1', target: 'node-3', data: { _isTemp: true } },
|
||||
]
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
const callArgs = mockSyncWorkflowDraft.mock.calls[0][0]
|
||||
const sentEdges = callArgs.params.graph.edges
|
||||
|
||||
expect(sentEdges).toHaveLength(1)
|
||||
expect(sentEdges[0].id).toBe('edge-1')
|
||||
})
|
||||
|
||||
it('should remove internal underscore properties from edges', async () => {
|
||||
mockStoreState.edges = [
|
||||
{
|
||||
id: 'edge-1',
|
||||
source: 'node-1',
|
||||
target: 'node-2',
|
||||
data: {
|
||||
_internalEdgeProp: 'should be removed',
|
||||
publicEdgeProp: 'should remain',
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
const callArgs = mockSyncWorkflowDraft.mock.calls[0][0]
|
||||
const sentEdge = callArgs.params.graph.edges[0]
|
||||
|
||||
expect(sentEdge.data).not.toHaveProperty('_internalEdgeProp')
|
||||
expect(sentEdge.data).toHaveProperty('publicEdgeProp', 'should remain')
|
||||
})
|
||||
})
|
||||
|
||||
describe('viewport handling', () => {
|
||||
it('should send current viewport from transform', async () => {
|
||||
mockStoreState.transform = [100, 200, 1.5]
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSyncWorkflowDraft).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
params: expect.objectContaining({
|
||||
graph: expect.objectContaining({
|
||||
viewport: { x: 100, y: 200, zoom: 1.5 },
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('multi-tab concurrent editing scenario (END-TO-END TEST)', () => {
|
||||
/**
|
||||
* Simulates the complete multi-tab scenario to verify the fix works correctly.
|
||||
*
|
||||
* Scenario:
|
||||
* 1. Tab A and Tab B both have the workflow open with hash 'hash-v1'
|
||||
* 2. Tab A saves successfully, server returns 'hash-v2'
|
||||
* 3. Tab B tries to save with 'hash-v1', gets 'draft_workflow_not_sync' error
|
||||
* 4. Tab B should only update hash to 'hash-v2', not overwrite canvas
|
||||
* 5. Tab B can now retry save with correct hash
|
||||
*/
|
||||
it('should preserve canvas data during hash conflict resolution', async () => {
|
||||
// Initial state: both tabs have hash-v1
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockWorkflowStoreState({ syncWorkflowDraftHash: 'hash-v1' }),
|
||||
)
|
||||
|
||||
// Tab B tries to save with old hash, server returns error
|
||||
const syncError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(syncError)
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
// Tab B attempts to sync
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// Verify the sync was attempted with old hash
|
||||
expect(mockSyncWorkflowDraft).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
params: expect.objectContaining({
|
||||
hash: 'hash-v1',
|
||||
}),
|
||||
}),
|
||||
)
|
||||
|
||||
// Verify handleRefreshWorkflowDraft was called with true (not overwrite canvas)
|
||||
await waitFor(() => {
|
||||
expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
// The key assertion: only one argument (true) was passed
|
||||
expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleRefreshWorkflowDraft.mock.calls[0]).toEqual([true])
|
||||
})
|
||||
|
||||
it('should handle multiple consecutive sync failures gracefully', async () => {
|
||||
// Create fresh error for each call to avoid bodyUsed issue
|
||||
mockSyncWorkflowDraft
|
||||
.mockRejectedValueOnce(createMockErrorResponse('draft_workflow_not_sync'))
|
||||
.mockRejectedValueOnce(createMockErrorResponse('draft_workflow_not_sync'))
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
// First sync attempt
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// Wait for first refresh call
|
||||
await waitFor(() => {
|
||||
expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Second sync attempt
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft()
|
||||
})
|
||||
|
||||
// Both should call handleRefreshWorkflowDraft with true
|
||||
await waitFor(() => {
|
||||
expect(mockHandleRefreshWorkflowDraft).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
mockHandleRefreshWorkflowDraft.mock.calls.forEach((call) => {
|
||||
expect(call).toEqual([true])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('callbacks behavior', () => {
|
||||
it('should not call onSuccess when sync fails', async () => {
|
||||
const syncError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(syncError)
|
||||
const onSuccess = vi.fn()
|
||||
const onError = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSuccess, onError })
|
||||
})
|
||||
|
||||
expect(onSuccess).not.toHaveBeenCalled()
|
||||
expect(onError).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should always call onSettled regardless of success or failure', async () => {
|
||||
const onSettled = vi.fn()
|
||||
|
||||
const { result } = renderHook(() => useNodesSyncDraft())
|
||||
|
||||
// Test success case
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSettled })
|
||||
})
|
||||
expect(onSettled).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Reset
|
||||
onSettled.mockClear()
|
||||
|
||||
// Test failure case
|
||||
const syncError = createMockErrorResponse('draft_workflow_not_sync')
|
||||
mockSyncWorkflowDraft.mockRejectedValue(syncError)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.doSyncWorkflowDraft(false, { onSettled })
|
||||
})
|
||||
expect(onSettled).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -115,7 +115,7 @@ export const useNodesSyncDraft = () => {
|
||||
if (error && error.json && !error.bodyUsed) {
|
||||
error.json().then((err: any) => {
|
||||
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError)
|
||||
handleRefreshWorkflowDraft()
|
||||
handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
}
|
||||
callback?.onError?.()
|
||||
|
||||
@@ -0,0 +1,556 @@
|
||||
/**
|
||||
* Test Suite for useWorkflowRefreshDraft Hook
|
||||
*
|
||||
* PURPOSE:
|
||||
* This hook is responsible for refreshing workflow draft data from the server.
|
||||
* The key fix being tested is the `notUpdateCanvas` parameter behavior.
|
||||
*
|
||||
* MULTI-TAB PROBLEM SCENARIO:
|
||||
* 1. User opens the same workflow in Tab A and Tab B (both have hash: v1)
|
||||
* 2. Tab A saves successfully, server returns new hash: v2
|
||||
* 3. Tab B tries to save with old hash: v1, server returns 400 error (draft_workflow_not_sync)
|
||||
* 4. BEFORE FIX: handleRefreshWorkflowDraft() was called without args, which fetched
|
||||
* draft AND overwrote canvas - user lost unsaved changes in Tab B
|
||||
* 5. AFTER FIX: handleRefreshWorkflowDraft(true) is called, which fetches draft but
|
||||
* only updates hash, preserving user's canvas changes
|
||||
*
|
||||
* TESTING STRATEGY:
|
||||
* We don't simulate actual tab switching UI behavior. Instead, we test the hook's
|
||||
* response to specific inputs:
|
||||
* - When notUpdateCanvas=true: should NOT call handleUpdateWorkflowCanvas
|
||||
* - When notUpdateCanvas=false/undefined: should call handleUpdateWorkflowCanvas
|
||||
*
|
||||
* This is behavior-driven testing - we verify "what the code does when given specific
|
||||
* inputs" rather than simulating complete user interaction flows.
|
||||
*/
|
||||
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useWorkflowRefreshDraft } from './use-workflow-refresh-draft'
|
||||
|
||||
// Mock the workflow service
|
||||
const mockFetchWorkflowDraft = vi.fn()
|
||||
vi.mock('@/service/workflow', () => ({
|
||||
fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args),
|
||||
}))
|
||||
|
||||
// Mock the workflow update hook
|
||||
const mockHandleUpdateWorkflowCanvas = vi.fn()
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowUpdate: () => ({
|
||||
handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock store state
|
||||
const mockSetSyncWorkflowDraftHash = vi.fn()
|
||||
const mockSetIsSyncingWorkflowDraft = vi.fn()
|
||||
const mockSetEnvironmentVariables = vi.fn()
|
||||
const mockSetEnvSecrets = vi.fn()
|
||||
const mockSetConversationVariables = vi.fn()
|
||||
const mockSetIsWorkflowDataLoaded = vi.fn()
|
||||
const mockCancelDebouncedSync = vi.fn()
|
||||
|
||||
const createMockStoreState = (overrides = {}) => ({
|
||||
appId: 'test-app-id',
|
||||
setSyncWorkflowDraftHash: mockSetSyncWorkflowDraftHash,
|
||||
setIsSyncingWorkflowDraft: mockSetIsSyncingWorkflowDraft,
|
||||
setEnvironmentVariables: mockSetEnvironmentVariables,
|
||||
setEnvSecrets: mockSetEnvSecrets,
|
||||
setConversationVariables: mockSetConversationVariables,
|
||||
setIsWorkflowDataLoaded: mockSetIsWorkflowDataLoaded,
|
||||
isWorkflowDataLoaded: true,
|
||||
debouncedSyncWorkflowDraft: {
|
||||
cancel: mockCancelDebouncedSync,
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const mockWorkflowStoreGetState = vi.fn()
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useWorkflowStore: () => ({
|
||||
getState: mockWorkflowStoreGetState,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Default mock response from fetchWorkflowDraft
|
||||
const createMockDraftResponse = (overrides = {}) => ({
|
||||
hash: 'new-hash-12345',
|
||||
graph: {
|
||||
nodes: [{ id: 'node-1', type: 'start', data: {} }],
|
||||
edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2' }],
|
||||
viewport: { x: 100, y: 200, zoom: 1.5 },
|
||||
},
|
||||
environment_variables: [
|
||||
{ id: 'env-1', name: 'API_KEY', value: 'secret-key', value_type: 'secret' },
|
||||
{ id: 'env-2', name: 'BASE_URL', value: 'https://api.example.com', value_type: 'string' },
|
||||
],
|
||||
conversation_variables: [
|
||||
{ id: 'conv-1', name: 'user_input', value: 'test' },
|
||||
],
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('useWorkflowRefreshDraft', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockWorkflowStoreGetState.mockReturnValue(createMockStoreState())
|
||||
mockFetchWorkflowDraft.mockResolvedValue(createMockDraftResponse())
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('handleRefreshWorkflowDraft function', () => {
|
||||
it('should return handleRefreshWorkflowDraft function', () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
expect(result.current.handleRefreshWorkflowDraft).toBeDefined()
|
||||
expect(typeof result.current.handleRefreshWorkflowDraft).toBe('function')
|
||||
})
|
||||
})
|
||||
|
||||
describe('notUpdateCanvas parameter behavior (THE KEY FIX)', () => {
|
||||
it('should NOT call handleUpdateWorkflowCanvas when notUpdateCanvas is true', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/apps/test-app-id/workflows/draft')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-12345')
|
||||
})
|
||||
|
||||
// THE KEY ASSERTION: Canvas should NOT be updated when notUpdateCanvas is true
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call handleUpdateWorkflowCanvas when notUpdateCanvas is false', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(false)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/apps/test-app-id/workflows/draft')
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
// Canvas SHOULD be updated when notUpdateCanvas is false
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
|
||||
nodes: [{ id: 'node-1', type: 'start', data: {} }],
|
||||
edges: [{ id: 'edge-1', source: 'node-1', target: 'node-2' }],
|
||||
viewport: { x: 100, y: 200, zoom: 1.5 },
|
||||
})
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-12345')
|
||||
})
|
||||
})
|
||||
|
||||
it('should call handleUpdateWorkflowCanvas when notUpdateCanvas is undefined (default)', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowDraft).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
// Canvas SHOULD be updated when notUpdateCanvas is undefined
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should still update hash even when notUpdateCanvas is true', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-12345')
|
||||
})
|
||||
|
||||
// Verify canvas was NOT updated
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should still update environment variables when notUpdateCanvas is true', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([
|
||||
{ id: 'env-1', name: 'API_KEY', value: '[__HIDDEN__]', value_type: 'secret' },
|
||||
{ id: 'env-2', name: 'BASE_URL', value: 'https://api.example.com', value_type: 'string' },
|
||||
])
|
||||
})
|
||||
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should still update env secrets when notUpdateCanvas is true', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetEnvSecrets).toHaveBeenCalledWith({
|
||||
'env-1': 'secret-key',
|
||||
})
|
||||
})
|
||||
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should still update conversation variables when notUpdateCanvas is true', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetConversationVariables).toHaveBeenCalledWith([
|
||||
{ id: 'conv-1', name: 'user_input', value: 'test' },
|
||||
])
|
||||
})
|
||||
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('syncing state management', () => {
|
||||
it('should set isSyncingWorkflowDraft to true before fetch', () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSetIsSyncingWorkflowDraft).toHaveBeenCalledWith(true)
|
||||
})
|
||||
|
||||
it('should set isSyncingWorkflowDraft to false after fetch completes', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetIsSyncingWorkflowDraft).toHaveBeenCalledWith(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('should set isSyncingWorkflowDraft to false even when fetch fails', async () => {
|
||||
mockFetchWorkflowDraft.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetIsSyncingWorkflowDraft).toHaveBeenCalledWith(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isWorkflowDataLoaded flag management', () => {
|
||||
it('should set isWorkflowDataLoaded to false before fetch when it was true', () => {
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockStoreState({ isWorkflowDataLoaded: true }),
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockSetIsWorkflowDataLoaded).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
it('should set isWorkflowDataLoaded to true after fetch succeeds', async () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetIsWorkflowDataLoaded).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('should restore isWorkflowDataLoaded when fetch fails and it was previously loaded', async () => {
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockStoreState({ isWorkflowDataLoaded: true }),
|
||||
)
|
||||
mockFetchWorkflowDraft.mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
// Should restore to true because wasLoaded was true
|
||||
expect(mockSetIsWorkflowDataLoaded).toHaveBeenLastCalledWith(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('debounced sync cancellation', () => {
|
||||
it('should cancel debounced sync before fetching draft', () => {
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
|
||||
expect(mockCancelDebouncedSync).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle case when debouncedSyncWorkflowDraft has no cancel method', () => {
|
||||
mockWorkflowStoreGetState.mockReturnValue(
|
||||
createMockStoreState({ debouncedSyncWorkflowDraft: {} }),
|
||||
)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
// Should not throw
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft()
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty graph in response', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-empty',
|
||||
graph: null,
|
||||
environment_variables: [],
|
||||
conversation_variables: [],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(false)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
|
||||
nodes: [],
|
||||
edges: [],
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle missing viewport in response', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-no-viewport',
|
||||
graph: {
|
||||
nodes: [{ id: 'node-1' }],
|
||||
edges: [],
|
||||
viewport: null,
|
||||
},
|
||||
environment_variables: [],
|
||||
conversation_variables: [],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(false)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
|
||||
nodes: [{ id: 'node-1' }],
|
||||
edges: [],
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle missing environment_variables in response', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-no-env',
|
||||
graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } },
|
||||
environment_variables: undefined,
|
||||
conversation_variables: [],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([])
|
||||
expect(mockSetEnvSecrets).toHaveBeenCalledWith({})
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle missing conversation_variables in response', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-no-conv',
|
||||
graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } },
|
||||
environment_variables: [],
|
||||
conversation_variables: undefined,
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetConversationVariables).toHaveBeenCalledWith([])
|
||||
})
|
||||
})
|
||||
|
||||
it('should filter only secret type for envSecrets', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-mixed-env',
|
||||
graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } },
|
||||
environment_variables: [
|
||||
{ id: 'env-1', name: 'SECRET_KEY', value: 'secret-value', value_type: 'secret' },
|
||||
{ id: 'env-2', name: 'PUBLIC_URL', value: 'https://example.com', value_type: 'string' },
|
||||
{ id: 'env-3', name: 'ANOTHER_SECRET', value: 'another-secret', value_type: 'secret' },
|
||||
],
|
||||
conversation_variables: [],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetEnvSecrets).toHaveBeenCalledWith({
|
||||
'env-1': 'secret-value',
|
||||
'env-3': 'another-secret',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide secret values in environment variables', async () => {
|
||||
mockFetchWorkflowDraft.mockResolvedValue({
|
||||
hash: 'hash-secrets',
|
||||
graph: { nodes: [], edges: [], viewport: { x: 0, y: 0, zoom: 1 } },
|
||||
environment_variables: [
|
||||
{ id: 'env-1', name: 'SECRET_KEY', value: 'super-secret', value_type: 'secret' },
|
||||
{ id: 'env-2', name: 'PUBLIC_URL', value: 'https://example.com', value_type: 'string' },
|
||||
],
|
||||
conversation_variables: [],
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([
|
||||
{ id: 'env-1', name: 'SECRET_KEY', value: '[__HIDDEN__]', value_type: 'secret' },
|
||||
{ id: 'env-2', name: 'PUBLIC_URL', value: 'https://example.com', value_type: 'string' },
|
||||
])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('multi-tab scenario simulation (THE BUG FIX VERIFICATION)', () => {
|
||||
/**
|
||||
* This test verifies the fix for the multi-tab scenario:
|
||||
* 1. User opens workflow in Tab A and Tab B
|
||||
* 2. Tab A saves draft successfully
|
||||
* 3. Tab B tries to save but gets 'draft_workflow_not_sync' error (hash mismatch)
|
||||
* 4. BEFORE FIX: Tab B would fetch draft and overwrite canvas with old data
|
||||
* 5. AFTER FIX: Tab B only updates hash, preserving user's canvas changes
|
||||
*/
|
||||
it('should only update hash when called with notUpdateCanvas=true (simulating sync error recovery)', async () => {
|
||||
const mockResponse = createMockDraftResponse()
|
||||
mockFetchWorkflowDraft.mockResolvedValue(mockResponse)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
// Simulate the sync error recovery scenario where notUpdateCanvas is true
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(true)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowDraft).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
// Hash should be updated for next sync attempt
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-12345')
|
||||
})
|
||||
|
||||
// Canvas should NOT be updated - user's changes are preserved
|
||||
expect(mockHandleUpdateWorkflowCanvas).not.toHaveBeenCalled()
|
||||
|
||||
// Other states should still be updated
|
||||
expect(mockSetEnvironmentVariables).toHaveBeenCalled()
|
||||
expect(mockSetConversationVariables).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should update canvas when called with notUpdateCanvas=false (normal refresh)', async () => {
|
||||
const mockResponse = createMockDraftResponse()
|
||||
mockFetchWorkflowDraft.mockResolvedValue(mockResponse)
|
||||
|
||||
const { result } = renderHook(() => useWorkflowRefreshDraft())
|
||||
|
||||
// Simulate normal refresh scenario
|
||||
act(() => {
|
||||
result.current.handleRefreshWorkflowDraft(false)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowDraft).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSetSyncWorkflowDraftHash).toHaveBeenCalledWith('new-hash-12345')
|
||||
})
|
||||
|
||||
// Canvas SHOULD be updated in normal refresh
|
||||
await waitFor(() => {
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ export const useWorkflowRefreshDraft = () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
|
||||
|
||||
const handleRefreshWorkflowDraft = useCallback(() => {
|
||||
const handleRefreshWorkflowDraft = useCallback((notUpdateCanvas?: boolean) => {
|
||||
const {
|
||||
appId,
|
||||
setSyncWorkflowDraftHash,
|
||||
@@ -31,12 +31,14 @@ export const useWorkflowRefreshDraft = () => {
|
||||
fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
|
||||
.then((response) => {
|
||||
// Ensure we have a valid workflow structure with viewport
|
||||
const workflowData: WorkflowDataUpdater = {
|
||||
nodes: response.graph?.nodes || [],
|
||||
edges: response.graph?.edges || [],
|
||||
viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 },
|
||||
if (!notUpdateCanvas) {
|
||||
const workflowData: WorkflowDataUpdater = {
|
||||
nodes: response.graph?.nodes || [],
|
||||
edges: response.graph?.edges || [],
|
||||
viewport: response.graph?.viewport || { x: 0, y: 0, zoom: 1 },
|
||||
}
|
||||
handleUpdateWorkflowCanvas(workflowData)
|
||||
}
|
||||
handleUpdateWorkflowCanvas(workflowData)
|
||||
setSyncWorkflowDraftHash(response.hash)
|
||||
setEnvSecrets((response.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => {
|
||||
acc[env.id] = env.value
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
useWorkflowReadOnly,
|
||||
} from '../hooks'
|
||||
import { useStore, useWorkflowStore } from '../store'
|
||||
import { BlockEnum, ControlMode, WorkflowRunningStatus } from '../types'
|
||||
import { BlockEnum, ControlMode } from '../types'
|
||||
import {
|
||||
getLayoutByDagre,
|
||||
getLayoutForChildNodes,
|
||||
@@ -36,17 +36,12 @@ export const useWorkflowInteractions = () => {
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
|
||||
const handleCancelDebugAndPreviewPanel = useCallback(() => {
|
||||
const { workflowRunningData } = workflowStore.getState()
|
||||
const runningStatus = workflowRunningData?.result?.status
|
||||
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
|
||||
workflowStore.setState({
|
||||
showDebugAndPreviewPanel: false,
|
||||
...(isActiveRun ? {} : { workflowRunningData: undefined }),
|
||||
workflowRunningData: undefined,
|
||||
})
|
||||
if (!isActiveRun) {
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}, [workflowStore, handleNodeCancelRunningStatus, handleEdgeCancelRunningStatus])
|
||||
|
||||
return {
|
||||
|
||||
@@ -21,7 +21,7 @@ import {
|
||||
import { BlockEnum } from '../../types'
|
||||
import ConversationVariableModal from './conversation-variable-modal'
|
||||
import Empty from './empty'
|
||||
import { useChat } from './hooks/use-chat'
|
||||
import { useChat } from './hooks'
|
||||
import UserInput from './user-input'
|
||||
|
||||
type ChatWrapperProps = {
|
||||
|
||||
516
web/app/components/workflow/panel/debug-and-preview/hooks.ts
Normal file
516
web/app/components/workflow/panel/debug-and-preview/hooks.ts
Normal file
@@ -0,0 +1,516 @@
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type {
|
||||
ChatItem,
|
||||
ChatItemInTree,
|
||||
Inputs,
|
||||
} from '@/app/components/base/chat/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import { uniqBy } from 'es-toolkit/compat'
|
||||
import { produce, setAutoFreeze } from 'immer'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
getProcessedInputs,
|
||||
processOpeningStatement,
|
||||
} from '@/app/components/base/chat/chat/utils'
|
||||
import { getThreadMessages } from '@/app/components/base/chat/utils'
|
||||
import {
|
||||
getProcessedFiles,
|
||||
getProcessedFilesFromResponse,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { useInvalidAllLastRun } from '@/service/use-workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../constants'
|
||||
import {
|
||||
useSetWorkflowVarsWithValue,
|
||||
useWorkflowRun,
|
||||
} from '../../hooks'
|
||||
import { useHooksStore } from '../../hooks-store'
|
||||
import { useWorkflowStore } from '../../store'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '../../types'
|
||||
|
||||
type GetAbortController = (abortController: AbortController) => void
|
||||
type SendCallback = {
|
||||
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<any>
|
||||
}
|
||||
export const useChat = (
|
||||
config: any,
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
},
|
||||
prevChatTree?: ChatItemInTree[],
|
||||
stopChat?: (taskId: string) => void,
|
||||
) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const hasStopResponded = useRef(false)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const conversationId = useRef('')
|
||||
const taskIdRef = useRef('')
|
||||
const [isResponding, setIsResponding] = useState(false)
|
||||
const isRespondingRef = useRef(false)
|
||||
const configsMap = useHooksStore(s => s.configsMap)
|
||||
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
|
||||
const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([])
|
||||
const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null)
|
||||
const {
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
} = workflowStore.getState()
|
||||
|
||||
const handleResponding = useCallback((isResponding: boolean) => {
|
||||
setIsResponding(isResponding)
|
||||
isRespondingRef.current = isResponding
|
||||
}, [])
|
||||
|
||||
const [chatTree, setChatTree] = useState<ChatItemInTree[]>(prevChatTree || [])
|
||||
const chatTreeRef = useRef<ChatItemInTree[]>(chatTree)
|
||||
const [targetMessageId, setTargetMessageId] = useState<string>()
|
||||
const threadMessages = useMemo(() => getThreadMessages(chatTree, targetMessageId), [chatTree, targetMessageId])
|
||||
|
||||
const getIntroduction = useCallback((str: string) => {
|
||||
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
|
||||
}, [formSettings?.inputs, formSettings?.inputsForm])
|
||||
|
||||
/** Final chat list that will be rendered */
|
||||
const chatList = useMemo(() => {
|
||||
const ret = [...threadMessages]
|
||||
if (config?.opening_statement) {
|
||||
const index = threadMessages.findIndex(item => item.isOpeningStatement)
|
||||
|
||||
if (index > -1) {
|
||||
ret[index] = {
|
||||
...ret[index],
|
||||
content: getIntroduction(config.opening_statement),
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
}
|
||||
}
|
||||
else {
|
||||
ret.unshift({
|
||||
id: `${Date.now()}`,
|
||||
content: getIntroduction(config.opening_statement),
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
|
||||
|
||||
useEffect(() => {
|
||||
setAutoFreeze(false)
|
||||
return () => {
|
||||
setAutoFreeze(true)
|
||||
}
|
||||
}, [])
|
||||
|
||||
/** Find the target node by bfs and then operate on it */
|
||||
const produceChatTreeNode = useCallback((targetId: string, operation: (node: ChatItemInTree) => void) => {
|
||||
return produce(chatTreeRef.current, (draft) => {
|
||||
const queue: ChatItemInTree[] = [...draft]
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!
|
||||
if (current.id === targetId) {
|
||||
operation(current)
|
||||
break
|
||||
}
|
||||
if (current.children)
|
||||
queue.push(...current.children)
|
||||
}
|
||||
})
|
||||
}, [])
|
||||
|
||||
const handleStop = useCallback(() => {
|
||||
hasStopResponded.current = true
|
||||
handleResponding(false)
|
||||
if (stopChat && taskIdRef.current)
|
||||
stopChat(taskIdRef.current)
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
if (suggestedQuestionsAbortControllerRef.current)
|
||||
suggestedQuestionsAbortControllerRef.current.abort()
|
||||
}, [handleResponding, setIterTimes, setLoopTimes, stopChat])
|
||||
|
||||
const handleRestart = useCallback(() => {
|
||||
conversationId.current = ''
|
||||
taskIdRef.current = ''
|
||||
handleStop()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
setChatTree([])
|
||||
setSuggestQuestions([])
|
||||
}, [
|
||||
handleStop,
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
])
|
||||
|
||||
const updateCurrentQAOnTree = useCallback(({
|
||||
parentId,
|
||||
responseItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
}: {
|
||||
parentId?: string
|
||||
responseItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
questionItem: ChatItem
|
||||
}) => {
|
||||
let nextState: ChatItemInTree[]
|
||||
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] }
|
||||
if (!parentId && !chatTree.some(item => [placeholderQuestionId, questionItem.id].includes(item.id))) {
|
||||
// QA whose parent is not provided is considered as a first message of the conversation,
|
||||
// and it should be a root node of the chat tree
|
||||
nextState = produce(chatTree, (draft) => {
|
||||
draft.push(currentQA)
|
||||
})
|
||||
}
|
||||
else {
|
||||
// find the target QA in the tree and update it; if not found, insert it to its parent node
|
||||
nextState = produceChatTreeNode(parentId!, (parentNode) => {
|
||||
const questionNodeIndex = parentNode.children!.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
if (questionNodeIndex === -1)
|
||||
parentNode.children!.push(currentQA)
|
||||
else
|
||||
parentNode.children![questionNodeIndex] = currentQA
|
||||
})
|
||||
}
|
||||
setChatTree(nextState)
|
||||
chatTreeRef.current = nextState
|
||||
}, [chatTree, produceChatTreeNode])
|
||||
|
||||
const handleSend = useCallback((
|
||||
params: {
|
||||
query: string
|
||||
files?: FileEntity[]
|
||||
parent_message_id?: string
|
||||
[key: string]: any
|
||||
},
|
||||
{
|
||||
onGetSuggestedQuestions,
|
||||
}: SendCallback,
|
||||
) => {
|
||||
if (isRespondingRef.current) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
|
||||
|
||||
const placeholderQuestionId = `question-${Date.now()}`
|
||||
const questionItem = {
|
||||
id: placeholderQuestionId,
|
||||
content: params.query,
|
||||
isAnswer: false,
|
||||
message_files: params.files,
|
||||
parentMessageId: params.parent_message_id,
|
||||
}
|
||||
|
||||
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
|
||||
const placeholderAnswerItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
|
||||
}
|
||||
|
||||
setTargetMessageId(parentMessage?.id)
|
||||
updateCurrentQAOnTree({
|
||||
parentId: params.parent_message_id,
|
||||
responseItem: placeholderAnswerItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
})
|
||||
|
||||
// answer
|
||||
const responseItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
agent_thoughts: [],
|
||||
message_files: [],
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex: parentMessage?.children?.length ?? chatTree.length,
|
||||
}
|
||||
|
||||
handleResponding(true)
|
||||
|
||||
const { files, inputs, ...restParams } = params
|
||||
const bodyParams = {
|
||||
files: getProcessedFiles(files || []),
|
||||
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
|
||||
...restParams,
|
||||
}
|
||||
if (bodyParams?.files?.length) {
|
||||
bodyParams.files = bodyParams.files.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
let hasSetResponseId = false
|
||||
|
||||
handleRun(
|
||||
bodyParams,
|
||||
{
|
||||
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
|
||||
responseItem.content = responseItem.content + message
|
||||
|
||||
if (messageId && !hasSetResponseId) {
|
||||
questionItem.id = `question-${messageId}`
|
||||
responseItem.id = messageId
|
||||
responseItem.parentMessageId = questionItem.id
|
||||
hasSetResponseId = true
|
||||
}
|
||||
|
||||
if (isFirstMessage && newConversationId)
|
||||
conversationId.current = newConversationId
|
||||
|
||||
taskIdRef.current = taskId
|
||||
if (messageId)
|
||||
responseItem.id = messageId
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
async onCompleted(hasError?: boolean, errorMessage?: string) {
|
||||
handleResponding(false)
|
||||
fetchInspectVars({})
|
||||
invalidAllLastRun()
|
||||
|
||||
if (hasError) {
|
||||
if (errorMessage) {
|
||||
responseItem.content = errorMessage
|
||||
responseItem.isError = true
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (config?.suggested_questions_after_answer?.enabled && !hasStopResponded.current && onGetSuggestedQuestions) {
|
||||
try {
|
||||
const { data }: any = await onGetSuggestedQuestions(
|
||||
responseItem.id,
|
||||
newAbortController => suggestedQuestionsAbortControllerRef.current = newAbortController,
|
||||
)
|
||||
setSuggestQuestions(data)
|
||||
}
|
||||
// eslint-disable-next-line unused-imports/no-unused-vars
|
||||
catch (error) {
|
||||
setSuggestQuestions([])
|
||||
}
|
||||
}
|
||||
},
|
||||
onMessageEnd: (messageEnd) => {
|
||||
responseItem.citation = messageEnd.metadata?.retriever_resources || []
|
||||
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
|
||||
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
responseItem.content = messageReplace.answer
|
||||
},
|
||||
onError() {
|
||||
handleResponding(false)
|
||||
},
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
taskIdRef.current = task_id
|
||||
responseItem.workflow_run_id = workflow_run_id
|
||||
responseItem.workflowProcess = {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onWorkflowFinished: ({ data }) => {
|
||||
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onIterationStart: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
})
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onIterationFinish: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onLoopStart: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
})
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onLoopFinish: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onNodeStarted: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as any)
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onNodeRetry: ({ data }) => {
|
||||
responseItem.workflowProcess!.tracing!.push(data)
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onNodeFinished: ({ data }) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
onAgentLog: ({ data }) => {
|
||||
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentNodeIndex > -1) {
|
||||
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
|
||||
|
||||
if (current.execution_metadata) {
|
||||
if (current.execution_metadata.agent_log) {
|
||||
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
|
||||
if (currentLogIndex > -1) {
|
||||
current.execution_metadata.agent_log[currentLogIndex] = {
|
||||
...current.execution_metadata.agent_log[currentLogIndex],
|
||||
...data,
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log.push(data)
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log = [data]
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata = {
|
||||
agent_log: [data],
|
||||
} as any
|
||||
}
|
||||
|
||||
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
|
||||
...current,
|
||||
}
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [threadMessages, chatTree.length, updateCurrentQAOnTree, handleResponding, formSettings?.inputsForm, handleRun, notify, t, config?.suggested_questions_after_answer?.enabled, fetchInspectVars, invalidAllLastRun])
|
||||
|
||||
return {
|
||||
conversationId: conversationId.current,
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
handleSend,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
isResponding,
|
||||
suggestedQuestions,
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { ChatItem, ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
|
||||
export type ChatConfig = {
|
||||
opening_statement?: string
|
||||
suggested_questions?: string[]
|
||||
suggested_questions_after_answer?: {
|
||||
enabled?: boolean
|
||||
}
|
||||
text_to_speech?: unknown
|
||||
speech_to_text?: unknown
|
||||
retriever_resource?: unknown
|
||||
sensitive_word_avoidance?: unknown
|
||||
file_upload?: unknown
|
||||
}
|
||||
|
||||
export type GetAbortController = (abortController: AbortController) => void
|
||||
|
||||
export type SendCallback = {
|
||||
onGetSuggestedQuestions?: (responseItemId: string, getAbortController: GetAbortController) => Promise<unknown>
|
||||
}
|
||||
|
||||
export type SendParams = {
|
||||
query: string
|
||||
files?: FileEntity[]
|
||||
parent_message_id?: string
|
||||
inputs?: Record<string, unknown>
|
||||
conversation_id?: string
|
||||
}
|
||||
|
||||
export type UpdateCurrentQAParams = {
|
||||
parentId?: string
|
||||
responseItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
questionItem: ChatItem
|
||||
}
|
||||
|
||||
export type ChatTreeUpdater = (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void
|
||||
@@ -1,90 +0,0 @@
|
||||
import { useCallback } from 'react'
|
||||
import { DEFAULT_ITER_TIMES, DEFAULT_LOOP_TIMES } from '../../../constants'
|
||||
import { useEdgesInteractionsWithoutSync } from '../../../hooks/use-edges-interactions-without-sync'
|
||||
import { useNodesInteractionsWithoutSync } from '../../../hooks/use-nodes-interactions-without-sync'
|
||||
import { useStore, useWorkflowStore } from '../../../store'
|
||||
import { WorkflowRunningStatus } from '../../../types'
|
||||
|
||||
type UseChatFlowControlParams = {
|
||||
stopChat?: (taskId: string) => void
|
||||
}
|
||||
|
||||
export function useChatFlowControl({
|
||||
stopChat,
|
||||
}: UseChatFlowControlParams) {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const setIsResponding = useStore(s => s.setIsResponding)
|
||||
const resetChatPreview = useStore(s => s.resetChatPreview)
|
||||
const setActiveTaskId = useStore(s => s.setActiveTaskId)
|
||||
const setHasStopResponded = useStore(s => s.setHasStopResponded)
|
||||
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
|
||||
const invalidateRun = useStore(s => s.invalidateRun)
|
||||
const { handleNodeCancelRunningStatus } = useNodesInteractionsWithoutSync()
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
|
||||
const { setIterTimes, setLoopTimes } = workflowStore.getState()
|
||||
|
||||
const handleResponding = useCallback((responding: boolean) => {
|
||||
setIsResponding(responding)
|
||||
}, [setIsResponding])
|
||||
|
||||
const handleStop = useCallback(() => {
|
||||
const {
|
||||
activeTaskId,
|
||||
suggestedQuestionsAbortController,
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
} = workflowStore.getState()
|
||||
const runningStatus = workflowRunningData?.result?.status
|
||||
const isActiveRun = runningStatus === WorkflowRunningStatus.Running || runningStatus === WorkflowRunningStatus.Waiting
|
||||
setHasStopResponded(true)
|
||||
handleResponding(false)
|
||||
if (stopChat && activeTaskId)
|
||||
stopChat(activeTaskId)
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
if (suggestedQuestionsAbortController)
|
||||
suggestedQuestionsAbortController.abort()
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
setActiveTaskId('')
|
||||
invalidateRun()
|
||||
if (isActiveRun && workflowRunningData) {
|
||||
setWorkflowRunningData({
|
||||
...workflowRunningData,
|
||||
result: {
|
||||
...workflowRunningData.result,
|
||||
status: WorkflowRunningStatus.Stopped,
|
||||
},
|
||||
})
|
||||
}
|
||||
if (isActiveRun) {
|
||||
handleNodeCancelRunningStatus()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}
|
||||
}, [
|
||||
handleResponding,
|
||||
setIterTimes,
|
||||
setLoopTimes,
|
||||
stopChat,
|
||||
workflowStore,
|
||||
setHasStopResponded,
|
||||
setSuggestedQuestionsAbortController,
|
||||
setActiveTaskId,
|
||||
invalidateRun,
|
||||
handleNodeCancelRunningStatus,
|
||||
handleEdgeCancelRunningStatus,
|
||||
])
|
||||
|
||||
const handleRestart = useCallback(() => {
|
||||
handleStop()
|
||||
resetChatPreview()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setLoopTimes(DEFAULT_LOOP_TIMES)
|
||||
}, [handleStop, setIterTimes, setLoopTimes, resetChatPreview])
|
||||
|
||||
return {
|
||||
handleResponding,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { processOpeningStatement } from '@/app/components/base/chat/chat/utils'
|
||||
import { getThreadMessages } from '@/app/components/base/chat/utils'
|
||||
|
||||
type UseChatListParams = {
|
||||
chatTree: ChatItemInTree[]
|
||||
targetMessageId: string | undefined
|
||||
config: {
|
||||
opening_statement?: string
|
||||
suggested_questions?: string[]
|
||||
} | undefined
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
}
|
||||
}
|
||||
|
||||
export function useChatList({
|
||||
chatTree,
|
||||
targetMessageId,
|
||||
config,
|
||||
formSettings,
|
||||
}: UseChatListParams) {
|
||||
const threadMessages = useMemo(
|
||||
() => getThreadMessages(chatTree, targetMessageId),
|
||||
[chatTree, targetMessageId],
|
||||
)
|
||||
|
||||
const getIntroduction = useCallback((str: string) => {
|
||||
return processOpeningStatement(str, formSettings?.inputs || {}, formSettings?.inputsForm || [])
|
||||
}, [formSettings?.inputs, formSettings?.inputsForm])
|
||||
|
||||
const chatList = useMemo(() => {
|
||||
const ret = [...threadMessages]
|
||||
if (config?.opening_statement) {
|
||||
const index = threadMessages.findIndex(item => item.isOpeningStatement)
|
||||
|
||||
if (index > -1) {
|
||||
ret[index] = {
|
||||
...ret[index],
|
||||
content: getIntroduction(config.opening_statement),
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
}
|
||||
}
|
||||
else {
|
||||
ret.unshift({
|
||||
id: `${Date.now()}`,
|
||||
content: getIntroduction(config.opening_statement),
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
suggestedQuestions: config.suggested_questions?.map((item: string) => getIntroduction(item)),
|
||||
})
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}, [threadMessages, config?.opening_statement, getIntroduction, config?.suggested_questions])
|
||||
|
||||
return {
|
||||
threadMessages,
|
||||
chatList,
|
||||
getIntroduction,
|
||||
}
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
import type { SendCallback, SendParams, UpdateCurrentQAParams } from './types'
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItem, ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { uniqBy } from 'es-toolkit/compat'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getProcessedInputs } from '@/app/components/base/chat/chat/utils'
|
||||
import { getProcessedFiles, getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { useInvalidAllLastRun } from '@/service/use-workflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { useSetWorkflowVarsWithValue, useWorkflowRun } from '../../../hooks'
|
||||
import { useHooksStore } from '../../../hooks-store'
|
||||
import { useStore, useWorkflowStore } from '../../../store'
|
||||
import { createWorkflowEventHandlers } from './use-workflow-event-handlers'
|
||||
|
||||
type UseChatMessageSenderParams = {
|
||||
threadMessages: ChatItemInTree[]
|
||||
config?: {
|
||||
suggested_questions_after_answer?: {
|
||||
enabled?: boolean
|
||||
}
|
||||
}
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
}
|
||||
handleResponding: (responding: boolean) => void
|
||||
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
|
||||
}
|
||||
|
||||
export function useChatMessageSender({
|
||||
threadMessages,
|
||||
config,
|
||||
formSettings,
|
||||
handleResponding,
|
||||
updateCurrentQAOnTree,
|
||||
}: UseChatMessageSenderParams) {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const workflowStore = useWorkflowStore()
|
||||
|
||||
const configsMap = useHooksStore(s => s.configsMap)
|
||||
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue()
|
||||
const setConversationId = useStore(s => s.setConversationId)
|
||||
const setTargetMessageId = useStore(s => s.setTargetMessageId)
|
||||
const setSuggestedQuestions = useStore(s => s.setSuggestedQuestions)
|
||||
const setActiveTaskId = useStore(s => s.setActiveTaskId)
|
||||
const setSuggestedQuestionsAbortController = useStore(s => s.setSuggestedQuestionsAbortController)
|
||||
const startRun = useStore(s => s.startRun)
|
||||
|
||||
const handleSend = useCallback((
|
||||
params: SendParams,
|
||||
{ onGetSuggestedQuestions }: SendCallback,
|
||||
) => {
|
||||
if (workflowStore.getState().isResponding) {
|
||||
notify({ type: 'info', message: t('errorMessage.waitForResponse', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
const { suggestedQuestionsAbortController } = workflowStore.getState()
|
||||
if (suggestedQuestionsAbortController)
|
||||
suggestedQuestionsAbortController.abort()
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
|
||||
const runId = startRun()
|
||||
const isCurrentRun = () => runId === workflowStore.getState().activeRunId
|
||||
|
||||
const parentMessage = threadMessages.find(item => item.id === params.parent_message_id)
|
||||
|
||||
const placeholderQuestionId = `question-${Date.now()}`
|
||||
const questionItem: ChatItem = {
|
||||
id: placeholderQuestionId,
|
||||
content: params.query,
|
||||
isAnswer: false,
|
||||
message_files: params.files,
|
||||
parentMessageId: params.parent_message_id,
|
||||
}
|
||||
|
||||
const siblingIndex = parentMessage?.children?.length ?? workflowStore.getState().chatTree.length
|
||||
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
|
||||
const placeholderAnswerItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex,
|
||||
}
|
||||
|
||||
setTargetMessageId(parentMessage?.id)
|
||||
updateCurrentQAOnTree({
|
||||
parentId: params.parent_message_id,
|
||||
responseItem: placeholderAnswerItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
})
|
||||
|
||||
const responseItem: ChatItem = {
|
||||
id: placeholderAnswerId,
|
||||
content: '',
|
||||
agent_thoughts: [],
|
||||
message_files: [],
|
||||
isAnswer: true,
|
||||
parentMessageId: questionItem.id,
|
||||
siblingIndex,
|
||||
}
|
||||
|
||||
handleResponding(true)
|
||||
|
||||
const { files, inputs, ...restParams } = params
|
||||
const bodyParams = {
|
||||
files: getProcessedFiles(files || []),
|
||||
inputs: getProcessedInputs(inputs || {}, formSettings?.inputsForm || []),
|
||||
...restParams,
|
||||
}
|
||||
if (bodyParams?.files?.length) {
|
||||
bodyParams.files = bodyParams.files.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
let hasSetResponseId = false
|
||||
|
||||
const workflowHandlers = createWorkflowEventHandlers({
|
||||
responseItem,
|
||||
questionItem,
|
||||
placeholderQuestionId,
|
||||
parentMessageId: params.parent_message_id,
|
||||
updateCurrentQAOnTree,
|
||||
})
|
||||
|
||||
handleRun(
|
||||
bodyParams,
|
||||
{
|
||||
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.content = responseItem.content + message
|
||||
|
||||
if (messageId && !hasSetResponseId) {
|
||||
questionItem.id = `question-${messageId}`
|
||||
responseItem.id = messageId
|
||||
responseItem.parentMessageId = questionItem.id
|
||||
hasSetResponseId = true
|
||||
}
|
||||
|
||||
if (isFirstMessage && newConversationId)
|
||||
setConversationId(newConversationId)
|
||||
|
||||
if (taskId)
|
||||
setActiveTaskId(taskId)
|
||||
if (messageId)
|
||||
responseItem.id = messageId
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
async onCompleted(hasError?: boolean, errorMessage?: string) {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
handleResponding(false)
|
||||
fetchInspectVars({})
|
||||
invalidAllLastRun()
|
||||
|
||||
if (hasError) {
|
||||
if (errorMessage) {
|
||||
responseItem.content = errorMessage
|
||||
responseItem.isError = true
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (config?.suggested_questions_after_answer?.enabled && !workflowStore.getState().hasStopResponded && onGetSuggestedQuestions) {
|
||||
try {
|
||||
const result = await onGetSuggestedQuestions(
|
||||
responseItem.id,
|
||||
newAbortController => setSuggestedQuestionsAbortController(newAbortController),
|
||||
) as { data: string[] }
|
||||
setSuggestedQuestions(result.data)
|
||||
}
|
||||
catch {
|
||||
setSuggestedQuestions([])
|
||||
}
|
||||
finally {
|
||||
setSuggestedQuestionsAbortController(null)
|
||||
}
|
||||
}
|
||||
},
|
||||
onMessageEnd: (messageEnd) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.citation = messageEnd.metadata?.retriever_resources || []
|
||||
const processedFilesFromResponse = getProcessedFilesFromResponse(messageEnd.files || [])
|
||||
responseItem.allFiles = uniqBy([...(responseItem.allFiles || []), ...(processedFilesFromResponse || [])], 'id')
|
||||
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: params.parent_message_id,
|
||||
})
|
||||
},
|
||||
onMessageReplace: (messageReplace) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
responseItem.content = messageReplace.answer
|
||||
},
|
||||
onError() {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
handleResponding(false)
|
||||
},
|
||||
onWorkflowStarted: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
const taskId = workflowHandlers.onWorkflowStarted(event)
|
||||
if (taskId)
|
||||
setActiveTaskId(taskId)
|
||||
},
|
||||
onWorkflowFinished: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onWorkflowFinished(event)
|
||||
},
|
||||
onIterationStart: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onIterationStart(event)
|
||||
},
|
||||
onIterationFinish: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onIterationFinish(event)
|
||||
},
|
||||
onLoopStart: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onLoopStart(event)
|
||||
},
|
||||
onLoopFinish: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onLoopFinish(event)
|
||||
},
|
||||
onNodeStarted: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeStarted(event)
|
||||
},
|
||||
onNodeRetry: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeRetry(event)
|
||||
},
|
||||
onNodeFinished: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onNodeFinished(event)
|
||||
},
|
||||
onAgentLog: (event) => {
|
||||
if (!isCurrentRun())
|
||||
return
|
||||
workflowHandlers.onAgentLog(event)
|
||||
},
|
||||
},
|
||||
)
|
||||
}, [
|
||||
threadMessages,
|
||||
updateCurrentQAOnTree,
|
||||
handleResponding,
|
||||
formSettings?.inputsForm,
|
||||
handleRun,
|
||||
notify,
|
||||
t,
|
||||
config?.suggested_questions_after_answer?.enabled,
|
||||
setTargetMessageId,
|
||||
setConversationId,
|
||||
setSuggestedQuestions,
|
||||
setActiveTaskId,
|
||||
setSuggestedQuestionsAbortController,
|
||||
startRun,
|
||||
fetchInspectVars,
|
||||
invalidAllLastRun,
|
||||
workflowStore,
|
||||
])
|
||||
|
||||
return { handleSend }
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import type { ChatTreeUpdater, UpdateCurrentQAParams } from './types'
|
||||
import type { ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
import { produce } from 'immer'
|
||||
import { useCallback } from 'react'
|
||||
|
||||
export function useChatTreeOperations(updateChatTree: ChatTreeUpdater) {
|
||||
const produceChatTreeNode = useCallback(
|
||||
(tree: ChatItemInTree[], targetId: string, operation: (node: ChatItemInTree) => void) => {
|
||||
return produce(tree, (draft) => {
|
||||
const queue: ChatItemInTree[] = [...draft]
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!
|
||||
if (current.id === targetId) {
|
||||
operation(current)
|
||||
break
|
||||
}
|
||||
if (current.children)
|
||||
queue.push(...current.children)
|
||||
}
|
||||
})
|
||||
},
|
||||
[],
|
||||
)
|
||||
|
||||
const updateCurrentQAOnTree = useCallback(({
|
||||
parentId,
|
||||
responseItem,
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
}: UpdateCurrentQAParams) => {
|
||||
const currentQA = { ...questionItem, children: [{ ...responseItem, children: [] }] } as ChatItemInTree
|
||||
updateChatTree((currentChatTree) => {
|
||||
if (!parentId) {
|
||||
const questionIndex = currentChatTree.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
return produce(currentChatTree, (draft) => {
|
||||
if (questionIndex === -1)
|
||||
draft.push(currentQA)
|
||||
else
|
||||
draft[questionIndex] = currentQA
|
||||
})
|
||||
}
|
||||
|
||||
return produceChatTreeNode(currentChatTree, parentId, (parentNode) => {
|
||||
if (!parentNode.children)
|
||||
parentNode.children = []
|
||||
const questionNodeIndex = parentNode.children.findIndex(item => [placeholderQuestionId, questionItem.id].includes(item.id))
|
||||
if (questionNodeIndex === -1)
|
||||
parentNode.children.push(currentQA)
|
||||
else
|
||||
parentNode.children[questionNodeIndex] = currentQA
|
||||
})
|
||||
})
|
||||
}, [produceChatTreeNode, updateChatTree])
|
||||
|
||||
return {
|
||||
produceChatTreeNode,
|
||||
updateCurrentQAOnTree,
|
||||
}
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
import type { ChatConfig } from './types'
|
||||
import type { InputForm } from '@/app/components/base/chat/chat/type'
|
||||
import type { ChatItemInTree, Inputs } from '@/app/components/base/chat/types'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useStore } from '../../../store'
|
||||
import { useChatFlowControl } from './use-chat-flow-control'
|
||||
import { useChatList } from './use-chat-list'
|
||||
import { useChatMessageSender } from './use-chat-message-sender'
|
||||
import { useChatTreeOperations } from './use-chat-tree-operations'
|
||||
|
||||
export function useChat(
|
||||
config: ChatConfig | undefined,
|
||||
formSettings?: {
|
||||
inputs: Inputs
|
||||
inputsForm: InputForm[]
|
||||
},
|
||||
prevChatTree?: ChatItemInTree[],
|
||||
stopChat?: (taskId: string) => void,
|
||||
) {
|
||||
const chatTree = useStore(s => s.chatTree)
|
||||
const conversationId = useStore(s => s.conversationId)
|
||||
const isResponding = useStore(s => s.isResponding)
|
||||
const suggestedQuestions = useStore(s => s.suggestedQuestions)
|
||||
const targetMessageId = useStore(s => s.targetMessageId)
|
||||
const updateChatTree = useStore(s => s.updateChatTree)
|
||||
const setTargetMessageId = useStore(s => s.setTargetMessageId)
|
||||
|
||||
const initialChatTreeRef = useRef(prevChatTree)
|
||||
useEffect(() => {
|
||||
const initialChatTree = initialChatTreeRef.current
|
||||
if (!initialChatTree || initialChatTree.length === 0)
|
||||
return
|
||||
updateChatTree(currentChatTree => (currentChatTree.length === 0 ? initialChatTree : currentChatTree))
|
||||
}, [updateChatTree])
|
||||
|
||||
const { updateCurrentQAOnTree } = useChatTreeOperations(updateChatTree)
|
||||
|
||||
const {
|
||||
handleResponding,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
} = useChatFlowControl({
|
||||
stopChat,
|
||||
})
|
||||
|
||||
const {
|
||||
threadMessages,
|
||||
chatList,
|
||||
} = useChatList({
|
||||
chatTree,
|
||||
targetMessageId,
|
||||
config,
|
||||
formSettings,
|
||||
})
|
||||
|
||||
const { handleSend } = useChatMessageSender({
|
||||
threadMessages,
|
||||
config,
|
||||
formSettings,
|
||||
handleResponding,
|
||||
updateCurrentQAOnTree,
|
||||
})
|
||||
|
||||
return {
|
||||
conversationId,
|
||||
chatList,
|
||||
setTargetMessageId,
|
||||
handleSend,
|
||||
handleStop,
|
||||
handleRestart,
|
||||
isResponding,
|
||||
suggestedQuestions,
|
||||
}
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
import type { UpdateCurrentQAParams } from './types'
|
||||
import type { ChatItem } from '@/app/components/base/chat/types'
|
||||
import type { AgentLogItem, NodeTracing } from '@/types/workflow'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '../../../types'
|
||||
|
||||
type WorkflowEventHandlersContext = {
|
||||
responseItem: ChatItem
|
||||
questionItem: ChatItem
|
||||
placeholderQuestionId: string
|
||||
parentMessageId?: string
|
||||
updateCurrentQAOnTree: (params: UpdateCurrentQAParams) => void
|
||||
}
|
||||
|
||||
type TracingData = Partial<NodeTracing> & { id: string }
|
||||
type AgentLogData = Partial<AgentLogItem> & { node_id: string, message_id: string }
|
||||
|
||||
export function createWorkflowEventHandlers(ctx: WorkflowEventHandlersContext) {
|
||||
const { responseItem, questionItem, placeholderQuestionId, parentMessageId, updateCurrentQAOnTree } = ctx
|
||||
|
||||
const updateTree = () => {
|
||||
updateCurrentQAOnTree({
|
||||
placeholderQuestionId,
|
||||
questionItem,
|
||||
responseItem,
|
||||
parentId: parentMessageId,
|
||||
})
|
||||
}
|
||||
|
||||
const updateTracingItem = (data: TracingData) => {
|
||||
const currentTracingIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.id === data.id)
|
||||
if (currentTracingIndex > -1) {
|
||||
responseItem.workflowProcess!.tracing[currentTracingIndex] = {
|
||||
...responseItem.workflowProcess!.tracing[currentTracingIndex],
|
||||
...data,
|
||||
}
|
||||
updateTree()
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }: { workflow_run_id: string, task_id: string }) => {
|
||||
responseItem.workflow_run_id = workflow_run_id
|
||||
responseItem.workflowProcess = {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
tracing: [],
|
||||
}
|
||||
updateTree()
|
||||
return task_id
|
||||
},
|
||||
|
||||
onWorkflowFinished: ({ data }: { data: { status: string } }) => {
|
||||
responseItem.workflowProcess!.status = data.status as WorkflowRunningStatus
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onIterationStart: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onIterationFinish: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onLoopStart: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onLoopFinish: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onNodeStarted: ({ data }: { data: Partial<NodeTracing> }) => {
|
||||
responseItem.workflowProcess!.tracing!.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as NodeTracing)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onNodeRetry: ({ data }: { data: NodeTracing }) => {
|
||||
responseItem.workflowProcess!.tracing!.push(data)
|
||||
updateTree()
|
||||
},
|
||||
|
||||
onNodeFinished: ({ data }: { data: TracingData }) => {
|
||||
updateTracingItem(data)
|
||||
},
|
||||
|
||||
onAgentLog: ({ data }: { data: AgentLogData }) => {
|
||||
const currentNodeIndex = responseItem.workflowProcess!.tracing!.findIndex(item => item.node_id === data.node_id)
|
||||
if (currentNodeIndex > -1) {
|
||||
const current = responseItem.workflowProcess!.tracing![currentNodeIndex]
|
||||
|
||||
if (current.execution_metadata) {
|
||||
if (current.execution_metadata.agent_log) {
|
||||
const currentLogIndex = current.execution_metadata.agent_log.findIndex(log => log.message_id === data.message_id)
|
||||
if (currentLogIndex > -1) {
|
||||
current.execution_metadata.agent_log[currentLogIndex] = {
|
||||
...current.execution_metadata.agent_log[currentLogIndex],
|
||||
...data,
|
||||
} as AgentLogItem
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log.push(data as AgentLogItem)
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata.agent_log = [data as AgentLogItem]
|
||||
}
|
||||
}
|
||||
else {
|
||||
current.execution_metadata = {
|
||||
agent_log: [data as AgentLogItem],
|
||||
} as NodeTracing['execution_metadata']
|
||||
}
|
||||
|
||||
responseItem.workflowProcess!.tracing[currentNodeIndex] = {
|
||||
...current,
|
||||
}
|
||||
|
||||
updateTree()
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
import type { StateCreator } from 'zustand'
|
||||
import type { ChatItemInTree } from '@/app/components/base/chat/types'
|
||||
|
||||
type ChatPreviewState = {
|
||||
chatTree: ChatItemInTree[]
|
||||
targetMessageId: string | undefined
|
||||
suggestedQuestions: string[]
|
||||
conversationId: string
|
||||
isResponding: boolean
|
||||
activeRunId: number
|
||||
activeTaskId: string
|
||||
hasStopResponded: boolean
|
||||
suggestedQuestionsAbortController: AbortController | null
|
||||
}
|
||||
|
||||
type ChatPreviewActions = {
|
||||
setChatTree: (chatTree: ChatItemInTree[]) => void
|
||||
updateChatTree: (updater: (chatTree: ChatItemInTree[]) => ChatItemInTree[]) => void
|
||||
setTargetMessageId: (messageId: string | undefined) => void
|
||||
setSuggestedQuestions: (questions: string[]) => void
|
||||
setConversationId: (conversationId: string) => void
|
||||
setIsResponding: (isResponding: boolean) => void
|
||||
setActiveTaskId: (taskId: string) => void
|
||||
setHasStopResponded: (hasStopResponded: boolean) => void
|
||||
setSuggestedQuestionsAbortController: (controller: AbortController | null) => void
|
||||
startRun: () => number
|
||||
invalidateRun: () => number
|
||||
resetChatPreview: () => void
|
||||
}
|
||||
|
||||
export type ChatPreviewSliceShape = ChatPreviewState & ChatPreviewActions
|
||||
|
||||
const initialState: ChatPreviewState = {
|
||||
chatTree: [],
|
||||
targetMessageId: undefined,
|
||||
suggestedQuestions: [],
|
||||
conversationId: '',
|
||||
isResponding: false,
|
||||
activeRunId: 0,
|
||||
activeTaskId: '',
|
||||
hasStopResponded: false,
|
||||
suggestedQuestionsAbortController: null,
|
||||
}
|
||||
|
||||
export const createChatPreviewSlice: StateCreator<ChatPreviewSliceShape> = (set, get) => ({
|
||||
...initialState,
|
||||
|
||||
setChatTree: chatTree => set({ chatTree }),
|
||||
|
||||
updateChatTree: updater => set((state) => {
|
||||
const nextChatTree = updater(state.chatTree)
|
||||
if (nextChatTree === state.chatTree)
|
||||
return state
|
||||
return { chatTree: nextChatTree }
|
||||
}),
|
||||
|
||||
setTargetMessageId: targetMessageId => set({ targetMessageId }),
|
||||
|
||||
setSuggestedQuestions: suggestedQuestions => set({ suggestedQuestions }),
|
||||
|
||||
setConversationId: conversationId => set({ conversationId }),
|
||||
|
||||
setIsResponding: isResponding => set({ isResponding }),
|
||||
|
||||
setActiveTaskId: activeTaskId => set({ activeTaskId }),
|
||||
|
||||
setHasStopResponded: hasStopResponded => set({ hasStopResponded }),
|
||||
|
||||
setSuggestedQuestionsAbortController: suggestedQuestionsAbortController => set({ suggestedQuestionsAbortController }),
|
||||
|
||||
startRun: () => {
|
||||
const activeRunId = get().activeRunId + 1
|
||||
set({
|
||||
activeRunId,
|
||||
activeTaskId: '',
|
||||
hasStopResponded: false,
|
||||
suggestedQuestionsAbortController: null,
|
||||
})
|
||||
return activeRunId
|
||||
},
|
||||
|
||||
invalidateRun: () => {
|
||||
const activeRunId = get().activeRunId + 1
|
||||
set({
|
||||
activeRunId,
|
||||
activeTaskId: '',
|
||||
suggestedQuestionsAbortController: null,
|
||||
})
|
||||
return activeRunId
|
||||
},
|
||||
|
||||
resetChatPreview: () => set(state => ({
|
||||
...initialState,
|
||||
activeRunId: state.activeRunId + 1,
|
||||
})),
|
||||
})
|
||||
@@ -1,7 +1,6 @@
|
||||
import type {
|
||||
StateCreator,
|
||||
} from 'zustand'
|
||||
import type { ChatPreviewSliceShape } from './chat-preview-slice'
|
||||
import type { ChatVariableSliceShape } from './chat-variable-slice'
|
||||
import type { InspectVarsSliceShape } from './debug/inspect-vars-slice'
|
||||
import type { EnvVariableSliceShape } from './env-variable-slice'
|
||||
@@ -23,7 +22,6 @@ import {
|
||||
} from 'zustand'
|
||||
import { createStore } from 'zustand/vanilla'
|
||||
import { WorkflowContext } from '@/app/components/workflow/context'
|
||||
import { createChatPreviewSlice } from './chat-preview-slice'
|
||||
import { createChatVariableSlice } from './chat-variable-slice'
|
||||
import { createInspectVarsSlice } from './debug/inspect-vars-slice'
|
||||
import { createEnvVariableSlice } from './env-variable-slice'
|
||||
@@ -44,8 +42,7 @@ export type SliceFromInjection
|
||||
& Partial<RagPipelineSliceShape>
|
||||
|
||||
export type Shape
|
||||
= ChatPreviewSliceShape
|
||||
& ChatVariableSliceShape
|
||||
= ChatVariableSliceShape
|
||||
& EnvVariableSliceShape
|
||||
& FormSliceShape
|
||||
& HelpLineSliceShape
|
||||
@@ -70,7 +67,6 @@ export const createWorkflowStore = (params: CreateWorkflowStoreParams) => {
|
||||
const { injectWorkflowStoreSliceFn } = params || {}
|
||||
|
||||
return createStore<Shape>((...args) => ({
|
||||
...createChatPreviewSlice(...args),
|
||||
...createChatVariableSlice(...args),
|
||||
...createEnvVariableSlice(...args),
|
||||
...createFormSlice(...args),
|
||||
|
||||
@@ -822,7 +822,7 @@
|
||||
"count": 2
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 15
|
||||
"count": 14
|
||||
}
|
||||
},
|
||||
"app/components/base/chat/chat/index.tsx": {
|
||||
@@ -3769,6 +3769,11 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/workflow/panel/debug-and-preview/hooks.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 7
|
||||
}
|
||||
},
|
||||
"app/components/workflow/panel/env-panel/variable-modal.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 4
|
||||
|
||||
@@ -152,6 +152,8 @@ export const formatNumberAbbreviated = (num: number) => {
|
||||
: `${formatted}${units[unitIndex].symbol}`
|
||||
}
|
||||
}
|
||||
// Fallback: if no threshold matched, return the number string
|
||||
return num.toString()
|
||||
}
|
||||
|
||||
export const formatToLocalTime = (time: Dayjs, local: Locale, format: string) => {
|
||||
|
||||
Reference in New Issue
Block a user