Compare commits

..

13 Commits

Author SHA1 Message Date
zxhlyh
4d3d8b35d9 Merge branch 'main' into feat/llm-node-support-tools 2026-01-08 14:28:13 +08:00
zxhlyh
c323028179 feat: llm node support tools 2026-01-08 14:27:37 +08:00
zxhlyh
70149ea05e Merge branch 'main' into feat/llm-node-support-tools 2026-01-07 16:29:47 +08:00
zxhlyh
1d93f41fcf feat: llm node support tools 2026-01-07 16:28:41 +08:00
zxhlyh
04f40303fd Merge branch 'main' into feat/llm-node-support-tools 2026-01-04 18:04:42 +08:00
zxhlyh
ececc5ec2c feat: llm node support tools 2026-01-04 18:03:47 +08:00
zxhlyh
e83635ee5a Merge branch 'main' into feat/llm-node-support-tools 2025-12-30 11:47:54 +08:00
zxhlyh
d79372a46d Merge branch 'main' into feat/llm-node-support-tools 2025-12-30 11:47:26 +08:00
zxhlyh
bbd11c9e89 feat: llm node support tools 2025-12-30 10:40:01 +08:00
zxhlyh
d132abcdb4 merge main 2025-12-29 15:55:45 +08:00
zxhlyh
d60348572e feat: llm node support tools 2025-12-29 14:55:26 +08:00
zxhlyh
0cff94d90e Merge branch 'main' into feat/llm-node-support-tools 2025-12-25 13:45:49 +08:00
zxhlyh
a7859de625 feat: llm node support tools 2025-12-24 14:15:55 +08:00
203 changed files with 5812 additions and 18606 deletions

View File

@@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy Trigger Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/trigger-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -1,12 +1,10 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
# See: https://github.com/langgenius/dify/issues/30743
on:
repository_dispatch:
types: [i18n-sync]
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
inputs:
files:
@@ -89,35 +87,26 @@ jobs:
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
fi
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
# Triggered by push via trigger-i18n-sync.yml workflow
# Validate required payload fields
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
exit 1
else
# Push trigger - detect changed files from the push
BEFORE_SHA="${{ github.event.before }}"
# Handle edge case: first push or force push may have null/zero SHA
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
# Fallback to comparing with parent commit
BEFORE_SHA="HEAD~1"
fi
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
changed=$(git diff --name-only "$BEFORE_SHA" ${{ github.sha }} -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
echo "CHANGED_FILES=$changed" >> $GITHUB_OUTPUT
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
echo "SYNC_MODE=incremental" >> $GITHUB_OUTPUT
# Decode the base64-encoded diff from the trigger workflow
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
echo "Warning: Failed to decode base64 diff payload" >&2
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
elif [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
# Generate detailed diff for the push
git diff "$BEFORE_SHA"..${{ github.sha }} -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
if [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "Unsupported event type: ${{ github.event_name }}"
exit 1
fi
# Truncate diff if too large (keep first 50KB)

View File

@@ -1,66 +0,0 @@
name: Trigger i18n Sync on Push
# This workflow bridges the push event to repository_dispatch
# because claude-code-action doesn't support push events directly.
# See: https://github.com/langgenius/dify/issues/30743
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Detect changed files and generate diff
id: detect
run: |
BEFORE_SHA="${{ github.event.before }}"
# Handle edge case: force push may have null/zero SHA
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
BEFORE_SHA="HEAD~1"
fi
# Detect changed i18n files
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
echo "changed_files=$changed" >> $GITHUB_OUTPUT
# Generate diff for context
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
# Truncate if too large (keep first 50KB to match receiving workflow)
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
# Base64 encode the diff for safe JSON transport (portable, single-line)
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
if [ -n "$changed" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Detected changed files: $changed"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No i18n changes detected"
fi
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'

1
.gitignore vendored
View File

@@ -209,7 +209,6 @@ api/.vscode
.history
.idea/
web/migration/
# pnpm
/.pnpm-store

2
.nvmrc
View File

@@ -1 +1 @@
24
22.11.0

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@@ -26,6 +26,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@@ -358,6 +358,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -75,24 +76,12 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
@@ -189,8 +178,12 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@@ -1,434 +0,0 @@
# Memory Module
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
## Overview
The memory module contains two types of memory implementations:
1. **TokenBufferMemory** - Conversation-level memory (existing)
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Memory Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ TokenBufferMemory │ │
│ │ Scope: Conversation │ │
│ │ Storage: Database (Message table) │ │
│ │ Key: conversation_id │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ NodeTokenBufferMemory │ │
│ │ Scope: Node within Conversation │ │
│ │ Storage: Object Storage (JSON file) │ │
│ │ Key: (app_id, conversation_id, node_id) │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
```
---
## TokenBufferMemory (Existing)
### Purpose
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
### Key Features
- **Conversation-scoped**: All messages within a conversation are candidates
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
- **Token-limited**: Truncates history to fit within `max_token_limit`
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
### Data Flow
```
Message Table TokenBufferMemory LLM
│ │ │
│ SELECT * FROM messages │ │
│ WHERE conversation_id = ? │ │
│ ORDER BY created_at DESC │ │
├─────────────────────────────────▶│ │
│ │ │
│ extract_thread_messages() │
│ │ │
│ build_prompt_message_with_files() │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage]
│ ├───────────────────────▶│
│ │ │
```
### Thread Extraction
When a user regenerates a response, a new thread is created:
```
Message A (user)
└── Message A' (assistant)
└── Message B (user)
└── Message B' (assistant)
└── Message A'' (assistant, regenerated) ← New thread
└── Message C (user)
└── Message C' (assistant)
```
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
### Usage
```python
from core.memory.token_buffer_memory import TokenBufferMemory
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
```
---
## NodeTokenBufferMemory (To Be Implemented)
### Purpose
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
### Use Cases
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
3. **Specialized Agents**: Each agent node maintains its own dialogue history
### Design Decisions
#### Storage: Object Storage for Messages (No New Database Table)
| Aspect | Database | Object Storage |
| ------------------------- | -------------------- | ------------------ |
| Cost | High | Low |
| Query Flexibility | High | Low |
| Schema Changes | Migration required | None |
| Consistency with existing | ConversationVariable | File uploads, logs |
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
**What is stored in Object Storage:**
- Message content (text)
- Message metadata (role, token_count, created_at)
- File references (upload_file_id, tool_file_id, etc.)
- Thread relationships (message_id, parent_message_id)
**What still requires Database queries:**
- File reconstruction: When reading node memory, file references are used to query
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
complete `File` objects with storage_key, mime_type, etc.
**Why this hybrid approach:**
- No database migration required (no new tables)
- Message data may be large, object storage is cost-effective
- File metadata is already in database, no need to duplicate
- Aligns with existing storage patterns (file uploads, logs)
#### Storage Key Format
```
node_memory/{app_id}/{conversation_id}/{node_id}.json
```
#### Data Structure
```json
{
"version": 1,
"messages": [
{
"message_id": "msg-001",
"parent_message_id": null,
"role": "user",
"content": "Analyze this image",
"files": [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": "file-uuid-123",
"belongs_to": "user"
}
],
"token_count": 15,
"created_at": "2026-01-07T10:00:00Z"
},
{
"message_id": "msg-002",
"parent_message_id": "msg-001",
"role": "assistant",
"content": "This is a landscape image...",
"files": [],
"token_count": 50,
"created_at": "2026-01-07T10:00:01Z"
}
]
}
```
### Thread Support
Node memory also supports thread extraction (for regeneration scenarios):
```python
def _extract_thread(
self,
messages: list[NodeMemoryMessage],
current_message_id: str
) -> list[NodeMemoryMessage]:
"""
Extract messages belonging to the thread of current_message_id.
Similar to extract_thread_messages() in TokenBufferMemory.
"""
...
```
### File Handling
Files are stored as references (not full metadata):
```python
class NodeMemoryFile(BaseModel):
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None # for local_file
tool_file_id: str | None # for tool_file
url: str | None # for remote_url
belongs_to: str # user / assistant
```
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
### API Design
```python
class NodeTokenBufferMemory:
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param model_instance: Model instance for token counting
"""
...
def add_messages(
self,
message_id: str,
parent_message_id: str | None,
user_content: str,
user_files: Sequence[File],
assistant_content: str,
assistant_files: Sequence[File],
) -> None:
"""
Append a dialogue turn (user + assistant) to node memory.
Call this after LLM node execution completes.
:param message_id: Current message ID (from Message table)
:param parent_message_id: Parent message ID (for thread tracking)
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
...
def get_history_prompt_messages(
self,
current_message_id: str,
tenant_id: str,
max_token_limit: int = 2000,
file_upload_config: FileUploadConfig | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
:param current_message_id: Current message ID (for thread extraction)
:param tenant_id: Tenant ID (for file reconstruction)
:param max_token_limit: Maximum tokens for history
:param file_upload_config: File upload configuration
:return: Sequence of PromptMessage for LLM context
"""
...
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
...
def clear(self) -> None:
"""
Clear all messages in this node's memory.
"""
...
```
### Data Flow
```
Object Storage NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ storage.load(key) │ │
│◀─────────────────────────────────┤ │
│ │ │
│ JSON data │ │
├─────────────────────────────────▶│ │
│ │ │
│ _extract_thread() │
│ │ │
│ _rebuild_files() via file_factory │
│ │ │
│ _build_prompt_messages() │
│ │ │
│ _truncate_by_tokens() │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
│ │◀── LLM execution complete │
│ │ │
│ │◀── add_messages() │
│ │ │
│ storage.save(key, data) │ │
│◀─────────────────────────────────┤ │
│ │ │
```
### Integration with LLM Node
```python
# In LLM Node execution
# 1. Fetch memory based on mode
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
memory = fetch_node_memory(
variable_pool=variable_pool,
app_id=app_id,
node_id=self.node_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
# Conversation-level memory (existing behavior)
memory = fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
else:
memory = None
# 2. Get history for context
if memory:
if isinstance(memory, NodeTokenBufferMemory):
history = memory.get_history_prompt_messages(
current_message_id=current_message_id,
tenant_id=tenant_id,
max_token_limit=max_token_limit,
)
else: # TokenBufferMemory
history = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
)
prompt_messages = [*history, *current_messages]
else:
prompt_messages = current_messages
# 3. Call LLM
response = model_instance.invoke(prompt_messages)
# 4. Append to node memory (only for NodeTokenBufferMemory)
if isinstance(memory, NodeTokenBufferMemory):
memory.add_messages(
message_id=message_id,
parent_message_id=parent_message_id,
user_content=user_input,
user_files=user_files,
assistant_content=response.content,
assistant_files=response_files,
)
memory.flush()
```
### Configuration
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
```python
class MemoryMode(StrEnum):
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
class MemoryConfig(BaseModel):
# Existing fields
role_prefix: RolePrefix | None = None
window: MemoryWindowConfig | None = None
query_prompt_template: str | None = None
# Memory mode (new)
mode: MemoryMode = MemoryMode.CONVERSATION
```
**Mode Behavior:**
| Mode | Memory Class | Scope | Availability |
| -------------- | --------------------- | ------------------------ | ------------- |
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
> fall back to no memory or raise a configuration error.
---
## Comparison
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | Object Storage (JSON) |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via file references) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
---
## Future Considerations
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
2. **Concurrency**: Consider Redis lock for concurrent node executions
3. **Compression**: Compress large memory files to reduce storage costs
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory

View File

@@ -1,15 +0,0 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeMemoryData,
NodeMemoryFile,
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeMemoryData",
"NodeMemoryFile",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

View File

@@ -1,83 +0,0 @@
"""
Base memory interfaces and types.
This module defines the common protocol for memory implementations.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
class BaseMemory(ABC):
"""
Abstract base class for memory implementations.
Provides a common interface for both conversation-level and node-level memory.
"""
@abstractmethod
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Sequence of PromptMessage for LLM context
"""
pass
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt as formatted text.
:param human_prefix: Prefix for human messages
:param ai_prefix: Prefix for assistant messages
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Formatted history text
"""
from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -1,353 +0,0 @@
"""
Node-level Token Buffer Memory for Chatflow.
This module provides node-scoped memory within a conversation.
Each LLM node in a workflow can maintain its own independent conversation history.
Note: This is only available in Chatflow (advanced-chat mode) because it requires
both conversation_id and node_id.
Design:
- Storage is indexed by workflow_run_id (each execution stores one turn)
- Thread tracking leverages Message table's parent_message_id structure
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
"""
import logging
from collections.abc import Sequence
from pydantic import BaseModel
from sqlalchemy import select
from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message
logger = logging.getLogger(__name__)
class NodeMemoryFile(BaseModel):
"""File reference stored in node memory."""
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None = None
tool_file_id: str | None = None
url: str | None = None
class NodeMemoryTurn(BaseModel):
"""A single dialogue turn (user + assistant) in node memory."""
user_content: str = ""
user_files: list[NodeMemoryFile] = []
assistant_content: str = ""
assistant_files: list[NodeMemoryFile] = []
class NodeMemoryData(BaseModel):
"""Root data structure for node memory storage."""
version: int = 1
# Key: workflow_run_id, Value: dialogue turn
turns: dict[str, NodeMemoryTurn] = {}
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history, stored in object storage.
Key design: Thread tracking is delegated to Message table's parent_message_id.
Storage is indexed by workflow_run_id for easy filtering.
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
"""
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param tenant_id: Tenant ID for file reconstruction
:param model_instance: Model instance for token counting
"""
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
self._data: NodeMemoryData | None = None
self._dirty = False
def _load(self) -> NodeMemoryData:
"""Load data from object storage."""
if self._data is not None:
return self._data
try:
raw = storage.load_once(self._storage_key)
self._data = NodeMemoryData.model_validate_json(raw)
except Exception:
# File not found or parse error, start fresh
self._data = NodeMemoryData()
return self._data
def _save(self) -> None:
"""Save data to object storage."""
if self._data is not None:
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
self._dirty = False
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
"""Convert File object to NodeMemoryFile reference."""
return NodeMemoryFile(
type=file.type.value if hasattr(file.type, "value") else str(file.type),
transfer_method=(
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
),
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
)
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
"""Convert NodeMemoryFile to mapping for file_factory."""
mapping: dict = {
"type": memory_file.type,
"transfer_method": memory_file.transfer_method,
}
if memory_file.upload_file_id:
mapping["upload_file_id"] = memory_file.upload_file_id
if memory_file.tool_file_id:
mapping["tool_file_id"] = memory_file.tool_file_id
if memory_file.url:
mapping["url"] = memory_file.url
return mapping
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
"""Rebuild File objects from NodeMemoryFile references."""
if not memory_files:
return []
from factories import file_factory
files = []
for mf in memory_files:
try:
mapping = self._memory_file_to_mapping(mf)
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
files.append(file)
except Exception as e:
logger.warning("Failed to rebuild file from memory: %s", e)
continue
return files
def _build_prompt_message(
self,
role: str,
content: str,
files: list[File],
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
) -> PromptMessage:
"""Build PromptMessage from content and files."""
from core.file import file_manager
if not files:
if role == "user":
return UserPromptMessage(content=content)
else:
return AssistantPromptMessage(content=content)
# Build multimodal content
prompt_contents: list = []
for file in files:
try:
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
prompt_contents.append(prompt_content)
except Exception as e:
logger.warning("Failed to convert file to prompt content: %s", e)
continue
prompt_contents.append(TextPromptMessageContent(data=content))
if role == "user":
return UserPromptMessage(content=prompt_contents)
else:
return AssistantPromptMessage(content=prompt_contents)
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
# Query messages for this conversation
stmt = (
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
)
messages = db.session.scalars(stmt.limit(500)).all()
if not messages:
return []
# Extract thread messages using existing logic
thread_messages = extract_thread_messages(messages)
# For newly created message, its answer is temporarily empty, skip it
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
workflow_run_ids = []
for msg in reversed(thread_messages):
if msg.workflow_run_id:
workflow_run_ids.append(msg.workflow_run_id)
return workflow_run_ids
def add_messages(
self,
workflow_run_id: str,
user_content: str,
user_files: Sequence[File] | None = None,
assistant_content: str = "",
assistant_files: Sequence[File] | None = None,
) -> None:
"""
Add a dialogue turn to node memory.
Call this after LLM node execution completes.
:param workflow_run_id: Current workflow execution ID
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
data = self._load()
# Convert files to memory file references
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
# Store the turn indexed by workflow_run_id
data.turns[workflow_run_id] = NodeMemoryTurn(
user_content=user_content,
user_files=user_memory_files,
assistant_content=assistant_content,
assistant_files=assistant_memory_files,
)
self._dirty = True
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
Thread tracking is handled by querying Message table's parent_message_id structure.
:param max_token_limit: Maximum tokens for history
:param message_limit: unused, for interface compatibility
:return: Sequence of PromptMessage for LLM context
"""
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
_ = message_limit
detail = ImagePromptMessageContent.DETAIL.HIGH
data = self._load()
if not data.turns:
return []
# Get workflow_run_ids for current thread from Message table
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Build prompt messages in thread order
prompt_messages: list[PromptMessage] = []
for wf_run_id in thread_workflow_run_ids:
turn = data.turns.get(wf_run_id)
if not turn:
# This workflow execution didn't have node memory stored
continue
# Build user message
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
user_msg = self._build_prompt_message(
role="user",
content=turn.user_content,
files=user_files,
detail=detail,
)
prompt_messages.append(user_msg)
# Build assistant message
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
assistant_msg = self._build_prompt_message(
role="assistant",
content=turn.assistant_content,
files=assistant_files,
detail=detail,
)
prompt_messages.append(assistant_msg)
if not prompt_messages:
return []
# Truncate by token limit
try:
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
while current_tokens > max_token_limit and len(prompt_messages) > 1:
prompt_messages.pop(0)
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
except Exception as e:
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
if self._dirty:
self._save()
def clear(self) -> None:
"""Clear all messages in this node's memory."""
self._data = NodeMemoryData()
self._save()
def exists(self) -> bool:
"""Check if node memory exists in storage."""
return storage.exists(self._storage_key)

View File

@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory(BaseMemory):
class TokenBufferMemory:
def __init__(
self,
conversation: Conversation,
@@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
@@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -1,4 +1,3 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
@@ -6,13 +5,6 @@ from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageRole
class MemoryMode(StrEnum):
"""Memory mode for LLM nodes."""
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class ChatModelMessage(BaseModel):
"""
Chat Message.
@@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,6 @@ class NodeType(StrEnum):
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
GROUP = "group"
@property
def is_trigger_node(self) -> bool:

View File

@@ -307,14 +307,7 @@ class Graph:
if not node_configs:
raise ValueError("Graph must have at least one node")
# Filter out UI-only node types:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config for node_config in node_configs
if node_config.get("type", "") != "custom-note"
and node_config.get("data", {}).get("type", "") != "group"
]
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@@ -125,11 +125,6 @@ class EventHandler:
Args:
event: The node started event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_started(event)
return
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
@@ -169,11 +164,6 @@ class EventHandler:
Args:
event: The node succeeded event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_success(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
@@ -236,11 +226,6 @@ class EventHandler:
Args:
event: The node failed event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_failed(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
@@ -360,57 +345,3 @@ class EventHandler:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)
def _is_virtual_node(self, node_id: str) -> bool:
"""
Check if node_id represents a virtual sub-node.
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
We check if the part before '.' exists in graph nodes.
"""
if "." in node_id:
parent_id = node_id.rsplit(".", 1)[0]
return parent_id in self._graph.nodes
return False
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle virtual node started event.
Virtual nodes don't need full execution tracking, just collect the event.
"""
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
"""
Handle virtual node success event.
Virtual nodes (extraction nodes) need special handling:
- Store outputs in variable pool (for reference by other nodes)
- Accumulate token usage
- Collect the event for logging
- Do NOT process edges or enqueue next nodes (parent node handles that)
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle virtual node failed event.
Virtual nodes (extraction nodes) failures are collected for logging,
but the parent node is responsible for handling the error.
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Collect the event for logging
self._event_collector.collect(event)

View File

@@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_type: str = ""
provider_id: str = ""
# Virtual node fields for extraction
is_virtual: bool = False
parent_node_id: str | None = None
extraction_source: str | None = None # e.g., "llm1.context"
extraction_prompt: str | None = None
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields

View File

@@ -1,13 +1,5 @@
from .entities import (
BaseIterationNodeData,
BaseIterationState,
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
VirtualNodeConfig,
)
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
__all__ = [
"BaseIterationNodeData",
@@ -16,7 +8,4 @@ __all__ = [
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
"VirtualNodeConfig",
"VirtualNodeExecutionError",
"VirtualNodeExecutor",
]

View File

@@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
return self
class VirtualNodeConfig(BaseModel):
"""Configuration for a virtual sub-node embedded within a parent node."""
# Local ID within parent node (e.g., "ext_1")
# Will be converted to global ID: "{parent_id}.{id}"
id: str
# Node type (e.g., "llm", "code", "tool")
type: str
# Full node data configuration
data: dict[str, Any] = {}
def get_global_id(self, parent_node_id: str) -> str:
"""Get the global node ID by combining parent ID and local ID."""
return f"{parent_node_id}.{self.id}"
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
@@ -193,9 +175,6 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Virtual sub-nodes that execute before the main node
virtual_nodes: list[VirtualNodeConfig] = []
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:

View File

@@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
@@ -271,52 +270,10 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual sub-nodes defined in node configuration.
Virtual nodes are complete node definitions that execute before the main node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool (via event handling)
- Supports retry via parent node's retry config
Returns:
dict mapping local_id -> outputs dict
"""
from .virtual_node_executor import VirtualNodeExecutor
virtual_nodes = self.node_data.virtual_nodes
if not virtual_nodes:
return {}
executor = VirtualNodeExecutor(
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
parent_node_id=self._node_id,
parent_retry_config=self.retry_config,
)
return (yield from executor.execute_virtual_nodes(virtual_nodes))
@property
def virtual_node_outputs(self) -> dict[str, Any]:
"""
Get the outputs from virtual sub-nodes.
Returns:
dict mapping local_id -> outputs dict
"""
return self._virtual_node_outputs
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute virtual sub-nodes before main node execution
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=execution_id,

View File

@@ -1,213 +0,0 @@
"""
Virtual Node Executor for running embedded sub-nodes within a parent node.
This module handles the execution of virtual nodes defined in a parent node's
`virtual_nodes` configuration. Virtual nodes are complete node definitions
that execute before the parent node.
Example configuration:
virtual_nodes:
- id: ext_1
type: llm
data:
model: {...}
prompt_template: [...]
"""
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunFailedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from libs.datetime_utils import naive_utc_now
from .entities import RetryConfig, VirtualNodeConfig
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class VirtualNodeExecutionError(Exception):
"""Error during virtual node execution"""
def __init__(self, node_id: str, original_error: Exception):
self.node_id = node_id
self.original_error = original_error
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
class VirtualNodeExecutor:
"""
Executes virtual sub-nodes embedded within a parent node.
Virtual nodes are complete node definitions that execute before the parent node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool
- Supports retry via parent node's retry config
"""
def __init__(
self,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
parent_node_id: str,
parent_retry_config: RetryConfig | None = None,
):
self._graph_init_params = graph_init_params
self._graph_runtime_state = graph_runtime_state
self._parent_node_id = parent_node_id
self._parent_retry_config = parent_retry_config or RetryConfig()
def execute_virtual_nodes(
self,
virtual_nodes: list[VirtualNodeConfig],
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual nodes in order.
Args:
virtual_nodes: List of virtual node configurations
Yields:
Node events from each virtual node execution
Returns:
dict mapping local_id -> outputs dict
"""
results: dict[str, Any] = {}
for vnode_config in virtual_nodes:
global_id = vnode_config.get_global_id(self._parent_node_id)
# Execute with retry
outputs = yield from self._execute_with_retry(vnode_config, global_id)
results[vnode_config.id] = outputs
return results
def _execute_with_retry(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute virtual node with retry support.
"""
retry_config = self._parent_retry_config
last_error: Exception | None = None
for attempt in range(retry_config.max_retries + 1):
try:
return (yield from self._execute_single_node(vnode_config, global_id))
except Exception as e:
last_error = e
if attempt < retry_config.max_retries:
# Yield retry event
yield NodeRunRetryEvent(
id=str(uuid4()),
node_id=global_id,
node_type=self._get_node_type(vnode_config.type),
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
start_at=naive_utc_now(),
error=str(e),
retry_index=attempt + 1,
)
time.sleep(retry_config.retry_interval_seconds)
continue
raise VirtualNodeExecutionError(global_id, e) from e
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
def _execute_single_node(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute a single virtual node by instantiating and running it.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
# Build node config
node_config: dict[str, Any] = {
"id": global_id,
"data": {
**vnode_config.data,
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
},
}
# Get the node class for this type
node_type = self._get_node_type(vnode_config.type)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_version = str(vnode_config.data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
raise ValueError(f"No class found for node type: {node_type}")
# Instantiate the node
node = node_cls(
id=global_id,
config=node_config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self._graph_runtime_state,
)
# Run and collect events
outputs: dict[str, Any] = {}
for event in node.run():
# Mark event as coming from virtual node
self._mark_event_as_virtual(event, vnode_config)
yield event
if isinstance(event, NodeRunSucceededEvent):
outputs = event.node_run_result.outputs or {}
elif isinstance(event, NodeRunFailedEvent):
raise Exception(event.error or "Virtual node execution failed")
return outputs
def _mark_event_as_virtual(
self,
event: GraphNodeEventBase,
vnode_config: VirtualNodeConfig,
) -> None:
"""Mark event as coming from a virtual node."""
if isinstance(event, NodeRunStartedEvent):
event.is_virtual = True
event.parent_node_id = self._parent_node_id
def _get_node_type(self, type_str: str) -> NodeType:
"""Convert type string to NodeType enum."""
type_mapping = {
"llm": NodeType.LLM,
"code": NodeType.CODE,
"tool": NodeType.TOOL,
"if-else": NodeType.IF_ELSE,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"template-transform": NodeType.TEMPLATE_TRANSFORM,
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
"http-request": NodeType.HTTP_REQUEST,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
}
return type_mapping.get(type_str, NodeType.LLM)

View File

@@ -8,13 +8,12 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
@@ -87,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool,
app_id: str,
tenant_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
node_id: str = "",
) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
:param variable_pool: Variable pool containing system variables
:param app_id: Application ID
:param tenant_id: Tenant ID
:param node_data_memory: Memory configuration
:param model_instance: Model instance for token counting
:param node_id: Node ID in the workflow (required for node mode)
:return: Memory instance or None if not applicable
"""
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# Get conversation_id from variable pool (required for both modes in Chatflow)
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if node_data_memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
if not node_id:
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return NodeTokenBufferMemory(
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
tenant_id=tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):

View File

@@ -16,8 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
@@ -209,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
@@ -304,41 +301,12 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Write to Node Memory if in node memory mode
if isinstance(memory, NodeTokenBufferMemory):
# Get workflow_run_id as the key for this execution
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else ""
if workflow_run_id:
# Resolve the query template to get actual user content
# query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}"
actual_query = variable_pool.convert_template(query or "").text
# Get user files from sys.files
user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES])
user_files: list[File] = []
if isinstance(user_files_var, ArrayFileSegment):
user_files = list(user_files_var.value)
elif isinstance(user_files_var, FileSegment):
user_files = [user_files_var.value]
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=actual_query,
user_files=user_files,
assistant_content=clean_text,
assistant_files=self._file_outputs,
)
memory.flush()
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@@ -598,22 +566,6 @@ class LLMNode(Node[LLMNodeData]):
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
@staticmethod
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
model_mode: str,
) -> list[dict[str, Any]]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
"""
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode, prompt_messages=context_messages
)
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@@ -826,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
@@ -1385,7 +1337,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
@@ -1402,7 +1354,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -42,22 +43,25 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
Draft7Validator(schema).validate(value)
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value

View File

@@ -89,20 +89,18 @@ class ToolNode(Node[ToolNodeData]):
)
return
# get parameters (use virtual_node_outputs from base class)
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
virtual_node_outputs=self.virtual_node_outputs,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
virtual_node_outputs=self.virtual_node_outputs,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -178,7 +176,6 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
virtual_node_outputs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@@ -187,17 +184,12 @@ class ToolNode(Node[ToolNodeData]):
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
for_log (bool): Whether to generate parameters for logging.
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
with global IDs like "{parent_id}.{local_id}".
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
virtual_node_outputs = virtual_node_outputs or {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
@@ -207,25 +199,14 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# Check if this references a virtual node output (local ID like [ext_1, text])
selector = tool_input.value
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
# Reference to virtual node output
local_id = selector[0]
var_name = selector[1]
outputs = virtual_node_outputs.get(local_id, {})
parameter_value = outputs.get(var_name)
else:
# Normal variable reference
variable = variable_pool.get(selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
continue
parameter_value = variable.value
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
template = str(tool_input.value)
segment_group = variable_pool.convert_template(template)
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")

View File

@@ -1,266 +0,0 @@
app:
description: Test for variable extraction feature
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: pav-test-extraction
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- local_file
- remote_url
enabled: false
image:
enabled: false
number_limits: 3
transfer_methods:
- local_file
- remote_url
number_limits: 3
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1767773675796-llm
source: '1767773675796'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
isInIteration: false
isInLoop: false
sourceType: llm
targetType: tool
id: llm-source-1767773709491-target
source: llm
sourceHandle: source
target: '1767773709491'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: answer
id: tool-source-answer-target
source: '1767773709491'
sourceHandle: source
target: answer
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1767773675796'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
query_prompt_template: ''
role_prefix:
assistant: ''
user: ''
window:
enabled: true
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
prompt_template:
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
role: system
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
请与用户进行对话,了解他们的搜索需求。
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
'
selected: false
title: LLM
type: llm
vision:
enabled: false
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: used for searching
ja_JP: used for searching
pt_BR: used for searching
zh_Hans: 用于搜索网页内容
label:
en_US: Query string
ja_JP: Query string
pt_BR: Query string
zh_Hans: 查询语句
llm_description: key words for searching
max: null
min: null
name: query
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
params:
query: ''
plugin_id: langgenius/google
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
provider_id: langgenius/google/google
provider_name: langgenius/google/google
provider_type: builtin
selected: false
title: GoogleSearch
tool_configurations: {}
tool_description: A tool for performing a Google SERP search and extracting
snippets and webpages.Input should be a search query.
tool_label: GoogleSearch
tool_name: google_search
tool_node_version: '2'
tool_parameters:
query:
type: variable
value:
- ext_1
- text
type: tool
virtual_nodes:
- data:
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
context:
enabled: false
prompt_template:
- role: user
text: '{{#llm.context#}}'
- role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
title: 提取搜索关键词
id: ext_1
type: llm
height: 52
id: '1767773709491'
position:
x: 682
y: 282
positionAbsolute:
x: 682
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '搜索结果:
{{#1767773709491.text#}}
'
selected: false
title: Answer
type: answer
height: 103
id: answer
position:
x: 984
y: 282
positionAbsolute:
x: 984
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: 151
y: 141.5
zoom: 1
rag_pipeline_variables: []

View File

@@ -0,0 +1,390 @@
"""
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
specifically testing the ANSWER node message_replace logic.
"""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
from core.app.entities.queue_entities import QueueNodeSucceededEvent
from core.workflow.enums import NodeType
from models import EndUser
from models.model import AppMode
class TestAnswerNodeMessageReplace:
"""Test cases for ANSWER node message_replace event logic."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=AdvancedChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
entity.workflow_run_id = "test-workflow-run-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.ADVANCED_CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
entity.query = "test query"
entity.files = []
entity.extras = {}
entity.trace_manager = None
entity.inputs = {}
entity.invoke_from = "debugger"
return entity
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow."""
workflow = Mock()
workflow.id = "test-workflow-id"
workflow.features_dict = {}
return workflow
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock()
manager.listen.return_value = []
manager.graph_runtime_state = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "advanced_chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.query = "test query"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_user(self):
"""Create a mock end user."""
user = MagicMock(spec=EndUser)
user.id = "test-user-id"
user.session_id = "test-session-id"
return user
@pytest.fixture
def mock_draft_var_saver_factory(self):
"""Create a mock draft variable saver factory."""
return Mock()
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_workflow,
mock_queue_manager,
mock_conversation,
mock_message,
mock_user,
mock_draft_var_saver_factory,
):
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
workflow=mock_workflow,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
user=mock_user,
stream=True,
dialogue_count=1,
draft_var_saver_factory=mock_draft_var_saver_factory,
)
# Initialize workflow run id to avoid validation errors
pipeline._workflow_run_id = "test-workflow-run-id"
# Mock the message cycle manager methods we need to track
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
return pipeline
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
"""
Test that when an ANSWER node's final output differs from accumulated answer,
a message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "initial answer"
# Create ANSWER node succeeded event with different final output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "updated final answer"},
)
# Mock the workflow response converter to avoid extra processing
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
responses = list(pipeline._handle_node_succeeded_event(event))
# Assert
assert pipeline._task_state.answer == "updated final answer"
# Verify message_replace was called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="updated final answer", reason="variable_update"
)
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's final output is the same as accumulated answer,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "same answer"
# Create ANSWER node succeeded event with same output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "same answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "same answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's output is None or missing 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with None output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": None},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node has empty outputs dict,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with empty outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
"""
Test that when an ANSWER node's outputs don't contain 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event without 'answer' key in outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"other_key": "some value"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
"""
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Test with LLM node
llm_event = QueueNodeSucceededEvent(
node_execution_id="test-llm-execution-id",
node_id="test-llm-node",
node_type=NodeType.LLM,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(llm_event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_end_node_does_not_send_message_replace(self, pipeline):
"""
Test that END nodes don't trigger message_replace events even with 'answer' output.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create END node succeeded event with answer output
event = QueueNodeSucceededEvent(
node_execution_id="test-end-execution-id",
node_id="test-end-node",
node_type=NodeType.END,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
"""
Test that when an ANSWER node's final output is numeric,
it gets converted to string properly.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "text answer"
# Create ANSWER node succeeded event with numeric output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": 12345},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should be converted to string
assert pipeline._task_state.answer == "12345"
# Verify message_replace was called with string
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="12345", reason="variable_update"
)
def test_answer_node_files_are_recorded(self, pipeline):
"""
Test that ANSWER nodes properly record files from outputs.
"""
# Arrange
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with files
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={
"answer": "same answer",
"files": [
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
],
},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: files should be recorded
assert len(pipeline._recorded_files) == 1
assert pipeline._recorded_files[0] == event.outputs["files"][0]

View File

@@ -1,77 +0,0 @@
"""
Unit tests for virtual node configuration.
"""
from core.workflow.nodes.base.entities import VirtualNodeConfig
class TestVirtualNodeConfig:
"""Tests for VirtualNodeConfig entity."""
def test_create_basic_config(self):
"""Test creating a basic virtual node config."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={
"title": "Extract keywords",
"model": {"provider": "openai", "name": "gpt-4o-mini"},
},
)
assert config.id == "ext_1"
assert config.type == "llm"
assert config.data["title"] == "Extract keywords"
def test_get_global_id(self):
"""Test generating global ID from parent ID."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={},
)
global_id = config.get_global_id("tool1")
assert global_id == "tool1.ext_1"
def test_get_global_id_with_different_parents(self):
"""Test global ID generation with different parent IDs."""
config = VirtualNodeConfig(id="sub_node", type="code", data={})
assert config.get_global_id("parent1") == "parent1.sub_node"
assert config.get_global_id("node_123") == "node_123.sub_node"
def test_empty_data(self):
"""Test virtual node config with empty data."""
config = VirtualNodeConfig(
id="test",
type="tool",
)
assert config.id == "test"
assert config.type == "tool"
assert config.data == {}
def test_complex_data(self):
"""Test virtual node config with complex data."""
config = VirtualNodeConfig(
id="llm_1",
type="llm",
data={
"title": "Generate summary",
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {"temperature": 0.7, "max_tokens": 500},
},
"prompt_template": [
{"role": "user", "text": "{{#llm1.context#}}"},
{"role": "user", "text": "Please summarize the conversation"},
],
},
)
assert config.data["model"]["provider"] == "openai"
assert len(config.data["prompt_template"]) == 2

View File

@@ -58,8 +58,6 @@ def test_json_object_valid_schema():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -70,7 +68,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
node = make_start_node(user_inputs, variables)
result = node._run()
@@ -89,8 +87,6 @@ def test_json_object_invalid_json_string():
"required": ["age", "name"],
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -101,12 +97,12 @@ def test_json_object_invalid_json_string():
)
]
# Providing a string instead of an object should raise a type error
# Missing closing brace makes this invalid JSON
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
node._run()
@@ -122,8 +118,6 @@ def test_json_object_does_not_match_schema():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -135,7 +129,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
node = make_start_node(user_inputs, variables)
@@ -155,8 +149,6 @@ def test_json_object_missing_required_schema_field():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -168,7 +160,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
user_inputs = {"profile": json.dumps({"age": 20})}
node = make_start_node(user_inputs, variables)

View File

@@ -2,11 +2,11 @@ import Marketplace from '@/app/components/plugins/marketplace'
import PluginPage from '@/app/components/plugins/plugin-page'
import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel'
const PluginList = () => {
const PluginList = async () => {
return (
<PluginPage
plugins={<PluginsPanel />}
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" />}
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" showSearchParams={false} />}
/>
)
}

View File

@@ -26,7 +26,6 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
import { useInvalidateAppList } from '@/service/use-apps'
import { fetchWorkflowDraft } from '@/service/workflow'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
@@ -67,7 +66,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
const { onPlanInfoChanged } = useProviderContext()
const appDetail = useAppStore(state => state.appDetail)
const setAppDetail = useAppStore(state => state.setAppDetail)
const invalidateAppList = useInvalidateAppList()
const [open, setOpen] = useState(openState)
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
@@ -193,7 +191,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
try {
await deleteApp(appDetail.id)
notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) })
invalidateAppList()
onPlanInfoChanged()
setAppDetail()
replace('/apps')
@@ -205,7 +202,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
})
}
setShowConfirmDelete(false)
}, [appDetail, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t])
}, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t])
const { isCurrentWorkspaceEditor } = useAppContext()

View File

@@ -83,7 +83,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
if (!isJsonObject || !tempPayload.json_schema)
return ''
try {
return tempPayload.json_schema
return JSON.stringify(JSON.parse(tempPayload.json_schema), null, 2)
}
catch {
return ''

View File

@@ -1,228 +0,0 @@
/**
* Tests for race condition prevention logic in chat message loading.
* These tests verify the core algorithms used in fetchData and loadMoreMessages
* to prevent race conditions, infinite loops, and stale state issues.
* See GitHub issue #30259 for context.
*/
// Test the race condition prevention logic in isolation
describe('Chat Message Loading Race Condition Prevention', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
describe('Request Deduplication', () => {
it('should deduplicate messages with same IDs when merging responses', async () => {
// Simulate the deduplication logic used in setAllChatItems
const existingItems = [
{ id: 'msg-1', isAnswer: false },
{ id: 'msg-2', isAnswer: true },
]
const newItems = [
{ id: 'msg-2', isAnswer: true }, // duplicate
{ id: 'msg-3', isAnswer: false }, // new
]
const existingIds = new Set(existingItems.map(item => item.id))
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
const mergedItems = [...uniqueNewItems, ...existingItems]
expect(uniqueNewItems).toHaveLength(1)
expect(uniqueNewItems[0].id).toBe('msg-3')
expect(mergedItems).toHaveLength(3)
})
})
describe('Retry Counter Logic', () => {
const MAX_RETRY_COUNT = 3
it('should increment retry counter when no unique items found', () => {
const state = { retryCount: 0 }
const prevItemsLength = 5
// Simulate the retry logic from loadMoreMessages
const uniqueNewItemsLength = 0
if (uniqueNewItemsLength === 0) {
if (state.retryCount < MAX_RETRY_COUNT && prevItemsLength > 1) {
state.retryCount++
}
else {
state.retryCount = 0
}
}
expect(state.retryCount).toBe(1)
})
it('should reset retry counter after MAX_RETRY_COUNT attempts', () => {
const state = { retryCount: MAX_RETRY_COUNT }
const prevItemsLength = 5
const uniqueNewItemsLength = 0
if (uniqueNewItemsLength === 0) {
if (state.retryCount < MAX_RETRY_COUNT && prevItemsLength > 1) {
state.retryCount++
}
else {
state.retryCount = 0
}
}
expect(state.retryCount).toBe(0)
})
it('should reset retry counter when unique items are found', () => {
const state = { retryCount: 2 }
// Simulate finding unique items (length > 0)
const processRetry = (uniqueCount: number) => {
if (uniqueCount === 0) {
state.retryCount++
}
else {
state.retryCount = 0
}
}
processRetry(3) // Found 3 unique items
expect(state.retryCount).toBe(0)
})
})
describe('Throttling Logic', () => {
const SCROLL_DEBOUNCE_MS = 200
it('should throttle requests within debounce window', () => {
const state = { lastLoadTime: 0 }
const results: boolean[] = []
const tryRequest = (now: number): boolean => {
if (now - state.lastLoadTime >= SCROLL_DEBOUNCE_MS) {
state.lastLoadTime = now
return true
}
return false
}
// First request - should pass
results.push(tryRequest(1000))
// Second request within debounce - should be blocked
results.push(tryRequest(1100))
// Third request after debounce - should pass
results.push(tryRequest(1300))
expect(results).toEqual([true, false, true])
})
})
describe('AbortController Cancellation', () => {
it('should abort previous request when new request starts', () => {
const state: { controller: AbortController | null } = { controller: null }
const abortedSignals: boolean[] = []
// First request
const controller1 = new AbortController()
state.controller = controller1
// Second request - should abort first
if (state.controller) {
state.controller.abort()
abortedSignals.push(state.controller.signal.aborted)
}
const controller2 = new AbortController()
state.controller = controller2
expect(abortedSignals).toEqual([true])
expect(controller1.signal.aborted).toBe(true)
expect(controller2.signal.aborted).toBe(false)
})
})
describe('Stale Response Detection', () => {
it('should ignore responses from outdated requests', () => {
const state = { requestId: 0 }
const processedResponses: number[] = []
// Simulate concurrent requests - each gets its own captured ID
const request1Id = ++state.requestId
const request2Id = ++state.requestId
// Request 2 completes first (current requestId is 2)
if (request2Id === state.requestId) {
processedResponses.push(request2Id)
}
// Request 1 completes later (stale - requestId is still 2)
if (request1Id === state.requestId) {
processedResponses.push(request1Id)
}
expect(processedResponses).toEqual([2])
expect(processedResponses).not.toContain(1)
})
})
describe('Pagination Anchor Management', () => {
it('should track oldest answer ID for pagination', () => {
let oldestAnswerIdRef: string | undefined
const chatItems = [
{ id: 'question-1', isAnswer: false },
{ id: 'answer-1', isAnswer: true },
{ id: 'question-2', isAnswer: false },
{ id: 'answer-2', isAnswer: true },
]
// Update pagination anchor with oldest answer ID
const answerItems = chatItems.filter(item => item.isAnswer)
const oldestAnswer = answerItems[answerItems.length - 1]
if (oldestAnswer?.id) {
oldestAnswerIdRef = oldestAnswer.id
}
expect(oldestAnswerIdRef).toBe('answer-2')
})
it('should use pagination anchor in subsequent requests', () => {
const oldestAnswerIdRef = 'answer-123'
const params: { conversation_id: string, limit: number, first_id?: string } = {
conversation_id: 'conv-1',
limit: 10,
}
if (oldestAnswerIdRef) {
params.first_id = oldestAnswerIdRef
}
expect(params.first_id).toBe('answer-123')
})
})
})
describe('Functional State Update Pattern', () => {
it('should use functional update to avoid stale closures', () => {
// Simulate the functional update pattern used in setAllChatItems
let state = [{ id: '1' }, { id: '2' }]
const newItems = [{ id: '3' }, { id: '2' }] // id '2' is duplicate
// Functional update pattern
const updater = (prevItems: { id: string }[]) => {
const existingIds = new Set(prevItems.map(item => item.id))
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
return [...uniqueNewItems, ...prevItems]
}
state = updater(state)
expect(state).toHaveLength(3)
expect(state.map(i => i.id)).toEqual(['3', '1', '2'])
})
})

View File

@@ -209,6 +209,7 @@ type IDetailPanel = {
function DetailPanel({ detail, onFeedback }: IDetailPanel) {
const MIN_ITEMS_FOR_SCROLL_LOADING = 8
const SCROLL_THRESHOLD_PX = 50
const SCROLL_DEBOUNCE_MS = 200
const { userProfile: { timezone } } = useAppContext()
const { formatTime } = useTimestamp()
@@ -227,103 +228,69 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
const [hasMore, setHasMore] = useState(true)
const [varValues, setVarValues] = useState<Record<string, string>>({})
const isLoadingRef = useRef(false)
const abortControllerRef = useRef<AbortController | null>(null)
const requestIdRef = useRef(0)
const lastLoadTimeRef = useRef(0)
const retryCountRef = useRef(0)
const oldestAnswerIdRef = useRef<string | undefined>(undefined)
const MAX_RETRY_COUNT = 3
const [allChatItems, setAllChatItems] = useState<IChatItem[]>([])
const [chatItemTree, setChatItemTree] = useState<ChatItemInTree[]>([])
const [threadChatItems, setThreadChatItems] = useState<IChatItem[]>([])
const fetchData = useCallback(async () => {
if (isLoadingRef.current || !hasMore)
if (isLoadingRef.current)
return
// Cancel any in-flight request
if (abortControllerRef.current) {
abortControllerRef.current.abort()
}
const controller = new AbortController()
abortControllerRef.current = controller
const currentRequestId = ++requestIdRef.current
try {
isLoadingRef.current = true
if (!hasMore)
return
const params: ChatMessagesRequest = {
conversation_id: detail.id,
limit: 10,
}
// Use ref for pagination anchor to avoid stale closure issues
if (oldestAnswerIdRef.current)
params.first_id = oldestAnswerIdRef.current
// Use the oldest answer item ID for pagination
const answerItems = allChatItems.filter(item => item.isAnswer)
const oldestAnswerItem = answerItems[answerItems.length - 1]
if (oldestAnswerItem?.id)
params.first_id = oldestAnswerItem.id
const messageRes = await fetchChatMessages({
url: `/apps/${appDetail?.id}/chat-messages`,
params,
})
// Ignore stale responses
if (currentRequestId !== requestIdRef.current || controller.signal.aborted)
return
if (messageRes.data.length > 0) {
const varValues = messageRes.data.at(-1)!.inputs
setVarValues(varValues)
}
setHasMore(messageRes.has_more)
const newItems = getFormattedChatList(messageRes.data, detail.id, timezone!, t('dateTimeFormat', { ns: 'appLog' }) as string)
const newAllChatItems = [
...getFormattedChatList(messageRes.data, detail.id, timezone!, t('dateTimeFormat', { ns: 'appLog' }) as string),
...allChatItems,
]
setAllChatItems(newAllChatItems)
// Use functional update to avoid stale state issues
setAllChatItems((prevItems: IChatItem[]) => {
const existingIds = new Set(prevItems.map(item => item.id))
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
return [...uniqueNewItems, ...prevItems]
})
let tree = buildChatItemTree(newAllChatItems)
if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) {
tree = [{
id: 'introduction',
isAnswer: true,
isOpeningStatement: true,
content: detail?.model_config?.configs?.introduction ?? 'hello',
feedbackDisabled: true,
children: tree,
}]
}
setChatItemTree(tree)
const lastMessageId = newAllChatItems.length > 0 ? newAllChatItems[newAllChatItems.length - 1].id : undefined
setThreadChatItems(getThreadMessages(tree, lastMessageId))
}
catch (err: unknown) {
if (err instanceof Error && err.name === 'AbortError')
return
catch (err) {
console.error('fetchData execution failed:', err)
}
finally {
isLoadingRef.current = false
if (abortControllerRef.current === controller)
abortControllerRef.current = null
}
}, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
// Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems
useEffect(() => {
if (allChatItems.length === 0)
return
let tree = buildChatItemTree(allChatItems)
if (!hasMore && detail?.model_config?.configs?.introduction) {
tree = [{
id: 'introduction',
isAnswer: true,
isOpeningStatement: true,
content: detail?.model_config?.configs?.introduction ?? 'hello',
feedbackDisabled: true,
children: tree,
}]
}
setChatItemTree(tree)
const lastMessageId = allChatItems.length > 0 ? allChatItems[allChatItems.length - 1].id : undefined
setThreadChatItems(getThreadMessages(tree, lastMessageId))
// Update pagination anchor ref with the oldest answer ID
const answerItems = allChatItems.filter(item => item.isAnswer)
const oldestAnswer = answerItems[answerItems.length - 1]
if (oldestAnswer?.id)
oldestAnswerIdRef.current = oldestAnswer.id
}, [allChatItems, hasMore, detail?.model_config?.configs?.introduction])
}, [allChatItems, detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
const switchSibling = useCallback((siblingMessageId: string) => {
const newThreadChatItems = getThreadMessages(chatItemTree, siblingMessageId)
@@ -430,12 +397,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
if (isLoading || !hasMore || !appDetail?.id || !detail.id)
return
// Throttle using ref to persist across re-renders
const now = Date.now()
if (now - lastLoadTimeRef.current < SCROLL_DEBOUNCE_MS)
return
lastLoadTimeRef.current = now
setIsLoading(true)
try {
@@ -444,9 +405,15 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
limit: 10,
}
// Use ref for pagination anchor to avoid stale closure issues
if (oldestAnswerIdRef.current) {
params.first_id = oldestAnswerIdRef.current
// Use the earliest response item as the first_id
const answerItems = allChatItems.filter(item => item.isAnswer)
const oldestAnswerItem = answerItems[answerItems.length - 1]
if (oldestAnswerItem?.id) {
params.first_id = oldestAnswerItem.id
}
else if (allChatItems.length > 0 && allChatItems[0]?.id) {
const firstId = allChatItems[0].id.replace('question-', '').replace('answer-', '')
params.first_id = firstId
}
const messageRes = await fetchChatMessages({
@@ -456,7 +423,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
if (!messageRes.data || messageRes.data.length === 0) {
setHasMore(false)
retryCountRef.current = 0
return
}
@@ -474,36 +440,91 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
t('dateTimeFormat', { ns: 'appLog' }) as string,
)
// Use functional update to get latest state and avoid stale closures
setAllChatItems((prevItems: IChatItem[]) => {
const existingIds = new Set(prevItems.map(item => item.id))
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
// Check for duplicate messages
const existingIds = new Set(allChatItems.map(item => item.id))
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
// If no unique items and we haven't exceeded retry limit, signal retry needed
if (uniqueNewItems.length === 0) {
if (retryCountRef.current < MAX_RETRY_COUNT && prevItems.length > 1) {
retryCountRef.current++
return prevItems
if (uniqueNewItems.length === 0) {
if (allChatItems.length > 1) {
const nextId = allChatItems[1].id.replace('question-', '').replace('answer-', '')
const retryParams = {
...params,
first_id: nextId,
}
else {
retryCountRef.current = 0
return prevItems
const retryRes = await fetchChatMessages({
url: `/apps/${appDetail.id}/chat-messages`,
params: retryParams,
})
if (retryRes.data && retryRes.data.length > 0) {
const retryItems = getFormattedChatList(
retryRes.data,
detail.id,
timezone!,
t('dateTimeFormat', { ns: 'appLog' }) as string,
)
const retryUniqueItems = retryItems.filter(item => !existingIds.has(item.id))
if (retryUniqueItems.length > 0) {
const newAllChatItems = [
...retryUniqueItems,
...allChatItems,
]
setAllChatItems(newAllChatItems)
let tree = buildChatItemTree(newAllChatItems)
if (retryRes.has_more === false && detail?.model_config?.configs?.introduction) {
tree = [{
id: 'introduction',
isAnswer: true,
isOpeningStatement: true,
content: detail?.model_config?.configs?.introduction ?? 'hello',
feedbackDisabled: true,
children: tree,
}]
}
setChatItemTree(tree)
setHasMore(retryRes.has_more)
setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id))
return
}
}
}
}
retryCountRef.current = 0
return [...uniqueNewItems, ...prevItems]
})
const newAllChatItems = [
...uniqueNewItems,
...allChatItems,
]
setAllChatItems(newAllChatItems)
let tree = buildChatItemTree(newAllChatItems)
if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) {
tree = [{
id: 'introduction',
isAnswer: true,
isOpeningStatement: true,
content: detail?.model_config?.configs?.introduction ?? 'hello',
feedbackDisabled: true,
children: tree,
}]
}
setChatItemTree(tree)
setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id))
}
catch (error) {
console.error(error)
setHasMore(false)
retryCountRef.current = 0
}
finally {
setIsLoading(false)
}
}, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
}, [allChatItems, detail.id, hasMore, isLoading, timezone, t, appDetail])
useEffect(() => {
const scrollableDiv = document.getElementById('scrollableDiv')
@@ -535,11 +556,24 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
if (!scrollContainer)
return
let lastLoadTime = 0
const throttleDelay = 200
const handleScroll = () => {
const currentScrollTop = scrollContainer!.scrollTop
const isNearTop = currentScrollTop < 30
const scrollHeight = scrollContainer!.scrollHeight
const clientHeight = scrollContainer!.clientHeight
if (isNearTop && hasMore && !isLoading) {
const distanceFromTop = currentScrollTop
const distanceFromBottom = scrollHeight - currentScrollTop - clientHeight
const now = Date.now()
const isNearTop = distanceFromTop < 30
// eslint-disable-next-line sonarjs/no-unused-vars
const _distanceFromBottom = distanceFromBottom < 30
if (isNearTop && hasMore && !isLoading && (now - lastLoadTime > throttleDelay)) {
lastLoadTime = now
loadMoreMessages()
}
}
@@ -585,6 +619,36 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
return () => cancelAnimationFrame(raf)
}, [])
// Add scroll listener to ensure loading is triggered
useEffect(() => {
if (threadChatItems.length >= MIN_ITEMS_FOR_SCROLL_LOADING && hasMore) {
const scrollableDiv = document.getElementById('scrollableDiv')
if (scrollableDiv) {
let loadingTimeout: NodeJS.Timeout | null = null
const handleScroll = () => {
const { scrollTop } = scrollableDiv
// Trigger loading when scrolling near the top
if (scrollTop < SCROLL_THRESHOLD_PX && !isLoadingRef.current) {
if (loadingTimeout)
clearTimeout(loadingTimeout)
loadingTimeout = setTimeout(fetchData, SCROLL_DEBOUNCE_MS) // 200ms debounce
}
}
scrollableDiv.addEventListener('scroll', handleScroll)
return () => {
scrollableDiv.removeEventListener('scroll', handleScroll)
if (loadingTimeout)
clearTimeout(loadingTimeout)
}
}
}
}, [threadChatItems.length, hasMore, fetchData])
return (
<div ref={ref} className="flex h-full flex-col rounded-xl border-[0.5px] border-components-panel-border">
{/* Panel Header */}

View File

@@ -10,7 +10,6 @@ const mockReplace = vi.fn()
const mockRouter = { replace: mockReplace }
vi.mock('next/navigation', () => ({
useRouter: () => mockRouter,
useSearchParams: () => new URLSearchParams(''),
}))
// Mock app context

View File

@@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks'
import dynamic from 'next/dynamic'
import {
useRouter,
useSearchParams,
} from 'next/navigation'
import { parseAsString, useQueryState } from 'nuqs'
import { useCallback, useEffect, useRef, useState } from 'react'
@@ -37,16 +36,6 @@ import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import NewAppCard from './new-app-card'
// Define valid tabs at module scope to avoid re-creation on each render and stale closures
const validTabs = new Set<string | AppModeEnum>([
'all',
AppModeEnum.WORKFLOW,
AppModeEnum.ADVANCED_CHAT,
AppModeEnum.CHAT,
AppModeEnum.AGENT_CHAT,
AppModeEnum.COMPLETION,
])
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
ssr: false,
})
@@ -58,41 +47,12 @@ const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const router = useRouter()
const searchParams = useSearchParams()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [activeTab, setActiveTab] = useQueryState(
'category',
parseAsString.withDefault('all').withOptions({ history: 'push' }),
)
// valid tabs for apps list; anything else should fallback to 'all'
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
useEffect(() => {
// avoid running on server
if (typeof window === 'undefined')
return
const mode = searchParams.get('mode')
if (!mode)
return
const url = new URL(window.location.href)
url.searchParams.delete('mode')
if (validTabs.has(mode)) {
// migrate to category key
url.searchParams.set('category', mode)
}
else {
url.searchParams.set('category', 'all')
}
router.replace(url.pathname + url.search)
}, [router, searchParams])
// 2) If category has an invalid value (e.g., 'discover'), reset to 'all'
useEffect(() => {
if (!validTabs.has(activeTab))
setActiveTab('all')
}, [activeTab, setActiveTab])
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)

View File

@@ -21,6 +21,7 @@ import BasicContent from './basic-content'
import More from './more'
import Operation from './operation'
import SuggestedQuestions from './suggested-questions'
import ToolCalls from './tool-calls'
import WorkflowProcessItem from './workflow-process'
type AnswerProps = {
@@ -61,6 +62,7 @@ const Answer: FC<AnswerProps> = ({
workflowProcess,
allFiles,
message_files,
toolCalls,
} = item
const hasAgentThoughts = !!agent_thoughts?.length
@@ -154,6 +156,11 @@ const Answer: FC<AnswerProps> = ({
/>
)
}
{
!!toolCalls?.length && (
<ToolCalls toolCalls={toolCalls} />
)
}
{
responding && contentIsEmpty && !hasAgentThoughts && (
<div className="flex h-5 w-6 items-center justify-center">

View File

@@ -0,0 +1,23 @@
import type { ToolCallItem } from '@/types/workflow'
import ToolCallItemComponent from '@/app/components/workflow/run/llm-log/tool-call-item'
type ToolCallsProps = {
toolCalls: ToolCallItem[]
}
const ToolCalls = ({
toolCalls,
}: ToolCallsProps) => {
return (
<div className="my-1 space-y-1">
{toolCalls.map((toolCall: ToolCallItem, index: number) => (
<ToolCallItemComponent
key={index}
payload={toolCall}
className="bg-background-gradient-bg-fill-chat-bubble-bg-2 shadow-none"
/>
))}
</div>
)
}
export default ToolCalls

View File

@@ -45,7 +45,7 @@ const WorkflowProcessItem = ({
return (
<div
className={cn(
'-mx-1 rounded-xl px-2.5',
'rounded-xl px-2.5',
collapse ? 'border-l-[0.25px] border-components-panel-border py-[7px]' : 'border-[0.5px] border-components-panel-border-subtle px-1 pb-1 pt-[7px]',
running && !collapse && 'bg-background-section-burn',
succeeded && !collapse && 'bg-state-success-hover',

View File

@@ -319,6 +319,9 @@ export const useChat = (
return player
}
let toolCallId = ''
let thoughtId = ''
ssePost(
url,
{
@@ -326,7 +329,19 @@ export const useChat = (
},
{
isPublicAPI,
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
onData: (message: string, isFirstMessage: boolean, {
conversationId: newConversationId,
messageId,
taskId,
chunk_type,
tool_icon,
tool_icon_dark,
tool_name,
tool_arguments,
tool_files,
tool_error,
tool_elapsed_time,
}: any) => {
if (!isAgentMode) {
responseItem.content = responseItem.content + message
}
@@ -336,6 +351,57 @@ export const useChat = (
lastThought.thought = lastThought.thought + message // need immer setAutoFreeze
}
if (chunk_type === 'tool_call') {
if (!responseItem.toolCalls)
responseItem.toolCalls = []
toolCallId = uuidV4()
responseItem.toolCalls?.push({
id: toolCallId,
type: 'tool',
toolName: tool_name,
toolArguments: tool_arguments,
toolIcon: tool_icon,
toolIconDark: tool_icon_dark,
})
}
if (chunk_type === 'tool_result') {
const currentToolCallIndex = responseItem.toolCalls?.findIndex(item => item.id === toolCallId) ?? -1
if (currentToolCallIndex > -1) {
responseItem.toolCalls![currentToolCallIndex].toolError = tool_error
responseItem.toolCalls![currentToolCallIndex].toolDuration = tool_elapsed_time
responseItem.toolCalls![currentToolCallIndex].toolFiles = tool_files
responseItem.toolCalls![currentToolCallIndex].toolOutput = message
}
}
if (chunk_type === 'thought_start') {
if (!responseItem.toolCalls)
responseItem.toolCalls = []
thoughtId = uuidV4()
responseItem.toolCalls.push({
id: thoughtId,
type: 'thought',
thoughtOutput: '',
})
}
if (chunk_type === 'thought') {
const currentThoughtIndex = responseItem.toolCalls?.findIndex(item => item.id === thoughtId) ?? -1
if (currentThoughtIndex > -1) {
responseItem.toolCalls![currentThoughtIndex].thoughtOutput += message
}
}
if (chunk_type === 'thought_end') {
const currentThoughtIndex = responseItem.toolCalls?.findIndex(item => item.id === thoughtId) ?? -1
if (currentThoughtIndex > -1) {
responseItem.toolCalls![currentThoughtIndex].thoughtOutput += message
responseItem.toolCalls![currentThoughtIndex].thoughtCompleted = true
}
}
if (messageId && !hasSetResponseId) {
questionItem.id = `question-${messageId}`
responseItem.id = messageId

View File

@@ -2,7 +2,7 @@ import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { TypeWithI18N } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { InputVarType } from '@/app/components/workflow/types'
import type { Annotation, MessageRating } from '@/models/log'
import type { FileResponse } from '@/types/workflow'
import type { FileResponse, ToolCallItem } from '@/types/workflow'
export type MessageMore = {
time: string
@@ -104,6 +104,7 @@ export type IChatItem = {
siblingIndex?: number
prevSibling?: string
nextSibling?: string
toolCalls?: ToolCallItem[]
}
export type Metadata = {

View File

@@ -37,7 +37,7 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
return
}
if (inputValue == null)
if (!inputValue)
return
if (item.type === InputVarType.singleFile) {
@@ -52,20 +52,6 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
else
processedInputs[item.variable] = getProcessedFiles(inputValue)
}
else if (item.type === InputVarType.jsonObject) {
// Prefer sending an object if the user entered valid JSON; otherwise keep the raw string.
try {
const v = typeof inputValue === 'string' ? JSON.parse(inputValue) : inputValue
if (v && typeof v === 'object' && !Array.isArray(v))
processedInputs[item.variable] = v
else
processedInputs[item.variable] = inputValue
}
catch {
// keep original string; backend will parse/validate
processedInputs[item.variable] = inputValue
}
}
})
return processedInputs

View File

@@ -66,9 +66,7 @@ const Header: FC<IHeaderProps> = ({
const listener = (event: MessageEvent) => handleMessageReceived(event)
window.addEventListener('message', listener)
// Security: Use document.referrer to get parent origin
const targetOrigin = document.referrer ? new URL(document.referrer).origin : '*'
window.parent.postMessage({ type: 'dify-chatbot-iframe-ready' }, targetOrigin)
window.parent.postMessage({ type: 'dify-chatbot-iframe-ready' }, '*')
return () => window.removeEventListener('message', listener)
}, [isIframe, handleMessageReceived])

View File

@@ -0,0 +1,4 @@
<svg width="12" height="14" viewBox="0 0 12 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2 9.49479C0.782372 8.51826 0 7.01768 0 5.33333C0 2.38782 2.38782 0 5.33333 0C8.20841 0 10.5503 2.27504 10.6608 5.12305L11.888 6.96354C12.0843 7.25794 12.0161 7.65424 11.7331 7.86654L10.6667 8.66602V10C10.6667 10.7364 10.0697 11.3333 9.33333 11.3333H8V13.3333H6.66667V10.6667C6.66667 10.2985 6.96514 10 7.33333 10H9.33333V8.33333C9.33333 8.12349 9.43239 7.92603 9.60026 7.80013L10.4284 7.17838L9.44531 5.70312C9.3723 5.59361 9.33333 5.46495 9.33333 5.33333C9.33333 3.1242 7.54248 1.33333 5.33333 1.33333C3.1242 1.33333 1.33333 3.1242 1.33333 5.33333C1.33333 6.69202 2.0103 7.89261 3.04818 8.61654C3.2269 8.74119 3.33329 8.94552 3.33333 9.16341V13.3333H2V9.49479Z" fill="#354052"/>
<path d="M6.04367 4.24012L5.6504 3.21778C5.59993 3.08657 5.47393 3 5.33333 3C5.19273 3 5.06673 3.08657 5.01627 3.21778L4.62303 4.24012C4.55531 4.41618 4.41618 4.55531 4.24012 4.62303L3.21778 5.01624C3.08657 5.0667 3 5.19276 3 5.33333C3 5.47393 3.08657 5.59993 3.21778 5.6504L4.24012 6.04367C4.41618 6.11133 4.55531 6.25047 4.62303 6.42653L5.01627 7.44887C5.06673 7.58007 5.19273 7.66667 5.33333 7.66667C5.47393 7.66667 5.59993 7.58007 5.6504 7.44887L6.04367 6.42653C6.11133 6.25047 6.25047 6.11133 6.42653 6.04367L7.44887 5.6504C7.58007 5.59993 7.66667 5.47393 7.66667 5.33333C7.66667 5.19276 7.58007 5.0667 7.44887 5.01624L6.42653 4.62303C6.25047 4.55531 6.11133 4.41618 6.04367 4.24012Z" fill="#354052"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,35 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "12",
"height": "14",
"viewBox": "0 0 12 14",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M2 9.49479C0.782372 8.51826 0 7.01768 0 5.33333C0 2.38782 2.38782 0 5.33333 0C8.20841 0 10.5503 2.27504 10.6608 5.12305L11.888 6.96354C12.0843 7.25794 12.0161 7.65424 11.7331 7.86654L10.6667 8.66602V10C10.6667 10.7364 10.0697 11.3333 9.33333 11.3333H8V13.3333H6.66667V10.6667C6.66667 10.2985 6.96514 10 7.33333 10H9.33333V8.33333C9.33333 8.12349 9.43239 7.92603 9.60026 7.80013L10.4284 7.17838L9.44531 5.70312C9.3723 5.59361 9.33333 5.46495 9.33333 5.33333C9.33333 3.1242 7.54248 1.33333 5.33333 1.33333C3.1242 1.33333 1.33333 3.1242 1.33333 5.33333C1.33333 6.69202 2.0103 7.89261 3.04818 8.61654C3.2269 8.74119 3.33329 8.94552 3.33333 9.16341V13.3333H2V9.49479Z",
"fill": "currentColor"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M6.04367 4.24012L5.6504 3.21778C5.59993 3.08657 5.47393 3 5.33333 3C5.19273 3 5.06673 3.08657 5.01627 3.21778L4.62303 4.24012C4.55531 4.41618 4.41618 4.55531 4.24012 4.62303L3.21778 5.01624C3.08657 5.0667 3 5.19276 3 5.33333C3 5.47393 3.08657 5.59993 3.21778 5.6504L4.24012 6.04367C4.41618 6.11133 4.55531 6.25047 4.62303 6.42653L5.01627 7.44887C5.06673 7.58007 5.19273 7.66667 5.33333 7.66667C5.47393 7.66667 5.59993 7.58007 5.6504 7.44887L6.04367 6.42653C6.11133 6.25047 6.25047 6.11133 6.42653 6.04367L7.44887 5.6504C7.58007 5.59993 7.66667 5.47393 7.66667 5.33333C7.66667 5.19276 7.58007 5.0667 7.44887 5.01624L6.42653 4.62303C6.25047 4.55531 6.11133 4.41618 6.04367 4.24012Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "Thinking"
}

View File

@@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import type { IconData } from '@/app/components/base/icons/IconBase'
import * as React from 'react'
import IconBase from '@/app/components/base/icons/IconBase'
import data from './Thinking.json'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'Thinking'
export default Icon

View File

@@ -24,6 +24,7 @@ export { default as ParameterExtractor } from './ParameterExtractor'
export { default as QuestionClassifier } from './QuestionClassifier'
export { default as Schedule } from './Schedule'
export { default as TemplatingTransform } from './TemplatingTransform'
export { default as Thinking } from './Thinking'
export { default as TriggerAll } from './TriggerAll'
export { default as VariableX } from './VariableX'
export { default as WebhookLine } from './WebhookLine'

View File

@@ -5,7 +5,6 @@ import type {
} from 'lexical'
import type { FC } from 'react'
import type {
AgentBlockType,
ContextBlockType,
CurrentBlockType,
ErrorMessageBlockType,
@@ -104,7 +103,6 @@ export type PromptEditorProps = {
currentBlock?: CurrentBlockType
errorMessageBlock?: ErrorMessageBlockType
lastRunBlock?: LastRunBlockType
agentBlock?: AgentBlockType
isSupportFileVar?: boolean
}
@@ -130,7 +128,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
currentBlock,
errorMessageBlock,
lastRunBlock,
agentBlock,
isSupportFileVar,
}) => {
const { eventEmitter } = useEventEmitterContextContext()
@@ -142,7 +139,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
{
replace: TextNode,
with: (node: TextNode) => new CustomTextNode(node.__text),
withKlass: CustomTextNode,
},
ContextBlockNode,
HistoryBlockNode,
@@ -216,22 +212,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
lastRunBlock={lastRunBlock}
isSupportFileVar={isSupportFileVar}
/>
{(!agentBlock || agentBlock.show) && (
<ComponentPickerBlock
triggerString="@"
contextBlock={contextBlock}
historyBlock={historyBlock}
queryBlock={queryBlock}
variableBlock={variableBlock}
externalToolBlock={externalToolBlock}
workflowVariableBlock={workflowVariableBlock}
currentBlock={currentBlock}
errorMessageBlock={errorMessageBlock}
lastRunBlock={lastRunBlock}
agentBlock={agentBlock}
isSupportFileVar={isSupportFileVar}
/>
)}
<ComponentPickerBlock
triggerString="{"
contextBlock={contextBlock}

View File

@@ -1,7 +1,6 @@
import type { MenuRenderFn } from '@lexical/react/LexicalTypeaheadMenuPlugin'
import type { TextNode } from 'lexical'
import type {
AgentBlockType,
ContextBlockType,
CurrentBlockType,
ErrorMessageBlockType,
@@ -21,11 +20,7 @@ import {
} from '@floating-ui/react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
import {
$getRoot,
$insertNodes,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import { KEY_ESCAPE_COMMAND } from 'lexical'
import {
Fragment,
memo,
@@ -34,9 +29,7 @@ import {
} from 'react'
import ReactDOM from 'react-dom'
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
import AgentNodeList from '@/app/components/workflow/nodes/_base/components/agent-node-list'
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
import { BlockEnum } from '@/app/components/workflow/types'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useBasicTypeaheadTriggerMatch } from '../../hooks'
import { $splitNodeContainingQuery } from '../../utils'
@@ -45,7 +38,6 @@ import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../error-message-block'
import { INSERT_LAST_RUN_BLOCK_COMMAND } from '../last-run-block'
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '../variable-block'
import { INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND } from '../workflow-variable-block'
import { $createWorkflowVariableBlockNode } from '../workflow-variable-block/node'
import { useOptions } from './hooks'
type ComponentPickerProps = {
@@ -59,7 +51,6 @@ type ComponentPickerProps = {
currentBlock?: CurrentBlockType
errorMessageBlock?: ErrorMessageBlockType
lastRunBlock?: LastRunBlockType
agentBlock?: AgentBlockType
isSupportFileVar?: boolean
}
const ComponentPicker = ({
@@ -73,7 +64,6 @@ const ComponentPicker = ({
currentBlock,
errorMessageBlock,
lastRunBlock,
agentBlock,
isSupportFileVar,
}: ComponentPickerProps) => {
const { eventEmitter } = useEventEmitterContextContext()
@@ -161,41 +151,12 @@ const ComponentPicker = ({
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
}, [editor])
const handleSelectAgent = useCallback((agent: { id: string, title: string }) => {
editor.update(() => {
const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!)
if (needRemove)
needRemove.remove()
const root = $getRoot()
const firstChild = root.getFirstChild()
if (firstChild) {
const selection = firstChild.selectStart()
if (selection) {
const workflowVariableBlockNode = $createWorkflowVariableBlockNode([agent.id, 'text'], {}, undefined)
$insertNodes([workflowVariableBlockNode])
}
}
})
agentBlock?.onSelect?.(agent)
handleClose()
}, [editor, checkForTriggerMatch, triggerString, agentBlock, handleClose])
const isAgentTrigger = triggerString === '@' && agentBlock?.show
const agentNodes = agentBlock?.agentNodes || []
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
anchorElementRef,
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) => {
if (isAgentTrigger) {
if (!(anchorElementRef.current && agentNodes.length))
return null
}
else {
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
}
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
setTimeout(() => {
if (anchorElementRef.current)
@@ -206,6 +167,9 @@ const ComponentPicker = ({
<>
{
ReactDOM.createPortal(
// The `LexicalMenu` will try to calculate the position of the floating menu based on the first child.
// Since we use floating ui, we need to wrap it with a div to prevent the position calculation being affected.
// See https://github.com/facebook/lexical/blob/ac97dfa9e14a73ea2d6934ff566282d7f758e8bb/packages/lexical-react/src/shared/LexicalMenu.ts#L493
<div className="h-0 w-0">
<div
className="w-[260px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
@@ -215,73 +179,56 @@ const ComponentPicker = ({
}}
ref={refs.setFloating}
>
{isAgentTrigger
? (
<AgentNodeList
nodes={agentNodes.map(node => ({
...node,
type: BlockEnum.Agent,
}))}
onSelect={handleSelectAgent}
{
workflowVariableBlock?.show && (
<div className="p-1">
<VarReferenceVars
searchBoxClassName="mt-1"
vars={workflowVariableOptions}
onChange={(variables: string[]) => {
handleSelectWorkflowVariable(variables)
}}
maxHeightClass="max-h-[34vh]"
isSupportFileVar={isSupportFileVar}
onClose={handleClose}
onBlur={handleClose}
maxHeightClass="max-h-[34vh]"
showManageInputField={workflowVariableBlock.showManageInputField}
onManageInputField={workflowVariableBlock.onManageInputField}
autoFocus={false}
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
/>
)
: (
<>
</div>
)
}
{
workflowVariableBlock?.show && !!options.length && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
<div>
{
options.map((option, index) => (
<Fragment key={option.key}>
{
workflowVariableBlock?.show && (
<div className="p-1">
<VarReferenceVars
searchBoxClassName="mt-1"
vars={workflowVariableOptions}
onChange={(variables: string[]) => {
handleSelectWorkflowVariable(variables)
}}
maxHeightClass="max-h-[34vh]"
isSupportFileVar={isSupportFileVar}
onClose={handleClose}
onBlur={handleClose}
showManageInputField={workflowVariableBlock.showManageInputField}
onManageInputField={workflowVariableBlock.onManageInputField}
autoFocus={false}
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
/>
</div>
)
}
{
workflowVariableBlock?.show && !!options.length && (
// Divider
index !== 0 && options.at(index - 1)?.group !== option.group && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
<div>
{
options.map((option, index) => (
<Fragment key={option.key}>
{
index !== 0 && options.at(index - 1)?.group !== option.group && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
{option.renderMenuOption({
queryString,
isSelected: selectedIndex === index,
onSelect: () => {
selectOptionAndCleanUp(option)
},
onSetHighlight: () => {
setHighlightedIndex(index)
},
})}
</Fragment>
))
}
</div>
</>
)}
{option.renderMenuOption({
queryString,
isSelected: selectedIndex === index,
onSelect: () => {
selectOptionAndCleanUp(option)
},
onSetHighlight: () => {
setHighlightedIndex(index)
},
})}
</Fragment>
))
}
</div>
</div>
</div>,
anchorElementRef.current,
@@ -289,7 +236,7 @@ const ComponentPicker = ({
}
</>
)
}, [isAgentTrigger, agentNodes, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, handleSelectAgent, handleClose, workflowVariableOptions, isSupportFileVar, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
return (
<LexicalTypeaheadMenuPlugin

View File

@@ -21,7 +21,6 @@ import {
VariableLabelInEditor,
} from '@/app/components/workflow/nodes/_base/components/variable/variable-label'
import { Type } from '@/app/components/workflow/nodes/llm/types'
import { BlockEnum } from '@/app/components/workflow/types'
import { isExceptionVariable } from '@/app/components/workflow/utils'
import { useSelectOrDelete } from '../../hooks'
import {
@@ -67,7 +66,6 @@ const WorkflowVariableBlockComponent = ({
)()
const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState<WorkflowNodesMap>(workflowNodesMap)
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
const isAgentVariable = node?.type === BlockEnum.Agent
const isException = isExceptionVariable(varName, node?.type)
const variableValid = useMemo(() => {
@@ -136,9 +134,6 @@ const WorkflowVariableBlockComponent = ({
})
}, [node, reactflow, store])
if (isAgentVariable)
return <span className="hidden" ref={ref} />
const Item = (
<VariableLabelInEditor
nodeType={node?.type}

View File

@@ -73,17 +73,6 @@ export type WorkflowVariableBlockType = {
onManageInputField?: () => void
}
export type AgentNode = {
id: string
title: string
}
export type AgentBlockType = {
show?: boolean
agentNodes?: AgentNode[]
onSelect?: (agent: AgentNode) => void
}
export type MenuTextMatch = {
leadOffset: number
matchingString: string

View File

@@ -1,29 +1,47 @@
import type { FC } from 'react'
import type { FullDocumentDetail } from '@/models/datasets'
import type { RETRIEVE_METHOD } from '@/types/app'
import type {
DataSourceInfo,
FullDocumentDetail,
IndexingStatusResponse,
LegacyDataSourceInfo,
ProcessRuleResponse,
} from '@/models/datasets'
import {
RiArrowRightLine,
RiCheckboxCircleFill,
RiErrorWarningFill,
RiLoader2Fill,
RiTerminalBoxLine,
} from '@remixicon/react'
import Image from 'next/image'
import Link from 'next/link'
import { useRouter } from 'next/navigation'
import { useMemo } from 'react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
import NotionIcon from '@/app/components/base/notion-icon'
import Tooltip from '@/app/components/base/tooltip'
import PriorityLabel from '@/app/components/billing/priority-label'
import { Plan } from '@/app/components/billing/type'
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
import { useProviderContext } from '@/context/provider-context'
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
import { DataSourceType, ProcessMode } from '@/models/datasets'
import { fetchIndexingStatusBatch as doFetchIndexingStatus } from '@/service/datasets'
import { useProcessRule } from '@/service/knowledge/use-dataset'
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
import IndexingProgressItem from './indexing-progress-item'
import RuleDetail from './rule-detail'
import UpgradeBanner from './upgrade-banner'
import { useIndexingStatusPolling } from './use-indexing-status-polling'
import { createDocumentLookup } from './utils'
import { RETRIEVE_METHOD } from '@/types/app'
import { sleep } from '@/utils'
import { cn } from '@/utils/classnames'
import DocumentFileIcon from '../../common/document-file-icon'
import { indexMethodIcon, retrievalIcon } from '../icons'
import { IndexingType } from '../step-two'
type EmbeddingProcessProps = {
type Props = {
datasetId: string
batchId: string
documents?: FullDocumentDetail[]
@@ -31,121 +49,333 @@ type EmbeddingProcessProps = {
retrievalMethod?: RETRIEVE_METHOD
}
// Status header component
const StatusHeader: FC<{ isEmbedding: boolean, isCompleted: boolean }> = ({
isEmbedding,
isCompleted,
}) => {
const RuleDetail: FC<{
sourceData?: ProcessRuleResponse
indexingType?: string
retrievalMethod?: RETRIEVE_METHOD
}> = ({ sourceData, indexingType, retrievalMethod }) => {
const { t } = useTranslation()
const segmentationRuleMap = {
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
}
const getRuleName = (key: string) => {
if (key === 'remove_extra_spaces')
return t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' })
if (key === 'remove_urls_emails')
return t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' })
if (key === 'remove_stopwords')
return t('stepTwo.removeStopwords', { ns: 'datasetCreation' })
}
const isNumber = (value: unknown) => {
return typeof value === 'number'
}
const getValue = useCallback((field: string) => {
let value: string | number | undefined = '-'
const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens)
? sourceData.rules.segmentation.max_tokens
: value
const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens)
? sourceData.rules.subchunk_segmentation.max_tokens
: value
switch (field) {
case 'mode':
value = !sourceData?.mode
? value
: sourceData.mode === ProcessMode.general
? (t('embedding.custom', { ns: 'datasetDocuments' }) as string)
: `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${sourceData?.rules?.parent_mode === 'paragraph'
? t('parentMode.paragraph', { ns: 'dataset' })
: t('parentMode.fullDoc', { ns: 'dataset' })}`
break
case 'segmentLength':
value = !sourceData?.mode
? value
: sourceData.mode === ProcessMode.general
? maxTokens
: `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
break
default:
value = !sourceData?.mode
? value
: sourceData?.rules?.pre_processing_rules?.filter(rule =>
rule.enabled).map(rule => getRuleName(rule.id)).join(',')
break
}
return value
}, [sourceData])
return (
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
{isEmbedding && (
<>
<RiLoader2Fill className="size-4 animate-spin" />
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
</>
)}
{isCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
<div className="flex flex-col gap-1">
{Object.keys(segmentationRuleMap).map((field) => {
return (
<FieldInfo
key={field}
label={segmentationRuleMap[field as keyof typeof segmentationRuleMap]}
displayedValue={String(getValue(field))}
/>
)
})}
<FieldInfo
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
displayedValue={t(`stepTwo.${indexingType === IndexingType.ECONOMICAL ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' }) as string}
valueIcon={(
<Image
className="size-4"
src={
indexingType === IndexingType.ECONOMICAL
? indexMethodIcon.economical
: indexMethodIcon.high_quality
}
alt=""
/>
)}
/>
<FieldInfo
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
// displayedValue={t(`datasetSettings.form.retrievalSetting.${retrievalMethod}`) as string}
displayedValue={t(`retrieval.${indexingType === IndexingType.ECONOMICAL ? 'keyword_search' : retrievalMethod ?? 'semantic_search'}.title`, { ns: 'dataset' })}
valueIcon={(
<Image
className="size-4"
src={
retrievalMethod === RETRIEVE_METHOD.fullText
? retrievalIcon.fullText
: retrievalMethod === RETRIEVE_METHOD.hybrid
? retrievalIcon.hybrid
: retrievalIcon.vector
}
alt=""
/>
)}
/>
</div>
)
}
// Action buttons component
const ActionButtons: FC<{
apiReferenceUrl: string
onNavToDocuments: () => void
}> = ({ apiReferenceUrl, onNavToDocuments }) => {
const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], indexingType, retrievalMethod }) => {
const { t } = useTranslation()
return (
<div className="mt-6 flex items-center gap-x-2 py-2">
<Link href={apiReferenceUrl} target="_blank" rel="noopener noreferrer">
<Button className="w-fit gap-x-0.5 px-3">
<RiTerminalBoxLine className="size-4" />
<span className="px-0.5">Access the API</span>
</Button>
</Link>
<Button
className="w-fit gap-x-0.5 px-3"
variant="primary"
onClick={onNavToDocuments}
>
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
</Button>
</div>
)
}
const EmbeddingProcess: FC<EmbeddingProcessProps> = ({
datasetId,
batchId,
documents = [],
indexingType,
retrievalMethod,
}) => {
const { enableBilling, plan } = useProviderContext()
const getFirstDocument = documents[0]
const [indexingStatusBatchDetail, setIndexingStatusDetail] = useState<IndexingStatusResponse[]>([])
const fetchIndexingStatus = async () => {
const status = await doFetchIndexingStatus({ datasetId, batchId })
setIndexingStatusDetail(status.data)
return status.data
}
const [isStopQuery, setIsStopQuery] = useState(false)
const isStopQueryRef = useRef(isStopQuery)
useEffect(() => {
isStopQueryRef.current = isStopQuery
}, [isStopQuery])
const stopQueryStatus = () => {
setIsStopQuery(true)
}
const startQueryStatus = async () => {
if (isStopQueryRef.current)
return
try {
const indexingStatusBatchDetail = await fetchIndexingStatus()
const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail.indexing_status))
if (isCompleted) {
stopQueryStatus()
return
}
await sleep(2500)
await startQueryStatus()
}
catch {
await sleep(2500)
await startQueryStatus()
}
}
useEffect(() => {
setIsStopQuery(false)
startQueryStatus()
return () => {
stopQueryStatus()
}
}, [])
// get rule
const { data: ruleDetail } = useProcessRule(getFirstDocument?.id)
const router = useRouter()
const invalidDocumentList = useInvalidDocumentList()
const apiReferenceUrl = useDatasetApiAccessUrl()
// Polling hook for indexing status
const { statusList, isEmbedding, isEmbeddingCompleted } = useIndexingStatusPolling({
datasetId,
batchId,
})
// Get process rule for the first document
const firstDocumentId = documents[0]?.id
const { data: ruleDetail } = useProcessRule(firstDocumentId)
// Document lookup utilities - memoized for performance
const documentLookup = useMemo(
() => createDocumentLookup(documents),
[documents],
)
const handleNavToDocuments = () => {
const navToDocumentList = () => {
invalidDocumentList()
router.push(`/datasets/${datasetId}/documents`)
}
const apiReferenceUrl = useDatasetApiAccessUrl()
const showUpgradeBanner = enableBilling && plan.type !== Plan.team
const isEmbedding = useMemo(() => {
return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || ''))
}, [indexingStatusBatchDetail])
const isEmbeddingCompleted = useMemo(() => {
return indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status || ''))
}, [indexingStatusBatchDetail])
const getSourceName = (id: string) => {
const doc = documents.find(document => document.id === id)
return doc?.name
}
const getFileType = (name?: string) => name?.split('.').pop() || 'txt'
const getSourcePercent = (detail: IndexingStatusResponse) => {
const completedCount = detail.completed_segments || 0
const totalCount = detail.total_segments || 0
if (totalCount === 0)
return 0
const percent = Math.round(completedCount * 100 / totalCount)
return percent > 100 ? 100 : percent
}
const getSourceType = (id: string) => {
const doc = documents.find(document => document.id === id)
return doc?.data_source_type as DataSourceType
}
const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
}
const getIcon = (id: string) => {
const doc = documents.find(document => document.id === id)
const info = doc?.data_source_info
if (info && isLegacyDataSourceInfo(info))
return info.notion_page_icon
return undefined
}
const isSourceEmbedding = (detail: IndexingStatusResponse) =>
['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '')
return (
<>
<div className="flex flex-col gap-y-3">
<StatusHeader isEmbedding={isEmbedding} isCompleted={isEmbeddingCompleted} />
{showUpgradeBanner && <UpgradeBanner />}
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
{isEmbedding && (
<>
<RiLoader2Fill className="size-4 animate-spin" />
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
</>
)}
{isEmbeddingCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
</div>
{
enableBilling && plan.type !== Plan.team && (
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
<ZapFast className="h-4 w-4 text-[#FB6514]" />
</div>
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
</div>
<UpgradeBtn loc="knowledge-speed-up" />
</div>
)
}
<div className="flex flex-col gap-0.5 pb-2">
{statusList.map(detail => (
<IndexingProgressItem
key={detail.id}
detail={detail}
name={documentLookup.getName(detail.id)}
sourceType={documentLookup.getSourceType(detail.id)}
notionIcon={documentLookup.getNotionIcon(detail.id)}
enableBilling={enableBilling}
/>
{indexingStatusBatchDetail.map(indexingStatusDetail => (
<div
key={indexingStatusDetail.id}
className={cn(
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
indexingStatusDetail.indexing_status === 'error' && 'bg-state-destructive-hover-alt',
)}
>
{isSourceEmbedding(indexingStatusDetail) && (
<div
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
style={{ width: `${getSourcePercent(indexingStatusDetail)}%` }}
/>
)}
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && (
<DocumentFileIcon
size="sm"
className="shrink-0"
name={getSourceName(indexingStatusDetail.id)}
extension={getFileType(getSourceName(indexingStatusDetail.id))}
/>
)}
{getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && (
<NotionIcon
className="shrink-0"
type="page"
src={getIcon(indexingStatusDetail.id)}
/>
)}
<div className="flex w-0 grow items-center gap-1" title={getSourceName(indexingStatusDetail.id)}>
<div className="system-xs-medium truncate text-text-secondary">
{getSourceName(indexingStatusDetail.id)}
</div>
{
enableBilling && (
<PriorityLabel className="ml-0" />
)
}
</div>
{isSourceEmbedding(indexingStatusDetail) && (
<div className="shrink-0 text-xs text-text-secondary">{`${getSourcePercent(indexingStatusDetail)}%`}</div>
)}
{indexingStatusDetail.indexing_status === 'error' && (
<Tooltip
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
offset={4}
popupContent={indexingStatusDetail.error}
>
<span>
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
</span>
</Tooltip>
)}
{indexingStatusDetail.indexing_status === 'completed' && (
<RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
)}
</div>
</div>
))}
</div>
<Divider type="horizontal" className="my-0 bg-divider-subtle" />
<RuleDetail
sourceData={ruleDetail}
indexingType={indexingType}
retrievalMethod={retrievalMethod}
/>
</div>
<ActionButtons
apiReferenceUrl={apiReferenceUrl}
onNavToDocuments={handleNavToDocuments}
/>
<div className="mt-6 flex items-center gap-x-2 py-2">
<Link
href={apiReferenceUrl}
target="_blank"
rel="noopener noreferrer"
>
<Button
className="w-fit gap-x-0.5 px-3"
>
<RiTerminalBoxLine className="size-4" />
<span className="px-0.5">Access the API</span>
</Button>
</Link>
<Button
className="w-fit gap-x-0.5 px-3"
variant="primary"
onClick={navToDocumentList}
>
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
</Button>
</div>
</>
)
}

View File

@@ -1,120 +0,0 @@
import type { FC } from 'react'
import type { IndexingStatusResponse } from '@/models/datasets'
import {
RiCheckboxCircleFill,
RiErrorWarningFill,
} from '@remixicon/react'
import NotionIcon from '@/app/components/base/notion-icon'
import Tooltip from '@/app/components/base/tooltip'
import PriorityLabel from '@/app/components/billing/priority-label'
import { DataSourceType } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import DocumentFileIcon from '../../common/document-file-icon'
import { getFileType, getSourcePercent, isSourceEmbedding } from './utils'
type IndexingProgressItemProps = {
detail: IndexingStatusResponse
name?: string
sourceType?: DataSourceType
notionIcon?: string
enableBilling?: boolean
}
// Status icon component for completed/error states
const StatusIcon: FC<{ status: string, error?: string }> = ({ status, error }) => {
if (status === 'completed')
return <RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
if (status === 'error') {
return (
<Tooltip
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
offset={4}
popupContent={error}
>
<span>
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
</span>
</Tooltip>
)
}
return null
}
// Source type icon component
const SourceTypeIcon: FC<{
sourceType?: DataSourceType
name?: string
notionIcon?: string
}> = ({ sourceType, name, notionIcon }) => {
if (sourceType === DataSourceType.FILE) {
return (
<DocumentFileIcon
size="sm"
className="shrink-0"
name={name}
extension={getFileType(name)}
/>
)
}
if (sourceType === DataSourceType.NOTION) {
return (
<NotionIcon
className="shrink-0"
type="page"
src={notionIcon}
/>
)
}
return null
}
const IndexingProgressItem: FC<IndexingProgressItemProps> = ({
detail,
name,
sourceType,
notionIcon,
enableBilling,
}) => {
const isEmbedding = isSourceEmbedding(detail)
const percent = getSourcePercent(detail)
const isError = detail.indexing_status === 'error'
return (
<div
className={cn(
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
isError && 'bg-state-destructive-hover-alt',
)}
>
{isEmbedding && (
<div
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
style={{ width: `${percent}%` }}
/>
)}
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
<SourceTypeIcon
sourceType={sourceType}
name={name}
notionIcon={notionIcon}
/>
<div className="flex w-0 grow items-center gap-1" title={name}>
<div className="system-xs-medium truncate text-text-secondary">
{name}
</div>
{enableBilling && <PriorityLabel className="ml-0" />}
</div>
{isEmbedding && (
<div className="shrink-0 text-xs text-text-secondary">{`${percent}%`}</div>
)}
<StatusIcon status={detail.indexing_status} error={detail.error} />
</div>
</div>
)
}
export default IndexingProgressItem

View File

@@ -1,133 +0,0 @@
import type { FC } from 'react'
import type { ProcessRuleResponse } from '@/models/datasets'
import Image from 'next/image'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
import { ProcessMode } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import { indexMethodIcon, retrievalIcon } from '../icons'
import { IndexingType } from '../step-two'
type RuleDetailProps = {
sourceData?: ProcessRuleResponse
indexingType?: string
retrievalMethod?: RETRIEVE_METHOD
}
// Lookup table for pre-processing rule names
const PRE_PROCESSING_RULE_KEYS = {
remove_extra_spaces: 'stepTwo.removeExtraSpaces',
remove_urls_emails: 'stepTwo.removeUrlEmails',
remove_stopwords: 'stepTwo.removeStopwords',
} as const
// Lookup table for retrieval method icons
const RETRIEVAL_ICON_MAP: Partial<Record<RETRIEVE_METHOD, string>> = {
[RETRIEVE_METHOD.fullText]: retrievalIcon.fullText,
[RETRIEVE_METHOD.hybrid]: retrievalIcon.hybrid,
[RETRIEVE_METHOD.semantic]: retrievalIcon.vector,
[RETRIEVE_METHOD.invertedIndex]: retrievalIcon.fullText,
[RETRIEVE_METHOD.keywordSearch]: retrievalIcon.fullText,
}
const isNumber = (value: unknown): value is number => typeof value === 'number'
const RuleDetail: FC<RuleDetailProps> = ({ sourceData, indexingType, retrievalMethod }) => {
const { t } = useTranslation()
const segmentationRuleLabels = {
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
}
const getRuleName = useCallback((key: string): string | undefined => {
const translationKey = PRE_PROCESSING_RULE_KEYS[key as keyof typeof PRE_PROCESSING_RULE_KEYS]
return translationKey ? t(translationKey, { ns: 'datasetCreation' }) : undefined
}, [t])
const getModeValue = useCallback((): string => {
if (!sourceData?.mode)
return '-'
if (sourceData.mode === ProcessMode.general)
return t('embedding.custom', { ns: 'datasetDocuments' })
const parentModeLabel = sourceData.rules?.parent_mode === 'paragraph'
? t('parentMode.paragraph', { ns: 'dataset' })
: t('parentMode.fullDoc', { ns: 'dataset' })
return `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${parentModeLabel}`
}, [sourceData, t])
const getSegmentLengthValue = useCallback((): string | number => {
if (!sourceData?.mode)
return '-'
const maxTokens = isNumber(sourceData.rules?.segmentation?.max_tokens)
? sourceData.rules.segmentation.max_tokens
: '-'
if (sourceData.mode === ProcessMode.general)
return maxTokens
const childMaxTokens = isNumber(sourceData.rules?.subchunk_segmentation?.max_tokens)
? sourceData.rules.subchunk_segmentation.max_tokens
: '-'
return `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
}, [sourceData, t])
const getTextCleaningValue = useCallback((): string => {
if (!sourceData?.mode)
return '-'
const enabledRules = sourceData.rules?.pre_processing_rules?.filter(rule => rule.enabled) || []
const ruleNames = enabledRules
.map((rule) => {
const name = getRuleName(rule.id)
return typeof name === 'string' ? name : ''
})
.filter(name => name)
return ruleNames.length > 0 ? ruleNames.join(',') : '-'
}, [sourceData, getRuleName])
const fieldValueGetters: Record<string, () => string | number> = {
mode: getModeValue,
segmentLength: getSegmentLengthValue,
textCleaning: getTextCleaningValue,
}
const isEconomical = indexingType === IndexingType.ECONOMICAL
const indexMethodIconSrc = isEconomical ? indexMethodIcon.economical : indexMethodIcon.high_quality
const indexModeLabel = t(`stepTwo.${isEconomical ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' })
const effectiveRetrievalMethod = isEconomical ? 'keyword_search' : (retrievalMethod ?? 'semantic_search')
const retrievalLabel = t(`retrieval.${effectiveRetrievalMethod}.title`, { ns: 'dataset' })
const retrievalIconSrc = RETRIEVAL_ICON_MAP[retrievalMethod as keyof typeof RETRIEVAL_ICON_MAP] ?? retrievalIcon.vector
return (
<div className="flex flex-col gap-1">
{Object.keys(segmentationRuleLabels).map(field => (
<FieldInfo
key={field}
label={segmentationRuleLabels[field as keyof typeof segmentationRuleLabels]}
displayedValue={String(fieldValueGetters[field]())}
/>
))}
<FieldInfo
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
displayedValue={indexModeLabel}
valueIcon={<Image className="size-4" src={indexMethodIconSrc} alt="" />}
/>
<FieldInfo
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
displayedValue={retrievalLabel}
valueIcon={<Image className="size-4" src={retrievalIconSrc} alt="" />}
/>
</div>
)
}
export default RuleDetail

View File

@@ -1,22 +0,0 @@
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
const UpgradeBanner: FC = () => {
const { t } = useTranslation()
return (
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
<ZapFast className="h-4 w-4 text-[#FB6514]" />
</div>
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
</div>
<UpgradeBtn loc="knowledge-speed-up" />
</div>
)
}
export default UpgradeBanner

View File

@@ -1,90 +0,0 @@
import type { IndexingStatusResponse } from '@/models/datasets'
import { useEffect, useRef, useState } from 'react'
import { fetchIndexingStatusBatch } from '@/service/datasets'
const POLLING_INTERVAL = 2500
const COMPLETED_STATUSES = ['completed', 'error', 'paused'] as const
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
type IndexingStatusPollingParams = {
datasetId: string
batchId: string
}
type IndexingStatusPollingResult = {
statusList: IndexingStatusResponse[]
isEmbedding: boolean
isEmbeddingCompleted: boolean
}
const isStatusCompleted = (status: string): boolean =>
COMPLETED_STATUSES.includes(status as typeof COMPLETED_STATUSES[number])
const isAllCompleted = (statusList: IndexingStatusResponse[]): boolean =>
statusList.every(item => isStatusCompleted(item.indexing_status))
/**
* Custom hook for polling indexing status with automatic stop on completion.
* Handles the polling lifecycle and provides derived states for UI rendering.
*/
export const useIndexingStatusPolling = ({
datasetId,
batchId,
}: IndexingStatusPollingParams): IndexingStatusPollingResult => {
const [statusList, setStatusList] = useState<IndexingStatusResponse[]>([])
const isStopPollingRef = useRef(false)
useEffect(() => {
// Reset polling state on mount
isStopPollingRef.current = false
let timeoutId: ReturnType<typeof setTimeout> | null = null
const fetchStatus = async (): Promise<IndexingStatusResponse[]> => {
const response = await fetchIndexingStatusBatch({ datasetId, batchId })
setStatusList(response.data)
return response.data
}
const poll = async (): Promise<void> => {
if (isStopPollingRef.current)
return
try {
const data = await fetchStatus()
if (isAllCompleted(data)) {
isStopPollingRef.current = true
return
}
}
catch {
// Continue polling on error
}
if (!isStopPollingRef.current) {
timeoutId = setTimeout(() => {
poll()
}, POLLING_INTERVAL)
}
}
poll()
return () => {
isStopPollingRef.current = true
if (timeoutId)
clearTimeout(timeoutId)
}
}, [datasetId, batchId])
const isEmbedding = statusList.some(item =>
EMBEDDING_STATUSES.includes(item?.indexing_status as typeof EMBEDDING_STATUSES[number]),
)
const isEmbeddingCompleted = statusList.length > 0 && isAllCompleted(statusList)
return {
statusList,
isEmbedding,
isEmbeddingCompleted,
}
}

View File

@@ -1,64 +0,0 @@
import type {
DataSourceInfo,
DataSourceType,
FullDocumentDetail,
IndexingStatusResponse,
LegacyDataSourceInfo,
} from '@/models/datasets'
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
/**
* Type guard for legacy data source info with upload_file property
*/
export const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
}
/**
* Check if a status indicates the source is being embedded
*/
export const isSourceEmbedding = (detail: IndexingStatusResponse): boolean =>
EMBEDDING_STATUSES.includes(detail.indexing_status as typeof EMBEDDING_STATUSES[number])
/**
* Calculate the progress percentage for a document
*/
export const getSourcePercent = (detail: IndexingStatusResponse): number => {
const completedCount = detail.completed_segments || 0
const totalCount = detail.total_segments || 0
if (totalCount === 0)
return 0
const percent = Math.round(completedCount * 100 / totalCount)
return Math.min(percent, 100)
}
/**
* Get file extension from filename, defaults to 'txt'
*/
export const getFileType = (name?: string): string =>
name?.split('.').pop() || 'txt'
/**
* Document lookup utilities - provides document info by ID from a list
*/
export const createDocumentLookup = (documents: FullDocumentDetail[]) => {
const documentMap = new Map(documents.map(doc => [doc.id, doc]))
return {
getDocument: (id: string) => documentMap.get(id),
getName: (id: string) => documentMap.get(id)?.name,
getSourceType: (id: string) => documentMap.get(id)?.data_source_type as DataSourceType | undefined,
getNotionIcon: (id: string) => {
const info = documentMap.get(id)?.data_source_info
if (info && isLegacyDataSourceInfo(info))
return info.notion_page_icon
return undefined
},
}
}

View File

@@ -1,199 +0,0 @@
'use client'
import type { FC } from 'react'
import type { PreProcessingRule } from '@/models/datasets'
import {
RiAlertFill,
RiSearchEyeLine,
} from '@remixicon/react'
import Image from 'next/image'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import SettingCog from '../../assets/setting-gear-mod.svg'
import s from '../index.module.css'
import LanguageSelect from '../language-select'
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import { OptionCard } from './option-card'
type TextLabelProps = {
children: React.ReactNode
}
const TextLabel: FC<TextLabelProps> = ({ children }) => {
return <label className="system-sm-semibold text-text-secondary">{children}</label>
}
type GeneralChunkingOptionsProps = {
// State
segmentIdentifier: string
maxChunkLength: number
overlap: number
rules: PreProcessingRule[]
currentDocForm: ChunkingMode
docLanguage: string
// Flags
isActive: boolean
isInUpload: boolean
isNotUploadInEmptyDataset: boolean
hasCurrentDatasetDocForm: boolean
// Actions
onSegmentIdentifierChange: (value: string) => void
onMaxChunkLengthChange: (value: number) => void
onOverlapChange: (value: number) => void
onRuleToggle: (id: string) => void
onDocFormChange: (form: ChunkingMode) => void
onDocLanguageChange: (lang: string) => void
onPreview: () => void
onReset: () => void
// Locale
locale: string
}
export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
segmentIdentifier,
maxChunkLength,
overlap,
rules,
currentDocForm,
docLanguage,
isActive,
isInUpload,
isNotUploadInEmptyDataset,
hasCurrentDatasetDocForm,
onSegmentIdentifierChange,
onMaxChunkLengthChange,
onOverlapChange,
onRuleToggle,
onDocFormChange,
onDocLanguageChange,
onPreview,
onReset,
locale,
}) => {
const { t } = useTranslation()
const getRuleName = (key: string): string => {
const ruleNameMap: Record<string, string> = {
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
}
return ruleNameMap[key] ?? key
}
return (
<OptionCard
className="mb-2 bg-background-section"
title={t('stepTwo.general', { ns: 'datasetCreation' })}
icon={<Image width={20} height={20} src={SettingCog} alt={t('stepTwo.general', { ns: 'datasetCreation' })} />}
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
description={t('stepTwo.generalTip', { ns: 'datasetCreation' })}
isActive={isActive}
onSwitched={() => onDocFormChange(ChunkingMode.text)}
actions={(
<>
<Button variant="secondary-accent" onClick={onPreview}>
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
</Button>
<Button variant="ghost" onClick={onReset}>
{t('stepTwo.reset', { ns: 'datasetCreation' })}
</Button>
</>
)}
noHighlight={isInUpload && isNotUploadInEmptyDataset}
>
<div className="flex flex-col gap-y-4">
<div className="flex gap-3">
<DelimiterInput
value={segmentIdentifier}
onChange={e => onSegmentIdentifierChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={maxChunkLength}
onChange={onMaxChunkLengthChange}
/>
<OverlapInput
unit="characters"
value={overlap}
min={1}
onChange={onOverlapChange}
/>
</div>
<div className="flex w-full flex-col">
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1">
{rules.map(rule => (
<div
key={rule.id}
className={s.ruleItem}
onClick={() => onRuleToggle(rule.id)}
>
<Checkbox checked={rule.enabled} />
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{getRuleName(rule.id)}
</label>
</div>
))}
{IS_CE_EDITION && (
<>
<Divider type="horizontal" className="my-4 bg-divider-subtle" />
<div className="flex items-center py-0.5">
<div
className="flex items-center"
onClick={() => {
if (hasCurrentDatasetDocForm)
return
if (currentDocForm === ChunkingMode.qa)
onDocFormChange(ChunkingMode.text)
else
onDocFormChange(ChunkingMode.qa)
}}
>
<Checkbox
checked={currentDocForm === ChunkingMode.qa}
disabled={hasCurrentDatasetDocForm}
/>
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{t('stepTwo.useQALanguage', { ns: 'datasetCreation' })}
</label>
</div>
<LanguageSelect
currentLanguage={docLanguage || locale}
onSelect={onDocLanguageChange}
disabled={currentDocForm !== ChunkingMode.qa}
/>
<Tooltip popupContent={t('stepTwo.QATip', { ns: 'datasetCreation' })} />
</div>
{currentDocForm === ChunkingMode.qa && (
<div
style={{
background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.1) 0%, rgba(255, 255, 255, 0.00) 100%)',
}}
className="mt-2 flex h-10 items-center gap-2 rounded-xl border border-components-panel-border px-3 text-xs shadow-xs backdrop-blur-[5px]"
>
<RiAlertFill className="size-4 text-text-warning-secondary" />
<span className="system-xs-medium text-text-primary">
{t('stepTwo.QATip', { ns: 'datasetCreation' })}
</span>
</div>
)}
</>
)}
</div>
</div>
</div>
</OptionCard>
)
}

View File

@@ -1,5 +0,0 @@
export { GeneralChunkingOptions } from './general-chunking-options'
export { IndexingModeSection } from './indexing-mode-section'
export { ParentChildOptions } from './parent-child-options'
export { PreviewPanel } from './preview-panel'
export { StepTwoFooter } from './step-two-footer'

View File

@@ -1,253 +0,0 @@
'use client'
import type { FC } from 'react'
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import Image from 'next/image'
import Link from 'next/link'
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import CustomDialog from '@/app/components/base/dialog'
import Divider from '@/app/components/base/divider'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useDocLink } from '@/context/i18n'
import { ChunkingMode } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import { indexMethodIcon } from '../../icons'
import { IndexingType } from '../hooks'
import s from '../index.module.css'
import { OptionCard } from './option-card'
type IndexingModeSectionProps = {
// State
indexType: IndexingType
hasSetIndexType: boolean
docForm: ChunkingMode
embeddingModel: DefaultModel
embeddingModelList?: Model[]
retrievalConfig: RetrievalConfig
showMultiModalTip: boolean
// Flags
isModelAndRetrievalConfigDisabled: boolean
datasetId?: string
// Modal state
isQAConfirmDialogOpen: boolean
// Actions
onIndexTypeChange: (type: IndexingType) => void
onEmbeddingModelChange: (model: DefaultModel) => void
onRetrievalConfigChange: (config: RetrievalConfig) => void
onQAConfirmDialogClose: () => void
onQAConfirmDialogConfirm: () => void
}
export const IndexingModeSection: FC<IndexingModeSectionProps> = ({
indexType,
hasSetIndexType,
docForm,
embeddingModel,
embeddingModelList,
retrievalConfig,
showMultiModalTip,
isModelAndRetrievalConfigDisabled,
datasetId,
isQAConfirmDialogOpen,
onIndexTypeChange,
onEmbeddingModelChange,
onRetrievalConfigChange,
onQAConfirmDialogClose,
onQAConfirmDialogConfirm,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const getIndexingTechnique = () => indexType
return (
<>
{/* Index Mode */}
<div className="system-md-semibold mb-1 text-text-secondary">
{t('stepTwo.indexMode', { ns: 'datasetCreation' })}
</div>
<div className="flex items-center gap-2">
{/* Qualified option */}
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.QUALIFIED)) && (
<OptionCard
className="flex-1 self-stretch"
title={(
<div className="flex items-center">
{t('stepTwo.qualified', { ns: 'datasetCreation' })}
<Badge
className={cn(
'ml-1 h-[18px]',
(!hasSetIndexType && indexType === IndexingType.QUALIFIED)
? 'border-text-accent-secondary text-text-accent-secondary'
: '',
)}
uppercase
>
{t('stepTwo.recommend', { ns: 'datasetCreation' })}
</Badge>
<span className="ml-auto">
{!hasSetIndexType && <span className={cn(s.radio)} />}
</span>
</div>
)}
description={t('stepTwo.qualifiedTip', { ns: 'datasetCreation' })}
icon={<Image src={indexMethodIcon.high_quality} alt="" />}
isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED}
disabled={hasSetIndexType}
onSwitched={() => onIndexTypeChange(IndexingType.QUALIFIED)}
/>
)}
{/* Economical option */}
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.ECONOMICAL)) && (
<>
<CustomDialog show={isQAConfirmDialogOpen} onClose={onQAConfirmDialogClose} className="w-[432px]">
<header className="mb-4 pt-6">
<h2 className="text-lg font-semibold text-text-primary">
{t('stepTwo.qaSwitchHighQualityTipTitle', { ns: 'datasetCreation' })}
</h2>
<p className="mt-2 text-sm font-normal text-text-secondary">
{t('stepTwo.qaSwitchHighQualityTipContent', { ns: 'datasetCreation' })}
</p>
</header>
<div className="flex gap-2 pb-6">
<Button className="ml-auto" onClick={onQAConfirmDialogClose}>
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
</Button>
<Button variant="primary" onClick={onQAConfirmDialogConfirm}>
{t('stepTwo.switch', { ns: 'datasetCreation' })}
</Button>
</div>
</CustomDialog>
<Tooltip
popupContent={(
<div className="rounded-lg border-components-panel-border bg-components-tooltip-bg p-3 text-xs font-medium text-text-secondary shadow-lg">
{docForm === ChunkingMode.qa
? t('stepTwo.notAvailableForQA', { ns: 'datasetCreation' })
: t('stepTwo.notAvailableForParentChild', { ns: 'datasetCreation' })}
</div>
)}
noDecoration
position="top"
asChild={false}
triggerClassName="flex-1 self-stretch"
>
<OptionCard
className="h-full"
title={t('stepTwo.economical', { ns: 'datasetCreation' })}
description={t('stepTwo.economicalTip', { ns: 'datasetCreation' })}
icon={<Image src={indexMethodIcon.economical} alt="" />}
isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL}
disabled={hasSetIndexType || docForm !== ChunkingMode.text}
onSwitched={() => onIndexTypeChange(IndexingType.ECONOMICAL)}
/>
</Tooltip>
</>
)}
</div>
{/* High quality tip */}
{!hasSetIndexType && indexType === IndexingType.QUALIFIED && (
<div className="mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-2 shadow-xs backdrop-blur-[5px]">
<div className="absolute bottom-0 left-0 right-0 top-0 bg-dataset-warning-message-bg opacity-40"></div>
<div className="p-1">
<AlertTriangle className="size-4 text-text-warning-secondary" />
</div>
<span className="system-xs-medium text-text-primary">
{t('stepTwo.highQualityTip', { ns: 'datasetCreation' })}
</span>
</div>
)}
{/* Economical index setting tip */}
{hasSetIndexType && indexType === IndexingType.ECONOMICAL && (
<div className="system-xs-medium mt-2 text-text-tertiary">
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
</Link>
</div>
)}
{/* Embedding model */}
{indexType === IndexingType.QUALIFIED && (
<div className="mt-5">
<div className={cn('system-md-semibold mb-1 text-text-secondary', datasetId && 'flex items-center justify-between')}>
{t('form.embeddingModel', { ns: 'datasetSettings' })}
</div>
<ModelSelector
readonly={isModelAndRetrievalConfigDisabled}
triggerClassName={isModelAndRetrievalConfigDisabled ? 'opacity-50' : ''}
defaultModel={embeddingModel}
modelList={embeddingModelList ?? []}
onSelect={onEmbeddingModelChange}
/>
{isModelAndRetrievalConfigDisabled && (
<div className="system-xs-medium mt-2 text-text-tertiary">
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
</Link>
</div>
)}
</div>
)}
<Divider className="my-5" />
{/* Retrieval Method Config */}
<div>
{!isModelAndRetrievalConfigDisabled
? (
<div className="mb-1">
<div className="system-md-semibold mb-0.5 text-text-secondary">
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
</div>
<div className="body-xs-regular text-text-tertiary">
<a
target="_blank"
rel="noopener noreferrer"
href={docLink('/guides/knowledge-base/create-knowledge-and-upload-documents')}
className="text-text-accent"
>
{t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })}
</a>
{t('form.retrievalSetting.longDescription', { ns: 'datasetSettings' })}
</div>
</div>
)
: (
<div className={cn('system-md-semibold mb-0.5 text-text-secondary', 'flex items-center justify-between')}>
<div>{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}</div>
</div>
)}
<div>
{getIndexingTechnique() === IndexingType.QUALIFIED
? (
<RetrievalMethodConfig
disabled={isModelAndRetrievalConfigDisabled}
value={retrievalConfig}
onChange={onRetrievalConfigChange}
showMultiModalTip={showMultiModalTip}
/>
)
: (
<EconomicalRetrievalMethodConfig
disabled={isModelAndRetrievalConfigDisabled}
value={retrievalConfig}
onChange={onRetrievalConfigChange}
/>
)}
</div>
</div>
</>
)
}

View File

@@ -1,191 +0,0 @@
'use client'
import type { FC } from 'react'
import type { ParentChildConfig } from '../hooks'
import type { ParentMode, PreProcessingRule } from '@/models/datasets'
import { RiSearchEyeLine } from '@remixicon/react'
import Image from 'next/image'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card'
import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg'
import BlueEffect from '../../assets/option-card-effect-blue.svg'
import s from '../index.module.css'
import { DelimiterInput, MaxLengthInput } from './inputs'
import { OptionCard } from './option-card'
type TextLabelProps = {
children: React.ReactNode
}
const TextLabel: FC<TextLabelProps> = ({ children }) => {
return <label className="system-sm-semibold text-text-secondary">{children}</label>
}
type ParentChildOptionsProps = {
// State
parentChildConfig: ParentChildConfig
rules: PreProcessingRule[]
currentDocForm: ChunkingMode
// Flags
isActive: boolean
isInUpload: boolean
isNotUploadInEmptyDataset: boolean
// Actions
onDocFormChange: (form: ChunkingMode) => void
onChunkForContextChange: (mode: ParentMode) => void
onParentDelimiterChange: (value: string) => void
onParentMaxLengthChange: (value: number) => void
onChildDelimiterChange: (value: string) => void
onChildMaxLengthChange: (value: number) => void
onRuleToggle: (id: string) => void
onPreview: () => void
onReset: () => void
}
export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
parentChildConfig,
rules,
currentDocForm: _currentDocForm,
isActive,
isInUpload,
isNotUploadInEmptyDataset,
onDocFormChange,
onChunkForContextChange,
onParentDelimiterChange,
onParentMaxLengthChange,
onChildDelimiterChange,
onChildMaxLengthChange,
onRuleToggle,
onPreview,
onReset,
}) => {
const { t } = useTranslation()
const getRuleName = (key: string): string => {
const ruleNameMap: Record<string, string> = {
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
}
return ruleNameMap[key] ?? key
}
return (
<OptionCard
title={t('stepTwo.parentChild', { ns: 'datasetCreation' })}
icon={<ParentChildChunk className="h-[20px] w-[20px]" />}
effectImg={BlueEffect.src}
className="text-util-colors-blue-light-blue-light-500"
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
description={t('stepTwo.parentChildTip', { ns: 'datasetCreation' })}
isActive={isActive}
onSwitched={() => onDocFormChange(ChunkingMode.parentChild)}
actions={(
<>
<Button variant="secondary-accent" onClick={onPreview}>
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
</Button>
<Button variant="ghost" onClick={onReset}>
{t('stepTwo.reset', { ns: 'datasetCreation' })}
</Button>
</>
)}
noHighlight={isInUpload && isNotUploadInEmptyDataset}
>
<div className="flex flex-col gap-4">
{/* Parent chunk for context */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.parentChunkForContext', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<RadioCard
className="mt-1"
icon={<Image src={Note} alt="" />}
title={t('stepTwo.paragraph', { ns: 'datasetCreation' })}
description={t('stepTwo.paragraphTip', { ns: 'datasetCreation' })}
isChosen={parentChildConfig.chunkForContext === 'paragraph'}
onChosen={() => onChunkForContextChange('paragraph')}
chosenConfig={(
<div className="flex gap-3">
<DelimiterInput
value={parentChildConfig.parent.delimiter}
tooltip={t('stepTwo.parentChildDelimiterTip', { ns: 'datasetCreation' })!}
onChange={e => onParentDelimiterChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={parentChildConfig.parent.maxLength}
onChange={onParentMaxLengthChange}
/>
</div>
)}
/>
<RadioCard
className="mt-2"
icon={<Image src={FileList} alt="" />}
title={t('stepTwo.fullDoc', { ns: 'datasetCreation' })}
description={t('stepTwo.fullDocTip', { ns: 'datasetCreation' })}
onChosen={() => onChunkForContextChange('full-doc')}
isChosen={parentChildConfig.chunkForContext === 'full-doc'}
/>
</div>
{/* Child chunk for retrieval */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.childChunkForRetrieval', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1 flex gap-3">
<DelimiterInput
value={parentChildConfig.child.delimiter}
tooltip={t('stepTwo.parentChildChunkDelimiterTip', { ns: 'datasetCreation' })!}
onChange={e => onChildDelimiterChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={parentChildConfig.child.maxLength}
onChange={onChildMaxLengthChange}
/>
</div>
</div>
{/* Rules */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1">
{rules.map(rule => (
<div
key={rule.id}
className={s.ruleItem}
onClick={() => onRuleToggle(rule.id)}
>
<Checkbox checked={rule.enabled} />
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{getRuleName(rule.id)}
</label>
</div>
))}
</div>
</div>
</div>
</OptionCard>
)
}

View File

@@ -1,171 +0,0 @@
'use client'
import type { FC } from 'react'
import type { ParentChildConfig } from '../hooks'
import type { DataSourceType, FileIndexingEstimateResponse } from '@/models/datasets'
import { RiSearchEyeLine } from '@remixicon/react'
import { noop } from 'es-toolkit/function'
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import FloatRightContainer from '@/app/components/base/float-right-container'
import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton'
import { FULL_DOC_PREVIEW_LENGTH } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import { ChunkContainer, QAPreview } from '../../../chunk'
import PreviewDocumentPicker from '../../../common/document-picker/preview-document-picker'
import { PreviewSlice } from '../../../formatted-text/flavours/preview-slice'
import { FormattedText } from '../../../formatted-text/formatted'
import PreviewContainer from '../../../preview/container'
import { PreviewHeader } from '../../../preview/header'
type PreviewPanelProps = {
// State
isMobile: boolean
dataSourceType: DataSourceType
currentDocForm: ChunkingMode
estimate?: FileIndexingEstimateResponse
parentChildConfig: ParentChildConfig
isSetting?: boolean
// Picker
pickerFiles: Array<{ id: string, name: string, extension: string }>
pickerValue: { id: string, name: string, extension: string }
// Mutation state
isIdle: boolean
isPending: boolean
// Actions
onPickerChange: (selected: { id: string, name: string }) => void
}
export const PreviewPanel: FC<PreviewPanelProps> = ({
isMobile,
dataSourceType: _dataSourceType,
currentDocForm,
estimate,
parentChildConfig,
isSetting,
pickerFiles,
pickerValue,
isIdle,
isPending,
onPickerChange,
}) => {
const { t } = useTranslation()
return (
<FloatRightContainer isMobile={isMobile} isOpen={true} onClose={noop} footer={null}>
<PreviewContainer
header={(
<PreviewHeader title={t('stepTwo.preview', { ns: 'datasetCreation' })}>
<div className="flex items-center gap-1">
<PreviewDocumentPicker
files={pickerFiles as Array<Required<{ id: string, name: string, extension: string }>>}
onChange={onPickerChange}
value={isSetting ? pickerFiles[0] : pickerValue}
/>
{currentDocForm !== ChunkingMode.qa && (
<Badge
text={t('stepTwo.previewChunkCount', {
ns: 'datasetCreation',
count: estimate?.total_segments || 0,
}) as string}
/>
)}
</div>
</PreviewHeader>
)}
className={cn('relative flex h-full w-1/2 shrink-0 p-4 pr-0', isMobile && 'w-full max-w-[524px]')}
mainClassName="space-y-6"
>
{/* QA Preview */}
{currentDocForm === ChunkingMode.qa && estimate?.qa_preview && (
estimate.qa_preview.map((item, index) => (
<ChunkContainer
key={item.question}
label={`Chunk-${index + 1}`}
characterCount={item.question.length + item.answer.length}
>
<QAPreview qa={item} />
</ChunkContainer>
))
)}
{/* Text Preview */}
{currentDocForm === ChunkingMode.text && estimate?.preview && (
estimate.preview.map((item, index) => (
<ChunkContainer
key={item.content}
label={`Chunk-${index + 1}`}
characterCount={item.content.length}
>
{item.content}
</ChunkContainer>
))
)}
{/* Parent-Child Preview */}
{currentDocForm === ChunkingMode.parentChild && estimate?.preview && (
estimate.preview.map((item, index) => {
const indexForLabel = index + 1
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
: item.child_chunks
return (
<ChunkContainer
key={item.content}
label={`Chunk-${indexForLabel}`}
characterCount={item.content.length}
>
<FormattedText>
{childChunks.map((child, childIndex) => {
const childIndexForLabel = childIndex + 1
return (
<PreviewSlice
key={`C-${childIndexForLabel}-${child}`}
label={`C-${childIndexForLabel}`}
text={child}
tooltip={`Child-chunk-${childIndexForLabel} · ${child.length} Characters`}
labelInnerClassName="text-[10px] font-semibold align-bottom leading-7"
dividerClassName="leading-7"
/>
)
})}
</FormattedText>
</ChunkContainer>
)
})
)}
{/* Idle State */}
{isIdle && (
<div className="flex h-full w-full items-center justify-center">
<div className="flex flex-col items-center justify-center gap-3">
<RiSearchEyeLine className="size-10 text-text-empty-state-icon" />
<p className="text-sm text-text-tertiary">
{t('stepTwo.previewChunkTip', { ns: 'datasetCreation' })}
</p>
</div>
</div>
)}
{/* Loading State */}
{isPending && (
<div className="space-y-6">
{Array.from({ length: 10 }, (_, i) => (
<SkeletonContainer key={i}>
<SkeletonRow>
<SkeletonRectangle className="w-20" />
<SkeletonPoint />
<SkeletonRectangle className="w-24" />
</SkeletonRow>
<SkeletonRectangle className="w-full" />
<SkeletonRectangle className="w-full" />
<SkeletonRectangle className="w-[422px]" />
</SkeletonContainer>
))}
</div>
)}
</PreviewContainer>
</FloatRightContainer>
)
}

View File

@@ -1,58 +0,0 @@
'use client'
import type { FC } from 'react'
import { RiArrowLeftLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
type StepTwoFooterProps = {
isSetting?: boolean
isCreating: boolean
onPrevious: () => void
onCreate: () => void
onCancel?: () => void
}
export const StepTwoFooter: FC<StepTwoFooterProps> = ({
isSetting,
isCreating,
onPrevious,
onCreate,
onCancel,
}) => {
const { t } = useTranslation()
if (!isSetting) {
return (
<div className="mt-8 flex items-center py-2">
<Button onClick={onPrevious}>
<RiArrowLeftLine className="mr-1 h-4 w-4" />
{t('stepTwo.previousStep', { ns: 'datasetCreation' })}
</Button>
<Button
className="ml-auto"
loading={isCreating}
variant="primary"
onClick={onCreate}
>
{t('stepTwo.nextStep', { ns: 'datasetCreation' })}
</Button>
</div>
)
}
return (
<div className="mt-8 flex items-center py-2">
<Button
loading={isCreating}
variant="primary"
onClick={onCreate}
>
{t('stepTwo.save', { ns: 'datasetCreation' })}
</Button>
<Button className="ml-2" onClick={onCancel}>
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
</Button>
</div>
)
}

View File

@@ -1,14 +0,0 @@
export { useDocumentCreation } from './use-document-creation'
export type { DocumentCreation, ValidationParams } from './use-document-creation'
export { IndexingType, useIndexingConfig } from './use-indexing-config'
export type { IndexingConfig } from './use-indexing-config'
export { useIndexingEstimate } from './use-indexing-estimate'
export type { IndexingEstimate } from './use-indexing-estimate'
export { usePreviewState } from './use-preview-state'
export type { PreviewState } from './use-preview-state'
export { DEFAULT_MAXIMUM_CHUNK_LENGTH, DEFAULT_OVERLAP, DEFAULT_SEGMENT_IDENTIFIER, defaultParentChildConfig, MAXIMUM_CHUNK_TOKEN_LENGTH, useSegmentationState } from './use-segmentation-state'
export type { ParentChildConfig, SegmentationState } from './use-segmentation-state'

View File

@@ -1,279 +0,0 @@
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { NotionPage } from '@/models/common'
import type {
ChunkingMode,
CrawlOptions,
CrawlResultItem,
CreateDocumentReq,
createDocumentResponse,
CustomFile,
FullDocumentDetail,
ProcessRule,
} from '@/models/datasets'
import type { RetrievalConfig, RETRIEVE_METHOD } from '@/types/app'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Toast from '@/app/components/base/toast'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { DataSourceProvider } from '@/models/common'
import {
DataSourceType,
} from '@/models/datasets'
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument } from '@/service/knowledge/use-create-dataset'
import { useInvalidDatasetList } from '@/service/knowledge/use-dataset'
import { IndexingType } from './use-indexing-config'
import { MAXIMUM_CHUNK_TOKEN_LENGTH } from './use-segmentation-state'
export type UseDocumentCreationOptions = {
datasetId?: string
isSetting?: boolean
documentDetail?: FullDocumentDetail
dataSourceType: DataSourceType
files: CustomFile[]
notionPages: NotionPage[]
notionCredentialId: string
websitePages: CrawlResultItem[]
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
// Callbacks
onStepChange?: (delta: number) => void
updateIndexingTypeCache?: (type: string) => void
updateResultCache?: (res: createDocumentResponse) => void
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
onSave?: () => void
mutateDatasetRes?: () => void
}
export type ValidationParams = {
segmentationType: string
maxChunkLength: number
limitMaxChunkLength: number
overlap: number
indexType: IndexingType
embeddingModel: DefaultModel
rerankModelList: Model[]
retrievalConfig: RetrievalConfig
}
export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
const { t } = useTranslation()
const {
datasetId,
isSetting,
documentDetail,
dataSourceType,
files,
notionPages,
notionCredentialId,
websitePages,
crawlOptions,
websiteCrawlProvider = DataSourceProvider.jinaReader,
websiteCrawlJobId = '',
onStepChange,
updateIndexingTypeCache,
updateResultCache,
updateRetrievalMethodCache,
onSave,
mutateDatasetRes,
} = options
const createFirstDocumentMutation = useCreateFirstDocument()
const createDocumentMutation = useCreateDocument(datasetId!)
const invalidDatasetList = useInvalidDatasetList()
const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending
// Validate creation params
const validateParams = useCallback((params: ValidationParams): boolean => {
const {
segmentationType,
maxChunkLength,
limitMaxChunkLength,
overlap,
indexType,
embeddingModel,
rerankModelList,
retrievalConfig,
} = params
if (segmentationType === 'general' && overlap > maxChunkLength) {
Toast.notify({ type: 'error', message: t('stepTwo.overlapCheck', { ns: 'datasetCreation' }) })
return false
}
if (segmentationType === 'general' && maxChunkLength > limitMaxChunkLength) {
Toast.notify({
type: 'error',
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: limitMaxChunkLength }),
})
return false
}
if (!isSetting) {
if (indexType === IndexingType.QUALIFIED && (!embeddingModel.model || !embeddingModel.provider)) {
Toast.notify({
type: 'error',
message: t('datasetConfig.embeddingModelRequired', { ns: 'appDebug' }),
})
return false
}
if (!isReRankModelSelected({
rerankModelList,
retrievalConfig,
indexMethod: indexType,
})) {
Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) })
return false
}
}
return true
}, [t, isSetting])
// Build creation params
const buildCreationParams = useCallback((
currentDocForm: ChunkingMode,
docLanguage: string,
processRule: ProcessRule,
retrievalConfig: RetrievalConfig,
embeddingModel: DefaultModel,
indexingTechnique: string,
): CreateDocumentReq | null => {
if (isSetting) {
return {
original_document_id: documentDetail?.id,
doc_form: currentDocForm,
doc_language: docLanguage,
process_rule: processRule,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
indexing_technique: indexingTechnique,
} as CreateDocumentReq
}
const params: CreateDocumentReq = {
data_source: {
type: dataSourceType,
info_list: {
data_source_type: dataSourceType,
},
},
indexing_technique: indexingTechnique,
process_rule: processRule,
doc_form: currentDocForm,
doc_language: docLanguage,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
} as CreateDocumentReq
// Add data source specific info
if (dataSourceType === DataSourceType.FILE) {
params.data_source!.info_list.file_info_list = {
file_ids: files.map(file => file.id || '').filter(Boolean),
}
}
if (dataSourceType === DataSourceType.NOTION)
params.data_source!.info_list.notion_info_list = getNotionInfo(notionPages, notionCredentialId)
if (dataSourceType === DataSourceType.WEB) {
params.data_source!.info_list.website_info_list = getWebsiteInfo({
websiteCrawlProvider,
websiteCrawlJobId,
websitePages,
crawlOptions,
})
}
return params
}, [
isSetting,
documentDetail,
dataSourceType,
files,
notionPages,
notionCredentialId,
websitePages,
websiteCrawlProvider,
websiteCrawlJobId,
crawlOptions,
])
// Execute creation
const executeCreation = useCallback(async (
params: CreateDocumentReq,
indexType: IndexingType,
retrievalConfig: RetrievalConfig,
) => {
if (!datasetId) {
await createFirstDocumentMutation.mutateAsync(params, {
onSuccess(data) {
updateIndexingTypeCache?.(indexType)
updateResultCache?.(data)
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
},
})
}
else {
await createDocumentMutation.mutateAsync(params, {
onSuccess(data) {
updateIndexingTypeCache?.(indexType)
updateResultCache?.(data)
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
},
})
}
mutateDatasetRes?.()
invalidDatasetList()
trackEvent('create_datasets', {
data_source_type: dataSourceType,
indexing_technique: indexType,
})
onStepChange?.(+1)
if (isSetting)
onSave?.()
}, [
datasetId,
createFirstDocumentMutation,
createDocumentMutation,
updateIndexingTypeCache,
updateResultCache,
updateRetrievalMethodCache,
mutateDatasetRes,
invalidDatasetList,
dataSourceType,
onStepChange,
isSetting,
onSave,
])
// Validate preview params
const validatePreviewParams = useCallback((maxChunkLength: number): boolean => {
if (maxChunkLength > MAXIMUM_CHUNK_TOKEN_LENGTH) {
Toast.notify({
type: 'error',
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: MAXIMUM_CHUNK_TOKEN_LENGTH }),
})
return false
}
return true
}, [t])
return {
isCreating,
validateParams,
buildCreationParams,
executeCreation,
validatePreviewParams,
}
}
export type DocumentCreation = ReturnType<typeof useDocumentCreation>

View File

@@ -1,143 +0,0 @@
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import { useEffect, useMemo, useState } from 'react'
import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { RETRIEVE_METHOD } from '@/types/app'
export enum IndexingType {
QUALIFIED = 'high_quality',
ECONOMICAL = 'economy',
}
const DEFAULT_RETRIEVAL_CONFIG: RetrievalConfig = {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
}
export type UseIndexingConfigOptions = {
initialIndexType?: IndexingType
initialEmbeddingModel?: DefaultModel
initialRetrievalConfig?: RetrievalConfig
isAPIKeySet: boolean
hasSetIndexType: boolean
}
export const useIndexingConfig = (options: UseIndexingConfigOptions) => {
const {
initialIndexType,
initialEmbeddingModel,
initialRetrievalConfig,
isAPIKeySet,
hasSetIndexType,
} = options
// Rerank model
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
// Embedding model list
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
// Index type state
const [indexType, setIndexType] = useState<IndexingType>(() => {
if (initialIndexType)
return initialIndexType
return isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL
})
// Embedding model state
const [embeddingModel, setEmbeddingModel] = useState<DefaultModel>(
initialEmbeddingModel ?? {
provider: defaultEmbeddingModel?.provider.provider || '',
model: defaultEmbeddingModel?.model || '',
},
)
// Retrieval config state
const [retrievalConfig, setRetrievalConfig] = useState<RetrievalConfig>(
initialRetrievalConfig ?? DEFAULT_RETRIEVAL_CONFIG,
)
// Sync retrieval config with rerank model when available
useEffect(() => {
if (initialRetrievalConfig)
return
setRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: !!isRerankDefaultModelValid,
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
})
}, [rerankDefaultModel, isRerankDefaultModelValid, initialRetrievalConfig])
// Sync index type with props
useEffect(() => {
if (initialIndexType)
setIndexType(initialIndexType)
else
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, initialIndexType])
// Show multimodal tip
const showMultiModalTip = useMemo(() => {
return checkShowMultiModalTip({
embeddingModel,
rerankingEnable: retrievalConfig.reranking_enable,
rerankModel: {
rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name,
rerankingModelName: retrievalConfig.reranking_model.reranking_model_name,
},
indexMethod: indexType,
embeddingModelList,
rerankModelList,
})
}, [embeddingModel, retrievalConfig, indexType, embeddingModelList, rerankModelList])
// Get effective indexing technique
const getIndexingTechnique = () => initialIndexType || indexType
return {
// Index type
indexType,
setIndexType,
hasSetIndexType,
getIndexingTechnique,
// Embedding model
embeddingModel,
setEmbeddingModel,
embeddingModelList,
defaultEmbeddingModel,
// Retrieval config
retrievalConfig,
setRetrievalConfig,
rerankModelList,
rerankDefaultModel,
isRerankDefaultModelValid,
// Computed
showMultiModalTip,
}
}
export type IndexingConfig = ReturnType<typeof useIndexingConfig>

View File

@@ -1,123 +0,0 @@
import type { IndexingType } from './use-indexing-config'
import type { NotionPage } from '@/models/common'
import type { ChunkingMode, CrawlOptions, CrawlResultItem, CustomFile, ProcessRule } from '@/models/datasets'
import { useCallback } from 'react'
import { DataSourceProvider } from '@/models/common'
import { DataSourceType } from '@/models/datasets'
import {
useFetchFileIndexingEstimateForFile,
useFetchFileIndexingEstimateForNotion,
useFetchFileIndexingEstimateForWeb,
} from '@/service/knowledge/use-create-dataset'
export type UseIndexingEstimateOptions = {
dataSourceType: DataSourceType
datasetId?: string
// Document settings
currentDocForm: ChunkingMode
docLanguage: string
// File data source
files: CustomFile[]
previewFileName?: string
// Notion data source
previewNotionPage: NotionPage
notionCredentialId: string
// Website data source
previewWebsitePage: CrawlResultItem
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
// Processing
indexingTechnique: IndexingType
processRule: ProcessRule
}
export const useIndexingEstimate = (options: UseIndexingEstimateOptions) => {
const {
dataSourceType,
datasetId,
currentDocForm,
docLanguage,
files,
previewFileName,
previewNotionPage,
notionCredentialId,
previewWebsitePage,
crawlOptions,
websiteCrawlProvider,
websiteCrawlJobId,
indexingTechnique,
processRule,
} = options
// File indexing estimate
const fileQuery = useFetchFileIndexingEstimateForFile({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.FILE,
files: previewFileName
? [files.find(file => file.name === previewFileName)!]
: files,
indexingTechnique,
processRule,
dataset_id: datasetId!,
})
// Notion indexing estimate
const notionQuery = useFetchFileIndexingEstimateForNotion({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.NOTION,
notionPages: [previewNotionPage],
indexingTechnique,
processRule,
dataset_id: datasetId || '',
credential_id: notionCredentialId,
})
// Website indexing estimate
const websiteQuery = useFetchFileIndexingEstimateForWeb({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.WEB,
websitePages: [previewWebsitePage],
crawlOptions,
websiteCrawlProvider: websiteCrawlProvider ?? DataSourceProvider.jinaReader,
websiteCrawlJobId: websiteCrawlJobId ?? '',
indexingTechnique,
processRule,
dataset_id: datasetId || '',
})
// Get current mutation based on data source type
const getCurrentMutation = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
return fileQuery
if (dataSourceType === DataSourceType.NOTION)
return notionQuery
return websiteQuery
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
const currentMutation = getCurrentMutation()
// Trigger estimate fetch
const fetchEstimate = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
fileQuery.mutate()
else if (dataSourceType === DataSourceType.NOTION)
notionQuery.mutate()
else
websiteQuery.mutate()
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
return {
currentMutation,
estimate: currentMutation.data,
isIdle: currentMutation.isIdle,
isPending: currentMutation.isPending,
fetchEstimate,
reset: currentMutation.reset,
}
}
export type IndexingEstimate = ReturnType<typeof useIndexingEstimate>

View File

@@ -1,127 +0,0 @@
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile, DocumentItem, FullDocumentDetail } from '@/models/datasets'
import { useCallback, useState } from 'react'
import { DataSourceType } from '@/models/datasets'
export type UsePreviewStateOptions = {
dataSourceType: DataSourceType
files: CustomFile[]
notionPages: NotionPage[]
websitePages: CrawlResultItem[]
documentDetail?: FullDocumentDetail
datasetId?: string
}
export const usePreviewState = (options: UsePreviewStateOptions) => {
const {
dataSourceType,
files,
notionPages,
websitePages,
documentDetail,
datasetId,
} = options
// File preview state
const [previewFile, setPreviewFile] = useState<DocumentItem>(
(datasetId && documentDetail)
? documentDetail.file
: files[0],
)
// Notion page preview state
const [previewNotionPage, setPreviewNotionPage] = useState<NotionPage>(
(datasetId && documentDetail)
? documentDetail.notion_page
: notionPages[0],
)
// Website page preview state
const [previewWebsitePage, setPreviewWebsitePage] = useState<CrawlResultItem>(
(datasetId && documentDetail)
? documentDetail.website_page
: websitePages[0],
)
// Get preview items for document picker based on data source type
const getPreviewPickerItems = useCallback(() => {
if (dataSourceType === DataSourceType.FILE) {
return files as Array<Required<CustomFile>>
}
if (dataSourceType === DataSourceType.NOTION) {
return notionPages.map(page => ({
id: page.page_id,
name: page.page_name,
extension: 'md',
}))
}
if (dataSourceType === DataSourceType.WEB) {
return websitePages.map(page => ({
id: page.source_url,
name: page.title,
extension: 'md',
}))
}
return []
}, [dataSourceType, files, notionPages, websitePages])
// Get current preview value for picker
const getPreviewPickerValue = useCallback(() => {
if (dataSourceType === DataSourceType.FILE) {
return previewFile as Required<CustomFile>
}
if (dataSourceType === DataSourceType.NOTION) {
return {
id: previewNotionPage?.page_id || '',
name: previewNotionPage?.page_name || '',
extension: 'md',
}
}
if (dataSourceType === DataSourceType.WEB) {
return {
id: previewWebsitePage?.source_url || '',
name: previewWebsitePage?.title || '',
extension: 'md',
}
}
return { id: '', name: '', extension: '' }
}, [dataSourceType, previewFile, previewNotionPage, previewWebsitePage])
// Handle preview change
const handlePreviewChange = useCallback((selected: { id: string, name: string }) => {
if (dataSourceType === DataSourceType.FILE) {
setPreviewFile(selected as DocumentItem)
}
else if (dataSourceType === DataSourceType.NOTION) {
const selectedPage = notionPages.find(page => page.page_id === selected.id)
if (selectedPage)
setPreviewNotionPage(selectedPage)
}
else if (dataSourceType === DataSourceType.WEB) {
const selectedPage = websitePages.find(page => page.source_url === selected.id)
if (selectedPage)
setPreviewWebsitePage(selectedPage)
}
}, [dataSourceType, notionPages, websitePages])
return {
// File preview
previewFile,
setPreviewFile,
// Notion preview
previewNotionPage,
setPreviewNotionPage,
// Website preview
previewWebsitePage,
setPreviewWebsitePage,
// Picker helpers
getPreviewPickerItems,
getPreviewPickerValue,
handlePreviewChange,
}
}
export type PreviewState = ReturnType<typeof usePreviewState>

View File

@@ -1,222 +0,0 @@
import type { ParentMode, PreProcessingRule, ProcessRule, Rules } from '@/models/datasets'
import { useCallback, useState } from 'react'
import { ChunkingMode, ProcessMode } from '@/models/datasets'
import escape from './escape'
import unescape from './unescape'
// Constants
export const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
export const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024
export const DEFAULT_OVERLAP = 50
export const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(
globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000',
10,
)
export type ParentChildConfig = {
chunkForContext: ParentMode
parent: {
delimiter: string
maxLength: number
}
child: {
delimiter: string
maxLength: number
}
}
export const defaultParentChildConfig: ParentChildConfig = {
chunkForContext: 'paragraph',
parent: {
delimiter: '\\n\\n',
maxLength: 1024,
},
child: {
delimiter: '\\n',
maxLength: 512,
},
}
export type UseSegmentationStateOptions = {
initialSegmentationType?: ProcessMode
}
export const useSegmentationState = (options: UseSegmentationStateOptions = {}) => {
const { initialSegmentationType } = options
// Segmentation type (general or parent-child)
const [segmentationType, setSegmentationType] = useState<ProcessMode>(
initialSegmentationType ?? ProcessMode.general,
)
// General chunking settings
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH)
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(MAXIMUM_CHUNK_TOKEN_LENGTH)
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
// Pre-processing rules
const [rules, setRules] = useState<PreProcessingRule[]>([])
const [defaultConfig, setDefaultConfig] = useState<Rules>()
// Parent-child config
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
// Escaped segment identifier setter
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
if (value) {
doSetSegmentIdentifier(escape(value))
}
else {
doSetSegmentIdentifier(canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)
}
}, [])
// Rule toggle handler
const toggleRule = useCallback((id: string) => {
setRules(prev => prev.map(rule =>
rule.id === id ? { ...rule, enabled: !rule.enabled } : rule,
))
}, [])
// Reset to defaults
const resetToDefaults = useCallback(() => {
if (defaultConfig) {
setSegmentIdentifier(defaultConfig.segmentation.separator)
setMaxChunkLength(defaultConfig.segmentation.max_tokens)
setOverlap(defaultConfig.segmentation.chunk_overlap!)
setRules(defaultConfig.pre_processing_rules)
}
setParentChildConfig(defaultParentChildConfig)
}, [defaultConfig, setSegmentIdentifier])
// Apply config from document detail
const applyConfigFromRules = useCallback((rulesConfig: Rules, isHierarchical: boolean) => {
const separator = rulesConfig.segmentation.separator
const max = rulesConfig.segmentation.max_tokens
const chunkOverlap = rulesConfig.segmentation.chunk_overlap
setSegmentIdentifier(separator)
setMaxChunkLength(max)
setOverlap(chunkOverlap!)
setRules(rulesConfig.pre_processing_rules)
setDefaultConfig(rulesConfig)
if (isHierarchical) {
setParentChildConfig({
chunkForContext: rulesConfig.parent_mode || 'paragraph',
parent: {
delimiter: escape(rulesConfig.segmentation.separator),
maxLength: rulesConfig.segmentation.max_tokens,
},
child: {
delimiter: escape(rulesConfig.subchunk_segmentation!.separator),
maxLength: rulesConfig.subchunk_segmentation!.max_tokens,
},
})
}
}, [setSegmentIdentifier])
// Get process rule for API
const getProcessRule = useCallback((docForm: ChunkingMode): ProcessRule => {
if (docForm === ChunkingMode.parentChild) {
return {
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(parentChildConfig.parent.delimiter),
max_tokens: parentChildConfig.parent.maxLength,
},
parent_mode: parentChildConfig.chunkForContext,
subchunk_segmentation: {
separator: unescape(parentChildConfig.child.delimiter),
max_tokens: parentChildConfig.child.maxLength,
},
},
mode: 'hierarchical',
} as ProcessRule
}
return {
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(segmentIdentifier),
max_tokens: maxChunkLength,
chunk_overlap: overlap,
},
},
mode: segmentationType,
} as ProcessRule
}, [rules, parentChildConfig, segmentIdentifier, maxChunkLength, overlap, segmentationType])
// Update parent config field
const updateParentConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
setParentChildConfig((prev) => {
let newValue: string | number
if (field === 'delimiter')
newValue = value ? escape(value as string) : ''
else
newValue = value
return {
...prev,
parent: { ...prev.parent, [field]: newValue },
}
})
}, [])
// Update child config field
const updateChildConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
setParentChildConfig((prev) => {
let newValue: string | number
if (field === 'delimiter')
newValue = value ? escape(value as string) : ''
else
newValue = value
return {
...prev,
child: { ...prev.child, [field]: newValue },
}
})
}, [])
// Set chunk for context mode
const setChunkForContext = useCallback((mode: ParentMode) => {
setParentChildConfig(prev => ({ ...prev, chunkForContext: mode }))
}, [])
return {
// General chunking state
segmentationType,
setSegmentationType,
segmentIdentifier,
setSegmentIdentifier,
maxChunkLength,
setMaxChunkLength,
limitMaxChunkLength,
setLimitMaxChunkLength,
overlap,
setOverlap,
// Rules
rules,
setRules,
defaultConfig,
setDefaultConfig,
toggleRule,
// Parent-child config
parentChildConfig,
setParentChildConfig,
updateParentConfig,
updateChildConfig,
setChunkForContext,
// Actions
resetToDefaults,
applyConfigFromRules,
getProcessRule,
}
}
export type SegmentationState = ReturnType<typeof useSegmentationState>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,28 +0,0 @@
import type { IndexingType } from './hooks'
import type { DataSourceProvider, NotionPage } from '@/models/common'
import type { CrawlOptions, CrawlResultItem, createDocumentResponse, CustomFile, DataSourceType, FullDocumentDetail } from '@/models/datasets'
import type { RETRIEVE_METHOD } from '@/types/app'
export type StepTwoProps = {
isSetting?: boolean
documentDetail?: FullDocumentDetail
isAPIKeySet: boolean
onSetting: () => void
datasetId?: string
indexingType?: IndexingType
retrievalMethod?: string
dataSourceType: DataSourceType
files: CustomFile[]
notionPages?: NotionPage[]
notionCredentialId: string
websitePages?: CrawlResultItem[]
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
onStepChange?: (delta: number) => void
updateIndexingTypeCache?: (type: string) => void
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
updateResultCache?: (res: createDocumentResponse) => void
onSave?: () => void
onCancel?: () => void
}

View File

@@ -2,6 +2,7 @@ import type { FC } from 'react'
import {
RiFileTextLine,
RiFilmAiLine,
RiHammerLine,
RiImageCircleAiLine,
RiVoiceAiFill,
} from '@remixicon/react'
@@ -38,17 +39,33 @@ const FeatureIcon: FC<FeatureIconProps> = ({
// )
// }
// if (feature === ModelFeatureEnum.toolCall) {
// return (
// <Tooltip
// popupContent={t('common.modelProvider.featureSupported', { feature: ModelFeatureTextEnum.toolCall })}
// >
// <ModelBadge className={`mr-0.5 !px-0 w-[18px] justify-center text-gray-500 ${className}`}>
// <MagicWand className='w-3 h-3' />
// </ModelBadge>
// </Tooltip>
// )
// }
if (feature === ModelFeatureEnum.toolCall) {
if (showFeaturesLabel) {
return (
<ModelBadge className={cn('gap-x-0.5', className)}>
<RiHammerLine className="size-3" />
<span>{ModelFeatureTextEnum.toolCall}</span>
</ModelBadge>
)
}
return (
<Tooltip
popupContent={t('modelProvider.featureSupported', { ns: 'common', feature: ModelFeatureTextEnum.toolCall })}
>
<div className="inline-block cursor-help">
<ModelBadge
className={cn(
'w-[18px] justify-center !px-0',
className,
)}
>
<RiHammerLine className="size-3" />
</ModelBadge>
</div>
</Tooltip>
)
}
// if (feature === ModelFeatureEnum.multiToolCall) {
// return (

View File

@@ -96,6 +96,14 @@ const PopupItem: FC<PopupItemProps> = ({
<div className='text-text-tertiary system-xs-regular'>{currentProvider?.description?.[language] || currentProvider?.description?.en_US}</div>
)} */}
<div className="flex flex-wrap gap-1">
{
modelItem.features?.includes(ModelFeatureEnum.toolCall) && (
<FeatureIcon
feature={ModelFeatureEnum.toolCall}
showFeaturesLabel
/>
)
}
{modelItem.model_type && (
<ModelBadge>
{modelTypeFormat(modelItem.model_type)}
@@ -118,7 +126,7 @@ const PopupItem: FC<PopupItemProps> = ({
<div className="pt-2">
<div className="system-2xs-medium-uppercase mb-1 text-text-tertiary">{t('model.capabilities', { ns: 'common' })}</div>
<div className="flex flex-wrap gap-1">
{modelItem.features?.map(feature => (
{modelItem.features?.filter(feature => feature !== ModelFeatureEnum.toolCall).map(feature => (
<FeatureIcon
key={feature}
feature={feature}

View File

@@ -1,81 +0,0 @@
import type { ActivePluginType } from './constants'
import type { PluginsSort, SearchParamsFromCollection } from './types'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useQueryState } from 'nuqs'
import { useCallback } from 'react'
import { DEFAULT_SORT, PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceSearchParamsParsers } from './search-params'
const marketplaceSortAtom = atom<PluginsSort>(DEFAULT_SORT)
export function useMarketplaceSort() {
return useAtom(marketplaceSortAtom)
}
export function useMarketplaceSortValue() {
return useAtomValue(marketplaceSortAtom)
}
export function useSetMarketplaceSort() {
return useSetAtom(marketplaceSortAtom)
}
/**
* Preserve the state for marketplace
*/
export const preserveSearchStateInQueryAtom = atom<boolean>(false)
const searchPluginTextAtom = atom<string>('')
const activePluginTypeAtom = atom<ActivePluginType>('all')
const filterPluginTagsAtom = atom<string[]>([])
export function useSearchPluginText() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('q', marketplaceSearchParamsParsers.q)
const atomState = useAtom(searchPluginTextAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
export function useActivePluginType() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('category', marketplaceSearchParamsParsers.category)
const atomState = useAtom(activePluginTypeAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
export function useFilterPluginTags() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('tags', marketplaceSearchParamsParsers.tags)
const atomState = useAtom(filterPluginTagsAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
/**
* Not all categories have collections, so we need to
* force the search mode for those categories.
*/
export const searchModeAtom = atom<true | null>(null)
export function useMarketplaceSearchMode() {
const [searchPluginText] = useSearchPluginText()
const [filterPluginTags] = useFilterPluginTags()
const [activePluginType] = useActivePluginType()
const searchMode = useAtomValue(searchModeAtom)
const isSearchMode = !!searchPluginText
|| filterPluginTags.length > 0
|| (searchMode ?? (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(activePluginType)))
return isSearchMode
}
export function useMarketplaceMoreClick() {
const [,setQ] = useSearchPluginText()
const setSort = useSetAtom(marketplaceSortAtom)
const setSearchMode = useSetAtom(searchModeAtom)
return useCallback((searchParams?: SearchParamsFromCollection) => {
if (!searchParams)
return
setQ(searchParams?.query || '')
setSort({
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
})
setSearchMode(true)
}, [setQ, setSort, setSearchMode])
}

View File

@@ -1,30 +1,6 @@
import { PluginCategoryEnum } from '../types'
export const DEFAULT_SORT = {
sortBy: 'install_count',
sortOrder: 'DESC',
}
export const SCROLL_BOTTOM_THRESHOLD = 100
export const PLUGIN_TYPE_SEARCH_MAP = {
all: 'all',
model: PluginCategoryEnum.model,
tool: PluginCategoryEnum.tool,
agent: PluginCategoryEnum.agent,
extension: PluginCategoryEnum.extension,
datasource: PluginCategoryEnum.datasource,
trigger: PluginCategoryEnum.trigger,
bundle: 'bundle',
} as const
type ValueOf<T> = T[keyof T]
export type ActivePluginType = ValueOf<typeof PLUGIN_TYPE_SEARCH_MAP>
export const PLUGIN_CATEGORY_WITH_COLLECTIONS = new Set<ActivePluginType>(
[
PLUGIN_TYPE_SEARCH_MAP.all,
PLUGIN_TYPE_SEARCH_MAP.tool,
],
)

View File

@@ -0,0 +1,332 @@
'use client'
import type {
ReactNode,
} from 'react'
import type { TagKey } from '../constants'
import type { Plugin } from '../types'
import type {
MarketplaceCollection,
PluginsSort,
SearchParams,
SearchParamsFromCollection,
} from './types'
import { debounce } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import {
createContext,
useContextSelector,
} from 'use-context-selector'
import { useMarketplaceFilters } from '@/hooks/use-query-params'
import { useInstalledPluginList } from '@/service/use-plugins'
import {
getValidCategoryKeys,
getValidTagKeys,
} from '../utils'
import { DEFAULT_SORT } from './constants'
import {
useMarketplaceCollectionsAndPlugins,
useMarketplaceContainerScroll,
useMarketplacePlugins,
} from './hooks'
import { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch'
import {
getMarketplaceListCondition,
getMarketplaceListFilterType,
} from './utils'
export type MarketplaceContextValue = {
searchPluginText: string
handleSearchPluginTextChange: (text: string) => void
filterPluginTags: string[]
handleFilterPluginTagsChange: (tags: string[]) => void
activePluginType: string
handleActivePluginTypeChange: (type: string) => void
page: number
handlePageChange: () => void
plugins?: Plugin[]
pluginsTotal?: number
resetPlugins: () => void
sort: PluginsSort
handleSortChange: (sort: PluginsSort) => void
handleQueryPlugins: () => void
handleMoreClick: (searchParams: SearchParamsFromCollection) => void
marketplaceCollectionsFromClient?: MarketplaceCollection[]
setMarketplaceCollectionsFromClient: (collections: MarketplaceCollection[]) => void
marketplaceCollectionPluginsMapFromClient?: Record<string, Plugin[]>
setMarketplaceCollectionPluginsMapFromClient: (map: Record<string, Plugin[]>) => void
isLoading: boolean
isSuccessCollections: boolean
}
export const MarketplaceContext = createContext<MarketplaceContextValue>({
searchPluginText: '',
handleSearchPluginTextChange: noop,
filterPluginTags: [],
handleFilterPluginTagsChange: noop,
activePluginType: 'all',
handleActivePluginTypeChange: noop,
page: 1,
handlePageChange: noop,
plugins: undefined,
pluginsTotal: 0,
resetPlugins: noop,
sort: DEFAULT_SORT,
handleSortChange: noop,
handleQueryPlugins: noop,
handleMoreClick: noop,
marketplaceCollectionsFromClient: [],
setMarketplaceCollectionsFromClient: noop,
marketplaceCollectionPluginsMapFromClient: {},
setMarketplaceCollectionPluginsMapFromClient: noop,
isLoading: false,
isSuccessCollections: false,
})
type MarketplaceContextProviderProps = {
children: ReactNode
searchParams?: SearchParams
shouldExclude?: boolean
scrollContainerId?: string
showSearchParams?: boolean
}
export function useMarketplaceContext(selector: (value: MarketplaceContextValue) => any) {
return useContextSelector(MarketplaceContext, selector)
}
export const MarketplaceContextProvider = ({
children,
searchParams,
shouldExclude,
scrollContainerId,
showSearchParams,
}: MarketplaceContextProviderProps) => {
// Use nuqs hook for URL-based filter state
const [urlFilters, setUrlFilters] = useMarketplaceFilters()
const { data, isSuccess } = useInstalledPluginList(!shouldExclude)
const exclude = useMemo(() => {
if (shouldExclude)
return data?.plugins.map(plugin => plugin.plugin_id)
}, [data?.plugins, shouldExclude])
// Initialize from URL params (legacy support) or use nuqs state
const queryFromSearchParams = searchParams?.q || urlFilters.q
const tagsFromSearchParams = getValidTagKeys(urlFilters.tags as TagKey[])
const hasValidTags = !!tagsFromSearchParams.length
const hasValidCategory = getValidCategoryKeys(urlFilters.category)
const categoryFromSearchParams = hasValidCategory || PLUGIN_TYPE_SEARCH_MAP.all
const [searchPluginText, setSearchPluginText] = useState(queryFromSearchParams)
const searchPluginTextRef = useRef(searchPluginText)
const [filterPluginTags, setFilterPluginTags] = useState<string[]>(tagsFromSearchParams)
const filterPluginTagsRef = useRef(filterPluginTags)
const [activePluginType, setActivePluginType] = useState(categoryFromSearchParams)
const activePluginTypeRef = useRef(activePluginType)
const [sort, setSort] = useState(DEFAULT_SORT)
const sortRef = useRef(sort)
const {
marketplaceCollections: marketplaceCollectionsFromClient,
setMarketplaceCollections: setMarketplaceCollectionsFromClient,
marketplaceCollectionPluginsMap: marketplaceCollectionPluginsMapFromClient,
setMarketplaceCollectionPluginsMap: setMarketplaceCollectionPluginsMapFromClient,
queryMarketplaceCollectionsAndPlugins,
isLoading,
isSuccess: isSuccessCollections,
} = useMarketplaceCollectionsAndPlugins()
const {
plugins,
total: pluginsTotal,
resetPlugins,
queryPlugins,
queryPluginsWithDebounced,
cancelQueryPluginsWithDebounced,
isLoading: isPluginsLoading,
fetchNextPage: fetchNextPluginsPage,
hasNextPage: hasNextPluginsPage,
page: pluginsPage,
} = useMarketplacePlugins()
const page = Math.max(pluginsPage || 0, 1)
useEffect(() => {
if (queryFromSearchParams || hasValidTags || hasValidCategory) {
queryPlugins({
query: queryFromSearchParams,
category: hasValidCategory,
tags: hasValidTags ? tagsFromSearchParams : [],
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
else {
if (shouldExclude && isSuccess) {
queryMarketplaceCollectionsAndPlugins({
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
}
}, [queryPlugins, queryMarketplaceCollectionsAndPlugins, isSuccess, exclude])
const handleQueryMarketplaceCollectionsAndPlugins = useCallback(() => {
queryMarketplaceCollectionsAndPlugins({
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
condition: getMarketplaceListCondition(activePluginTypeRef.current),
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
resetPlugins()
}, [exclude, queryMarketplaceCollectionsAndPlugins, resetPlugins])
const applyUrlFilters = useCallback(() => {
if (!showSearchParams)
return
const nextFilters = {
q: searchPluginTextRef.current,
category: activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
}
const categoryChanged = urlFilters.category !== nextFilters.category
setUrlFilters(nextFilters, {
history: categoryChanged ? 'push' : 'replace',
})
}, [setUrlFilters, showSearchParams, urlFilters.category])
const debouncedUpdateSearchParams = useMemo(() => debounce(() => {
applyUrlFilters()
}, 500), [applyUrlFilters])
const handleUpdateSearchParams = useCallback((debounced?: boolean) => {
if (debounced) {
debouncedUpdateSearchParams()
}
else {
applyUrlFilters()
}
}, [applyUrlFilters, debouncedUpdateSearchParams])
const handleQueryPlugins = useCallback((debounced?: boolean) => {
handleUpdateSearchParams(debounced)
if (debounced) {
queryPluginsWithDebounced({
query: searchPluginTextRef.current,
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
else {
queryPlugins({
query: searchPluginTextRef.current,
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
}, [exclude, queryPluginsWithDebounced, queryPlugins, handleUpdateSearchParams])
const handleQuery = useCallback((debounced?: boolean) => {
if (!searchPluginTextRef.current && !filterPluginTagsRef.current.length) {
handleUpdateSearchParams(debounced)
cancelQueryPluginsWithDebounced()
handleQueryMarketplaceCollectionsAndPlugins()
return
}
handleQueryPlugins(debounced)
}, [handleQueryMarketplaceCollectionsAndPlugins, handleQueryPlugins, cancelQueryPluginsWithDebounced, handleUpdateSearchParams])
const handleSearchPluginTextChange = useCallback((text: string) => {
setSearchPluginText(text)
searchPluginTextRef.current = text
handleQuery(true)
}, [handleQuery])
const handleFilterPluginTagsChange = useCallback((tags: string[]) => {
setFilterPluginTags(tags)
filterPluginTagsRef.current = tags
handleQuery()
}, [handleQuery])
const handleActivePluginTypeChange = useCallback((type: string) => {
setActivePluginType(type)
activePluginTypeRef.current = type
handleQuery()
}, [handleQuery])
const handleSortChange = useCallback((sort: PluginsSort) => {
setSort(sort)
sortRef.current = sort
handleQueryPlugins()
}, [handleQueryPlugins])
const handlePageChange = useCallback(() => {
if (hasNextPluginsPage)
fetchNextPluginsPage()
}, [fetchNextPluginsPage, hasNextPluginsPage])
const handleMoreClick = useCallback((searchParams: SearchParamsFromCollection) => {
setSearchPluginText(searchParams?.query || '')
searchPluginTextRef.current = searchParams?.query || ''
setSort({
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
})
sortRef.current = {
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
}
handleQueryPlugins()
}, [handleQueryPlugins])
useMarketplaceContainerScroll(handlePageChange, scrollContainerId)
return (
<MarketplaceContext.Provider
value={{
searchPluginText,
handleSearchPluginTextChange,
filterPluginTags,
handleFilterPluginTagsChange,
activePluginType,
handleActivePluginTypeChange,
page,
handlePageChange,
plugins,
pluginsTotal,
resetPlugins,
sort,
handleSortChange,
handleQueryPlugins,
handleMoreClick,
marketplaceCollectionsFromClient,
setMarketplaceCollectionsFromClient,
marketplaceCollectionPluginsMapFromClient,
setMarketplaceCollectionPluginsMapFromClient,
isLoading: isLoading || isPluginsLoading,
isSuccessCollections,
}}
>
{children}
</MarketplaceContext.Provider>
)
}

View File

@@ -26,9 +26,6 @@ import {
getMarketplacePluginsByCollectionId,
} from './utils'
/**
* @deprecated Use useMarketplaceCollectionsAndPlugins from query.ts instead
*/
export const useMarketplaceCollectionsAndPlugins = () => {
const [queryParams, setQueryParams] = useState<CollectionsAndPluginsSearchParams>()
const [marketplaceCollectionsOverride, setMarketplaceCollections] = useState<MarketplaceCollection[]>()
@@ -92,9 +89,7 @@ export const useMarketplacePluginsByCollectionId = (
isSuccess,
}
}
/**
* @deprecated Use useMarketplacePlugins from query.ts instead
*/
export const useMarketplacePlugins = () => {
const queryClient = useQueryClient()
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()

View File

@@ -1,15 +0,0 @@
'use client'
import { useHydrateAtoms } from 'jotai/utils'
import { preserveSearchStateInQueryAtom } from './atoms'
export function HydrateMarketplaceAtoms({
preserveSearchStateInQuery,
children,
}: {
preserveSearchStateInQuery: boolean
children: React.ReactNode
}) {
useHydrateAtoms([[preserveSearchStateInQueryAtom, preserveSearchStateInQuery]])
return <>{children}</>
}

View File

@@ -1,45 +0,0 @@
import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
// The server side logic should move to marketplace's codebase so that we can get rid of Next.js
async function getDehydratedState(searchParams?: Promise<SearchParams>) {
if (!searchParams) {
return
}
const loadSearchParams = createLoader(marketplaceSearchParamsParsers)
const params = await loadSearchParams(searchParams)
if (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(params.category)) {
return
}
const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
})
return dehydrate(queryClient)
}
export async function HydrateQueryClient({
searchParams,
children,
}: {
searchParams: Promise<SearchParams> | undefined
children: React.ReactNode
}) {
const dehydratedState = await getDehydratedState(searchParams)
return (
<HydrationBoundary state={dehydratedState}>
{children}
</HydrationBoundary>
)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,39 +1,55 @@
import type { SearchParams } from 'nuqs'
import type { MarketplaceCollection, SearchParams } from './types'
import type { Plugin } from '@/app/components/plugins/types'
import { TanstackQueryInitializer } from '@/context/query-client'
import { MarketplaceContextProvider } from './context'
import Description from './description'
import { HydrateMarketplaceAtoms } from './hydration-client'
import { HydrateQueryClient } from './hydration-server'
import ListWrapper from './list/list-wrapper'
import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper'
import { getMarketplaceCollectionsAndPlugins } from './utils'
type MarketplaceProps = {
showInstallButton?: boolean
shouldExclude?: boolean
searchParams?: SearchParams
pluginTypeSwitchClassName?: string
/**
* Pass the search params from the request to prefetch data on the server
* and preserve the search params in the URL.
*/
searchParams?: Promise<SearchParams>
scrollContainerId?: string
showSearchParams?: boolean
}
const Marketplace = async ({
showInstallButton = true,
pluginTypeSwitchClassName,
shouldExclude,
searchParams,
pluginTypeSwitchClassName,
scrollContainerId,
showSearchParams = true,
}: MarketplaceProps) => {
let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
if (!shouldExclude) {
const marketplaceCollectionsAndPluginsData = await getMarketplaceCollectionsAndPlugins()
marketplaceCollections = marketplaceCollectionsAndPluginsData.marketplaceCollections
marketplaceCollectionPluginsMap = marketplaceCollectionsAndPluginsData.marketplaceCollectionPluginsMap
}
return (
<TanstackQueryInitializer>
<HydrateQueryClient searchParams={searchParams}>
<HydrateMarketplaceAtoms preserveSearchStateInQuery={!!searchParams}>
<Description />
<StickySearchAndSwitchWrapper
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
/>
<ListWrapper
showInstallButton={showInstallButton}
/>
</HydrateMarketplaceAtoms>
</HydrateQueryClient>
<MarketplaceContextProvider
searchParams={searchParams}
shouldExclude={shouldExclude}
scrollContainerId={scrollContainerId}
showSearchParams={showSearchParams}
>
<Description />
<StickySearchAndSwitchWrapper
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
showSearchParams={showSearchParams}
/>
<ListWrapper
marketplaceCollections={marketplaceCollections}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap}
showInstallButton={showInstallButton}
/>
</MarketplaceContextProvider>
</TanstackQueryInitializer>
)
}

View File

@@ -1,6 +1,6 @@
import type { MarketplaceCollection, SearchParamsFromCollection } from '../types'
import type { Plugin } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import List from './index'
@@ -30,27 +30,23 @@ vi.mock('#i18n', () => ({
useLocale: () => 'en-US',
}))
// Mock marketplace state hooks with controllable values
const { mockMarketplaceData, mockMoreClick } = vi.hoisted(() => {
return {
mockMarketplaceData: {
plugins: undefined as Plugin[] | undefined,
pluginsTotal: 0,
marketplaceCollections: undefined as MarketplaceCollection[] | undefined,
marketplaceCollectionPluginsMap: undefined as Record<string, Plugin[]> | undefined,
isLoading: false,
page: 1,
},
mockMoreClick: vi.fn(),
}
})
// Mock useMarketplaceContext with controllable values
const mockContextValues = {
plugins: undefined as Plugin[] | undefined,
pluginsTotal: 0,
marketplaceCollectionsFromClient: undefined as MarketplaceCollection[] | undefined,
marketplaceCollectionPluginsMapFromClient: undefined as Record<string, Plugin[]> | undefined,
isLoading: false,
isSuccessCollections: false,
handleQueryPlugins: vi.fn(),
searchPluginText: '',
filterPluginTags: [] as string[],
page: 1,
handleMoreClick: vi.fn(),
}
vi.mock('../state', () => ({
useMarketplaceData: () => mockMarketplaceData,
}))
vi.mock('../atoms', () => ({
useMarketplaceMoreClick: () => mockMoreClick,
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
}))
// Mock useLocale context
@@ -582,7 +578,7 @@ describe('ListWithCollection', () => {
// View More Button Tests
// ================================
describe('View More Button', () => {
it('should render View More button when collection is searchable', () => {
it('should render View More button when collection is searchable and onMoreClick is provided', () => {
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
@@ -591,12 +587,14 @@ describe('ListWithCollection', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
@@ -611,24 +609,24 @@ describe('ListWithCollection', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
expect(screen.queryByText('View More')).not.toBeInTheDocument()
})
it('should call moreClick hook with search_params when View More is clicked', () => {
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
it('should not render View More button when onMoreClick is not provided', () => {
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: searchParams,
})]
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
@@ -639,13 +637,38 @@ describe('ListWithCollection', () => {
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={undefined}
/>,
)
expect(screen.queryByText('View More')).not.toBeInTheDocument()
})
it('should call onMoreClick with search_params when View More is clicked', () => {
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: searchParams,
})]
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
fireEvent.click(screen.getByText('View More'))
expect(mockMoreClick).toHaveBeenCalledTimes(1)
expect(mockMoreClick).toHaveBeenCalledWith(searchParams)
expect(onMoreClick).toHaveBeenCalledTimes(1)
expect(onMoreClick).toHaveBeenCalledWith(searchParams)
})
})
@@ -779,15 +802,24 @@ describe('ListWithCollection', () => {
// ListWrapper Component Tests
// ================================
describe('ListWrapper', () => {
const defaultProps = {
marketplaceCollections: [] as MarketplaceCollection[],
marketplaceCollectionPluginsMap: {} as Record<string, Plugin[]>,
showInstallButton: false,
}
beforeEach(() => {
vi.clearAllMocks()
// Reset mock data
mockMarketplaceData.plugins = undefined
mockMarketplaceData.pluginsTotal = 0
mockMarketplaceData.marketplaceCollections = undefined
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
// Reset context values
mockContextValues.plugins = undefined
mockContextValues.pluginsTotal = 0
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
mockContextValues.isLoading = false
mockContextValues.isSuccessCollections = false
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
mockContextValues.page = 1
})
// ================================
@@ -795,32 +827,32 @@ describe('ListWrapper', () => {
// ================================
describe('Rendering', () => {
it('should render without crashing', () => {
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(document.body).toBeInTheDocument()
})
it('should render with scrollbarGutter style', () => {
const { container } = render(<ListWrapper />)
const { container } = render(<ListWrapper {...defaultProps} />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveStyle({ scrollbarGutter: 'stable' })
})
it('should render Loading component when isLoading is true and page is 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 1
mockContextValues.isLoading = true
mockContextValues.page = 1
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
})
it('should not render Loading component when page > 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockContextValues.isLoading = true
mockContextValues.page = 2
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
})
@@ -831,26 +863,26 @@ describe('ListWrapper', () => {
// ================================
describe('Plugins Header', () => {
it('should render plugins result count when plugins are present', () => {
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 5
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 5
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
})
it('should render SortDropdown when plugins are present', () => {
mockMarketplaceData.plugins = createMockPluginList(1)
mockContextValues.plugins = createMockPluginList(1)
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('sort-dropdown')).toBeInTheDocument()
})
it('should not render plugins header when plugins is undefined', () => {
mockMarketplaceData.plugins = undefined
mockContextValues.plugins = undefined
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.queryByTestId('sort-dropdown')).not.toBeInTheDocument()
})
@@ -860,60 +892,197 @@ describe('ListWrapper', () => {
// List Rendering Logic Tests
// ================================
describe('List Rendering Logic', () => {
it('should render collections when not loading', () => {
mockMarketplaceData.isLoading = false
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
it('should render List when not loading', () => {
mockContextValues.isLoading = false
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should render List when loading but page > 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
mockContextValues.isLoading = true
mockContextValues.page = 2
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should use client collections when available', () => {
const serverCollections = createMockCollectionList(1)
serverCollections[0].label = { 'en-US': 'Server Collection' }
const clientCollections = createMockCollectionList(1)
clientCollections[0].label = { 'en-US': 'Client Collection' }
const serverPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const clientPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = clientCollections
mockContextValues.marketplaceCollectionPluginsMapFromClient = clientPluginsMap
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={serverCollections}
marketplaceCollectionPluginsMap={serverPluginsMap}
/>,
)
expect(screen.getByText('Client Collection')).toBeInTheDocument()
expect(screen.queryByText('Server Collection')).not.toBeInTheDocument()
})
it('should use server collections when client collections are not available', () => {
const serverCollections = createMockCollectionList(1)
serverCollections[0].label = { 'en-US': 'Server Collection' }
const serverPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={serverCollections}
marketplaceCollectionPluginsMap={serverPluginsMap}
/>,
)
expect(screen.getByText('Server Collection')).toBeInTheDocument()
})
})
// ================================
// Data Integration Tests
// Context Integration Tests
// ================================
describe('Data Integration', () => {
it('should pass plugins from state to List', () => {
mockMarketplaceData.plugins = createMockPluginList(2)
describe('Context Integration', () => {
it('should pass plugins from context to List', () => {
const plugins = createMockPluginList(2)
mockContextValues.plugins = plugins
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('card-plugin-0')).toBeInTheDocument()
expect(screen.getByTestId('card-plugin-1')).toBeInTheDocument()
})
it('should show View More button and call moreClick hook', () => {
mockMarketplaceData.marketplaceCollections = [createMockCollection({
it('should pass handleMoreClick from context to List', () => {
const mockHandleMoreClick = vi.fn()
mockContextValues.handleMoreClick = mockHandleMoreClick
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: { query: 'test' },
})]
mockMarketplaceData.marketplaceCollectionPluginsMap = {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
fireEvent.click(screen.getByText('View More'))
expect(mockMoreClick).toHaveBeenCalled()
expect(mockHandleMoreClick).toHaveBeenCalled()
})
})
// ================================
// Effect Tests (handleQueryPlugins)
// ================================
describe('handleQueryPlugins Effect', () => {
it('should call handleQueryPlugins when conditions are met', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when client collections exist', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = createMockCollectionList(1)
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
// Give time for effect to run
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when search text exists', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = 'search text'
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when filter tags exist', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = ['tag1']
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
})
@@ -921,32 +1090,32 @@ describe('ListWrapper', () => {
// Edge Cases Tests
// ================================
describe('Edge Cases', () => {
it('should handle empty plugins array', () => {
mockMarketplaceData.plugins = []
mockMarketplaceData.pluginsTotal = 0
it('should handle empty plugins array from context', () => {
mockContextValues.plugins = []
mockContextValues.pluginsTotal = 0
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
})
it('should handle large pluginsTotal', () => {
mockMarketplaceData.plugins = createMockPluginList(10)
mockMarketplaceData.pluginsTotal = 10000
mockContextValues.plugins = createMockPluginList(10)
mockContextValues.pluginsTotal = 10000
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('10000 plugins found')).toBeInTheDocument()
})
it('should handle both loading and has plugins', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 50
mockContextValues.isLoading = true
mockContextValues.page = 2
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 50
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
// Should show plugins header and list
expect(screen.getByText('50 plugins found')).toBeInTheDocument()
@@ -1259,72 +1428,106 @@ describe('CardWrapper (via List integration)', () => {
describe('Combined Workflows', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMarketplaceData.plugins = undefined
mockMarketplaceData.pluginsTotal = 0
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
mockMarketplaceData.marketplaceCollections = undefined
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
mockContextValues.plugins = undefined
mockContextValues.pluginsTotal = 0
mockContextValues.isLoading = false
mockContextValues.page = 1
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
})
it('should transition from loading to showing collections', async () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 1
mockContextValues.isLoading = true
mockContextValues.page = 1
const { rerender } = render(<ListWrapper />)
const { rerender } = render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
// Simulate loading complete
mockMarketplaceData.isLoading = false
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
mockContextValues.isLoading = false
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = collections
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
rerender(<ListWrapper />)
rerender(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should transition from collections to search results', async () => {
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = collections
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
const { rerender } = render(<ListWrapper />)
const { rerender } = render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
// Simulate search results
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 5
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 5
rerender(<ListWrapper />)
rerender(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.queryByText('Collection 0')).not.toBeInTheDocument()
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
})
it('should handle empty search results', () => {
mockMarketplaceData.plugins = []
mockMarketplaceData.pluginsTotal = 0
mockContextValues.plugins = []
mockContextValues.pluginsTotal = 0
render(<ListWrapper />)
render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
})
it('should support pagination (page > 1)', () => {
mockMarketplaceData.plugins = createMockPluginList(40)
mockMarketplaceData.pluginsTotal = 80
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockContextValues.plugins = createMockPluginList(40)
mockContextValues.pluginsTotal = 80
mockContextValues.isLoading = true
mockContextValues.page = 2
render(<ListWrapper />)
render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
// Should show existing results while loading more
expect(screen.getByText('80 plugins found')).toBeInTheDocument()
@@ -1339,9 +1542,9 @@ describe('Combined Workflows', () => {
describe('Accessibility', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMarketplaceData.plugins = undefined
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
mockContextValues.plugins = undefined
mockContextValues.isLoading = false
mockContextValues.page = 1
})
it('should have semantic structure with collections', () => {
@@ -1370,11 +1573,13 @@ describe('Accessibility', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)

View File

@@ -13,6 +13,7 @@ type ListProps = {
showInstallButton?: boolean
cardContainerClassName?: string
cardRender?: (plugin: Plugin) => React.JSX.Element | null
onMoreClick?: () => void
emptyClassName?: string
}
const List = ({
@@ -22,6 +23,7 @@ const List = ({
showInstallButton,
cardContainerClassName,
cardRender,
onMoreClick,
emptyClassName,
}: ListProps) => {
return (
@@ -34,6 +36,7 @@ const List = ({
showInstallButton={showInstallButton}
cardContainerClassName={cardContainerClassName}
cardRender={cardRender}
onMoreClick={onMoreClick}
/>
)
}

View File

@@ -1,12 +1,12 @@
'use client'
import type { MarketplaceCollection } from '../types'
import type { SearchParamsFromCollection } from '@/app/components/plugins/marketplace/types'
import type { Plugin } from '@/app/components/plugins/types'
import { useLocale, useTranslation } from '#i18n'
import { RiArrowRightSLine } from '@remixicon/react'
import { getLanguage } from '@/i18n-config/language'
import { cn } from '@/utils/classnames'
import { useMarketplaceMoreClick } from '../atoms'
import CardWrapper from './card-wrapper'
type ListWithCollectionProps = {
@@ -15,6 +15,7 @@ type ListWithCollectionProps = {
showInstallButton?: boolean
cardContainerClassName?: string
cardRender?: (plugin: Plugin) => React.JSX.Element | null
onMoreClick?: (searchParams?: SearchParamsFromCollection) => void
}
const ListWithCollection = ({
marketplaceCollections,
@@ -22,10 +23,10 @@ const ListWithCollection = ({
showInstallButton,
cardContainerClassName,
cardRender,
onMoreClick,
}: ListWithCollectionProps) => {
const { t } = useTranslation()
const locale = useLocale()
const onMoreClick = useMarketplaceMoreClick()
return (
<>
@@ -43,10 +44,10 @@ const ListWithCollection = ({
<div className="system-xs-regular text-text-tertiary">{collection.description[getLanguage(locale)]}</div>
</div>
{
collection.searchable && (
collection.searchable && onMoreClick && (
<div
className="system-xs-medium flex cursor-pointer items-center text-text-accent "
onClick={() => onMoreClick(collection.search_params)}
onClick={() => onMoreClick?.(collection.search_params)}
>
{t('marketplace.viewMore', { ns: 'plugin' })}
<RiArrowRightSLine className="h-4 w-4" />

View File

@@ -1,26 +1,46 @@
'use client'
import type { Plugin } from '../../types'
import type { MarketplaceCollection } from '../types'
import { useTranslation } from '#i18n'
import { useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useMarketplaceContext } from '../context'
import SortDropdown from '../sort-dropdown'
import { useMarketplaceData } from '../state'
import List from './index'
type ListWrapperProps = {
marketplaceCollections: MarketplaceCollection[]
marketplaceCollectionPluginsMap: Record<string, Plugin[]>
showInstallButton?: boolean
}
const ListWrapper = ({
marketplaceCollections,
marketplaceCollectionPluginsMap,
showInstallButton,
}: ListWrapperProps) => {
const { t } = useTranslation()
const plugins = useMarketplaceContext(v => v.plugins)
const pluginsTotal = useMarketplaceContext(v => v.pluginsTotal)
const marketplaceCollectionsFromClient = useMarketplaceContext(v => v.marketplaceCollectionsFromClient)
const marketplaceCollectionPluginsMapFromClient = useMarketplaceContext(v => v.marketplaceCollectionPluginsMapFromClient)
const isLoading = useMarketplaceContext(v => v.isLoading)
const isSuccessCollections = useMarketplaceContext(v => v.isSuccessCollections)
const handleQueryPlugins = useMarketplaceContext(v => v.handleQueryPlugins)
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
const page = useMarketplaceContext(v => v.page)
const handleMoreClick = useMarketplaceContext(v => v.handleMoreClick)
const {
plugins,
pluginsTotal,
marketplaceCollections,
marketplaceCollectionPluginsMap,
isLoading,
page,
} = useMarketplaceData()
useEffect(() => {
if (
!marketplaceCollectionsFromClient?.length
&& isSuccessCollections
&& !searchPluginText
&& !filterPluginTags.length
) {
handleQueryPlugins()
}
}, [handleQueryPlugins, marketplaceCollections, marketplaceCollectionsFromClient, isSuccessCollections, searchPluginText, filterPluginTags])
return (
<div
@@ -46,10 +66,11 @@ const ListWrapper = ({
{
(!isLoading || page > 1) && (
<List
marketplaceCollections={marketplaceCollections || []}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap || {}}
marketplaceCollections={marketplaceCollectionsFromClient || marketplaceCollections}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMapFromClient || marketplaceCollectionPluginsMap}
plugins={plugins}
showInstallButton={showInstallButton}
onMoreClick={handleMoreClick}
/>
)
}

View File

@@ -1,5 +1,4 @@
'use client'
import type { ActivePluginType } from './constants'
import { useTranslation } from '#i18n'
import {
RiArchive2Line,
@@ -9,27 +8,35 @@ import {
RiPuzzle2Line,
RiSpeakAiLine,
} from '@remixicon/react'
import { useSetAtom } from 'jotai'
import { useCallback, useEffect } from 'react'
import { Trigger as TriggerIcon } from '@/app/components/base/icons/src/vender/plugin'
import { cn } from '@/utils/classnames'
import { searchModeAtom, useActivePluginType } from './atoms'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS, PLUGIN_TYPE_SEARCH_MAP } from './constants'
import { PluginCategoryEnum } from '../types'
import { useMarketplaceContext } from './context'
export const PLUGIN_TYPE_SEARCH_MAP = {
all: 'all',
model: PluginCategoryEnum.model,
tool: PluginCategoryEnum.tool,
agent: PluginCategoryEnum.agent,
extension: PluginCategoryEnum.extension,
datasource: PluginCategoryEnum.datasource,
trigger: PluginCategoryEnum.trigger,
bundle: 'bundle',
}
type PluginTypeSwitchProps = {
className?: string
showSearchParams?: boolean
}
const PluginTypeSwitch = ({
className,
showSearchParams,
}: PluginTypeSwitchProps) => {
const { t } = useTranslation()
const [activePluginType, handleActivePluginTypeChange] = useActivePluginType()
const setSearchMode = useSetAtom(searchModeAtom)
const activePluginType = useMarketplaceContext(s => s.activePluginType)
const handleActivePluginTypeChange = useMarketplaceContext(s => s.handleActivePluginTypeChange)
const options: Array<{
value: ActivePluginType
text: string
icon: React.ReactNode | null
}> = [
const options = [
{
value: PLUGIN_TYPE_SEARCH_MAP.all,
text: t('category.all', { ns: 'plugin' }),
@@ -72,6 +79,23 @@ const PluginTypeSwitch = ({
},
]
const handlePopState = useCallback(() => {
if (!showSearchParams)
return
// nuqs handles popstate automatically
const url = new URL(window.location.href)
const category = url.searchParams.get('category') || PLUGIN_TYPE_SEARCH_MAP.all
handleActivePluginTypeChange(category)
}, [showSearchParams, handleActivePluginTypeChange])
useEffect(() => {
// nuqs manages popstate internally, but we keep this for URL sync
window.addEventListener('popstate', handlePopState)
return () => {
window.removeEventListener('popstate', handlePopState)
}
}, [handlePopState])
return (
<div className={cn(
'flex shrink-0 items-center justify-center space-x-2 bg-background-body py-3',
@@ -88,9 +112,6 @@ const PluginTypeSwitch = ({
)}
onClick={() => {
handleActivePluginTypeChange(option.value)
if (PLUGIN_CATEGORY_WITH_COLLECTIONS.has(option.value)) {
setSearchMode(null)
}
}}
>
{option.icon}

View File

@@ -1,38 +0,0 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams,
) {
return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
})
}
export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined,
) {
return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,
enabled: queryParams !== undefined,
})
}

View File

@@ -26,19 +26,16 @@ vi.mock('#i18n', () => ({
}),
}))
// Mock marketplace state hooks
const { mockSearchPluginText, mockHandleSearchPluginTextChange, mockFilterPluginTags, mockHandleFilterPluginTagsChange } = vi.hoisted(() => {
return {
mockSearchPluginText: '',
mockHandleSearchPluginTextChange: vi.fn(),
mockFilterPluginTags: [] as string[],
mockHandleFilterPluginTagsChange: vi.fn(),
}
})
// Mock useMarketplaceContext
const mockContextValues = {
searchPluginText: '',
handleSearchPluginTextChange: vi.fn(),
filterPluginTags: [] as string[],
handleFilterPluginTagsChange: vi.fn(),
}
vi.mock('../atoms', () => ({
useSearchPluginText: () => [mockSearchPluginText, mockHandleSearchPluginTextChange],
useFilterPluginTags: () => [mockFilterPluginTags, mockHandleFilterPluginTagsChange],
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
}))
// Mock useTags hook
@@ -433,6 +430,9 @@ describe('SearchBoxWrapper', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPortalOpenState = false
// Reset context values
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
})
describe('Rendering', () => {
@@ -456,14 +456,28 @@ describe('SearchBoxWrapper', () => {
})
})
describe('Hook Integration', () => {
describe('Context Integration', () => {
it('should use searchPluginText from context', () => {
mockContextValues.searchPluginText = 'context search'
render(<SearchBoxWrapper />)
expect(screen.getByDisplayValue('context search')).toBeInTheDocument()
})
it('should call handleSearchPluginTextChange when search changes', () => {
render(<SearchBoxWrapper />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: 'new search' } })
expect(mockHandleSearchPluginTextChange).toHaveBeenCalledWith('new search')
expect(mockContextValues.handleSearchPluginTextChange).toHaveBeenCalledWith('new search')
})
it('should use filterPluginTags from context', () => {
mockContextValues.filterPluginTags = ['agent', 'rag']
render(<SearchBoxWrapper />)
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
})

View File

@@ -1,13 +1,15 @@
'use client'
import { useTranslation } from '#i18n'
import { useFilterPluginTags, useSearchPluginText } from '../atoms'
import { useMarketplaceContext } from '../context'
import SearchBox from './index'
const SearchBoxWrapper = () => {
const { t } = useTranslation()
const [searchPluginText, handleSearchPluginTextChange] = useSearchPluginText()
const [filterPluginTags, handleFilterPluginTagsChange] = useFilterPluginTags()
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
const handleSearchPluginTextChange = useMarketplaceContext(v => v.handleSearchPluginTextChange)
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
const handleFilterPluginTagsChange = useMarketplaceContext(v => v.handleFilterPluginTagsChange)
return (
<SearchBox

View File

@@ -1,9 +0,0 @@
import type { ActivePluginType } from './constants'
import { parseAsArrayOf, parseAsString, parseAsStringEnum } from 'nuqs/server'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
export const marketplaceSearchParamsParsers = {
category: parseAsStringEnum<ActivePluginType>(Object.values(PLUGIN_TYPE_SEARCH_MAP) as ActivePluginType[]).withDefault('all').withOptions({ history: 'replace', clearOnDefault: false }),
q: parseAsString.withDefault('').withOptions({ history: 'replace' }),
tags: parseAsArrayOf(parseAsString).withDefault([]).withOptions({ history: 'replace' }),
}

View File

@@ -1,3 +1,4 @@
import type { MarketplaceContextValue } from '../context'
import { fireEvent, render, screen, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
@@ -27,12 +28,18 @@ vi.mock('#i18n', () => ({
}),
}))
// Mock marketplace atoms with controllable values
let mockSort: { sortBy: string, sortOrder: string } = { sortBy: 'install_count', sortOrder: 'DESC' }
// Mock marketplace context with controllable values
let mockSort = { sortBy: 'install_count', sortOrder: 'DESC' }
const mockHandleSortChange = vi.fn()
vi.mock('../atoms', () => ({
useMarketplaceSort: () => [mockSort, mockHandleSortChange],
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (value: MarketplaceContextValue) => unknown) => {
const contextValue = {
sort: mockSort,
handleSortChange: mockHandleSortChange,
} as unknown as MarketplaceContextValue
return selector(contextValue)
},
}))
// Mock portal component with controllable open state

Some files were not shown because too many files have changed in this diff Show More