mirror of
https://github.com/langgenius/dify.git
synced 2026-01-10 00:04:14 +00:00
Compare commits
13 Commits
feat/pull-
...
feat/llm-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d3d8b35d9 | ||
|
|
c323028179 | ||
|
|
70149ea05e | ||
|
|
1d93f41fcf | ||
|
|
04f40303fd | ||
|
|
ececc5ec2c | ||
|
|
e83635ee5a | ||
|
|
d79372a46d | ||
|
|
bbd11c9e89 | ||
|
|
d132abcdb4 | ||
|
|
d60348572e | ||
|
|
0cff94d90e | ||
|
|
a7859de625 |
@@ -1,4 +1,4 @@
|
||||
name: Deploy Agent Dev
|
||||
name: Deploy Trigger Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/agent-dev"
|
||||
- "deploy/trigger-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
47
.github/workflows/translate-i18n-claude.yml
vendored
47
.github/workflows/translate-i18n-claude.yml
vendored
@@ -1,12 +1,10 @@
|
||||
name: Translate i18n Files with Claude Code
|
||||
|
||||
# Note: claude-code-action doesn't support push events directly.
|
||||
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
|
||||
# See: https://github.com/langgenius/dify/issues/30743
|
||||
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [i18n-sync]
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
files:
|
||||
@@ -89,35 +87,26 @@ jobs:
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
|
||||
# Triggered by push via trigger-i18n-sync.yml workflow
|
||||
# Validate required payload fields
|
||||
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
|
||||
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
|
||||
exit 1
|
||||
else
|
||||
# Push trigger - detect changed files from the push
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
# Handle edge case: first push or force push may have null/zero SHA
|
||||
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
# Fallback to comparing with parent commit
|
||||
BEFORE_SHA="HEAD~1"
|
||||
fi
|
||||
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
|
||||
changed=$(git diff --name-only "$BEFORE_SHA" ${{ github.sha }} -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
|
||||
echo "CHANGED_FILES=$changed" >> $GITHUB_OUTPUT
|
||||
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=incremental" >> $GITHUB_OUTPUT
|
||||
|
||||
# Decode the base64-encoded diff from the trigger workflow
|
||||
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
|
||||
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
|
||||
echo "Warning: Failed to decode base64 diff payload" >&2
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
elif [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
# Generate detailed diff for the push
|
||||
git diff "$BEFORE_SHA"..${{ github.sha }} -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
|
||||
if [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "Unsupported event type: ${{ github.event_name }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Truncate diff if too large (keep first 50KB)
|
||||
|
||||
66
.github/workflows/trigger-i18n-sync.yml
vendored
66
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -1,66 +0,0 @@
|
||||
name: Trigger i18n Sync on Push
|
||||
|
||||
# This workflow bridges the push event to repository_dispatch
|
||||
# because claude-code-action doesn't support push events directly.
|
||||
# See: https://github.com/langgenius/dify/issues/30743
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect
|
||||
run: |
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
# Handle edge case: force push may have null/zero SHA
|
||||
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BEFORE_SHA="HEAD~1"
|
||||
fi
|
||||
|
||||
# Detect changed i18n files
|
||||
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
|
||||
echo "changed_files=$changed" >> $GITHUB_OUTPUT
|
||||
|
||||
# Generate diff for context
|
||||
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
|
||||
|
||||
# Truncate if too large (keep first 50KB to match receiving workflow)
|
||||
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
|
||||
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
|
||||
|
||||
# Base64 encode the diff for safe JSON transport (portable, single-line)
|
||||
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
|
||||
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
|
||||
|
||||
if [ -n "$changed" ]; then
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
echo "Detected changed files: $changed"
|
||||
else
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
echo "No i18n changes detected"
|
||||
fi
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
event-type: i18n-sync
|
||||
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -209,7 +209,6 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
@@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: dict | None = Field(default=None)
|
||||
json_schema: str | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: dict | None) -> dict | None:
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(schema)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
@@ -26,6 +26,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AdvancedChatAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
|
||||
@@ -358,6 +358,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
|
||||
# For ANSWER nodes, check if we need to send a message_replace event
|
||||
# Only send if the final output differs from the accumulated task_state.answer
|
||||
# This happens when variables were updated by variable_assigner during workflow execution
|
||||
if event.node_type == NodeType.ANSWER and event.outputs:
|
||||
final_answer = event.outputs.get("answer")
|
||||
if final_answer is not None and final_answer != self._task_state.answer:
|
||||
logger.info(
|
||||
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
|
||||
final_answer,
|
||||
self._task_state.answer,
|
||||
)
|
||||
# Update the task state answer
|
||||
self._task_state.answer = str(final_answer)
|
||||
# Send message_replace event to update the UI
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=str(final_answer),
|
||||
reason="variable_update",
|
||||
)
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
@@ -75,24 +76,12 @@ class BaseAppGenerator:
|
||||
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||
|
||||
# Check if all files are converted to File
|
||||
invalid_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict)
|
||||
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
|
||||
]
|
||||
if invalid_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
|
||||
|
||||
invalid_list_dict_keys = [
|
||||
k
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
and any(isinstance(item, dict) for item in v)
|
||||
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
|
||||
]
|
||||
if invalid_list_dict_keys:
|
||||
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
|
||||
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||
raise ValueError("Invalid input type")
|
||||
if any(
|
||||
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||
):
|
||||
raise ValueError("Invalid input type")
|
||||
|
||||
return user_inputs
|
||||
|
||||
@@ -189,8 +178,12 @@ class BaseAppGenerator:
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
# Memory Module
|
||||
|
||||
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
|
||||
|
||||
## Overview
|
||||
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory Architecture │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ TokenBufferMemory │ │
|
||||
│ │ Scope: Conversation │ │
|
||||
│ │ Storage: Database (Message table) │ │
|
||||
│ │ Key: conversation_id │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: Object Storage (JSON file) │ │
|
||||
│ │ Key: (app_id, conversation_id, node_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TokenBufferMemory (Existing)
|
||||
|
||||
### Purpose
|
||||
|
||||
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Conversation-scoped**: All messages within a conversation are candidates
|
||||
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
|
||||
- **Token-limited**: Truncates history to fit within `max_token_limit`
|
||||
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Message Table TokenBufferMemory LLM
|
||||
│ │ │
|
||||
│ SELECT * FROM messages │ │
|
||||
│ WHERE conversation_id = ? │ │
|
||||
│ ORDER BY created_at DESC │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ extract_thread_messages() │
|
||||
│ │ │
|
||||
│ build_prompt_message_with_files() │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage]
|
||||
│ ├───────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Extraction
|
||||
|
||||
When a user regenerates a response, a new thread is created:
|
||||
|
||||
```
|
||||
Message A (user)
|
||||
└── Message A' (assistant)
|
||||
└── Message B (user)
|
||||
└── Message B' (assistant)
|
||||
└── Message A'' (assistant, regenerated) ← New thread
|
||||
└── Message C (user)
|
||||
└── Message C' (assistant)
|
||||
```
|
||||
|
||||
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory (To Be Implemented)
|
||||
|
||||
### Purpose
|
||||
|
||||
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
### Use Cases
|
||||
|
||||
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design Decisions
|
||||
|
||||
#### Storage: Object Storage for Messages (No New Database Table)
|
||||
|
||||
| Aspect | Database | Object Storage |
|
||||
| ------------------------- | -------------------- | ------------------ |
|
||||
| Cost | High | Low |
|
||||
| Query Flexibility | High | Low |
|
||||
| Schema Changes | Migration required | None |
|
||||
| Consistency with existing | ConversationVariable | File uploads, logs |
|
||||
|
||||
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
|
||||
|
||||
**What is stored in Object Storage:**
|
||||
|
||||
- Message content (text)
|
||||
- Message metadata (role, token_count, created_at)
|
||||
- File references (upload_file_id, tool_file_id, etc.)
|
||||
- Thread relationships (message_id, parent_message_id)
|
||||
|
||||
**What still requires Database queries:**
|
||||
|
||||
- File reconstruction: When reading node memory, file references are used to query
|
||||
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
|
||||
complete `File` objects with storage_key, mime_type, etc.
|
||||
|
||||
**Why this hybrid approach:**
|
||||
|
||||
- No database migration required (no new tables)
|
||||
- Message data may be large, object storage is cost-effective
|
||||
- File metadata is already in database, no need to duplicate
|
||||
- Aligns with existing storage patterns (file uploads, logs)
|
||||
|
||||
#### Storage Key Format
|
||||
|
||||
```
|
||||
node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
```
|
||||
|
||||
#### Data Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"messages": [
|
||||
{
|
||||
"message_id": "msg-001",
|
||||
"parent_message_id": null,
|
||||
"role": "user",
|
||||
"content": "Analyze this image",
|
||||
"files": [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "file-uuid-123",
|
||||
"belongs_to": "user"
|
||||
}
|
||||
],
|
||||
"token_count": 15,
|
||||
"created_at": "2026-01-07T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"message_id": "msg-002",
|
||||
"parent_message_id": "msg-001",
|
||||
"role": "assistant",
|
||||
"content": "This is a landscape image...",
|
||||
"files": [],
|
||||
"token_count": 50,
|
||||
"created_at": "2026-01-07T10:00:01Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Thread Support
|
||||
|
||||
Node memory also supports thread extraction (for regeneration scenarios):
|
||||
|
||||
```python
|
||||
def _extract_thread(
|
||||
self,
|
||||
messages: list[NodeMemoryMessage],
|
||||
current_message_id: str
|
||||
) -> list[NodeMemoryMessage]:
|
||||
"""
|
||||
Extract messages belonging to the thread of current_message_id.
|
||||
Similar to extract_thread_messages() in TokenBufferMemory.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### File Handling
|
||||
|
||||
Files are stored as references (not full metadata):
|
||||
|
||||
```python
|
||||
class NodeMemoryFile(BaseModel):
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None # for local_file
|
||||
tool_file_id: str | None # for tool_file
|
||||
url: str | None # for remote_url
|
||||
belongs_to: str # user / assistant
|
||||
```
|
||||
|
||||
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
|
||||
|
||||
### API Design
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
...
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
message_id: str,
|
||||
parent_message_id: str | None,
|
||||
user_content: str,
|
||||
user_files: Sequence[File],
|
||||
assistant_content: str,
|
||||
assistant_files: Sequence[File],
|
||||
) -> None:
|
||||
"""
|
||||
Append a dialogue turn (user + assistant) to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param message_id: Current message ID (from Message table)
|
||||
:param parent_message_id: Parent message ID (for thread tracking)
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
current_message_id: str,
|
||||
tenant_id: str,
|
||||
max_token_limit: int = 2000,
|
||||
file_upload_config: FileUploadConfig | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
:param current_message_id: Current message ID (for thread extraction)
|
||||
:param tenant_id: Tenant ID (for file reconstruction)
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param file_upload_config: File upload configuration
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
...
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all messages in this node's memory.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Object Storage NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ storage.load(key) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ JSON data │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ _extract_thread() │
|
||||
│ │ │
|
||||
│ _rebuild_files() via file_factory │
|
||||
│ │ │
|
||||
│ _build_prompt_messages() │
|
||||
│ │ │
|
||||
│ _truncate_by_tokens() │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
│ │◀── LLM execution complete │
|
||||
│ │ │
|
||||
│ │◀── add_messages() │
|
||||
│ │ │
|
||||
│ storage.save(key, data) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Integration with LLM Node
|
||||
|
||||
```python
|
||||
# In LLM Node execution
|
||||
|
||||
# 1. Fetch memory based on mode
|
||||
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
memory = fetch_node_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_id=self.node_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
|
||||
# Conversation-level memory (existing behavior)
|
||||
memory = fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
|
||||
# 2. Get history for context
|
||||
if memory:
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
history = memory.get_history_prompt_messages(
|
||||
current_message_id=current_message_id,
|
||||
tenant_id=tenant_id,
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
else: # TokenBufferMemory
|
||||
history = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
prompt_messages = [*history, *current_messages]
|
||||
else:
|
||||
prompt_messages = current_messages
|
||||
|
||||
# 3. Call LLM
|
||||
response = model_instance.invoke(prompt_messages)
|
||||
|
||||
# 4. Append to node memory (only for NodeTokenBufferMemory)
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
memory.add_messages(
|
||||
message_id=message_id,
|
||||
parent_message_id=parent_message_id,
|
||||
user_content=user_input,
|
||||
user_files=user_files,
|
||||
assistant_content=response.content,
|
||||
assistant_files=response_files,
|
||||
)
|
||||
memory.flush()
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
# Existing fields
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
# Memory mode (new)
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
**Mode Behavior:**
|
||||
|
||||
| Mode | Memory Class | Scope | Availability |
|
||||
| -------------- | --------------------- | ------------------------ | ------------- |
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
|
||||
> fall back to no memory or raise a configuration error.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | Object Storage (JSON) |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via file references) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
|
||||
2. **Concurrency**: Consider Redis lock for concurrent node executions
|
||||
3. **Compression**: Compress large memory files to reduce storage costs
|
||||
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory
|
||||
@@ -1,15 +0,0 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeMemoryData,
|
||||
NodeMemoryFile,
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeMemoryData",
|
||||
"NodeMemoryFile",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
@@ -1,83 +0,0 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
@@ -1,353 +0,0 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- Storage is indexed by workflow_run_id (each execution stores one turn)
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeMemoryFile(BaseModel):
|
||||
"""File reference stored in node memory."""
|
||||
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None = None
|
||||
tool_file_id: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class NodeMemoryTurn(BaseModel):
|
||||
"""A single dialogue turn (user + assistant) in node memory."""
|
||||
|
||||
user_content: str = ""
|
||||
user_files: list[NodeMemoryFile] = []
|
||||
assistant_content: str = ""
|
||||
assistant_files: list[NodeMemoryFile] = []
|
||||
|
||||
|
||||
class NodeMemoryData(BaseModel):
|
||||
"""Root data structure for node memory storage."""
|
||||
|
||||
version: int = 1
|
||||
# Key: workflow_run_id, Value: dialogue turn
|
||||
turns: dict[str, NodeMemoryTurn] = {}
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history, stored in object storage.
|
||||
|
||||
Key design: Thread tracking is delegated to Message table's parent_message_id.
|
||||
Storage is indexed by workflow_run_id for easy filtering.
|
||||
|
||||
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param tenant_id: Tenant ID for file reconstruction
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
|
||||
self._data: NodeMemoryData | None = None
|
||||
self._dirty = False
|
||||
|
||||
def _load(self) -> NodeMemoryData:
|
||||
"""Load data from object storage."""
|
||||
if self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
raw = storage.load_once(self._storage_key)
|
||||
self._data = NodeMemoryData.model_validate_json(raw)
|
||||
except Exception:
|
||||
# File not found or parse error, start fresh
|
||||
self._data = NodeMemoryData()
|
||||
|
||||
return self._data
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save data to object storage."""
|
||||
if self._data is not None:
|
||||
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
|
||||
self._dirty = False
|
||||
|
||||
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
|
||||
"""Convert File object to NodeMemoryFile reference."""
|
||||
return NodeMemoryFile(
|
||||
type=file.type.value if hasattr(file.type, "value") else str(file.type),
|
||||
transfer_method=(
|
||||
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
|
||||
),
|
||||
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
|
||||
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
)
|
||||
|
||||
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
|
||||
"""Convert NodeMemoryFile to mapping for file_factory."""
|
||||
mapping: dict = {
|
||||
"type": memory_file.type,
|
||||
"transfer_method": memory_file.transfer_method,
|
||||
}
|
||||
if memory_file.upload_file_id:
|
||||
mapping["upload_file_id"] = memory_file.upload_file_id
|
||||
if memory_file.tool_file_id:
|
||||
mapping["tool_file_id"] = memory_file.tool_file_id
|
||||
if memory_file.url:
|
||||
mapping["url"] = memory_file.url
|
||||
return mapping
|
||||
|
||||
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
|
||||
"""Rebuild File objects from NodeMemoryFile references."""
|
||||
if not memory_files:
|
||||
return []
|
||||
|
||||
from factories import file_factory
|
||||
|
||||
files = []
|
||||
for mf in memory_files:
|
||||
try:
|
||||
mapping = self._memory_file_to_mapping(mf)
|
||||
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
|
||||
files.append(file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rebuild file from memory: %s", e)
|
||||
continue
|
||||
return files
|
||||
|
||||
def _build_prompt_message(
|
||||
self,
|
||||
role: str,
|
||||
content: str,
|
||||
files: list[File],
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
|
||||
) -> PromptMessage:
|
||||
"""Build PromptMessage from content and files."""
|
||||
from core.file import file_manager
|
||||
|
||||
if not files:
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=content)
|
||||
else:
|
||||
return AssistantPromptMessage(content=content)
|
||||
|
||||
# Build multimodal content
|
||||
prompt_contents: list = []
|
||||
for file in files:
|
||||
try:
|
||||
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
|
||||
prompt_contents.append(prompt_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert file to prompt content: %s", e)
|
||||
continue
|
||||
|
||||
prompt_contents.append(TextPromptMessageContent(data=content))
|
||||
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=prompt_contents)
|
||||
else:
|
||||
return AssistantPromptMessage(content=prompt_contents)
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
# Query messages for this conversation
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
messages = db.session.scalars(stmt.limit(500)).all()
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
workflow_run_ids = []
|
||||
for msg in reversed(thread_messages):
|
||||
if msg.workflow_run_id:
|
||||
workflow_run_ids.append(msg.workflow_run_id)
|
||||
|
||||
return workflow_run_ids
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
user_content: str,
|
||||
user_files: Sequence[File] | None = None,
|
||||
assistant_content: str = "",
|
||||
assistant_files: Sequence[File] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a dialogue turn to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param workflow_run_id: Current workflow execution ID
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
data = self._load()
|
||||
|
||||
# Convert files to memory file references
|
||||
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
|
||||
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
|
||||
|
||||
# Store the turn indexed by workflow_run_id
|
||||
data.turns[workflow_run_id] = NodeMemoryTurn(
|
||||
user_content=user_content,
|
||||
user_files=user_memory_files,
|
||||
assistant_content=assistant_content,
|
||||
assistant_files=assistant_memory_files,
|
||||
)
|
||||
|
||||
self._dirty = True
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Thread tracking is handled by querying Message table's parent_message_id structure.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: unused, for interface compatibility
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
|
||||
_ = message_limit
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
data = self._load()
|
||||
|
||||
if not data.turns:
|
||||
return []
|
||||
|
||||
# Get workflow_run_ids for current thread from Message table
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Build prompt messages in thread order
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for wf_run_id in thread_workflow_run_ids:
|
||||
turn = data.turns.get(wf_run_id)
|
||||
if not turn:
|
||||
# This workflow execution didn't have node memory stored
|
||||
continue
|
||||
|
||||
# Build user message
|
||||
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
|
||||
user_msg = self._build_prompt_message(
|
||||
role="user",
|
||||
content=turn.user_content,
|
||||
files=user_files,
|
||||
detail=detail,
|
||||
)
|
||||
prompt_messages.append(user_msg)
|
||||
|
||||
# Build assistant message
|
||||
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
|
||||
assistant_msg = self._build_prompt_message(
|
||||
role="assistant",
|
||||
content=turn.assistant_content,
|
||||
files=assistant_files,
|
||||
detail=detail,
|
||||
)
|
||||
prompt_messages.append(assistant_msg)
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
if self._dirty:
|
||||
self._save()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages in this node's memory."""
|
||||
self._data = NodeMemoryData()
|
||||
self._save()
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if node memory exists in storage."""
|
||||
return storage.exists(self._storage_key)
|
||||
@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
class TokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
@@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
@@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -6,13 +5,6 @@ from pydantic import BaseModel
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class MemoryMode(StrEnum):
|
||||
"""Memory mode for LLM nodes."""
|
||||
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
@@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -63,7 +63,6 @@ class NodeType(StrEnum):
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
GROUP = "group"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
|
||||
@@ -307,14 +307,7 @@ class Graph:
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
# Filter out UI-only node types:
|
||||
# - custom-note: top-level type (node_config.type == "custom-note")
|
||||
# - group: data-level type (node_config.data.type == "group")
|
||||
node_configs = [
|
||||
node_config for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note"
|
||||
and node_config.get("data", {}).get("type", "") != "group"
|
||||
]
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
@@ -125,11 +125,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_started(event)
|
||||
return
|
||||
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
@@ -169,11 +164,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_success(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
@@ -236,11 +226,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Check if this is a virtual node (extraction node)
|
||||
if self._is_virtual_node(event.node_id):
|
||||
self._handle_virtual_node_failed(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
@@ -360,57 +345,3 @@ class EventHandler:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
def _is_virtual_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id represents a virtual sub-node.
|
||||
|
||||
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
|
||||
We check if the part before '.' exists in graph nodes.
|
||||
"""
|
||||
if "." in node_id:
|
||||
parent_id = node_id.rsplit(".", 1)[0]
|
||||
return parent_id in self._graph.nodes
|
||||
return False
|
||||
|
||||
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle virtual node started event.
|
||||
|
||||
Virtual nodes don't need full execution tracking, just collect the event.
|
||||
"""
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle virtual node success event.
|
||||
|
||||
Virtual nodes (extraction nodes) need special handling:
|
||||
- Store outputs in variable pool (for reference by other nodes)
|
||||
- Accumulate token usage
|
||||
- Collect the event for logging
|
||||
- Do NOT process edges or enqueue next nodes (parent node handles that)
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle virtual node failed event.
|
||||
|
||||
Virtual nodes (extraction nodes) failures are collected for logging,
|
||||
but the parent node is responsible for handling the error.
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Collect the event for logging
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_type: str = ""
|
||||
provider_id: str = ""
|
||||
|
||||
# Virtual node fields for extraction
|
||||
is_virtual: bool = False
|
||||
parent_node_id: str | None = None
|
||||
extraction_source: str | None = None # e.g., "llm1.context"
|
||||
extraction_prompt: str | None = None
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
|
||||
@@ -1,13 +1,5 @@
|
||||
from .entities import (
|
||||
BaseIterationNodeData,
|
||||
BaseIterationState,
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
VirtualNodeConfig,
|
||||
)
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
@@ -16,7 +8,4 @@ __all__ = [
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"LLMUsageTrackingMixin",
|
||||
"VirtualNodeConfig",
|
||||
"VirtualNodeExecutionError",
|
||||
"VirtualNodeExecutor",
|
||||
]
|
||||
|
||||
@@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class VirtualNodeConfig(BaseModel):
|
||||
"""Configuration for a virtual sub-node embedded within a parent node."""
|
||||
|
||||
# Local ID within parent node (e.g., "ext_1")
|
||||
# Will be converted to global ID: "{parent_id}.{id}"
|
||||
id: str
|
||||
|
||||
# Node type (e.g., "llm", "code", "tool")
|
||||
type: str
|
||||
|
||||
# Full node data configuration
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
def get_global_id(self, parent_node_id: str) -> str:
|
||||
"""Get the global node ID by combining parent ID and local ID."""
|
||||
return f"{parent_node_id}.{self.id}"
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
@@ -193,9 +175,6 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Virtual sub-nodes that execute before the main node
|
||||
virtual_nodes: list[VirtualNodeConfig] = []
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
|
||||
@@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
@@ -271,52 +270,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual sub-nodes defined in node configuration.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the main node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool (via event handling)
|
||||
- Supports retry via parent node's retry config
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
"""
|
||||
from .virtual_node_executor import VirtualNodeExecutor
|
||||
|
||||
virtual_nodes = self.node_data.virtual_nodes
|
||||
if not virtual_nodes:
|
||||
return {}
|
||||
|
||||
executor = VirtualNodeExecutor(
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
parent_node_id=self._node_id,
|
||||
parent_retry_config=self.retry_config,
|
||||
)
|
||||
|
||||
return (yield from executor.execute_virtual_nodes(virtual_nodes))
|
||||
|
||||
@property
|
||||
def virtual_node_outputs(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the outputs from virtual sub-nodes.
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
"""
|
||||
return self._virtual_node_outputs
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute virtual sub-nodes before main node execution
|
||||
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
"""
|
||||
Virtual Node Executor for running embedded sub-nodes within a parent node.
|
||||
|
||||
This module handles the execution of virtual nodes defined in a parent node's
|
||||
`virtual_nodes` configuration. Virtual nodes are complete node definitions
|
||||
that execute before the parent node.
|
||||
|
||||
Example configuration:
|
||||
virtual_nodes:
|
||||
- id: ext_1
|
||||
type: llm
|
||||
data:
|
||||
model: {...}
|
||||
prompt_template: [...]
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import RetryConfig, VirtualNodeConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class VirtualNodeExecutionError(Exception):
|
||||
"""Error during virtual node execution"""
|
||||
|
||||
def __init__(self, node_id: str, original_error: Exception):
|
||||
self.node_id = node_id
|
||||
self.original_error = original_error
|
||||
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
|
||||
|
||||
|
||||
class VirtualNodeExecutor:
|
||||
"""
|
||||
Executes virtual sub-nodes embedded within a parent node.
|
||||
|
||||
Virtual nodes are complete node definitions that execute before the parent node.
|
||||
Each virtual node:
|
||||
- Has its own global ID: "{parent_id}.{local_id}"
|
||||
- Generates standard node events
|
||||
- Stores outputs in the variable pool
|
||||
- Supports retry via parent node's retry config
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
parent_node_id: str,
|
||||
parent_retry_config: RetryConfig | None = None,
|
||||
):
|
||||
self._graph_init_params = graph_init_params
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._parent_node_id = parent_node_id
|
||||
self._parent_retry_config = parent_retry_config or RetryConfig()
|
||||
|
||||
def execute_virtual_nodes(
|
||||
self,
|
||||
virtual_nodes: list[VirtualNodeConfig],
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute all virtual nodes in order.
|
||||
|
||||
Args:
|
||||
virtual_nodes: List of virtual node configurations
|
||||
|
||||
Yields:
|
||||
Node events from each virtual node execution
|
||||
|
||||
Returns:
|
||||
dict mapping local_id -> outputs dict
|
||||
"""
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for vnode_config in virtual_nodes:
|
||||
global_id = vnode_config.get_global_id(self._parent_node_id)
|
||||
|
||||
# Execute with retry
|
||||
outputs = yield from self._execute_with_retry(vnode_config, global_id)
|
||||
results[vnode_config.id] = outputs
|
||||
|
||||
return results
|
||||
|
||||
def _execute_with_retry(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute virtual node with retry support.
|
||||
"""
|
||||
retry_config = self._parent_retry_config
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(retry_config.max_retries + 1):
|
||||
try:
|
||||
return (yield from self._execute_single_node(vnode_config, global_id))
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt < retry_config.max_retries:
|
||||
# Yield retry event
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid4()),
|
||||
node_id=global_id,
|
||||
node_type=self._get_node_type(vnode_config.type),
|
||||
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
start_at=naive_utc_now(),
|
||||
error=str(e),
|
||||
retry_index=attempt + 1,
|
||||
)
|
||||
|
||||
time.sleep(retry_config.retry_interval_seconds)
|
||||
continue
|
||||
|
||||
raise VirtualNodeExecutionError(global_id, e) from e
|
||||
|
||||
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
|
||||
|
||||
def _execute_single_node(
|
||||
self,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
global_id: str,
|
||||
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
|
||||
"""
|
||||
Execute a single virtual node by instantiating and running it.
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Build node config
|
||||
node_config: dict[str, Any] = {
|
||||
"id": global_id,
|
||||
"data": {
|
||||
**vnode_config.data,
|
||||
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
|
||||
},
|
||||
}
|
||||
|
||||
# Get the node class for this type
|
||||
node_type = self._get_node_type(vnode_config.type)
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
node_version = str(vnode_config.data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
raise ValueError(f"No class found for node type: {node_type}")
|
||||
|
||||
# Instantiate the node
|
||||
node = node_cls(
|
||||
id=global_id,
|
||||
config=node_config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
)
|
||||
|
||||
# Run and collect events
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
for event in node.run():
|
||||
# Mark event as coming from virtual node
|
||||
self._mark_event_as_virtual(event, vnode_config)
|
||||
yield event
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
outputs = event.node_run_result.outputs or {}
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
raise Exception(event.error or "Virtual node execution failed")
|
||||
|
||||
return outputs
|
||||
|
||||
def _mark_event_as_virtual(
|
||||
self,
|
||||
event: GraphNodeEventBase,
|
||||
vnode_config: VirtualNodeConfig,
|
||||
) -> None:
|
||||
"""Mark event as coming from a virtual node."""
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
event.is_virtual = True
|
||||
event.parent_node_id = self._parent_node_id
|
||||
|
||||
def _get_node_type(self, type_str: str) -> NodeType:
|
||||
"""Convert type string to NodeType enum."""
|
||||
type_mapping = {
|
||||
"llm": NodeType.LLM,
|
||||
"code": NodeType.CODE,
|
||||
"tool": NodeType.TOOL,
|
||||
"if-else": NodeType.IF_ELSE,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"template-transform": NodeType.TEMPLATE_TRANSFORM,
|
||||
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
|
||||
"http-request": NodeType.HTTP_REQUEST,
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
}
|
||||
return type_mapping.get(type_str, NodeType.LLM)
|
||||
@@ -8,13 +8,12 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
@@ -87,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
node_id: str = "",
|
||||
) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param app_id: Application ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param node_data_memory: Memory configuration
|
||||
:param model_instance: Model instance for token counting
|
||||
:param node_id: Node ID in the workflow (required for node mode)
|
||||
:return: Memory instance or None if not applicable
|
||||
"""
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id from variable pool (required for both modes in Chatflow)
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data_memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
if not node_id:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
|
||||
@@ -16,8 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
@@ -209,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
@@ -304,41 +301,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Write to Node Memory if in node memory mode
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else ""
|
||||
|
||||
if workflow_run_id:
|
||||
# Resolve the query template to get actual user content
|
||||
# query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}"
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
|
||||
# Get user files from sys.files
|
||||
user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
user_files: list[File] = []
|
||||
if isinstance(user_files_var, ArrayFileSegment):
|
||||
user_files = list(user_files_var.value)
|
||||
elif isinstance(user_files_var, FileSegment):
|
||||
user_files = [user_files_var.value]
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=actual_query,
|
||||
user_files=user_files,
|
||||
assistant_content=clean_text,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
memory.flush()
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@@ -598,22 +566,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
@staticmethod
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
model_mode: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_mode, prompt_messages=context_messages
|
||||
)
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@@ -826,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: BaseMemory | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -1385,7 +1337,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -1402,7 +1354,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
@@ -42,22 +43,25 @@ class StartNode(Node[StartNodeData]):
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
# If no value provided, skip further processing for this key
|
||||
if not value:
|
||||
continue
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"JSON object for '{key}' must be an object")
|
||||
|
||||
# Overwrite with normalized dict to ensure downstream consistency
|
||||
node_inputs[key] = value
|
||||
|
||||
# If schema exists, then validate against it
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
Draft7Validator(schema).validate(value)
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = json_value
|
||||
|
||||
@@ -89,20 +89,18 @@ class ToolNode(Node[ToolNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters (use virtual_node_outputs from base class)
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
virtual_node_outputs=self.virtual_node_outputs,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
virtual_node_outputs=self.virtual_node_outputs,
|
||||
)
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
@@ -178,7 +176,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
virtual_node_outputs: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
@@ -187,17 +184,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
for_log (bool): Whether to generate parameters for logging.
|
||||
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
|
||||
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
|
||||
with global IDs like "{parent_id}.{local_id}".
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
virtual_node_outputs = virtual_node_outputs or {}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
@@ -207,25 +199,14 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
# Check if this references a virtual node output (local ID like [ext_1, text])
|
||||
selector = tool_input.value
|
||||
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
|
||||
# Reference to virtual node output
|
||||
local_id = selector[0]
|
||||
var_name = selector[1]
|
||||
outputs = virtual_node_outputs.get(local_id, {})
|
||||
parameter_value = outputs.get(var_name)
|
||||
else:
|
||||
# Normal variable reference
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
template = str(tool_input.value)
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
|
||||
266
api/tests/fixtures/pav-test-extraction.yml
vendored
266
api/tests/fixtures/pav-test-extraction.yml
vendored
@@ -1,266 +0,0 @@
|
||||
app:
|
||||
description: Test for variable extraction feature
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: pav-test-extraction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1767773675796-llm
|
||||
source: '1767773675796'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: llm
|
||||
targetType: tool
|
||||
id: llm-source-1767773709491-target
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: '1767773709491'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: tool
|
||||
targetType: answer
|
||||
id: tool-source-answer-target
|
||||
source: '1767773709491'
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1767773675796'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: ''
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: true
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
prompt_template:
|
||||
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
|
||||
role: system
|
||||
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
|
||||
|
||||
请与用户进行对话,了解他们的搜索需求。
|
||||
|
||||
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
is_team_authorization: true
|
||||
paramSchemas:
|
||||
- auto_generate: null
|
||||
default: null
|
||||
form: llm
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
ja_JP: used for searching
|
||||
pt_BR: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
label:
|
||||
en_US: Query string
|
||||
ja_JP: Query string
|
||||
pt_BR: Query string
|
||||
zh_Hans: 查询语句
|
||||
llm_description: key words for searching
|
||||
max: null
|
||||
min: null
|
||||
name: query
|
||||
options: []
|
||||
placeholder: null
|
||||
precision: null
|
||||
required: true
|
||||
scope: null
|
||||
template: null
|
||||
type: string
|
||||
params:
|
||||
query: ''
|
||||
plugin_id: langgenius/google
|
||||
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
|
||||
provider_id: langgenius/google/google
|
||||
provider_name: langgenius/google/google
|
||||
provider_type: builtin
|
||||
selected: false
|
||||
title: GoogleSearch
|
||||
tool_configurations: {}
|
||||
tool_description: A tool for performing a Google SERP search and extracting
|
||||
snippets and webpages.Input should be a search query.
|
||||
tool_label: GoogleSearch
|
||||
tool_name: google_search
|
||||
tool_node_version: '2'
|
||||
tool_parameters:
|
||||
query:
|
||||
type: variable
|
||||
value:
|
||||
- ext_1
|
||||
- text
|
||||
type: tool
|
||||
virtual_nodes:
|
||||
- data:
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
context:
|
||||
enabled: false
|
||||
prompt_template:
|
||||
- role: user
|
||||
text: '{{#llm.context#}}'
|
||||
- role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
title: 提取搜索关键词
|
||||
id: ext_1
|
||||
type: llm
|
||||
height: 52
|
||||
id: '1767773709491'
|
||||
position:
|
||||
x: 682
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 682
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '搜索结果:
|
||||
|
||||
{{#1767773709491.text#}}
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 984
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 984
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: 151
|
||||
y: 141.5
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
|
||||
specifically testing the ANSWER node message_replace logic.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
from core.workflow.enums import NodeType
|
||||
from models import EndUser
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAnswerNodeMessageReplace:
|
||||
"""Test cases for ANSWER node message_replace event logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=AdvancedChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
entity.workflow_run_id = "test-workflow-run-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
entity.query = "test query"
|
||||
entity.files = []
|
||||
entity.extras = {}
|
||||
entity.trace_manager = None
|
||||
entity.inputs = {}
|
||||
entity.invoke_from = "debugger"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow."""
|
||||
workflow = Mock()
|
||||
workflow.id = "test-workflow-id"
|
||||
workflow.features_dict = {}
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock()
|
||||
manager.listen.return_value = []
|
||||
manager.graph_runtime_state = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "advanced_chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.query = "test query"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock end user."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = "test-user-id"
|
||||
user.session_id = "test-session-id"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_draft_var_saver_factory(self):
|
||||
"""Create a mock draft variable saver factory."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_workflow,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_user,
|
||||
mock_draft_var_saver_factory,
|
||||
):
|
||||
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
|
||||
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
workflow=mock_workflow,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
user=mock_user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=mock_draft_var_saver_factory,
|
||||
)
|
||||
# Initialize workflow run id to avoid validation errors
|
||||
pipeline._workflow_run_id = "test-workflow-run-id"
|
||||
# Mock the message cycle manager methods we need to track
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
|
||||
return pipeline
|
||||
|
||||
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
|
||||
"""
|
||||
Test that when an ANSWER node's final output differs from accumulated answer,
|
||||
a message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "initial answer"
|
||||
|
||||
# Create ANSWER node succeeded event with different final output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "updated final answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter to avoid extra processing
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert
|
||||
assert pipeline._task_state.answer == "updated final answer"
|
||||
# Verify message_replace was called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="updated final answer", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is the same as accumulated answer,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "same answer"
|
||||
|
||||
# Create ANSWER node succeeded event with same output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "same answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "same answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's output is None or missing 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with None output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": None},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node has empty outputs dict,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with empty outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's outputs don't contain 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event without 'answer' key in outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"other_key": "some value"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Test with LLM node
|
||||
llm_event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-llm-execution-id",
|
||||
node_id="test-llm-node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(llm_event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_end_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that END nodes don't trigger message_replace events even with 'answer' output.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create END node succeeded event with answer output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-end-execution-id",
|
||||
node_id="test-end-node",
|
||||
node_type=NodeType.END,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is numeric,
|
||||
it gets converted to string properly.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "text answer"
|
||||
|
||||
# Create ANSWER node succeeded event with numeric output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": 12345},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should be converted to string
|
||||
assert pipeline._task_state.answer == "12345"
|
||||
# Verify message_replace was called with string
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="12345", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_files_are_recorded(self, pipeline):
|
||||
"""
|
||||
Test that ANSWER nodes properly record files from outputs.
|
||||
"""
|
||||
# Arrange
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with files
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={
|
||||
"answer": "same answer",
|
||||
"files": [
|
||||
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: files should be recorded
|
||||
assert len(pipeline._recorded_files) == 1
|
||||
assert pipeline._recorded_files[0] == event.outputs["files"][0]
|
||||
@@ -1,77 +0,0 @@
|
||||
"""
|
||||
Unit tests for virtual node configuration.
|
||||
"""
|
||||
|
||||
from core.workflow.nodes.base.entities import VirtualNodeConfig
|
||||
|
||||
|
||||
class TestVirtualNodeConfig:
|
||||
"""Tests for VirtualNodeConfig entity."""
|
||||
|
||||
def test_create_basic_config(self):
|
||||
"""Test creating a basic virtual node config."""
|
||||
config = VirtualNodeConfig(
|
||||
id="ext_1",
|
||||
type="llm",
|
||||
data={
|
||||
"title": "Extract keywords",
|
||||
"model": {"provider": "openai", "name": "gpt-4o-mini"},
|
||||
},
|
||||
)
|
||||
|
||||
assert config.id == "ext_1"
|
||||
assert config.type == "llm"
|
||||
assert config.data["title"] == "Extract keywords"
|
||||
|
||||
def test_get_global_id(self):
|
||||
"""Test generating global ID from parent ID."""
|
||||
config = VirtualNodeConfig(
|
||||
id="ext_1",
|
||||
type="llm",
|
||||
data={},
|
||||
)
|
||||
|
||||
global_id = config.get_global_id("tool1")
|
||||
assert global_id == "tool1.ext_1"
|
||||
|
||||
def test_get_global_id_with_different_parents(self):
|
||||
"""Test global ID generation with different parent IDs."""
|
||||
config = VirtualNodeConfig(id="sub_node", type="code", data={})
|
||||
|
||||
assert config.get_global_id("parent1") == "parent1.sub_node"
|
||||
assert config.get_global_id("node_123") == "node_123.sub_node"
|
||||
|
||||
def test_empty_data(self):
|
||||
"""Test virtual node config with empty data."""
|
||||
config = VirtualNodeConfig(
|
||||
id="test",
|
||||
type="tool",
|
||||
)
|
||||
|
||||
assert config.id == "test"
|
||||
assert config.type == "tool"
|
||||
assert config.data == {}
|
||||
|
||||
def test_complex_data(self):
|
||||
"""Test virtual node config with complex data."""
|
||||
config = VirtualNodeConfig(
|
||||
id="llm_1",
|
||||
type="llm",
|
||||
data={
|
||||
"title": "Generate summary",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 500},
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "user", "text": "{{#llm1.context#}}"},
|
||||
{"role": "user", "text": "Please summarize the conversation"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert config.data["model"]["provider"] == "openai"
|
||||
assert len(config.data["prompt_template"]) == 2
|
||||
|
||||
@@ -58,8 +58,6 @@ def test_json_object_valid_schema():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -70,7 +68,7 @@ def test_json_object_valid_schema():
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
@@ -89,8 +87,6 @@ def test_json_object_invalid_json_string():
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -101,12 +97,12 @@ def test_json_object_invalid_json_string():
|
||||
)
|
||||
]
|
||||
|
||||
# Providing a string instead of an object should raise a type error
|
||||
# Missing closing brace makes this invalid JSON
|
||||
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
node._run()
|
||||
|
||||
|
||||
@@ -122,8 +118,6 @@ def test_json_object_does_not_match_schema():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -135,7 +129,7 @@ def test_json_object_does_not_match_schema():
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@@ -155,8 +149,6 @@ def test_json_object_missing_required_schema_field():
|
||||
}
|
||||
)
|
||||
|
||||
schema = json.loads(schema)
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
@@ -168,7 +160,7 @@ def test_json_object_missing_required_schema_field():
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
|
||||
@@ -2,11 +2,11 @@ import Marketplace from '@/app/components/plugins/marketplace'
|
||||
import PluginPage from '@/app/components/plugins/plugin-page'
|
||||
import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel'
|
||||
|
||||
const PluginList = () => {
|
||||
const PluginList = async () => {
|
||||
return (
|
||||
<PluginPage
|
||||
plugins={<PluginsPanel />}
|
||||
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" />}
|
||||
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" showSearchParams={false} />}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
|
||||
import { useInvalidateAppList } from '@/service/use-apps'
|
||||
import { fetchWorkflowDraft } from '@/service/workflow'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { getRedirection } from '@/utils/app-redirection'
|
||||
@@ -67,7 +66,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
||||
const { onPlanInfoChanged } = useProviderContext()
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
const setAppDetail = useAppStore(state => state.setAppDetail)
|
||||
const invalidateAppList = useInvalidateAppList()
|
||||
const [open, setOpen] = useState(openState)
|
||||
const [showEditModal, setShowEditModal] = useState(false)
|
||||
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
|
||||
@@ -193,7 +191,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
||||
try {
|
||||
await deleteApp(appDetail.id)
|
||||
notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) })
|
||||
invalidateAppList()
|
||||
onPlanInfoChanged()
|
||||
setAppDetail()
|
||||
replace('/apps')
|
||||
@@ -205,7 +202,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
|
||||
})
|
||||
}
|
||||
setShowConfirmDelete(false)
|
||||
}, [appDetail, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t])
|
||||
}, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t])
|
||||
|
||||
const { isCurrentWorkspaceEditor } = useAppContext()
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
||||
if (!isJsonObject || !tempPayload.json_schema)
|
||||
return ''
|
||||
try {
|
||||
return tempPayload.json_schema
|
||||
return JSON.stringify(JSON.parse(tempPayload.json_schema), null, 2)
|
||||
}
|
||||
catch {
|
||||
return ''
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
/**
|
||||
* Tests for race condition prevention logic in chat message loading.
|
||||
* These tests verify the core algorithms used in fetchData and loadMoreMessages
|
||||
* to prevent race conditions, infinite loops, and stale state issues.
|
||||
* See GitHub issue #30259 for context.
|
||||
*/
|
||||
|
||||
// Test the race condition prevention logic in isolation
|
||||
describe('Chat Message Loading Race Condition Prevention', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('Request Deduplication', () => {
|
||||
it('should deduplicate messages with same IDs when merging responses', async () => {
|
||||
// Simulate the deduplication logic used in setAllChatItems
|
||||
const existingItems = [
|
||||
{ id: 'msg-1', isAnswer: false },
|
||||
{ id: 'msg-2', isAnswer: true },
|
||||
]
|
||||
const newItems = [
|
||||
{ id: 'msg-2', isAnswer: true }, // duplicate
|
||||
{ id: 'msg-3', isAnswer: false }, // new
|
||||
]
|
||||
|
||||
const existingIds = new Set(existingItems.map(item => item.id))
|
||||
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
|
||||
const mergedItems = [...uniqueNewItems, ...existingItems]
|
||||
|
||||
expect(uniqueNewItems).toHaveLength(1)
|
||||
expect(uniqueNewItems[0].id).toBe('msg-3')
|
||||
expect(mergedItems).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Retry Counter Logic', () => {
|
||||
const MAX_RETRY_COUNT = 3
|
||||
|
||||
it('should increment retry counter when no unique items found', () => {
|
||||
const state = { retryCount: 0 }
|
||||
const prevItemsLength = 5
|
||||
|
||||
// Simulate the retry logic from loadMoreMessages
|
||||
const uniqueNewItemsLength = 0
|
||||
|
||||
if (uniqueNewItemsLength === 0) {
|
||||
if (state.retryCount < MAX_RETRY_COUNT && prevItemsLength > 1) {
|
||||
state.retryCount++
|
||||
}
|
||||
else {
|
||||
state.retryCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
expect(state.retryCount).toBe(1)
|
||||
})
|
||||
|
||||
it('should reset retry counter after MAX_RETRY_COUNT attempts', () => {
|
||||
const state = { retryCount: MAX_RETRY_COUNT }
|
||||
const prevItemsLength = 5
|
||||
const uniqueNewItemsLength = 0
|
||||
|
||||
if (uniqueNewItemsLength === 0) {
|
||||
if (state.retryCount < MAX_RETRY_COUNT && prevItemsLength > 1) {
|
||||
state.retryCount++
|
||||
}
|
||||
else {
|
||||
state.retryCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
expect(state.retryCount).toBe(0)
|
||||
})
|
||||
|
||||
it('should reset retry counter when unique items are found', () => {
|
||||
const state = { retryCount: 2 }
|
||||
|
||||
// Simulate finding unique items (length > 0)
|
||||
const processRetry = (uniqueCount: number) => {
|
||||
if (uniqueCount === 0) {
|
||||
state.retryCount++
|
||||
}
|
||||
else {
|
||||
state.retryCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
processRetry(3) // Found 3 unique items
|
||||
|
||||
expect(state.retryCount).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Throttling Logic', () => {
|
||||
const SCROLL_DEBOUNCE_MS = 200
|
||||
|
||||
it('should throttle requests within debounce window', () => {
|
||||
const state = { lastLoadTime: 0 }
|
||||
const results: boolean[] = []
|
||||
|
||||
const tryRequest = (now: number): boolean => {
|
||||
if (now - state.lastLoadTime >= SCROLL_DEBOUNCE_MS) {
|
||||
state.lastLoadTime = now
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// First request - should pass
|
||||
results.push(tryRequest(1000))
|
||||
// Second request within debounce - should be blocked
|
||||
results.push(tryRequest(1100))
|
||||
// Third request after debounce - should pass
|
||||
results.push(tryRequest(1300))
|
||||
|
||||
expect(results).toEqual([true, false, true])
|
||||
})
|
||||
})
|
||||
|
||||
describe('AbortController Cancellation', () => {
|
||||
it('should abort previous request when new request starts', () => {
|
||||
const state: { controller: AbortController | null } = { controller: null }
|
||||
const abortedSignals: boolean[] = []
|
||||
|
||||
// First request
|
||||
const controller1 = new AbortController()
|
||||
state.controller = controller1
|
||||
|
||||
// Second request - should abort first
|
||||
if (state.controller) {
|
||||
state.controller.abort()
|
||||
abortedSignals.push(state.controller.signal.aborted)
|
||||
}
|
||||
const controller2 = new AbortController()
|
||||
state.controller = controller2
|
||||
|
||||
expect(abortedSignals).toEqual([true])
|
||||
expect(controller1.signal.aborted).toBe(true)
|
||||
expect(controller2.signal.aborted).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Stale Response Detection', () => {
|
||||
it('should ignore responses from outdated requests', () => {
|
||||
const state = { requestId: 0 }
|
||||
const processedResponses: number[] = []
|
||||
|
||||
// Simulate concurrent requests - each gets its own captured ID
|
||||
const request1Id = ++state.requestId
|
||||
const request2Id = ++state.requestId
|
||||
|
||||
// Request 2 completes first (current requestId is 2)
|
||||
if (request2Id === state.requestId) {
|
||||
processedResponses.push(request2Id)
|
||||
}
|
||||
|
||||
// Request 1 completes later (stale - requestId is still 2)
|
||||
if (request1Id === state.requestId) {
|
||||
processedResponses.push(request1Id)
|
||||
}
|
||||
|
||||
expect(processedResponses).toEqual([2])
|
||||
expect(processedResponses).not.toContain(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Pagination Anchor Management', () => {
|
||||
it('should track oldest answer ID for pagination', () => {
|
||||
let oldestAnswerIdRef: string | undefined
|
||||
|
||||
const chatItems = [
|
||||
{ id: 'question-1', isAnswer: false },
|
||||
{ id: 'answer-1', isAnswer: true },
|
||||
{ id: 'question-2', isAnswer: false },
|
||||
{ id: 'answer-2', isAnswer: true },
|
||||
]
|
||||
|
||||
// Update pagination anchor with oldest answer ID
|
||||
const answerItems = chatItems.filter(item => item.isAnswer)
|
||||
const oldestAnswer = answerItems[answerItems.length - 1]
|
||||
if (oldestAnswer?.id) {
|
||||
oldestAnswerIdRef = oldestAnswer.id
|
||||
}
|
||||
|
||||
expect(oldestAnswerIdRef).toBe('answer-2')
|
||||
})
|
||||
|
||||
it('should use pagination anchor in subsequent requests', () => {
|
||||
const oldestAnswerIdRef = 'answer-123'
|
||||
const params: { conversation_id: string, limit: number, first_id?: string } = {
|
||||
conversation_id: 'conv-1',
|
||||
limit: 10,
|
||||
}
|
||||
|
||||
if (oldestAnswerIdRef) {
|
||||
params.first_id = oldestAnswerIdRef
|
||||
}
|
||||
|
||||
expect(params.first_id).toBe('answer-123')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Functional State Update Pattern', () => {
|
||||
it('should use functional update to avoid stale closures', () => {
|
||||
// Simulate the functional update pattern used in setAllChatItems
|
||||
let state = [{ id: '1' }, { id: '2' }]
|
||||
|
||||
const newItems = [{ id: '3' }, { id: '2' }] // id '2' is duplicate
|
||||
|
||||
// Functional update pattern
|
||||
const updater = (prevItems: { id: string }[]) => {
|
||||
const existingIds = new Set(prevItems.map(item => item.id))
|
||||
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
|
||||
return [...uniqueNewItems, ...prevItems]
|
||||
}
|
||||
|
||||
state = updater(state)
|
||||
|
||||
expect(state).toHaveLength(3)
|
||||
expect(state.map(i => i.id)).toEqual(['3', '1', '2'])
|
||||
})
|
||||
})
|
||||
@@ -209,6 +209,7 @@ type IDetailPanel = {
|
||||
|
||||
function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
const MIN_ITEMS_FOR_SCROLL_LOADING = 8
|
||||
const SCROLL_THRESHOLD_PX = 50
|
||||
const SCROLL_DEBOUNCE_MS = 200
|
||||
const { userProfile: { timezone } } = useAppContext()
|
||||
const { formatTime } = useTimestamp()
|
||||
@@ -227,103 +228,69 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
const [hasMore, setHasMore] = useState(true)
|
||||
const [varValues, setVarValues] = useState<Record<string, string>>({})
|
||||
const isLoadingRef = useRef(false)
|
||||
const abortControllerRef = useRef<AbortController | null>(null)
|
||||
const requestIdRef = useRef(0)
|
||||
const lastLoadTimeRef = useRef(0)
|
||||
const retryCountRef = useRef(0)
|
||||
const oldestAnswerIdRef = useRef<string | undefined>(undefined)
|
||||
const MAX_RETRY_COUNT = 3
|
||||
|
||||
const [allChatItems, setAllChatItems] = useState<IChatItem[]>([])
|
||||
const [chatItemTree, setChatItemTree] = useState<ChatItemInTree[]>([])
|
||||
const [threadChatItems, setThreadChatItems] = useState<IChatItem[]>([])
|
||||
|
||||
const fetchData = useCallback(async () => {
|
||||
if (isLoadingRef.current || !hasMore)
|
||||
if (isLoadingRef.current)
|
||||
return
|
||||
|
||||
// Cancel any in-flight request
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort()
|
||||
}
|
||||
|
||||
const controller = new AbortController()
|
||||
abortControllerRef.current = controller
|
||||
const currentRequestId = ++requestIdRef.current
|
||||
|
||||
try {
|
||||
isLoadingRef.current = true
|
||||
|
||||
if (!hasMore)
|
||||
return
|
||||
|
||||
const params: ChatMessagesRequest = {
|
||||
conversation_id: detail.id,
|
||||
limit: 10,
|
||||
}
|
||||
// Use ref for pagination anchor to avoid stale closure issues
|
||||
if (oldestAnswerIdRef.current)
|
||||
params.first_id = oldestAnswerIdRef.current
|
||||
|
||||
// Use the oldest answer item ID for pagination
|
||||
const answerItems = allChatItems.filter(item => item.isAnswer)
|
||||
const oldestAnswerItem = answerItems[answerItems.length - 1]
|
||||
if (oldestAnswerItem?.id)
|
||||
params.first_id = oldestAnswerItem.id
|
||||
const messageRes = await fetchChatMessages({
|
||||
url: `/apps/${appDetail?.id}/chat-messages`,
|
||||
params,
|
||||
})
|
||||
|
||||
// Ignore stale responses
|
||||
if (currentRequestId !== requestIdRef.current || controller.signal.aborted)
|
||||
return
|
||||
if (messageRes.data.length > 0) {
|
||||
const varValues = messageRes.data.at(-1)!.inputs
|
||||
setVarValues(varValues)
|
||||
}
|
||||
setHasMore(messageRes.has_more)
|
||||
|
||||
const newItems = getFormattedChatList(messageRes.data, detail.id, timezone!, t('dateTimeFormat', { ns: 'appLog' }) as string)
|
||||
const newAllChatItems = [
|
||||
...getFormattedChatList(messageRes.data, detail.id, timezone!, t('dateTimeFormat', { ns: 'appLog' }) as string),
|
||||
...allChatItems,
|
||||
]
|
||||
setAllChatItems(newAllChatItems)
|
||||
|
||||
// Use functional update to avoid stale state issues
|
||||
setAllChatItems((prevItems: IChatItem[]) => {
|
||||
const existingIds = new Set(prevItems.map(item => item.id))
|
||||
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
|
||||
return [...uniqueNewItems, ...prevItems]
|
||||
})
|
||||
let tree = buildChatItemTree(newAllChatItems)
|
||||
if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) {
|
||||
tree = [{
|
||||
id: 'introduction',
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
content: detail?.model_config?.configs?.introduction ?? 'hello',
|
||||
feedbackDisabled: true,
|
||||
children: tree,
|
||||
}]
|
||||
}
|
||||
setChatItemTree(tree)
|
||||
|
||||
const lastMessageId = newAllChatItems.length > 0 ? newAllChatItems[newAllChatItems.length - 1].id : undefined
|
||||
setThreadChatItems(getThreadMessages(tree, lastMessageId))
|
||||
}
|
||||
catch (err: unknown) {
|
||||
if (err instanceof Error && err.name === 'AbortError')
|
||||
return
|
||||
catch (err) {
|
||||
console.error('fetchData execution failed:', err)
|
||||
}
|
||||
finally {
|
||||
isLoadingRef.current = false
|
||||
if (abortControllerRef.current === controller)
|
||||
abortControllerRef.current = null
|
||||
}
|
||||
}, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
|
||||
|
||||
// Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems
|
||||
useEffect(() => {
|
||||
if (allChatItems.length === 0)
|
||||
return
|
||||
|
||||
let tree = buildChatItemTree(allChatItems)
|
||||
if (!hasMore && detail?.model_config?.configs?.introduction) {
|
||||
tree = [{
|
||||
id: 'introduction',
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
content: detail?.model_config?.configs?.introduction ?? 'hello',
|
||||
feedbackDisabled: true,
|
||||
children: tree,
|
||||
}]
|
||||
}
|
||||
setChatItemTree(tree)
|
||||
|
||||
const lastMessageId = allChatItems.length > 0 ? allChatItems[allChatItems.length - 1].id : undefined
|
||||
setThreadChatItems(getThreadMessages(tree, lastMessageId))
|
||||
|
||||
// Update pagination anchor ref with the oldest answer ID
|
||||
const answerItems = allChatItems.filter(item => item.isAnswer)
|
||||
const oldestAnswer = answerItems[answerItems.length - 1]
|
||||
if (oldestAnswer?.id)
|
||||
oldestAnswerIdRef.current = oldestAnswer.id
|
||||
}, [allChatItems, hasMore, detail?.model_config?.configs?.introduction])
|
||||
}, [allChatItems, detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
|
||||
|
||||
const switchSibling = useCallback((siblingMessageId: string) => {
|
||||
const newThreadChatItems = getThreadMessages(chatItemTree, siblingMessageId)
|
||||
@@ -430,12 +397,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
if (isLoading || !hasMore || !appDetail?.id || !detail.id)
|
||||
return
|
||||
|
||||
// Throttle using ref to persist across re-renders
|
||||
const now = Date.now()
|
||||
if (now - lastLoadTimeRef.current < SCROLL_DEBOUNCE_MS)
|
||||
return
|
||||
lastLoadTimeRef.current = now
|
||||
|
||||
setIsLoading(true)
|
||||
|
||||
try {
|
||||
@@ -444,9 +405,15 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
limit: 10,
|
||||
}
|
||||
|
||||
// Use ref for pagination anchor to avoid stale closure issues
|
||||
if (oldestAnswerIdRef.current) {
|
||||
params.first_id = oldestAnswerIdRef.current
|
||||
// Use the earliest response item as the first_id
|
||||
const answerItems = allChatItems.filter(item => item.isAnswer)
|
||||
const oldestAnswerItem = answerItems[answerItems.length - 1]
|
||||
if (oldestAnswerItem?.id) {
|
||||
params.first_id = oldestAnswerItem.id
|
||||
}
|
||||
else if (allChatItems.length > 0 && allChatItems[0]?.id) {
|
||||
const firstId = allChatItems[0].id.replace('question-', '').replace('answer-', '')
|
||||
params.first_id = firstId
|
||||
}
|
||||
|
||||
const messageRes = await fetchChatMessages({
|
||||
@@ -456,7 +423,6 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
|
||||
if (!messageRes.data || messageRes.data.length === 0) {
|
||||
setHasMore(false)
|
||||
retryCountRef.current = 0
|
||||
return
|
||||
}
|
||||
|
||||
@@ -474,36 +440,91 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
t('dateTimeFormat', { ns: 'appLog' }) as string,
|
||||
)
|
||||
|
||||
// Use functional update to get latest state and avoid stale closures
|
||||
setAllChatItems((prevItems: IChatItem[]) => {
|
||||
const existingIds = new Set(prevItems.map(item => item.id))
|
||||
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
|
||||
// Check for duplicate messages
|
||||
const existingIds = new Set(allChatItems.map(item => item.id))
|
||||
const uniqueNewItems = newItems.filter(item => !existingIds.has(item.id))
|
||||
|
||||
// If no unique items and we haven't exceeded retry limit, signal retry needed
|
||||
if (uniqueNewItems.length === 0) {
|
||||
if (retryCountRef.current < MAX_RETRY_COUNT && prevItems.length > 1) {
|
||||
retryCountRef.current++
|
||||
return prevItems
|
||||
if (uniqueNewItems.length === 0) {
|
||||
if (allChatItems.length > 1) {
|
||||
const nextId = allChatItems[1].id.replace('question-', '').replace('answer-', '')
|
||||
|
||||
const retryParams = {
|
||||
...params,
|
||||
first_id: nextId,
|
||||
}
|
||||
else {
|
||||
retryCountRef.current = 0
|
||||
return prevItems
|
||||
|
||||
const retryRes = await fetchChatMessages({
|
||||
url: `/apps/${appDetail.id}/chat-messages`,
|
||||
params: retryParams,
|
||||
})
|
||||
|
||||
if (retryRes.data && retryRes.data.length > 0) {
|
||||
const retryItems = getFormattedChatList(
|
||||
retryRes.data,
|
||||
detail.id,
|
||||
timezone!,
|
||||
t('dateTimeFormat', { ns: 'appLog' }) as string,
|
||||
)
|
||||
|
||||
const retryUniqueItems = retryItems.filter(item => !existingIds.has(item.id))
|
||||
if (retryUniqueItems.length > 0) {
|
||||
const newAllChatItems = [
|
||||
...retryUniqueItems,
|
||||
...allChatItems,
|
||||
]
|
||||
|
||||
setAllChatItems(newAllChatItems)
|
||||
|
||||
let tree = buildChatItemTree(newAllChatItems)
|
||||
if (retryRes.has_more === false && detail?.model_config?.configs?.introduction) {
|
||||
tree = [{
|
||||
id: 'introduction',
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
content: detail?.model_config?.configs?.introduction ?? 'hello',
|
||||
feedbackDisabled: true,
|
||||
children: tree,
|
||||
}]
|
||||
}
|
||||
setChatItemTree(tree)
|
||||
setHasMore(retryRes.has_more)
|
||||
setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
retryCountRef.current = 0
|
||||
return [...uniqueNewItems, ...prevItems]
|
||||
})
|
||||
const newAllChatItems = [
|
||||
...uniqueNewItems,
|
||||
...allChatItems,
|
||||
]
|
||||
|
||||
setAllChatItems(newAllChatItems)
|
||||
|
||||
let tree = buildChatItemTree(newAllChatItems)
|
||||
if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) {
|
||||
tree = [{
|
||||
id: 'introduction',
|
||||
isAnswer: true,
|
||||
isOpeningStatement: true,
|
||||
content: detail?.model_config?.configs?.introduction ?? 'hello',
|
||||
feedbackDisabled: true,
|
||||
children: tree,
|
||||
}]
|
||||
}
|
||||
setChatItemTree(tree)
|
||||
|
||||
setThreadChatItems(getThreadMessages(tree, newAllChatItems.at(-1)?.id))
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
setHasMore(false)
|
||||
retryCountRef.current = 0
|
||||
}
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
|
||||
}, [allChatItems, detail.id, hasMore, isLoading, timezone, t, appDetail])
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = document.getElementById('scrollableDiv')
|
||||
@@ -535,11 +556,24 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
if (!scrollContainer)
|
||||
return
|
||||
|
||||
let lastLoadTime = 0
|
||||
const throttleDelay = 200
|
||||
|
||||
const handleScroll = () => {
|
||||
const currentScrollTop = scrollContainer!.scrollTop
|
||||
const isNearTop = currentScrollTop < 30
|
||||
const scrollHeight = scrollContainer!.scrollHeight
|
||||
const clientHeight = scrollContainer!.clientHeight
|
||||
|
||||
if (isNearTop && hasMore && !isLoading) {
|
||||
const distanceFromTop = currentScrollTop
|
||||
const distanceFromBottom = scrollHeight - currentScrollTop - clientHeight
|
||||
|
||||
const now = Date.now()
|
||||
|
||||
const isNearTop = distanceFromTop < 30
|
||||
// eslint-disable-next-line sonarjs/no-unused-vars
|
||||
const _distanceFromBottom = distanceFromBottom < 30
|
||||
if (isNearTop && hasMore && !isLoading && (now - lastLoadTime > throttleDelay)) {
|
||||
lastLoadTime = now
|
||||
loadMoreMessages()
|
||||
}
|
||||
}
|
||||
@@ -585,6 +619,36 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
return () => cancelAnimationFrame(raf)
|
||||
}, [])
|
||||
|
||||
// Add scroll listener to ensure loading is triggered
|
||||
useEffect(() => {
|
||||
if (threadChatItems.length >= MIN_ITEMS_FOR_SCROLL_LOADING && hasMore) {
|
||||
const scrollableDiv = document.getElementById('scrollableDiv')
|
||||
|
||||
if (scrollableDiv) {
|
||||
let loadingTimeout: NodeJS.Timeout | null = null
|
||||
|
||||
const handleScroll = () => {
|
||||
const { scrollTop } = scrollableDiv
|
||||
|
||||
// Trigger loading when scrolling near the top
|
||||
if (scrollTop < SCROLL_THRESHOLD_PX && !isLoadingRef.current) {
|
||||
if (loadingTimeout)
|
||||
clearTimeout(loadingTimeout)
|
||||
|
||||
loadingTimeout = setTimeout(fetchData, SCROLL_DEBOUNCE_MS) // 200ms debounce
|
||||
}
|
||||
}
|
||||
|
||||
scrollableDiv.addEventListener('scroll', handleScroll)
|
||||
return () => {
|
||||
scrollableDiv.removeEventListener('scroll', handleScroll)
|
||||
if (loadingTimeout)
|
||||
clearTimeout(loadingTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [threadChatItems.length, hasMore, fetchData])
|
||||
|
||||
return (
|
||||
<div ref={ref} className="flex h-full flex-col rounded-xl border-[0.5px] border-components-panel-border">
|
||||
{/* Panel Header */}
|
||||
|
||||
@@ -10,7 +10,6 @@ const mockReplace = vi.fn()
|
||||
const mockRouter = { replace: mockReplace }
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => mockRouter,
|
||||
useSearchParams: () => new URLSearchParams(''),
|
||||
}))
|
||||
|
||||
// Mock app context
|
||||
|
||||
@@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks'
|
||||
import dynamic from 'next/dynamic'
|
||||
import {
|
||||
useRouter,
|
||||
useSearchParams,
|
||||
} from 'next/navigation'
|
||||
import { parseAsString, useQueryState } from 'nuqs'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
@@ -37,16 +36,6 @@ import useAppsQueryState from './hooks/use-apps-query-state'
|
||||
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
|
||||
import NewAppCard from './new-app-card'
|
||||
|
||||
// Define valid tabs at module scope to avoid re-creation on each render and stale closures
|
||||
const validTabs = new Set<string | AppModeEnum>([
|
||||
'all',
|
||||
AppModeEnum.WORKFLOW,
|
||||
AppModeEnum.ADVANCED_CHAT,
|
||||
AppModeEnum.CHAT,
|
||||
AppModeEnum.AGENT_CHAT,
|
||||
AppModeEnum.COMPLETION,
|
||||
])
|
||||
|
||||
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
|
||||
ssr: false,
|
||||
})
|
||||
@@ -58,41 +47,12 @@ const List = () => {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
|
||||
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
|
||||
const [activeTab, setActiveTab] = useQueryState(
|
||||
'category',
|
||||
parseAsString.withDefault('all').withOptions({ history: 'push' }),
|
||||
)
|
||||
|
||||
// valid tabs for apps list; anything else should fallback to 'all'
|
||||
|
||||
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
|
||||
useEffect(() => {
|
||||
// avoid running on server
|
||||
if (typeof window === 'undefined')
|
||||
return
|
||||
const mode = searchParams.get('mode')
|
||||
if (!mode)
|
||||
return
|
||||
const url = new URL(window.location.href)
|
||||
url.searchParams.delete('mode')
|
||||
if (validTabs.has(mode)) {
|
||||
// migrate to category key
|
||||
url.searchParams.set('category', mode)
|
||||
}
|
||||
else {
|
||||
url.searchParams.set('category', 'all')
|
||||
}
|
||||
router.replace(url.pathname + url.search)
|
||||
}, [router, searchParams])
|
||||
|
||||
// 2) If category has an invalid value (e.g., 'discover'), reset to 'all'
|
||||
useEffect(() => {
|
||||
if (!validTabs.has(activeTab))
|
||||
setActiveTab('all')
|
||||
}, [activeTab, setActiveTab])
|
||||
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
|
||||
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
|
||||
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
|
||||
|
||||
@@ -21,6 +21,7 @@ import BasicContent from './basic-content'
|
||||
import More from './more'
|
||||
import Operation from './operation'
|
||||
import SuggestedQuestions from './suggested-questions'
|
||||
import ToolCalls from './tool-calls'
|
||||
import WorkflowProcessItem from './workflow-process'
|
||||
|
||||
type AnswerProps = {
|
||||
@@ -61,6 +62,7 @@ const Answer: FC<AnswerProps> = ({
|
||||
workflowProcess,
|
||||
allFiles,
|
||||
message_files,
|
||||
toolCalls,
|
||||
} = item
|
||||
const hasAgentThoughts = !!agent_thoughts?.length
|
||||
|
||||
@@ -154,6 +156,11 @@ const Answer: FC<AnswerProps> = ({
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
!!toolCalls?.length && (
|
||||
<ToolCalls toolCalls={toolCalls} />
|
||||
)
|
||||
}
|
||||
{
|
||||
responding && contentIsEmpty && !hasAgentThoughts && (
|
||||
<div className="flex h-5 w-6 items-center justify-center">
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import type { ToolCallItem } from '@/types/workflow'
|
||||
import ToolCallItemComponent from '@/app/components/workflow/run/llm-log/tool-call-item'
|
||||
|
||||
type ToolCallsProps = {
|
||||
toolCalls: ToolCallItem[]
|
||||
}
|
||||
const ToolCalls = ({
|
||||
toolCalls,
|
||||
}: ToolCallsProps) => {
|
||||
return (
|
||||
<div className="my-1 space-y-1">
|
||||
{toolCalls.map((toolCall: ToolCallItem, index: number) => (
|
||||
<ToolCallItemComponent
|
||||
key={index}
|
||||
payload={toolCall}
|
||||
className="bg-background-gradient-bg-fill-chat-bubble-bg-2 shadow-none"
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ToolCalls
|
||||
@@ -45,7 +45,7 @@ const WorkflowProcessItem = ({
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'-mx-1 rounded-xl px-2.5',
|
||||
'rounded-xl px-2.5',
|
||||
collapse ? 'border-l-[0.25px] border-components-panel-border py-[7px]' : 'border-[0.5px] border-components-panel-border-subtle px-1 pb-1 pt-[7px]',
|
||||
running && !collapse && 'bg-background-section-burn',
|
||||
succeeded && !collapse && 'bg-state-success-hover',
|
||||
|
||||
@@ -319,6 +319,9 @@ export const useChat = (
|
||||
return player
|
||||
}
|
||||
|
||||
let toolCallId = ''
|
||||
let thoughtId = ''
|
||||
|
||||
ssePost(
|
||||
url,
|
||||
{
|
||||
@@ -326,7 +329,19 @@ export const useChat = (
|
||||
},
|
||||
{
|
||||
isPublicAPI,
|
||||
onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId, taskId }: any) => {
|
||||
onData: (message: string, isFirstMessage: boolean, {
|
||||
conversationId: newConversationId,
|
||||
messageId,
|
||||
taskId,
|
||||
chunk_type,
|
||||
tool_icon,
|
||||
tool_icon_dark,
|
||||
tool_name,
|
||||
tool_arguments,
|
||||
tool_files,
|
||||
tool_error,
|
||||
tool_elapsed_time,
|
||||
}: any) => {
|
||||
if (!isAgentMode) {
|
||||
responseItem.content = responseItem.content + message
|
||||
}
|
||||
@@ -336,6 +351,57 @@ export const useChat = (
|
||||
lastThought.thought = lastThought.thought + message // need immer setAutoFreeze
|
||||
}
|
||||
|
||||
if (chunk_type === 'tool_call') {
|
||||
if (!responseItem.toolCalls)
|
||||
responseItem.toolCalls = []
|
||||
toolCallId = uuidV4()
|
||||
responseItem.toolCalls?.push({
|
||||
id: toolCallId,
|
||||
type: 'tool',
|
||||
toolName: tool_name,
|
||||
toolArguments: tool_arguments,
|
||||
toolIcon: tool_icon,
|
||||
toolIconDark: tool_icon_dark,
|
||||
})
|
||||
}
|
||||
|
||||
if (chunk_type === 'tool_result') {
|
||||
const currentToolCallIndex = responseItem.toolCalls?.findIndex(item => item.id === toolCallId) ?? -1
|
||||
|
||||
if (currentToolCallIndex > -1) {
|
||||
responseItem.toolCalls![currentToolCallIndex].toolError = tool_error
|
||||
responseItem.toolCalls![currentToolCallIndex].toolDuration = tool_elapsed_time
|
||||
responseItem.toolCalls![currentToolCallIndex].toolFiles = tool_files
|
||||
responseItem.toolCalls![currentToolCallIndex].toolOutput = message
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk_type === 'thought_start') {
|
||||
if (!responseItem.toolCalls)
|
||||
responseItem.toolCalls = []
|
||||
thoughtId = uuidV4()
|
||||
responseItem.toolCalls.push({
|
||||
id: thoughtId,
|
||||
type: 'thought',
|
||||
thoughtOutput: '',
|
||||
})
|
||||
}
|
||||
|
||||
if (chunk_type === 'thought') {
|
||||
const currentThoughtIndex = responseItem.toolCalls?.findIndex(item => item.id === thoughtId) ?? -1
|
||||
if (currentThoughtIndex > -1) {
|
||||
responseItem.toolCalls![currentThoughtIndex].thoughtOutput += message
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk_type === 'thought_end') {
|
||||
const currentThoughtIndex = responseItem.toolCalls?.findIndex(item => item.id === thoughtId) ?? -1
|
||||
if (currentThoughtIndex > -1) {
|
||||
responseItem.toolCalls![currentThoughtIndex].thoughtOutput += message
|
||||
responseItem.toolCalls![currentThoughtIndex].thoughtCompleted = true
|
||||
}
|
||||
}
|
||||
|
||||
if (messageId && !hasSetResponseId) {
|
||||
questionItem.id = `question-${messageId}`
|
||||
responseItem.id = messageId
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { TypeWithI18N } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { InputVarType } from '@/app/components/workflow/types'
|
||||
import type { Annotation, MessageRating } from '@/models/log'
|
||||
import type { FileResponse } from '@/types/workflow'
|
||||
import type { FileResponse, ToolCallItem } from '@/types/workflow'
|
||||
|
||||
export type MessageMore = {
|
||||
time: string
|
||||
@@ -104,6 +104,7 @@ export type IChatItem = {
|
||||
siblingIndex?: number
|
||||
prevSibling?: string
|
||||
nextSibling?: string
|
||||
toolCalls?: ToolCallItem[]
|
||||
}
|
||||
|
||||
export type Metadata = {
|
||||
|
||||
@@ -37,7 +37,7 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
|
||||
return
|
||||
}
|
||||
|
||||
if (inputValue == null)
|
||||
if (!inputValue)
|
||||
return
|
||||
|
||||
if (item.type === InputVarType.singleFile) {
|
||||
@@ -52,20 +52,6 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
|
||||
else
|
||||
processedInputs[item.variable] = getProcessedFiles(inputValue)
|
||||
}
|
||||
else if (item.type === InputVarType.jsonObject) {
|
||||
// Prefer sending an object if the user entered valid JSON; otherwise keep the raw string.
|
||||
try {
|
||||
const v = typeof inputValue === 'string' ? JSON.parse(inputValue) : inputValue
|
||||
if (v && typeof v === 'object' && !Array.isArray(v))
|
||||
processedInputs[item.variable] = v
|
||||
else
|
||||
processedInputs[item.variable] = inputValue
|
||||
}
|
||||
catch {
|
||||
// keep original string; backend will parse/validate
|
||||
processedInputs[item.variable] = inputValue
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return processedInputs
|
||||
|
||||
@@ -66,9 +66,7 @@ const Header: FC<IHeaderProps> = ({
|
||||
const listener = (event: MessageEvent) => handleMessageReceived(event)
|
||||
window.addEventListener('message', listener)
|
||||
|
||||
// Security: Use document.referrer to get parent origin
|
||||
const targetOrigin = document.referrer ? new URL(document.referrer).origin : '*'
|
||||
window.parent.postMessage({ type: 'dify-chatbot-iframe-ready' }, targetOrigin)
|
||||
window.parent.postMessage({ type: 'dify-chatbot-iframe-ready' }, '*')
|
||||
|
||||
return () => window.removeEventListener('message', listener)
|
||||
}, [isIframe, handleMessageReceived])
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
<svg width="12" height="14" viewBox="0 0 12 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M2 9.49479C0.782372 8.51826 0 7.01768 0 5.33333C0 2.38782 2.38782 0 5.33333 0C8.20841 0 10.5503 2.27504 10.6608 5.12305L11.888 6.96354C12.0843 7.25794 12.0161 7.65424 11.7331 7.86654L10.6667 8.66602V10C10.6667 10.7364 10.0697 11.3333 9.33333 11.3333H8V13.3333H6.66667V10.6667C6.66667 10.2985 6.96514 10 7.33333 10H9.33333V8.33333C9.33333 8.12349 9.43239 7.92603 9.60026 7.80013L10.4284 7.17838L9.44531 5.70312C9.3723 5.59361 9.33333 5.46495 9.33333 5.33333C9.33333 3.1242 7.54248 1.33333 5.33333 1.33333C3.1242 1.33333 1.33333 3.1242 1.33333 5.33333C1.33333 6.69202 2.0103 7.89261 3.04818 8.61654C3.2269 8.74119 3.33329 8.94552 3.33333 9.16341V13.3333H2V9.49479Z" fill="#354052"/>
|
||||
<path d="M6.04367 4.24012L5.6504 3.21778C5.59993 3.08657 5.47393 3 5.33333 3C5.19273 3 5.06673 3.08657 5.01627 3.21778L4.62303 4.24012C4.55531 4.41618 4.41618 4.55531 4.24012 4.62303L3.21778 5.01624C3.08657 5.0667 3 5.19276 3 5.33333C3 5.47393 3.08657 5.59993 3.21778 5.6504L4.24012 6.04367C4.41618 6.11133 4.55531 6.25047 4.62303 6.42653L5.01627 7.44887C5.06673 7.58007 5.19273 7.66667 5.33333 7.66667C5.47393 7.66667 5.59993 7.58007 5.6504 7.44887L6.04367 6.42653C6.11133 6.25047 6.25047 6.11133 6.42653 6.04367L7.44887 5.6504C7.58007 5.59993 7.66667 5.47393 7.66667 5.33333C7.66667 5.19276 7.58007 5.0667 7.44887 5.01624L6.42653 4.62303C6.25047 4.55531 6.11133 4.41618 6.04367 4.24012Z" fill="#354052"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "12",
|
||||
"height": "14",
|
||||
"viewBox": "0 0 12 14",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M2 9.49479C0.782372 8.51826 0 7.01768 0 5.33333C0 2.38782 2.38782 0 5.33333 0C8.20841 0 10.5503 2.27504 10.6608 5.12305L11.888 6.96354C12.0843 7.25794 12.0161 7.65424 11.7331 7.86654L10.6667 8.66602V10C10.6667 10.7364 10.0697 11.3333 9.33333 11.3333H8V13.3333H6.66667V10.6667C6.66667 10.2985 6.96514 10 7.33333 10H9.33333V8.33333C9.33333 8.12349 9.43239 7.92603 9.60026 7.80013L10.4284 7.17838L9.44531 5.70312C9.3723 5.59361 9.33333 5.46495 9.33333 5.33333C9.33333 3.1242 7.54248 1.33333 5.33333 1.33333C3.1242 1.33333 1.33333 3.1242 1.33333 5.33333C1.33333 6.69202 2.0103 7.89261 3.04818 8.61654C3.2269 8.74119 3.33329 8.94552 3.33333 9.16341V13.3333H2V9.49479Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M6.04367 4.24012L5.6504 3.21778C5.59993 3.08657 5.47393 3 5.33333 3C5.19273 3 5.06673 3.08657 5.01627 3.21778L4.62303 4.24012C4.55531 4.41618 4.41618 4.55531 4.24012 4.62303L3.21778 5.01624C3.08657 5.0667 3 5.19276 3 5.33333C3 5.47393 3.08657 5.59993 3.21778 5.6504L4.24012 6.04367C4.41618 6.11133 4.55531 6.25047 4.62303 6.42653L5.01627 7.44887C5.06673 7.58007 5.19273 7.66667 5.33333 7.66667C5.47393 7.66667 5.59993 7.58007 5.6504 7.44887L6.04367 6.42653C6.11133 6.25047 6.25047 6.11133 6.42653 6.04367L7.44887 5.6504C7.58007 5.59993 7.66667 5.47393 7.66667 5.33333C7.66667 5.19276 7.58007 5.0667 7.44887 5.01624L6.42653 4.62303C6.25047 4.55531 6.11133 4.41618 6.04367 4.24012Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Thinking"
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './Thinking.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'Thinking'
|
||||
|
||||
export default Icon
|
||||
@@ -24,6 +24,7 @@ export { default as ParameterExtractor } from './ParameterExtractor'
|
||||
export { default as QuestionClassifier } from './QuestionClassifier'
|
||||
export { default as Schedule } from './Schedule'
|
||||
export { default as TemplatingTransform } from './TemplatingTransform'
|
||||
export { default as Thinking } from './Thinking'
|
||||
export { default as TriggerAll } from './TriggerAll'
|
||||
export { default as VariableX } from './VariableX'
|
||||
export { default as WebhookLine } from './WebhookLine'
|
||||
|
||||
@@ -5,7 +5,6 @@ import type {
|
||||
} from 'lexical'
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -104,7 +103,6 @@ export type PromptEditorProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
|
||||
@@ -130,7 +128,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -142,7 +139,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
{
|
||||
replace: TextNode,
|
||||
with: (node: TextNode) => new CustomTextNode(node.__text),
|
||||
withKlass: CustomTextNode,
|
||||
},
|
||||
ContextBlockNode,
|
||||
HistoryBlockNode,
|
||||
@@ -216,22 +212,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
{(!agentBlock || agentBlock.show) && (
|
||||
<ComponentPickerBlock
|
||||
triggerString="@"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
agentBlock={agentBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
)}
|
||||
<ComponentPickerBlock
|
||||
triggerString="{"
|
||||
contextBlock={contextBlock}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { MenuRenderFn } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import type { TextNode } from 'lexical'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -21,11 +20,7 @@ import {
|
||||
} from '@floating-ui/react'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import {
|
||||
$getRoot,
|
||||
$insertNodes,
|
||||
KEY_ESCAPE_COMMAND,
|
||||
} from 'lexical'
|
||||
import { KEY_ESCAPE_COMMAND } from 'lexical'
|
||||
import {
|
||||
Fragment,
|
||||
memo,
|
||||
@@ -34,9 +29,7 @@ import {
|
||||
} from 'react'
|
||||
import ReactDOM from 'react-dom'
|
||||
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
|
||||
import AgentNodeList from '@/app/components/workflow/nodes/_base/components/agent-node-list'
|
||||
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useBasicTypeaheadTriggerMatch } from '../../hooks'
|
||||
import { $splitNodeContainingQuery } from '../../utils'
|
||||
@@ -45,7 +38,6 @@ import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../error-message-block'
|
||||
import { INSERT_LAST_RUN_BLOCK_COMMAND } from '../last-run-block'
|
||||
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '../variable-block'
|
||||
import { INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND } from '../workflow-variable-block'
|
||||
import { $createWorkflowVariableBlockNode } from '../workflow-variable-block/node'
|
||||
import { useOptions } from './hooks'
|
||||
|
||||
type ComponentPickerProps = {
|
||||
@@ -59,7 +51,6 @@ type ComponentPickerProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
const ComponentPicker = ({
|
||||
@@ -73,7 +64,6 @@ const ComponentPicker = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}: ComponentPickerProps) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -161,41 +151,12 @@ const ComponentPicker = ({
|
||||
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
|
||||
}, [editor])
|
||||
|
||||
const handleSelectAgent = useCallback((agent: { id: string, title: string }) => {
|
||||
editor.update(() => {
|
||||
const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!)
|
||||
if (needRemove)
|
||||
needRemove.remove()
|
||||
|
||||
const root = $getRoot()
|
||||
const firstChild = root.getFirstChild()
|
||||
if (firstChild) {
|
||||
const selection = firstChild.selectStart()
|
||||
if (selection) {
|
||||
const workflowVariableBlockNode = $createWorkflowVariableBlockNode([agent.id, 'text'], {}, undefined)
|
||||
$insertNodes([workflowVariableBlockNode])
|
||||
}
|
||||
}
|
||||
})
|
||||
agentBlock?.onSelect?.(agent)
|
||||
handleClose()
|
||||
}, [editor, checkForTriggerMatch, triggerString, agentBlock, handleClose])
|
||||
|
||||
const isAgentTrigger = triggerString === '@' && agentBlock?.show
|
||||
const agentNodes = agentBlock?.agentNodes || []
|
||||
|
||||
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
|
||||
anchorElementRef,
|
||||
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
|
||||
) => {
|
||||
if (isAgentTrigger) {
|
||||
if (!(anchorElementRef.current && agentNodes.length))
|
||||
return null
|
||||
}
|
||||
else {
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
}
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
|
||||
setTimeout(() => {
|
||||
if (anchorElementRef.current)
|
||||
@@ -206,6 +167,9 @@ const ComponentPicker = ({
|
||||
<>
|
||||
{
|
||||
ReactDOM.createPortal(
|
||||
// The `LexicalMenu` will try to calculate the position of the floating menu based on the first child.
|
||||
// Since we use floating ui, we need to wrap it with a div to prevent the position calculation being affected.
|
||||
// See https://github.com/facebook/lexical/blob/ac97dfa9e14a73ea2d6934ff566282d7f758e8bb/packages/lexical-react/src/shared/LexicalMenu.ts#L493
|
||||
<div className="h-0 w-0">
|
||||
<div
|
||||
className="w-[260px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
|
||||
@@ -215,73 +179,56 @@ const ComponentPicker = ({
|
||||
}}
|
||||
ref={refs.setFloating}
|
||||
>
|
||||
{isAgentTrigger
|
||||
? (
|
||||
<AgentNodeList
|
||||
nodes={agentNodes.map(node => ({
|
||||
...node,
|
||||
type: BlockEnum.Agent,
|
||||
}))}
|
||||
onSelect={handleSelectAgent}
|
||||
{
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
{
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
// Divider
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
{
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
</div>,
|
||||
anchorElementRef.current,
|
||||
@@ -289,7 +236,7 @@ const ComponentPicker = ({
|
||||
}
|
||||
</>
|
||||
)
|
||||
}, [isAgentTrigger, agentNodes, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, handleSelectAgent, handleClose, workflowVariableOptions, isSupportFileVar, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
|
||||
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin
|
||||
|
||||
@@ -21,7 +21,6 @@ import {
|
||||
VariableLabelInEditor,
|
||||
} from '@/app/components/workflow/nodes/_base/components/variable/variable-label'
|
||||
import { Type } from '@/app/components/workflow/nodes/llm/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { isExceptionVariable } from '@/app/components/workflow/utils'
|
||||
import { useSelectOrDelete } from '../../hooks'
|
||||
import {
|
||||
@@ -67,7 +66,6 @@ const WorkflowVariableBlockComponent = ({
|
||||
)()
|
||||
const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState<WorkflowNodesMap>(workflowNodesMap)
|
||||
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
|
||||
const isAgentVariable = node?.type === BlockEnum.Agent
|
||||
|
||||
const isException = isExceptionVariable(varName, node?.type)
|
||||
const variableValid = useMemo(() => {
|
||||
@@ -136,9 +134,6 @@ const WorkflowVariableBlockComponent = ({
|
||||
})
|
||||
}, [node, reactflow, store])
|
||||
|
||||
if (isAgentVariable)
|
||||
return <span className="hidden" ref={ref} />
|
||||
|
||||
const Item = (
|
||||
<VariableLabelInEditor
|
||||
nodeType={node?.type}
|
||||
|
||||
@@ -73,17 +73,6 @@ export type WorkflowVariableBlockType = {
|
||||
onManageInputField?: () => void
|
||||
}
|
||||
|
||||
export type AgentNode = {
|
||||
id: string
|
||||
title: string
|
||||
}
|
||||
|
||||
export type AgentBlockType = {
|
||||
show?: boolean
|
||||
agentNodes?: AgentNode[]
|
||||
onSelect?: (agent: AgentNode) => void
|
||||
}
|
||||
|
||||
export type MenuTextMatch = {
|
||||
leadOffset: number
|
||||
matchingString: string
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,29 +1,47 @@
|
||||
import type { FC } from 'react'
|
||||
import type { FullDocumentDetail } from '@/models/datasets'
|
||||
import type { RETRIEVE_METHOD } from '@/types/app'
|
||||
import type {
|
||||
DataSourceInfo,
|
||||
FullDocumentDetail,
|
||||
IndexingStatusResponse,
|
||||
LegacyDataSourceInfo,
|
||||
ProcessRuleResponse,
|
||||
} from '@/models/datasets'
|
||||
import {
|
||||
RiArrowRightLine,
|
||||
RiCheckboxCircleFill,
|
||||
RiErrorWarningFill,
|
||||
RiLoader2Fill,
|
||||
RiTerminalBoxLine,
|
||||
} from '@remixicon/react'
|
||||
import Image from 'next/image'
|
||||
import Link from 'next/link'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useMemo } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import PriorityLabel from '@/app/components/billing/priority-label'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
|
||||
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
|
||||
import { DataSourceType, ProcessMode } from '@/models/datasets'
|
||||
import { fetchIndexingStatusBatch as doFetchIndexingStatus } from '@/service/datasets'
|
||||
import { useProcessRule } from '@/service/knowledge/use-dataset'
|
||||
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
|
||||
import IndexingProgressItem from './indexing-progress-item'
|
||||
import RuleDetail from './rule-detail'
|
||||
import UpgradeBanner from './upgrade-banner'
|
||||
import { useIndexingStatusPolling } from './use-indexing-status-polling'
|
||||
import { createDocumentLookup } from './utils'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { sleep } from '@/utils'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import DocumentFileIcon from '../../common/document-file-icon'
|
||||
import { indexMethodIcon, retrievalIcon } from '../icons'
|
||||
import { IndexingType } from '../step-two'
|
||||
|
||||
type EmbeddingProcessProps = {
|
||||
type Props = {
|
||||
datasetId: string
|
||||
batchId: string
|
||||
documents?: FullDocumentDetail[]
|
||||
@@ -31,121 +49,333 @@ type EmbeddingProcessProps = {
|
||||
retrievalMethod?: RETRIEVE_METHOD
|
||||
}
|
||||
|
||||
// Status header component
|
||||
const StatusHeader: FC<{ isEmbedding: boolean, isCompleted: boolean }> = ({
|
||||
isEmbedding,
|
||||
isCompleted,
|
||||
}) => {
|
||||
const RuleDetail: FC<{
|
||||
sourceData?: ProcessRuleResponse
|
||||
indexingType?: string
|
||||
retrievalMethod?: RETRIEVE_METHOD
|
||||
}> = ({ sourceData, indexingType, retrievalMethod }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const segmentationRuleMap = {
|
||||
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
|
||||
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
|
||||
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
|
||||
}
|
||||
|
||||
const getRuleName = (key: string) => {
|
||||
if (key === 'remove_extra_spaces')
|
||||
return t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' })
|
||||
|
||||
if (key === 'remove_urls_emails')
|
||||
return t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' })
|
||||
|
||||
if (key === 'remove_stopwords')
|
||||
return t('stepTwo.removeStopwords', { ns: 'datasetCreation' })
|
||||
}
|
||||
|
||||
const isNumber = (value: unknown) => {
|
||||
return typeof value === 'number'
|
||||
}
|
||||
|
||||
const getValue = useCallback((field: string) => {
|
||||
let value: string | number | undefined = '-'
|
||||
const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens)
|
||||
? sourceData.rules.segmentation.max_tokens
|
||||
: value
|
||||
const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens)
|
||||
? sourceData.rules.subchunk_segmentation.max_tokens
|
||||
: value
|
||||
switch (field) {
|
||||
case 'mode':
|
||||
value = !sourceData?.mode
|
||||
? value
|
||||
: sourceData.mode === ProcessMode.general
|
||||
? (t('embedding.custom', { ns: 'datasetDocuments' }) as string)
|
||||
: `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${sourceData?.rules?.parent_mode === 'paragraph'
|
||||
? t('parentMode.paragraph', { ns: 'dataset' })
|
||||
: t('parentMode.fullDoc', { ns: 'dataset' })}`
|
||||
break
|
||||
case 'segmentLength':
|
||||
value = !sourceData?.mode
|
||||
? value
|
||||
: sourceData.mode === ProcessMode.general
|
||||
? maxTokens
|
||||
: `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
|
||||
break
|
||||
default:
|
||||
value = !sourceData?.mode
|
||||
? value
|
||||
: sourceData?.rules?.pre_processing_rules?.filter(rule =>
|
||||
rule.enabled).map(rule => getRuleName(rule.id)).join(',')
|
||||
break
|
||||
}
|
||||
return value
|
||||
}, [sourceData])
|
||||
|
||||
return (
|
||||
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
|
||||
{isEmbedding && (
|
||||
<>
|
||||
<RiLoader2Fill className="size-4 animate-spin" />
|
||||
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
|
||||
</>
|
||||
)}
|
||||
{isCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
|
||||
<div className="flex flex-col gap-1">
|
||||
{Object.keys(segmentationRuleMap).map((field) => {
|
||||
return (
|
||||
<FieldInfo
|
||||
key={field}
|
||||
label={segmentationRuleMap[field as keyof typeof segmentationRuleMap]}
|
||||
displayedValue={String(getValue(field))}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
<FieldInfo
|
||||
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
displayedValue={t(`stepTwo.${indexingType === IndexingType.ECONOMICAL ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' }) as string}
|
||||
valueIcon={(
|
||||
<Image
|
||||
className="size-4"
|
||||
src={
|
||||
indexingType === IndexingType.ECONOMICAL
|
||||
? indexMethodIcon.economical
|
||||
: indexMethodIcon.high_quality
|
||||
}
|
||||
alt=""
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<FieldInfo
|
||||
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
// displayedValue={t(`datasetSettings.form.retrievalSetting.${retrievalMethod}`) as string}
|
||||
displayedValue={t(`retrieval.${indexingType === IndexingType.ECONOMICAL ? 'keyword_search' : retrievalMethod ?? 'semantic_search'}.title`, { ns: 'dataset' })}
|
||||
valueIcon={(
|
||||
<Image
|
||||
className="size-4"
|
||||
src={
|
||||
retrievalMethod === RETRIEVE_METHOD.fullText
|
||||
? retrievalIcon.fullText
|
||||
: retrievalMethod === RETRIEVE_METHOD.hybrid
|
||||
? retrievalIcon.hybrid
|
||||
: retrievalIcon.vector
|
||||
}
|
||||
alt=""
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Action buttons component
|
||||
const ActionButtons: FC<{
|
||||
apiReferenceUrl: string
|
||||
onNavToDocuments: () => void
|
||||
}> = ({ apiReferenceUrl, onNavToDocuments }) => {
|
||||
const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], indexingType, retrievalMethod }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="mt-6 flex items-center gap-x-2 py-2">
|
||||
<Link href={apiReferenceUrl} target="_blank" rel="noopener noreferrer">
|
||||
<Button className="w-fit gap-x-0.5 px-3">
|
||||
<RiTerminalBoxLine className="size-4" />
|
||||
<span className="px-0.5">Access the API</span>
|
||||
</Button>
|
||||
</Link>
|
||||
<Button
|
||||
className="w-fit gap-x-0.5 px-3"
|
||||
variant="primary"
|
||||
onClick={onNavToDocuments}
|
||||
>
|
||||
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
|
||||
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const EmbeddingProcess: FC<EmbeddingProcessProps> = ({
|
||||
datasetId,
|
||||
batchId,
|
||||
documents = [],
|
||||
indexingType,
|
||||
retrievalMethod,
|
||||
}) => {
|
||||
const { enableBilling, plan } = useProviderContext()
|
||||
|
||||
const getFirstDocument = documents[0]
|
||||
|
||||
const [indexingStatusBatchDetail, setIndexingStatusDetail] = useState<IndexingStatusResponse[]>([])
|
||||
const fetchIndexingStatus = async () => {
|
||||
const status = await doFetchIndexingStatus({ datasetId, batchId })
|
||||
setIndexingStatusDetail(status.data)
|
||||
return status.data
|
||||
}
|
||||
|
||||
const [isStopQuery, setIsStopQuery] = useState(false)
|
||||
const isStopQueryRef = useRef(isStopQuery)
|
||||
useEffect(() => {
|
||||
isStopQueryRef.current = isStopQuery
|
||||
}, [isStopQuery])
|
||||
const stopQueryStatus = () => {
|
||||
setIsStopQuery(true)
|
||||
}
|
||||
|
||||
const startQueryStatus = async () => {
|
||||
if (isStopQueryRef.current)
|
||||
return
|
||||
|
||||
try {
|
||||
const indexingStatusBatchDetail = await fetchIndexingStatus()
|
||||
const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail.indexing_status))
|
||||
if (isCompleted) {
|
||||
stopQueryStatus()
|
||||
return
|
||||
}
|
||||
await sleep(2500)
|
||||
await startQueryStatus()
|
||||
}
|
||||
catch {
|
||||
await sleep(2500)
|
||||
await startQueryStatus()
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
setIsStopQuery(false)
|
||||
startQueryStatus()
|
||||
return () => {
|
||||
stopQueryStatus()
|
||||
}
|
||||
}, [])
|
||||
|
||||
// get rule
|
||||
const { data: ruleDetail } = useProcessRule(getFirstDocument?.id)
|
||||
|
||||
const router = useRouter()
|
||||
const invalidDocumentList = useInvalidDocumentList()
|
||||
const apiReferenceUrl = useDatasetApiAccessUrl()
|
||||
|
||||
// Polling hook for indexing status
|
||||
const { statusList, isEmbedding, isEmbeddingCompleted } = useIndexingStatusPolling({
|
||||
datasetId,
|
||||
batchId,
|
||||
})
|
||||
|
||||
// Get process rule for the first document
|
||||
const firstDocumentId = documents[0]?.id
|
||||
const { data: ruleDetail } = useProcessRule(firstDocumentId)
|
||||
|
||||
// Document lookup utilities - memoized for performance
|
||||
const documentLookup = useMemo(
|
||||
() => createDocumentLookup(documents),
|
||||
[documents],
|
||||
)
|
||||
|
||||
const handleNavToDocuments = () => {
|
||||
const navToDocumentList = () => {
|
||||
invalidDocumentList()
|
||||
router.push(`/datasets/${datasetId}/documents`)
|
||||
}
|
||||
const apiReferenceUrl = useDatasetApiAccessUrl()
|
||||
|
||||
const showUpgradeBanner = enableBilling && plan.type !== Plan.team
|
||||
const isEmbedding = useMemo(() => {
|
||||
return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || ''))
|
||||
}, [indexingStatusBatchDetail])
|
||||
const isEmbeddingCompleted = useMemo(() => {
|
||||
return indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status || ''))
|
||||
}, [indexingStatusBatchDetail])
|
||||
|
||||
const getSourceName = (id: string) => {
|
||||
const doc = documents.find(document => document.id === id)
|
||||
return doc?.name
|
||||
}
|
||||
const getFileType = (name?: string) => name?.split('.').pop() || 'txt'
|
||||
const getSourcePercent = (detail: IndexingStatusResponse) => {
|
||||
const completedCount = detail.completed_segments || 0
|
||||
const totalCount = detail.total_segments || 0
|
||||
if (totalCount === 0)
|
||||
return 0
|
||||
const percent = Math.round(completedCount * 100 / totalCount)
|
||||
return percent > 100 ? 100 : percent
|
||||
}
|
||||
const getSourceType = (id: string) => {
|
||||
const doc = documents.find(document => document.id === id)
|
||||
return doc?.data_source_type as DataSourceType
|
||||
}
|
||||
|
||||
const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
|
||||
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
|
||||
}
|
||||
|
||||
const getIcon = (id: string) => {
|
||||
const doc = documents.find(document => document.id === id)
|
||||
const info = doc?.data_source_info
|
||||
if (info && isLegacyDataSourceInfo(info))
|
||||
return info.notion_page_icon
|
||||
return undefined
|
||||
}
|
||||
const isSourceEmbedding = (detail: IndexingStatusResponse) =>
|
||||
['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '')
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-y-3">
|
||||
<StatusHeader isEmbedding={isEmbedding} isCompleted={isEmbeddingCompleted} />
|
||||
|
||||
{showUpgradeBanner && <UpgradeBanner />}
|
||||
|
||||
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
|
||||
{isEmbedding && (
|
||||
<>
|
||||
<RiLoader2Fill className="size-4 animate-spin" />
|
||||
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
|
||||
</>
|
||||
)}
|
||||
{isEmbeddingCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
|
||||
</div>
|
||||
{
|
||||
enableBilling && plan.type !== Plan.team && (
|
||||
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
|
||||
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
|
||||
<ZapFast className="h-4 w-4 text-[#FB6514]" />
|
||||
</div>
|
||||
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
|
||||
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
|
||||
</div>
|
||||
<UpgradeBtn loc="knowledge-speed-up" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className="flex flex-col gap-0.5 pb-2">
|
||||
{statusList.map(detail => (
|
||||
<IndexingProgressItem
|
||||
key={detail.id}
|
||||
detail={detail}
|
||||
name={documentLookup.getName(detail.id)}
|
||||
sourceType={documentLookup.getSourceType(detail.id)}
|
||||
notionIcon={documentLookup.getNotionIcon(detail.id)}
|
||||
enableBilling={enableBilling}
|
||||
/>
|
||||
{indexingStatusBatchDetail.map(indexingStatusDetail => (
|
||||
<div
|
||||
key={indexingStatusDetail.id}
|
||||
className={cn(
|
||||
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
|
||||
indexingStatusDetail.indexing_status === 'error' && 'bg-state-destructive-hover-alt',
|
||||
)}
|
||||
>
|
||||
{isSourceEmbedding(indexingStatusDetail) && (
|
||||
<div
|
||||
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
|
||||
style={{ width: `${getSourcePercent(indexingStatusDetail)}%` }}
|
||||
/>
|
||||
)}
|
||||
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
|
||||
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && (
|
||||
<DocumentFileIcon
|
||||
size="sm"
|
||||
className="shrink-0"
|
||||
name={getSourceName(indexingStatusDetail.id)}
|
||||
extension={getFileType(getSourceName(indexingStatusDetail.id))}
|
||||
/>
|
||||
)}
|
||||
{getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && (
|
||||
<NotionIcon
|
||||
className="shrink-0"
|
||||
type="page"
|
||||
src={getIcon(indexingStatusDetail.id)}
|
||||
/>
|
||||
)}
|
||||
<div className="flex w-0 grow items-center gap-1" title={getSourceName(indexingStatusDetail.id)}>
|
||||
<div className="system-xs-medium truncate text-text-secondary">
|
||||
{getSourceName(indexingStatusDetail.id)}
|
||||
</div>
|
||||
{
|
||||
enableBilling && (
|
||||
<PriorityLabel className="ml-0" />
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{isSourceEmbedding(indexingStatusDetail) && (
|
||||
<div className="shrink-0 text-xs text-text-secondary">{`${getSourcePercent(indexingStatusDetail)}%`}</div>
|
||||
)}
|
||||
{indexingStatusDetail.indexing_status === 'error' && (
|
||||
<Tooltip
|
||||
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
|
||||
offset={4}
|
||||
popupContent={indexingStatusDetail.error}
|
||||
>
|
||||
<span>
|
||||
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
|
||||
</span>
|
||||
</Tooltip>
|
||||
)}
|
||||
{indexingStatusDetail.indexing_status === 'completed' && (
|
||||
<RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<Divider type="horizontal" className="my-0 bg-divider-subtle" />
|
||||
|
||||
<RuleDetail
|
||||
sourceData={ruleDetail}
|
||||
indexingType={indexingType}
|
||||
retrievalMethod={retrievalMethod}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ActionButtons
|
||||
apiReferenceUrl={apiReferenceUrl}
|
||||
onNavToDocuments={handleNavToDocuments}
|
||||
/>
|
||||
<div className="mt-6 flex items-center gap-x-2 py-2">
|
||||
<Link
|
||||
href={apiReferenceUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<Button
|
||||
className="w-fit gap-x-0.5 px-3"
|
||||
>
|
||||
<RiTerminalBoxLine className="size-4" />
|
||||
<span className="px-0.5">Access the API</span>
|
||||
</Button>
|
||||
</Link>
|
||||
<Button
|
||||
className="w-fit gap-x-0.5 px-3"
|
||||
variant="primary"
|
||||
onClick={navToDocumentList}
|
||||
>
|
||||
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
|
||||
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import type { IndexingStatusResponse } from '@/models/datasets'
|
||||
import {
|
||||
RiCheckboxCircleFill,
|
||||
RiErrorWarningFill,
|
||||
} from '@remixicon/react'
|
||||
import NotionIcon from '@/app/components/base/notion-icon'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import PriorityLabel from '@/app/components/billing/priority-label'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import DocumentFileIcon from '../../common/document-file-icon'
|
||||
import { getFileType, getSourcePercent, isSourceEmbedding } from './utils'
|
||||
|
||||
type IndexingProgressItemProps = {
|
||||
detail: IndexingStatusResponse
|
||||
name?: string
|
||||
sourceType?: DataSourceType
|
||||
notionIcon?: string
|
||||
enableBilling?: boolean
|
||||
}
|
||||
|
||||
// Status icon component for completed/error states
|
||||
const StatusIcon: FC<{ status: string, error?: string }> = ({ status, error }) => {
|
||||
if (status === 'completed')
|
||||
return <RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
|
||||
|
||||
if (status === 'error') {
|
||||
return (
|
||||
<Tooltip
|
||||
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
|
||||
offset={4}
|
||||
popupContent={error}
|
||||
>
|
||||
<span>
|
||||
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
|
||||
</span>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// Source type icon component
|
||||
const SourceTypeIcon: FC<{
|
||||
sourceType?: DataSourceType
|
||||
name?: string
|
||||
notionIcon?: string
|
||||
}> = ({ sourceType, name, notionIcon }) => {
|
||||
if (sourceType === DataSourceType.FILE) {
|
||||
return (
|
||||
<DocumentFileIcon
|
||||
size="sm"
|
||||
className="shrink-0"
|
||||
name={name}
|
||||
extension={getFileType(name)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
if (sourceType === DataSourceType.NOTION) {
|
||||
return (
|
||||
<NotionIcon
|
||||
className="shrink-0"
|
||||
type="page"
|
||||
src={notionIcon}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
const IndexingProgressItem: FC<IndexingProgressItemProps> = ({
|
||||
detail,
|
||||
name,
|
||||
sourceType,
|
||||
notionIcon,
|
||||
enableBilling,
|
||||
}) => {
|
||||
const isEmbedding = isSourceEmbedding(detail)
|
||||
const percent = getSourcePercent(detail)
|
||||
const isError = detail.indexing_status === 'error'
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
|
||||
isError && 'bg-state-destructive-hover-alt',
|
||||
)}
|
||||
>
|
||||
{isEmbedding && (
|
||||
<div
|
||||
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
|
||||
style={{ width: `${percent}%` }}
|
||||
/>
|
||||
)}
|
||||
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
|
||||
<SourceTypeIcon
|
||||
sourceType={sourceType}
|
||||
name={name}
|
||||
notionIcon={notionIcon}
|
||||
/>
|
||||
<div className="flex w-0 grow items-center gap-1" title={name}>
|
||||
<div className="system-xs-medium truncate text-text-secondary">
|
||||
{name}
|
||||
</div>
|
||||
{enableBilling && <PriorityLabel className="ml-0" />}
|
||||
</div>
|
||||
{isEmbedding && (
|
||||
<div className="shrink-0 text-xs text-text-secondary">{`${percent}%`}</div>
|
||||
)}
|
||||
<StatusIcon status={detail.indexing_status} error={detail.error} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default IndexingProgressItem
|
||||
@@ -1,133 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import type { ProcessRuleResponse } from '@/models/datasets'
|
||||
import Image from 'next/image'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
|
||||
import { ProcessMode } from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { indexMethodIcon, retrievalIcon } from '../icons'
|
||||
import { IndexingType } from '../step-two'
|
||||
|
||||
type RuleDetailProps = {
|
||||
sourceData?: ProcessRuleResponse
|
||||
indexingType?: string
|
||||
retrievalMethod?: RETRIEVE_METHOD
|
||||
}
|
||||
|
||||
// Lookup table for pre-processing rule names
|
||||
const PRE_PROCESSING_RULE_KEYS = {
|
||||
remove_extra_spaces: 'stepTwo.removeExtraSpaces',
|
||||
remove_urls_emails: 'stepTwo.removeUrlEmails',
|
||||
remove_stopwords: 'stepTwo.removeStopwords',
|
||||
} as const
|
||||
|
||||
// Lookup table for retrieval method icons
|
||||
const RETRIEVAL_ICON_MAP: Partial<Record<RETRIEVE_METHOD, string>> = {
|
||||
[RETRIEVE_METHOD.fullText]: retrievalIcon.fullText,
|
||||
[RETRIEVE_METHOD.hybrid]: retrievalIcon.hybrid,
|
||||
[RETRIEVE_METHOD.semantic]: retrievalIcon.vector,
|
||||
[RETRIEVE_METHOD.invertedIndex]: retrievalIcon.fullText,
|
||||
[RETRIEVE_METHOD.keywordSearch]: retrievalIcon.fullText,
|
||||
}
|
||||
|
||||
const isNumber = (value: unknown): value is number => typeof value === 'number'
|
||||
|
||||
const RuleDetail: FC<RuleDetailProps> = ({ sourceData, indexingType, retrievalMethod }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const segmentationRuleLabels = {
|
||||
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
|
||||
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
|
||||
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
|
||||
}
|
||||
|
||||
const getRuleName = useCallback((key: string): string | undefined => {
|
||||
const translationKey = PRE_PROCESSING_RULE_KEYS[key as keyof typeof PRE_PROCESSING_RULE_KEYS]
|
||||
return translationKey ? t(translationKey, { ns: 'datasetCreation' }) : undefined
|
||||
}, [t])
|
||||
|
||||
const getModeValue = useCallback((): string => {
|
||||
if (!sourceData?.mode)
|
||||
return '-'
|
||||
|
||||
if (sourceData.mode === ProcessMode.general)
|
||||
return t('embedding.custom', { ns: 'datasetDocuments' })
|
||||
|
||||
const parentModeLabel = sourceData.rules?.parent_mode === 'paragraph'
|
||||
? t('parentMode.paragraph', { ns: 'dataset' })
|
||||
: t('parentMode.fullDoc', { ns: 'dataset' })
|
||||
|
||||
return `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${parentModeLabel}`
|
||||
}, [sourceData, t])
|
||||
|
||||
const getSegmentLengthValue = useCallback((): string | number => {
|
||||
if (!sourceData?.mode)
|
||||
return '-'
|
||||
|
||||
const maxTokens = isNumber(sourceData.rules?.segmentation?.max_tokens)
|
||||
? sourceData.rules.segmentation.max_tokens
|
||||
: '-'
|
||||
|
||||
if (sourceData.mode === ProcessMode.general)
|
||||
return maxTokens
|
||||
|
||||
const childMaxTokens = isNumber(sourceData.rules?.subchunk_segmentation?.max_tokens)
|
||||
? sourceData.rules.subchunk_segmentation.max_tokens
|
||||
: '-'
|
||||
|
||||
return `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
|
||||
}, [sourceData, t])
|
||||
|
||||
const getTextCleaningValue = useCallback((): string => {
|
||||
if (!sourceData?.mode)
|
||||
return '-'
|
||||
|
||||
const enabledRules = sourceData.rules?.pre_processing_rules?.filter(rule => rule.enabled) || []
|
||||
const ruleNames = enabledRules
|
||||
.map((rule) => {
|
||||
const name = getRuleName(rule.id)
|
||||
return typeof name === 'string' ? name : ''
|
||||
})
|
||||
.filter(name => name)
|
||||
return ruleNames.length > 0 ? ruleNames.join(',') : '-'
|
||||
}, [sourceData, getRuleName])
|
||||
|
||||
const fieldValueGetters: Record<string, () => string | number> = {
|
||||
mode: getModeValue,
|
||||
segmentLength: getSegmentLengthValue,
|
||||
textCleaning: getTextCleaningValue,
|
||||
}
|
||||
|
||||
const isEconomical = indexingType === IndexingType.ECONOMICAL
|
||||
const indexMethodIconSrc = isEconomical ? indexMethodIcon.economical : indexMethodIcon.high_quality
|
||||
const indexModeLabel = t(`stepTwo.${isEconomical ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' })
|
||||
|
||||
const effectiveRetrievalMethod = isEconomical ? 'keyword_search' : (retrievalMethod ?? 'semantic_search')
|
||||
const retrievalLabel = t(`retrieval.${effectiveRetrievalMethod}.title`, { ns: 'dataset' })
|
||||
const retrievalIconSrc = RETRIEVAL_ICON_MAP[retrievalMethod as keyof typeof RETRIEVAL_ICON_MAP] ?? retrievalIcon.vector
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-1">
|
||||
{Object.keys(segmentationRuleLabels).map(field => (
|
||||
<FieldInfo
|
||||
key={field}
|
||||
label={segmentationRuleLabels[field as keyof typeof segmentationRuleLabels]}
|
||||
displayedValue={String(fieldValueGetters[field]())}
|
||||
/>
|
||||
))}
|
||||
<FieldInfo
|
||||
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
displayedValue={indexModeLabel}
|
||||
valueIcon={<Image className="size-4" src={indexMethodIconSrc} alt="" />}
|
||||
/>
|
||||
<FieldInfo
|
||||
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
displayedValue={retrievalLabel}
|
||||
valueIcon={<Image className="size-4" src={retrievalIconSrc} alt="" />}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default RuleDetail
|
||||
@@ -1,22 +0,0 @@
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
|
||||
|
||||
const UpgradeBanner: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
|
||||
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
|
||||
<ZapFast className="h-4 w-4 text-[#FB6514]" />
|
||||
</div>
|
||||
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
|
||||
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
|
||||
</div>
|
||||
<UpgradeBtn loc="knowledge-speed-up" />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default UpgradeBanner
|
||||
@@ -1,90 +0,0 @@
|
||||
import type { IndexingStatusResponse } from '@/models/datasets'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { fetchIndexingStatusBatch } from '@/service/datasets'
|
||||
|
||||
const POLLING_INTERVAL = 2500
|
||||
const COMPLETED_STATUSES = ['completed', 'error', 'paused'] as const
|
||||
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
|
||||
|
||||
type IndexingStatusPollingParams = {
|
||||
datasetId: string
|
||||
batchId: string
|
||||
}
|
||||
|
||||
type IndexingStatusPollingResult = {
|
||||
statusList: IndexingStatusResponse[]
|
||||
isEmbedding: boolean
|
||||
isEmbeddingCompleted: boolean
|
||||
}
|
||||
|
||||
const isStatusCompleted = (status: string): boolean =>
|
||||
COMPLETED_STATUSES.includes(status as typeof COMPLETED_STATUSES[number])
|
||||
|
||||
const isAllCompleted = (statusList: IndexingStatusResponse[]): boolean =>
|
||||
statusList.every(item => isStatusCompleted(item.indexing_status))
|
||||
|
||||
/**
|
||||
* Custom hook for polling indexing status with automatic stop on completion.
|
||||
* Handles the polling lifecycle and provides derived states for UI rendering.
|
||||
*/
|
||||
export const useIndexingStatusPolling = ({
|
||||
datasetId,
|
||||
batchId,
|
||||
}: IndexingStatusPollingParams): IndexingStatusPollingResult => {
|
||||
const [statusList, setStatusList] = useState<IndexingStatusResponse[]>([])
|
||||
const isStopPollingRef = useRef(false)
|
||||
|
||||
useEffect(() => {
|
||||
// Reset polling state on mount
|
||||
isStopPollingRef.current = false
|
||||
let timeoutId: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
const fetchStatus = async (): Promise<IndexingStatusResponse[]> => {
|
||||
const response = await fetchIndexingStatusBatch({ datasetId, batchId })
|
||||
setStatusList(response.data)
|
||||
return response.data
|
||||
}
|
||||
|
||||
const poll = async (): Promise<void> => {
|
||||
if (isStopPollingRef.current)
|
||||
return
|
||||
|
||||
try {
|
||||
const data = await fetchStatus()
|
||||
if (isAllCompleted(data)) {
|
||||
isStopPollingRef.current = true
|
||||
return
|
||||
}
|
||||
}
|
||||
catch {
|
||||
// Continue polling on error
|
||||
}
|
||||
|
||||
if (!isStopPollingRef.current) {
|
||||
timeoutId = setTimeout(() => {
|
||||
poll()
|
||||
}, POLLING_INTERVAL)
|
||||
}
|
||||
}
|
||||
|
||||
poll()
|
||||
|
||||
return () => {
|
||||
isStopPollingRef.current = true
|
||||
if (timeoutId)
|
||||
clearTimeout(timeoutId)
|
||||
}
|
||||
}, [datasetId, batchId])
|
||||
|
||||
const isEmbedding = statusList.some(item =>
|
||||
EMBEDDING_STATUSES.includes(item?.indexing_status as typeof EMBEDDING_STATUSES[number]),
|
||||
)
|
||||
|
||||
const isEmbeddingCompleted = statusList.length > 0 && isAllCompleted(statusList)
|
||||
|
||||
return {
|
||||
statusList,
|
||||
isEmbedding,
|
||||
isEmbeddingCompleted,
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import type {
|
||||
DataSourceInfo,
|
||||
DataSourceType,
|
||||
FullDocumentDetail,
|
||||
IndexingStatusResponse,
|
||||
LegacyDataSourceInfo,
|
||||
} from '@/models/datasets'
|
||||
|
||||
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
|
||||
|
||||
/**
|
||||
* Type guard for legacy data source info with upload_file property
|
||||
*/
|
||||
export const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
|
||||
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a status indicates the source is being embedded
|
||||
*/
|
||||
export const isSourceEmbedding = (detail: IndexingStatusResponse): boolean =>
|
||||
EMBEDDING_STATUSES.includes(detail.indexing_status as typeof EMBEDDING_STATUSES[number])
|
||||
|
||||
/**
|
||||
* Calculate the progress percentage for a document
|
||||
*/
|
||||
export const getSourcePercent = (detail: IndexingStatusResponse): number => {
|
||||
const completedCount = detail.completed_segments || 0
|
||||
const totalCount = detail.total_segments || 0
|
||||
|
||||
if (totalCount === 0)
|
||||
return 0
|
||||
|
||||
const percent = Math.round(completedCount * 100 / totalCount)
|
||||
return Math.min(percent, 100)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get file extension from filename, defaults to 'txt'
|
||||
*/
|
||||
export const getFileType = (name?: string): string =>
|
||||
name?.split('.').pop() || 'txt'
|
||||
|
||||
/**
|
||||
* Document lookup utilities - provides document info by ID from a list
|
||||
*/
|
||||
export const createDocumentLookup = (documents: FullDocumentDetail[]) => {
|
||||
const documentMap = new Map(documents.map(doc => [doc.id, doc]))
|
||||
|
||||
return {
|
||||
getDocument: (id: string) => documentMap.get(id),
|
||||
|
||||
getName: (id: string) => documentMap.get(id)?.name,
|
||||
|
||||
getSourceType: (id: string) => documentMap.get(id)?.data_source_type as DataSourceType | undefined,
|
||||
|
||||
getNotionIcon: (id: string) => {
|
||||
const info = documentMap.get(id)?.data_source_info
|
||||
if (info && isLegacyDataSourceInfo(info))
|
||||
return info.notion_page_icon
|
||||
return undefined
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,199 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { PreProcessingRule } from '@/models/datasets'
|
||||
import {
|
||||
RiAlertFill,
|
||||
RiSearchEyeLine,
|
||||
} from '@remixicon/react'
|
||||
import Image from 'next/image'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import SettingCog from '../../assets/setting-gear-mod.svg'
|
||||
import s from '../index.module.css'
|
||||
import LanguageSelect from '../language-select'
|
||||
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
|
||||
import { OptionCard } from './option-card'
|
||||
|
||||
type TextLabelProps = {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
const TextLabel: FC<TextLabelProps> = ({ children }) => {
|
||||
return <label className="system-sm-semibold text-text-secondary">{children}</label>
|
||||
}
|
||||
|
||||
type GeneralChunkingOptionsProps = {
|
||||
// State
|
||||
segmentIdentifier: string
|
||||
maxChunkLength: number
|
||||
overlap: number
|
||||
rules: PreProcessingRule[]
|
||||
currentDocForm: ChunkingMode
|
||||
docLanguage: string
|
||||
// Flags
|
||||
isActive: boolean
|
||||
isInUpload: boolean
|
||||
isNotUploadInEmptyDataset: boolean
|
||||
hasCurrentDatasetDocForm: boolean
|
||||
// Actions
|
||||
onSegmentIdentifierChange: (value: string) => void
|
||||
onMaxChunkLengthChange: (value: number) => void
|
||||
onOverlapChange: (value: number) => void
|
||||
onRuleToggle: (id: string) => void
|
||||
onDocFormChange: (form: ChunkingMode) => void
|
||||
onDocLanguageChange: (lang: string) => void
|
||||
onPreview: () => void
|
||||
onReset: () => void
|
||||
// Locale
|
||||
locale: string
|
||||
}
|
||||
|
||||
export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
||||
segmentIdentifier,
|
||||
maxChunkLength,
|
||||
overlap,
|
||||
rules,
|
||||
currentDocForm,
|
||||
docLanguage,
|
||||
isActive,
|
||||
isInUpload,
|
||||
isNotUploadInEmptyDataset,
|
||||
hasCurrentDatasetDocForm,
|
||||
onSegmentIdentifierChange,
|
||||
onMaxChunkLengthChange,
|
||||
onOverlapChange,
|
||||
onRuleToggle,
|
||||
onDocFormChange,
|
||||
onDocLanguageChange,
|
||||
onPreview,
|
||||
onReset,
|
||||
locale,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getRuleName = (key: string): string => {
|
||||
const ruleNameMap: Record<string, string> = {
|
||||
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
|
||||
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
|
||||
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
|
||||
}
|
||||
return ruleNameMap[key] ?? key
|
||||
}
|
||||
|
||||
return (
|
||||
<OptionCard
|
||||
className="mb-2 bg-background-section"
|
||||
title={t('stepTwo.general', { ns: 'datasetCreation' })}
|
||||
icon={<Image width={20} height={20} src={SettingCog} alt={t('stepTwo.general', { ns: 'datasetCreation' })} />}
|
||||
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
|
||||
description={t('stepTwo.generalTip', { ns: 'datasetCreation' })}
|
||||
isActive={isActive}
|
||||
onSwitched={() => onDocFormChange(ChunkingMode.text)}
|
||||
actions={(
|
||||
<>
|
||||
<Button variant="secondary-accent" onClick={onPreview}>
|
||||
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
|
||||
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
<Button variant="ghost" onClick={onReset}>
|
||||
{t('stepTwo.reset', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
noHighlight={isInUpload && isNotUploadInEmptyDataset}
|
||||
>
|
||||
<div className="flex flex-col gap-y-4">
|
||||
<div className="flex gap-3">
|
||||
<DelimiterInput
|
||||
value={segmentIdentifier}
|
||||
onChange={e => onSegmentIdentifierChange(e.target.value)}
|
||||
/>
|
||||
<MaxLengthInput
|
||||
unit="characters"
|
||||
value={maxChunkLength}
|
||||
onChange={onMaxChunkLengthChange}
|
||||
/>
|
||||
<OverlapInput
|
||||
unit="characters"
|
||||
value={overlap}
|
||||
min={1}
|
||||
onChange={onOverlapChange}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex w-full flex-col">
|
||||
<div className="flex items-center gap-x-2">
|
||||
<div className="inline-flex shrink-0">
|
||||
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
|
||||
</div>
|
||||
<Divider className="grow" bgStyle="gradient" />
|
||||
</div>
|
||||
<div className="mt-1">
|
||||
{rules.map(rule => (
|
||||
<div
|
||||
key={rule.id}
|
||||
className={s.ruleItem}
|
||||
onClick={() => onRuleToggle(rule.id)}
|
||||
>
|
||||
<Checkbox checked={rule.enabled} />
|
||||
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
|
||||
{getRuleName(rule.id)}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
{IS_CE_EDITION && (
|
||||
<>
|
||||
<Divider type="horizontal" className="my-4 bg-divider-subtle" />
|
||||
<div className="flex items-center py-0.5">
|
||||
<div
|
||||
className="flex items-center"
|
||||
onClick={() => {
|
||||
if (hasCurrentDatasetDocForm)
|
||||
return
|
||||
if (currentDocForm === ChunkingMode.qa)
|
||||
onDocFormChange(ChunkingMode.text)
|
||||
else
|
||||
onDocFormChange(ChunkingMode.qa)
|
||||
}}
|
||||
>
|
||||
<Checkbox
|
||||
checked={currentDocForm === ChunkingMode.qa}
|
||||
disabled={hasCurrentDatasetDocForm}
|
||||
/>
|
||||
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
|
||||
{t('stepTwo.useQALanguage', { ns: 'datasetCreation' })}
|
||||
</label>
|
||||
</div>
|
||||
<LanguageSelect
|
||||
currentLanguage={docLanguage || locale}
|
||||
onSelect={onDocLanguageChange}
|
||||
disabled={currentDocForm !== ChunkingMode.qa}
|
||||
/>
|
||||
<Tooltip popupContent={t('stepTwo.QATip', { ns: 'datasetCreation' })} />
|
||||
</div>
|
||||
{currentDocForm === ChunkingMode.qa && (
|
||||
<div
|
||||
style={{
|
||||
background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.1) 0%, rgba(255, 255, 255, 0.00) 100%)',
|
||||
}}
|
||||
className="mt-2 flex h-10 items-center gap-2 rounded-xl border border-components-panel-border px-3 text-xs shadow-xs backdrop-blur-[5px]"
|
||||
>
|
||||
<RiAlertFill className="size-4 text-text-warning-secondary" />
|
||||
<span className="system-xs-medium text-text-primary">
|
||||
{t('stepTwo.QATip', { ns: 'datasetCreation' })}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</OptionCard>
|
||||
)
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
export { GeneralChunkingOptions } from './general-chunking-options'
|
||||
export { IndexingModeSection } from './indexing-mode-section'
|
||||
export { ParentChildOptions } from './parent-child-options'
|
||||
export { PreviewPanel } from './preview-panel'
|
||||
export { StepTwoFooter } from './step-two-footer'
|
||||
@@ -1,253 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import Image from 'next/image'
|
||||
import Link from 'next/link'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import Button from '@/app/components/base/button'
|
||||
import CustomDialog from '@/app/components/base/dialog'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { indexMethodIcon } from '../../icons'
|
||||
import { IndexingType } from '../hooks'
|
||||
import s from '../index.module.css'
|
||||
import { OptionCard } from './option-card'
|
||||
|
||||
type IndexingModeSectionProps = {
|
||||
// State
|
||||
indexType: IndexingType
|
||||
hasSetIndexType: boolean
|
||||
docForm: ChunkingMode
|
||||
embeddingModel: DefaultModel
|
||||
embeddingModelList?: Model[]
|
||||
retrievalConfig: RetrievalConfig
|
||||
showMultiModalTip: boolean
|
||||
// Flags
|
||||
isModelAndRetrievalConfigDisabled: boolean
|
||||
datasetId?: string
|
||||
// Modal state
|
||||
isQAConfirmDialogOpen: boolean
|
||||
// Actions
|
||||
onIndexTypeChange: (type: IndexingType) => void
|
||||
onEmbeddingModelChange: (model: DefaultModel) => void
|
||||
onRetrievalConfigChange: (config: RetrievalConfig) => void
|
||||
onQAConfirmDialogClose: () => void
|
||||
onQAConfirmDialogConfirm: () => void
|
||||
}
|
||||
|
||||
export const IndexingModeSection: FC<IndexingModeSectionProps> = ({
|
||||
indexType,
|
||||
hasSetIndexType,
|
||||
docForm,
|
||||
embeddingModel,
|
||||
embeddingModelList,
|
||||
retrievalConfig,
|
||||
showMultiModalTip,
|
||||
isModelAndRetrievalConfigDisabled,
|
||||
datasetId,
|
||||
isQAConfirmDialogOpen,
|
||||
onIndexTypeChange,
|
||||
onEmbeddingModelChange,
|
||||
onRetrievalConfigChange,
|
||||
onQAConfirmDialogClose,
|
||||
onQAConfirmDialogConfirm,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const docLink = useDocLink()
|
||||
|
||||
const getIndexingTechnique = () => indexType
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Index Mode */}
|
||||
<div className="system-md-semibold mb-1 text-text-secondary">
|
||||
{t('stepTwo.indexMode', { ns: 'datasetCreation' })}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{/* Qualified option */}
|
||||
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.QUALIFIED)) && (
|
||||
<OptionCard
|
||||
className="flex-1 self-stretch"
|
||||
title={(
|
||||
<div className="flex items-center">
|
||||
{t('stepTwo.qualified', { ns: 'datasetCreation' })}
|
||||
<Badge
|
||||
className={cn(
|
||||
'ml-1 h-[18px]',
|
||||
(!hasSetIndexType && indexType === IndexingType.QUALIFIED)
|
||||
? 'border-text-accent-secondary text-text-accent-secondary'
|
||||
: '',
|
||||
)}
|
||||
uppercase
|
||||
>
|
||||
{t('stepTwo.recommend', { ns: 'datasetCreation' })}
|
||||
</Badge>
|
||||
<span className="ml-auto">
|
||||
{!hasSetIndexType && <span className={cn(s.radio)} />}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
description={t('stepTwo.qualifiedTip', { ns: 'datasetCreation' })}
|
||||
icon={<Image src={indexMethodIcon.high_quality} alt="" />}
|
||||
isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED}
|
||||
disabled={hasSetIndexType}
|
||||
onSwitched={() => onIndexTypeChange(IndexingType.QUALIFIED)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Economical option */}
|
||||
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.ECONOMICAL)) && (
|
||||
<>
|
||||
<CustomDialog show={isQAConfirmDialogOpen} onClose={onQAConfirmDialogClose} className="w-[432px]">
|
||||
<header className="mb-4 pt-6">
|
||||
<h2 className="text-lg font-semibold text-text-primary">
|
||||
{t('stepTwo.qaSwitchHighQualityTipTitle', { ns: 'datasetCreation' })}
|
||||
</h2>
|
||||
<p className="mt-2 text-sm font-normal text-text-secondary">
|
||||
{t('stepTwo.qaSwitchHighQualityTipContent', { ns: 'datasetCreation' })}
|
||||
</p>
|
||||
</header>
|
||||
<div className="flex gap-2 pb-6">
|
||||
<Button className="ml-auto" onClick={onQAConfirmDialogClose}>
|
||||
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
<Button variant="primary" onClick={onQAConfirmDialogConfirm}>
|
||||
{t('stepTwo.switch', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
</div>
|
||||
</CustomDialog>
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="rounded-lg border-components-panel-border bg-components-tooltip-bg p-3 text-xs font-medium text-text-secondary shadow-lg">
|
||||
{docForm === ChunkingMode.qa
|
||||
? t('stepTwo.notAvailableForQA', { ns: 'datasetCreation' })
|
||||
: t('stepTwo.notAvailableForParentChild', { ns: 'datasetCreation' })}
|
||||
</div>
|
||||
)}
|
||||
noDecoration
|
||||
position="top"
|
||||
asChild={false}
|
||||
triggerClassName="flex-1 self-stretch"
|
||||
>
|
||||
<OptionCard
|
||||
className="h-full"
|
||||
title={t('stepTwo.economical', { ns: 'datasetCreation' })}
|
||||
description={t('stepTwo.economicalTip', { ns: 'datasetCreation' })}
|
||||
icon={<Image src={indexMethodIcon.economical} alt="" />}
|
||||
isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL}
|
||||
disabled={hasSetIndexType || docForm !== ChunkingMode.text}
|
||||
onSwitched={() => onIndexTypeChange(IndexingType.ECONOMICAL)}
|
||||
/>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* High quality tip */}
|
||||
{!hasSetIndexType && indexType === IndexingType.QUALIFIED && (
|
||||
<div className="mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-2 shadow-xs backdrop-blur-[5px]">
|
||||
<div className="absolute bottom-0 left-0 right-0 top-0 bg-dataset-warning-message-bg opacity-40"></div>
|
||||
<div className="p-1">
|
||||
<AlertTriangle className="size-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<span className="system-xs-medium text-text-primary">
|
||||
{t('stepTwo.highQualityTip', { ns: 'datasetCreation' })}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Economical index setting tip */}
|
||||
{hasSetIndexType && indexType === IndexingType.ECONOMICAL && (
|
||||
<div className="system-xs-medium mt-2 text-text-tertiary">
|
||||
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
|
||||
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
|
||||
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Embedding model */}
|
||||
{indexType === IndexingType.QUALIFIED && (
|
||||
<div className="mt-5">
|
||||
<div className={cn('system-md-semibold mb-1 text-text-secondary', datasetId && 'flex items-center justify-between')}>
|
||||
{t('form.embeddingModel', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<ModelSelector
|
||||
readonly={isModelAndRetrievalConfigDisabled}
|
||||
triggerClassName={isModelAndRetrievalConfigDisabled ? 'opacity-50' : ''}
|
||||
defaultModel={embeddingModel}
|
||||
modelList={embeddingModelList ?? []}
|
||||
onSelect={onEmbeddingModelChange}
|
||||
/>
|
||||
{isModelAndRetrievalConfigDisabled && (
|
||||
<div className="system-xs-medium mt-2 text-text-tertiary">
|
||||
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
|
||||
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
|
||||
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Divider className="my-5" />
|
||||
|
||||
{/* Retrieval Method Config */}
|
||||
<div>
|
||||
{!isModelAndRetrievalConfigDisabled
|
||||
? (
|
||||
<div className="mb-1">
|
||||
<div className="system-md-semibold mb-0.5 text-text-secondary">
|
||||
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
<div className="body-xs-regular text-text-tertiary">
|
||||
<a
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href={docLink('/guides/knowledge-base/create-knowledge-and-upload-documents')}
|
||||
className="text-text-accent"
|
||||
>
|
||||
{t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })}
|
||||
</a>
|
||||
{t('form.retrievalSetting.longDescription', { ns: 'datasetSettings' })}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
: (
|
||||
<div className={cn('system-md-semibold mb-0.5 text-text-secondary', 'flex items-center justify-between')}>
|
||||
<div>{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div>
|
||||
{getIndexingTechnique() === IndexingType.QUALIFIED
|
||||
? (
|
||||
<RetrievalMethodConfig
|
||||
disabled={isModelAndRetrievalConfigDisabled}
|
||||
value={retrievalConfig}
|
||||
onChange={onRetrievalConfigChange}
|
||||
showMultiModalTip={showMultiModalTip}
|
||||
/>
|
||||
)
|
||||
: (
|
||||
<EconomicalRetrievalMethodConfig
|
||||
disabled={isModelAndRetrievalConfigDisabled}
|
||||
value={retrievalConfig}
|
||||
onChange={onRetrievalConfigChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { ParentChildConfig } from '../hooks'
|
||||
import type { ParentMode, PreProcessingRule } from '@/models/datasets'
|
||||
import { RiSearchEyeLine } from '@remixicon/react'
|
||||
import Image from 'next/image'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||
import RadioCard from '@/app/components/base/radio-card'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import FileList from '../../assets/file-list-3-fill.svg'
|
||||
import Note from '../../assets/note-mod.svg'
|
||||
import BlueEffect from '../../assets/option-card-effect-blue.svg'
|
||||
import s from '../index.module.css'
|
||||
import { DelimiterInput, MaxLengthInput } from './inputs'
|
||||
import { OptionCard } from './option-card'
|
||||
|
||||
type TextLabelProps = {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
const TextLabel: FC<TextLabelProps> = ({ children }) => {
|
||||
return <label className="system-sm-semibold text-text-secondary">{children}</label>
|
||||
}
|
||||
|
||||
type ParentChildOptionsProps = {
|
||||
// State
|
||||
parentChildConfig: ParentChildConfig
|
||||
rules: PreProcessingRule[]
|
||||
currentDocForm: ChunkingMode
|
||||
// Flags
|
||||
isActive: boolean
|
||||
isInUpload: boolean
|
||||
isNotUploadInEmptyDataset: boolean
|
||||
// Actions
|
||||
onDocFormChange: (form: ChunkingMode) => void
|
||||
onChunkForContextChange: (mode: ParentMode) => void
|
||||
onParentDelimiterChange: (value: string) => void
|
||||
onParentMaxLengthChange: (value: number) => void
|
||||
onChildDelimiterChange: (value: string) => void
|
||||
onChildMaxLengthChange: (value: number) => void
|
||||
onRuleToggle: (id: string) => void
|
||||
onPreview: () => void
|
||||
onReset: () => void
|
||||
}
|
||||
|
||||
export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
||||
parentChildConfig,
|
||||
rules,
|
||||
currentDocForm: _currentDocForm,
|
||||
isActive,
|
||||
isInUpload,
|
||||
isNotUploadInEmptyDataset,
|
||||
onDocFormChange,
|
||||
onChunkForContextChange,
|
||||
onParentDelimiterChange,
|
||||
onParentMaxLengthChange,
|
||||
onChildDelimiterChange,
|
||||
onChildMaxLengthChange,
|
||||
onRuleToggle,
|
||||
onPreview,
|
||||
onReset,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const getRuleName = (key: string): string => {
|
||||
const ruleNameMap: Record<string, string> = {
|
||||
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
|
||||
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
|
||||
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
|
||||
}
|
||||
return ruleNameMap[key] ?? key
|
||||
}
|
||||
|
||||
return (
|
||||
<OptionCard
|
||||
title={t('stepTwo.parentChild', { ns: 'datasetCreation' })}
|
||||
icon={<ParentChildChunk className="h-[20px] w-[20px]" />}
|
||||
effectImg={BlueEffect.src}
|
||||
className="text-util-colors-blue-light-blue-light-500"
|
||||
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
|
||||
description={t('stepTwo.parentChildTip', { ns: 'datasetCreation' })}
|
||||
isActive={isActive}
|
||||
onSwitched={() => onDocFormChange(ChunkingMode.parentChild)}
|
||||
actions={(
|
||||
<>
|
||||
<Button variant="secondary-accent" onClick={onPreview}>
|
||||
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
|
||||
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
<Button variant="ghost" onClick={onReset}>
|
||||
{t('stepTwo.reset', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
noHighlight={isInUpload && isNotUploadInEmptyDataset}
|
||||
>
|
||||
<div className="flex flex-col gap-4">
|
||||
{/* Parent chunk for context */}
|
||||
<div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<div className="inline-flex shrink-0">
|
||||
<TextLabel>{t('stepTwo.parentChunkForContext', { ns: 'datasetCreation' })}</TextLabel>
|
||||
</div>
|
||||
<Divider className="grow" bgStyle="gradient" />
|
||||
</div>
|
||||
<RadioCard
|
||||
className="mt-1"
|
||||
icon={<Image src={Note} alt="" />}
|
||||
title={t('stepTwo.paragraph', { ns: 'datasetCreation' })}
|
||||
description={t('stepTwo.paragraphTip', { ns: 'datasetCreation' })}
|
||||
isChosen={parentChildConfig.chunkForContext === 'paragraph'}
|
||||
onChosen={() => onChunkForContextChange('paragraph')}
|
||||
chosenConfig={(
|
||||
<div className="flex gap-3">
|
||||
<DelimiterInput
|
||||
value={parentChildConfig.parent.delimiter}
|
||||
tooltip={t('stepTwo.parentChildDelimiterTip', { ns: 'datasetCreation' })!}
|
||||
onChange={e => onParentDelimiterChange(e.target.value)}
|
||||
/>
|
||||
<MaxLengthInput
|
||||
unit="characters"
|
||||
value={parentChildConfig.parent.maxLength}
|
||||
onChange={onParentMaxLengthChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<RadioCard
|
||||
className="mt-2"
|
||||
icon={<Image src={FileList} alt="" />}
|
||||
title={t('stepTwo.fullDoc', { ns: 'datasetCreation' })}
|
||||
description={t('stepTwo.fullDocTip', { ns: 'datasetCreation' })}
|
||||
onChosen={() => onChunkForContextChange('full-doc')}
|
||||
isChosen={parentChildConfig.chunkForContext === 'full-doc'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Child chunk for retrieval */}
|
||||
<div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<div className="inline-flex shrink-0">
|
||||
<TextLabel>{t('stepTwo.childChunkForRetrieval', { ns: 'datasetCreation' })}</TextLabel>
|
||||
</div>
|
||||
<Divider className="grow" bgStyle="gradient" />
|
||||
</div>
|
||||
<div className="mt-1 flex gap-3">
|
||||
<DelimiterInput
|
||||
value={parentChildConfig.child.delimiter}
|
||||
tooltip={t('stepTwo.parentChildChunkDelimiterTip', { ns: 'datasetCreation' })!}
|
||||
onChange={e => onChildDelimiterChange(e.target.value)}
|
||||
/>
|
||||
<MaxLengthInput
|
||||
unit="characters"
|
||||
value={parentChildConfig.child.maxLength}
|
||||
onChange={onChildMaxLengthChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Rules */}
|
||||
<div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<div className="inline-flex shrink-0">
|
||||
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
|
||||
</div>
|
||||
<Divider className="grow" bgStyle="gradient" />
|
||||
</div>
|
||||
<div className="mt-1">
|
||||
{rules.map(rule => (
|
||||
<div
|
||||
key={rule.id}
|
||||
className={s.ruleItem}
|
||||
onClick={() => onRuleToggle(rule.id)}
|
||||
>
|
||||
<Checkbox checked={rule.enabled} />
|
||||
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
|
||||
{getRuleName(rule.id)}
|
||||
</label>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</OptionCard>
|
||||
)
|
||||
}
|
||||
@@ -1,171 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { ParentChildConfig } from '../hooks'
|
||||
import type { DataSourceType, FileIndexingEstimateResponse } from '@/models/datasets'
|
||||
import { RiSearchEyeLine } from '@remixicon/react'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import FloatRightContainer from '@/app/components/base/float-right-container'
|
||||
import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton'
|
||||
import { FULL_DOC_PREVIEW_LENGTH } from '@/config'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { ChunkContainer, QAPreview } from '../../../chunk'
|
||||
import PreviewDocumentPicker from '../../../common/document-picker/preview-document-picker'
|
||||
import { PreviewSlice } from '../../../formatted-text/flavours/preview-slice'
|
||||
import { FormattedText } from '../../../formatted-text/formatted'
|
||||
import PreviewContainer from '../../../preview/container'
|
||||
import { PreviewHeader } from '../../../preview/header'
|
||||
|
||||
type PreviewPanelProps = {
|
||||
// State
|
||||
isMobile: boolean
|
||||
dataSourceType: DataSourceType
|
||||
currentDocForm: ChunkingMode
|
||||
estimate?: FileIndexingEstimateResponse
|
||||
parentChildConfig: ParentChildConfig
|
||||
isSetting?: boolean
|
||||
// Picker
|
||||
pickerFiles: Array<{ id: string, name: string, extension: string }>
|
||||
pickerValue: { id: string, name: string, extension: string }
|
||||
// Mutation state
|
||||
isIdle: boolean
|
||||
isPending: boolean
|
||||
// Actions
|
||||
onPickerChange: (selected: { id: string, name: string }) => void
|
||||
}
|
||||
|
||||
export const PreviewPanel: FC<PreviewPanelProps> = ({
|
||||
isMobile,
|
||||
dataSourceType: _dataSourceType,
|
||||
currentDocForm,
|
||||
estimate,
|
||||
parentChildConfig,
|
||||
isSetting,
|
||||
pickerFiles,
|
||||
pickerValue,
|
||||
isIdle,
|
||||
isPending,
|
||||
onPickerChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<FloatRightContainer isMobile={isMobile} isOpen={true} onClose={noop} footer={null}>
|
||||
<PreviewContainer
|
||||
header={(
|
||||
<PreviewHeader title={t('stepTwo.preview', { ns: 'datasetCreation' })}>
|
||||
<div className="flex items-center gap-1">
|
||||
<PreviewDocumentPicker
|
||||
files={pickerFiles as Array<Required<{ id: string, name: string, extension: string }>>}
|
||||
onChange={onPickerChange}
|
||||
value={isSetting ? pickerFiles[0] : pickerValue}
|
||||
/>
|
||||
{currentDocForm !== ChunkingMode.qa && (
|
||||
<Badge
|
||||
text={t('stepTwo.previewChunkCount', {
|
||||
ns: 'datasetCreation',
|
||||
count: estimate?.total_segments || 0,
|
||||
}) as string}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</PreviewHeader>
|
||||
)}
|
||||
className={cn('relative flex h-full w-1/2 shrink-0 p-4 pr-0', isMobile && 'w-full max-w-[524px]')}
|
||||
mainClassName="space-y-6"
|
||||
>
|
||||
{/* QA Preview */}
|
||||
{currentDocForm === ChunkingMode.qa && estimate?.qa_preview && (
|
||||
estimate.qa_preview.map((item, index) => (
|
||||
<ChunkContainer
|
||||
key={item.question}
|
||||
label={`Chunk-${index + 1}`}
|
||||
characterCount={item.question.length + item.answer.length}
|
||||
>
|
||||
<QAPreview qa={item} />
|
||||
</ChunkContainer>
|
||||
))
|
||||
)}
|
||||
|
||||
{/* Text Preview */}
|
||||
{currentDocForm === ChunkingMode.text && estimate?.preview && (
|
||||
estimate.preview.map((item, index) => (
|
||||
<ChunkContainer
|
||||
key={item.content}
|
||||
label={`Chunk-${index + 1}`}
|
||||
characterCount={item.content.length}
|
||||
>
|
||||
{item.content}
|
||||
</ChunkContainer>
|
||||
))
|
||||
)}
|
||||
|
||||
{/* Parent-Child Preview */}
|
||||
{currentDocForm === ChunkingMode.parentChild && estimate?.preview && (
|
||||
estimate.preview.map((item, index) => {
|
||||
const indexForLabel = index + 1
|
||||
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
|
||||
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
|
||||
: item.child_chunks
|
||||
return (
|
||||
<ChunkContainer
|
||||
key={item.content}
|
||||
label={`Chunk-${indexForLabel}`}
|
||||
characterCount={item.content.length}
|
||||
>
|
||||
<FormattedText>
|
||||
{childChunks.map((child, childIndex) => {
|
||||
const childIndexForLabel = childIndex + 1
|
||||
return (
|
||||
<PreviewSlice
|
||||
key={`C-${childIndexForLabel}-${child}`}
|
||||
label={`C-${childIndexForLabel}`}
|
||||
text={child}
|
||||
tooltip={`Child-chunk-${childIndexForLabel} · ${child.length} Characters`}
|
||||
labelInnerClassName="text-[10px] font-semibold align-bottom leading-7"
|
||||
dividerClassName="leading-7"
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</FormattedText>
|
||||
</ChunkContainer>
|
||||
)
|
||||
})
|
||||
)}
|
||||
|
||||
{/* Idle State */}
|
||||
{isIdle && (
|
||||
<div className="flex h-full w-full items-center justify-center">
|
||||
<div className="flex flex-col items-center justify-center gap-3">
|
||||
<RiSearchEyeLine className="size-10 text-text-empty-state-icon" />
|
||||
<p className="text-sm text-text-tertiary">
|
||||
{t('stepTwo.previewChunkTip', { ns: 'datasetCreation' })}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Loading State */}
|
||||
{isPending && (
|
||||
<div className="space-y-6">
|
||||
{Array.from({ length: 10 }, (_, i) => (
|
||||
<SkeletonContainer key={i}>
|
||||
<SkeletonRow>
|
||||
<SkeletonRectangle className="w-20" />
|
||||
<SkeletonPoint />
|
||||
<SkeletonRectangle className="w-24" />
|
||||
</SkeletonRow>
|
||||
<SkeletonRectangle className="w-full" />
|
||||
<SkeletonRectangle className="w-full" />
|
||||
<SkeletonRectangle className="w-[422px]" />
|
||||
</SkeletonContainer>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</PreviewContainer>
|
||||
</FloatRightContainer>
|
||||
)
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { RiArrowLeftLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
type StepTwoFooterProps = {
|
||||
isSetting?: boolean
|
||||
isCreating: boolean
|
||||
onPrevious: () => void
|
||||
onCreate: () => void
|
||||
onCancel?: () => void
|
||||
}
|
||||
|
||||
export const StepTwoFooter: FC<StepTwoFooterProps> = ({
|
||||
isSetting,
|
||||
isCreating,
|
||||
onPrevious,
|
||||
onCreate,
|
||||
onCancel,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
if (!isSetting) {
|
||||
return (
|
||||
<div className="mt-8 flex items-center py-2">
|
||||
<Button onClick={onPrevious}>
|
||||
<RiArrowLeftLine className="mr-1 h-4 w-4" />
|
||||
{t('stepTwo.previousStep', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
<Button
|
||||
className="ml-auto"
|
||||
loading={isCreating}
|
||||
variant="primary"
|
||||
onClick={onCreate}
|
||||
>
|
||||
{t('stepTwo.nextStep', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mt-8 flex items-center py-2">
|
||||
<Button
|
||||
loading={isCreating}
|
||||
variant="primary"
|
||||
onClick={onCreate}
|
||||
>
|
||||
{t('stepTwo.save', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
<Button className="ml-2" onClick={onCancel}>
|
||||
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
export { useDocumentCreation } from './use-document-creation'
|
||||
export type { DocumentCreation, ValidationParams } from './use-document-creation'
|
||||
|
||||
export { IndexingType, useIndexingConfig } from './use-indexing-config'
|
||||
export type { IndexingConfig } from './use-indexing-config'
|
||||
|
||||
export { useIndexingEstimate } from './use-indexing-estimate'
|
||||
export type { IndexingEstimate } from './use-indexing-estimate'
|
||||
|
||||
export { usePreviewState } from './use-preview-state'
|
||||
export type { PreviewState } from './use-preview-state'
|
||||
|
||||
export { DEFAULT_MAXIMUM_CHUNK_LENGTH, DEFAULT_OVERLAP, DEFAULT_SEGMENT_IDENTIFIER, defaultParentChildConfig, MAXIMUM_CHUNK_TOKEN_LENGTH, useSegmentationState } from './use-segmentation-state'
|
||||
export type { ParentChildConfig, SegmentationState } from './use-segmentation-state'
|
||||
@@ -1,279 +0,0 @@
|
||||
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { NotionPage } from '@/models/common'
|
||||
import type {
|
||||
ChunkingMode,
|
||||
CrawlOptions,
|
||||
CrawlResultItem,
|
||||
CreateDocumentReq,
|
||||
createDocumentResponse,
|
||||
CustomFile,
|
||||
FullDocumentDetail,
|
||||
ProcessRule,
|
||||
} from '@/models/datasets'
|
||||
import type { RetrievalConfig, RETRIEVE_METHOD } from '@/types/app'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { DataSourceProvider } from '@/models/common'
|
||||
import {
|
||||
DataSourceType,
|
||||
} from '@/models/datasets'
|
||||
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument } from '@/service/knowledge/use-create-dataset'
|
||||
import { useInvalidDatasetList } from '@/service/knowledge/use-dataset'
|
||||
import { IndexingType } from './use-indexing-config'
|
||||
import { MAXIMUM_CHUNK_TOKEN_LENGTH } from './use-segmentation-state'
|
||||
|
||||
export type UseDocumentCreationOptions = {
|
||||
datasetId?: string
|
||||
isSetting?: boolean
|
||||
documentDetail?: FullDocumentDetail
|
||||
dataSourceType: DataSourceType
|
||||
files: CustomFile[]
|
||||
notionPages: NotionPage[]
|
||||
notionCredentialId: string
|
||||
websitePages: CrawlResultItem[]
|
||||
crawlOptions?: CrawlOptions
|
||||
websiteCrawlProvider?: DataSourceProvider
|
||||
websiteCrawlJobId?: string
|
||||
// Callbacks
|
||||
onStepChange?: (delta: number) => void
|
||||
updateIndexingTypeCache?: (type: string) => void
|
||||
updateResultCache?: (res: createDocumentResponse) => void
|
||||
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
|
||||
onSave?: () => void
|
||||
mutateDatasetRes?: () => void
|
||||
}
|
||||
|
||||
export type ValidationParams = {
|
||||
segmentationType: string
|
||||
maxChunkLength: number
|
||||
limitMaxChunkLength: number
|
||||
overlap: number
|
||||
indexType: IndexingType
|
||||
embeddingModel: DefaultModel
|
||||
rerankModelList: Model[]
|
||||
retrievalConfig: RetrievalConfig
|
||||
}
|
||||
|
||||
export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
datasetId,
|
||||
isSetting,
|
||||
documentDetail,
|
||||
dataSourceType,
|
||||
files,
|
||||
notionPages,
|
||||
notionCredentialId,
|
||||
websitePages,
|
||||
crawlOptions,
|
||||
websiteCrawlProvider = DataSourceProvider.jinaReader,
|
||||
websiteCrawlJobId = '',
|
||||
onStepChange,
|
||||
updateIndexingTypeCache,
|
||||
updateResultCache,
|
||||
updateRetrievalMethodCache,
|
||||
onSave,
|
||||
mutateDatasetRes,
|
||||
} = options
|
||||
|
||||
const createFirstDocumentMutation = useCreateFirstDocument()
|
||||
const createDocumentMutation = useCreateDocument(datasetId!)
|
||||
const invalidDatasetList = useInvalidDatasetList()
|
||||
|
||||
const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending
|
||||
|
||||
// Validate creation params
|
||||
const validateParams = useCallback((params: ValidationParams): boolean => {
|
||||
const {
|
||||
segmentationType,
|
||||
maxChunkLength,
|
||||
limitMaxChunkLength,
|
||||
overlap,
|
||||
indexType,
|
||||
embeddingModel,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
} = params
|
||||
|
||||
if (segmentationType === 'general' && overlap > maxChunkLength) {
|
||||
Toast.notify({ type: 'error', message: t('stepTwo.overlapCheck', { ns: 'datasetCreation' }) })
|
||||
return false
|
||||
}
|
||||
|
||||
if (segmentationType === 'general' && maxChunkLength > limitMaxChunkLength) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: limitMaxChunkLength }),
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
if (!isSetting) {
|
||||
if (indexType === IndexingType.QUALIFIED && (!embeddingModel.model || !embeddingModel.provider)) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('datasetConfig.embeddingModelRequired', { ns: 'appDebug' }),
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
if (!isReRankModelSelected({
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod: indexType,
|
||||
})) {
|
||||
Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) })
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}, [t, isSetting])
|
||||
|
||||
// Build creation params
|
||||
const buildCreationParams = useCallback((
|
||||
currentDocForm: ChunkingMode,
|
||||
docLanguage: string,
|
||||
processRule: ProcessRule,
|
||||
retrievalConfig: RetrievalConfig,
|
||||
embeddingModel: DefaultModel,
|
||||
indexingTechnique: string,
|
||||
): CreateDocumentReq | null => {
|
||||
if (isSetting) {
|
||||
return {
|
||||
original_document_id: documentDetail?.id,
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
process_rule: processRule,
|
||||
retrieval_model: retrievalConfig,
|
||||
embedding_model: embeddingModel.model,
|
||||
embedding_model_provider: embeddingModel.provider,
|
||||
indexing_technique: indexingTechnique,
|
||||
} as CreateDocumentReq
|
||||
}
|
||||
|
||||
const params: CreateDocumentReq = {
|
||||
data_source: {
|
||||
type: dataSourceType,
|
||||
info_list: {
|
||||
data_source_type: dataSourceType,
|
||||
},
|
||||
},
|
||||
indexing_technique: indexingTechnique,
|
||||
process_rule: processRule,
|
||||
doc_form: currentDocForm,
|
||||
doc_language: docLanguage,
|
||||
retrieval_model: retrievalConfig,
|
||||
embedding_model: embeddingModel.model,
|
||||
embedding_model_provider: embeddingModel.provider,
|
||||
} as CreateDocumentReq
|
||||
|
||||
// Add data source specific info
|
||||
if (dataSourceType === DataSourceType.FILE) {
|
||||
params.data_source!.info_list.file_info_list = {
|
||||
file_ids: files.map(file => file.id || '').filter(Boolean),
|
||||
}
|
||||
}
|
||||
if (dataSourceType === DataSourceType.NOTION)
|
||||
params.data_source!.info_list.notion_info_list = getNotionInfo(notionPages, notionCredentialId)
|
||||
|
||||
if (dataSourceType === DataSourceType.WEB) {
|
||||
params.data_source!.info_list.website_info_list = getWebsiteInfo({
|
||||
websiteCrawlProvider,
|
||||
websiteCrawlJobId,
|
||||
websitePages,
|
||||
crawlOptions,
|
||||
})
|
||||
}
|
||||
|
||||
return params
|
||||
}, [
|
||||
isSetting,
|
||||
documentDetail,
|
||||
dataSourceType,
|
||||
files,
|
||||
notionPages,
|
||||
notionCredentialId,
|
||||
websitePages,
|
||||
websiteCrawlProvider,
|
||||
websiteCrawlJobId,
|
||||
crawlOptions,
|
||||
])
|
||||
|
||||
// Execute creation
|
||||
const executeCreation = useCallback(async (
|
||||
params: CreateDocumentReq,
|
||||
indexType: IndexingType,
|
||||
retrievalConfig: RetrievalConfig,
|
||||
) => {
|
||||
if (!datasetId) {
|
||||
await createFirstDocumentMutation.mutateAsync(params, {
|
||||
onSuccess(data) {
|
||||
updateIndexingTypeCache?.(indexType)
|
||||
updateResultCache?.(data)
|
||||
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
|
||||
},
|
||||
})
|
||||
}
|
||||
else {
|
||||
await createDocumentMutation.mutateAsync(params, {
|
||||
onSuccess(data) {
|
||||
updateIndexingTypeCache?.(indexType)
|
||||
updateResultCache?.(data)
|
||||
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
mutateDatasetRes?.()
|
||||
invalidDatasetList()
|
||||
|
||||
trackEvent('create_datasets', {
|
||||
data_source_type: dataSourceType,
|
||||
indexing_technique: indexType,
|
||||
})
|
||||
|
||||
onStepChange?.(+1)
|
||||
|
||||
if (isSetting)
|
||||
onSave?.()
|
||||
}, [
|
||||
datasetId,
|
||||
createFirstDocumentMutation,
|
||||
createDocumentMutation,
|
||||
updateIndexingTypeCache,
|
||||
updateResultCache,
|
||||
updateRetrievalMethodCache,
|
||||
mutateDatasetRes,
|
||||
invalidDatasetList,
|
||||
dataSourceType,
|
||||
onStepChange,
|
||||
isSetting,
|
||||
onSave,
|
||||
])
|
||||
|
||||
// Validate preview params
|
||||
const validatePreviewParams = useCallback((maxChunkLength: number): boolean => {
|
||||
if (maxChunkLength > MAXIMUM_CHUNK_TOKEN_LENGTH) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: MAXIMUM_CHUNK_TOKEN_LENGTH }),
|
||||
})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}, [t])
|
||||
|
||||
return {
|
||||
isCreating,
|
||||
validateParams,
|
||||
buildCreationParams,
|
||||
executeCreation,
|
||||
validatePreviewParams,
|
||||
}
|
||||
}
|
||||
|
||||
export type DocumentCreation = ReturnType<typeof useDocumentCreation>
|
||||
@@ -1,143 +0,0 @@
|
||||
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
|
||||
export enum IndexingType {
|
||||
QUALIFIED = 'high_quality',
|
||||
ECONOMICAL = 'economy',
|
||||
}
|
||||
|
||||
const DEFAULT_RETRIEVAL_CONFIG: RetrievalConfig = {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
}
|
||||
|
||||
export type UseIndexingConfigOptions = {
|
||||
initialIndexType?: IndexingType
|
||||
initialEmbeddingModel?: DefaultModel
|
||||
initialRetrievalConfig?: RetrievalConfig
|
||||
isAPIKeySet: boolean
|
||||
hasSetIndexType: boolean
|
||||
}
|
||||
|
||||
export const useIndexingConfig = (options: UseIndexingConfigOptions) => {
|
||||
const {
|
||||
initialIndexType,
|
||||
initialEmbeddingModel,
|
||||
initialRetrievalConfig,
|
||||
isAPIKeySet,
|
||||
hasSetIndexType,
|
||||
} = options
|
||||
|
||||
// Rerank model
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
// Embedding model list
|
||||
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
|
||||
const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
|
||||
|
||||
// Index type state
|
||||
const [indexType, setIndexType] = useState<IndexingType>(() => {
|
||||
if (initialIndexType)
|
||||
return initialIndexType
|
||||
return isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL
|
||||
})
|
||||
|
||||
// Embedding model state
|
||||
const [embeddingModel, setEmbeddingModel] = useState<DefaultModel>(
|
||||
initialEmbeddingModel ?? {
|
||||
provider: defaultEmbeddingModel?.provider.provider || '',
|
||||
model: defaultEmbeddingModel?.model || '',
|
||||
},
|
||||
)
|
||||
|
||||
// Retrieval config state
|
||||
const [retrievalConfig, setRetrievalConfig] = useState<RetrievalConfig>(
|
||||
initialRetrievalConfig ?? DEFAULT_RETRIEVAL_CONFIG,
|
||||
)
|
||||
|
||||
// Sync retrieval config with rerank model when available
|
||||
useEffect(() => {
|
||||
if (initialRetrievalConfig)
|
||||
return
|
||||
|
||||
setRetrievalConfig({
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: !!isRerankDefaultModelValid,
|
||||
reranking_model: {
|
||||
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
|
||||
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
|
||||
},
|
||||
top_k: 3,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0.5,
|
||||
})
|
||||
}, [rerankDefaultModel, isRerankDefaultModelValid, initialRetrievalConfig])
|
||||
|
||||
// Sync index type with props
|
||||
useEffect(() => {
|
||||
if (initialIndexType)
|
||||
setIndexType(initialIndexType)
|
||||
else
|
||||
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
|
||||
}, [isAPIKeySet, initialIndexType])
|
||||
|
||||
// Show multimodal tip
|
||||
const showMultiModalTip = useMemo(() => {
|
||||
return checkShowMultiModalTip({
|
||||
embeddingModel,
|
||||
rerankingEnable: retrievalConfig.reranking_enable,
|
||||
rerankModel: {
|
||||
rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name,
|
||||
rerankingModelName: retrievalConfig.reranking_model.reranking_model_name,
|
||||
},
|
||||
indexMethod: indexType,
|
||||
embeddingModelList,
|
||||
rerankModelList,
|
||||
})
|
||||
}, [embeddingModel, retrievalConfig, indexType, embeddingModelList, rerankModelList])
|
||||
|
||||
// Get effective indexing technique
|
||||
const getIndexingTechnique = () => initialIndexType || indexType
|
||||
|
||||
return {
|
||||
// Index type
|
||||
indexType,
|
||||
setIndexType,
|
||||
hasSetIndexType,
|
||||
getIndexingTechnique,
|
||||
|
||||
// Embedding model
|
||||
embeddingModel,
|
||||
setEmbeddingModel,
|
||||
embeddingModelList,
|
||||
defaultEmbeddingModel,
|
||||
|
||||
// Retrieval config
|
||||
retrievalConfig,
|
||||
setRetrievalConfig,
|
||||
rerankModelList,
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelValid,
|
||||
|
||||
// Computed
|
||||
showMultiModalTip,
|
||||
}
|
||||
}
|
||||
|
||||
export type IndexingConfig = ReturnType<typeof useIndexingConfig>
|
||||
@@ -1,123 +0,0 @@
|
||||
import type { IndexingType } from './use-indexing-config'
|
||||
import type { NotionPage } from '@/models/common'
|
||||
import type { ChunkingMode, CrawlOptions, CrawlResultItem, CustomFile, ProcessRule } from '@/models/datasets'
|
||||
import { useCallback } from 'react'
|
||||
import { DataSourceProvider } from '@/models/common'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import {
|
||||
useFetchFileIndexingEstimateForFile,
|
||||
useFetchFileIndexingEstimateForNotion,
|
||||
useFetchFileIndexingEstimateForWeb,
|
||||
} from '@/service/knowledge/use-create-dataset'
|
||||
|
||||
export type UseIndexingEstimateOptions = {
|
||||
dataSourceType: DataSourceType
|
||||
datasetId?: string
|
||||
// Document settings
|
||||
currentDocForm: ChunkingMode
|
||||
docLanguage: string
|
||||
// File data source
|
||||
files: CustomFile[]
|
||||
previewFileName?: string
|
||||
// Notion data source
|
||||
previewNotionPage: NotionPage
|
||||
notionCredentialId: string
|
||||
// Website data source
|
||||
previewWebsitePage: CrawlResultItem
|
||||
crawlOptions?: CrawlOptions
|
||||
websiteCrawlProvider?: DataSourceProvider
|
||||
websiteCrawlJobId?: string
|
||||
// Processing
|
||||
indexingTechnique: IndexingType
|
||||
processRule: ProcessRule
|
||||
}
|
||||
|
||||
export const useIndexingEstimate = (options: UseIndexingEstimateOptions) => {
|
||||
const {
|
||||
dataSourceType,
|
||||
datasetId,
|
||||
currentDocForm,
|
||||
docLanguage,
|
||||
files,
|
||||
previewFileName,
|
||||
previewNotionPage,
|
||||
notionCredentialId,
|
||||
previewWebsitePage,
|
||||
crawlOptions,
|
||||
websiteCrawlProvider,
|
||||
websiteCrawlJobId,
|
||||
indexingTechnique,
|
||||
processRule,
|
||||
} = options
|
||||
|
||||
// File indexing estimate
|
||||
const fileQuery = useFetchFileIndexingEstimateForFile({
|
||||
docForm: currentDocForm,
|
||||
docLanguage,
|
||||
dataSourceType: DataSourceType.FILE,
|
||||
files: previewFileName
|
||||
? [files.find(file => file.name === previewFileName)!]
|
||||
: files,
|
||||
indexingTechnique,
|
||||
processRule,
|
||||
dataset_id: datasetId!,
|
||||
})
|
||||
|
||||
// Notion indexing estimate
|
||||
const notionQuery = useFetchFileIndexingEstimateForNotion({
|
||||
docForm: currentDocForm,
|
||||
docLanguage,
|
||||
dataSourceType: DataSourceType.NOTION,
|
||||
notionPages: [previewNotionPage],
|
||||
indexingTechnique,
|
||||
processRule,
|
||||
dataset_id: datasetId || '',
|
||||
credential_id: notionCredentialId,
|
||||
})
|
||||
|
||||
// Website indexing estimate
|
||||
const websiteQuery = useFetchFileIndexingEstimateForWeb({
|
||||
docForm: currentDocForm,
|
||||
docLanguage,
|
||||
dataSourceType: DataSourceType.WEB,
|
||||
websitePages: [previewWebsitePage],
|
||||
crawlOptions,
|
||||
websiteCrawlProvider: websiteCrawlProvider ?? DataSourceProvider.jinaReader,
|
||||
websiteCrawlJobId: websiteCrawlJobId ?? '',
|
||||
indexingTechnique,
|
||||
processRule,
|
||||
dataset_id: datasetId || '',
|
||||
})
|
||||
|
||||
// Get current mutation based on data source type
|
||||
const getCurrentMutation = useCallback(() => {
|
||||
if (dataSourceType === DataSourceType.FILE)
|
||||
return fileQuery
|
||||
if (dataSourceType === DataSourceType.NOTION)
|
||||
return notionQuery
|
||||
return websiteQuery
|
||||
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
|
||||
|
||||
const currentMutation = getCurrentMutation()
|
||||
|
||||
// Trigger estimate fetch
|
||||
const fetchEstimate = useCallback(() => {
|
||||
if (dataSourceType === DataSourceType.FILE)
|
||||
fileQuery.mutate()
|
||||
else if (dataSourceType === DataSourceType.NOTION)
|
||||
notionQuery.mutate()
|
||||
else
|
||||
websiteQuery.mutate()
|
||||
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
|
||||
|
||||
return {
|
||||
currentMutation,
|
||||
estimate: currentMutation.data,
|
||||
isIdle: currentMutation.isIdle,
|
||||
isPending: currentMutation.isPending,
|
||||
fetchEstimate,
|
||||
reset: currentMutation.reset,
|
||||
}
|
||||
}
|
||||
|
||||
export type IndexingEstimate = ReturnType<typeof useIndexingEstimate>
|
||||
@@ -1,127 +0,0 @@
|
||||
import type { NotionPage } from '@/models/common'
|
||||
import type { CrawlResultItem, CustomFile, DocumentItem, FullDocumentDetail } from '@/models/datasets'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
|
||||
export type UsePreviewStateOptions = {
|
||||
dataSourceType: DataSourceType
|
||||
files: CustomFile[]
|
||||
notionPages: NotionPage[]
|
||||
websitePages: CrawlResultItem[]
|
||||
documentDetail?: FullDocumentDetail
|
||||
datasetId?: string
|
||||
}
|
||||
|
||||
export const usePreviewState = (options: UsePreviewStateOptions) => {
|
||||
const {
|
||||
dataSourceType,
|
||||
files,
|
||||
notionPages,
|
||||
websitePages,
|
||||
documentDetail,
|
||||
datasetId,
|
||||
} = options
|
||||
|
||||
// File preview state
|
||||
const [previewFile, setPreviewFile] = useState<DocumentItem>(
|
||||
(datasetId && documentDetail)
|
||||
? documentDetail.file
|
||||
: files[0],
|
||||
)
|
||||
|
||||
// Notion page preview state
|
||||
const [previewNotionPage, setPreviewNotionPage] = useState<NotionPage>(
|
||||
(datasetId && documentDetail)
|
||||
? documentDetail.notion_page
|
||||
: notionPages[0],
|
||||
)
|
||||
|
||||
// Website page preview state
|
||||
const [previewWebsitePage, setPreviewWebsitePage] = useState<CrawlResultItem>(
|
||||
(datasetId && documentDetail)
|
||||
? documentDetail.website_page
|
||||
: websitePages[0],
|
||||
)
|
||||
|
||||
// Get preview items for document picker based on data source type
|
||||
const getPreviewPickerItems = useCallback(() => {
|
||||
if (dataSourceType === DataSourceType.FILE) {
|
||||
return files as Array<Required<CustomFile>>
|
||||
}
|
||||
if (dataSourceType === DataSourceType.NOTION) {
|
||||
return notionPages.map(page => ({
|
||||
id: page.page_id,
|
||||
name: page.page_name,
|
||||
extension: 'md',
|
||||
}))
|
||||
}
|
||||
if (dataSourceType === DataSourceType.WEB) {
|
||||
return websitePages.map(page => ({
|
||||
id: page.source_url,
|
||||
name: page.title,
|
||||
extension: 'md',
|
||||
}))
|
||||
}
|
||||
return []
|
||||
}, [dataSourceType, files, notionPages, websitePages])
|
||||
|
||||
// Get current preview value for picker
|
||||
const getPreviewPickerValue = useCallback(() => {
|
||||
if (dataSourceType === DataSourceType.FILE) {
|
||||
return previewFile as Required<CustomFile>
|
||||
}
|
||||
if (dataSourceType === DataSourceType.NOTION) {
|
||||
return {
|
||||
id: previewNotionPage?.page_id || '',
|
||||
name: previewNotionPage?.page_name || '',
|
||||
extension: 'md',
|
||||
}
|
||||
}
|
||||
if (dataSourceType === DataSourceType.WEB) {
|
||||
return {
|
||||
id: previewWebsitePage?.source_url || '',
|
||||
name: previewWebsitePage?.title || '',
|
||||
extension: 'md',
|
||||
}
|
||||
}
|
||||
return { id: '', name: '', extension: '' }
|
||||
}, [dataSourceType, previewFile, previewNotionPage, previewWebsitePage])
|
||||
|
||||
// Handle preview change
|
||||
const handlePreviewChange = useCallback((selected: { id: string, name: string }) => {
|
||||
if (dataSourceType === DataSourceType.FILE) {
|
||||
setPreviewFile(selected as DocumentItem)
|
||||
}
|
||||
else if (dataSourceType === DataSourceType.NOTION) {
|
||||
const selectedPage = notionPages.find(page => page.page_id === selected.id)
|
||||
if (selectedPage)
|
||||
setPreviewNotionPage(selectedPage)
|
||||
}
|
||||
else if (dataSourceType === DataSourceType.WEB) {
|
||||
const selectedPage = websitePages.find(page => page.source_url === selected.id)
|
||||
if (selectedPage)
|
||||
setPreviewWebsitePage(selectedPage)
|
||||
}
|
||||
}, [dataSourceType, notionPages, websitePages])
|
||||
|
||||
return {
|
||||
// File preview
|
||||
previewFile,
|
||||
setPreviewFile,
|
||||
|
||||
// Notion preview
|
||||
previewNotionPage,
|
||||
setPreviewNotionPage,
|
||||
|
||||
// Website preview
|
||||
previewWebsitePage,
|
||||
setPreviewWebsitePage,
|
||||
|
||||
// Picker helpers
|
||||
getPreviewPickerItems,
|
||||
getPreviewPickerValue,
|
||||
handlePreviewChange,
|
||||
}
|
||||
}
|
||||
|
||||
export type PreviewState = ReturnType<typeof usePreviewState>
|
||||
@@ -1,222 +0,0 @@
|
||||
import type { ParentMode, PreProcessingRule, ProcessRule, Rules } from '@/models/datasets'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { ChunkingMode, ProcessMode } from '@/models/datasets'
|
||||
import escape from './escape'
|
||||
import unescape from './unescape'
|
||||
|
||||
// Constants
|
||||
export const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
|
||||
export const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024
|
||||
export const DEFAULT_OVERLAP = 50
|
||||
export const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(
|
||||
globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000',
|
||||
10,
|
||||
)
|
||||
|
||||
export type ParentChildConfig = {
|
||||
chunkForContext: ParentMode
|
||||
parent: {
|
||||
delimiter: string
|
||||
maxLength: number
|
||||
}
|
||||
child: {
|
||||
delimiter: string
|
||||
maxLength: number
|
||||
}
|
||||
}
|
||||
|
||||
export const defaultParentChildConfig: ParentChildConfig = {
|
||||
chunkForContext: 'paragraph',
|
||||
parent: {
|
||||
delimiter: '\\n\\n',
|
||||
maxLength: 1024,
|
||||
},
|
||||
child: {
|
||||
delimiter: '\\n',
|
||||
maxLength: 512,
|
||||
},
|
||||
}
|
||||
|
||||
export type UseSegmentationStateOptions = {
|
||||
initialSegmentationType?: ProcessMode
|
||||
}
|
||||
|
||||
export const useSegmentationState = (options: UseSegmentationStateOptions = {}) => {
|
||||
const { initialSegmentationType } = options
|
||||
|
||||
// Segmentation type (general or parent-child)
|
||||
const [segmentationType, setSegmentationType] = useState<ProcessMode>(
|
||||
initialSegmentationType ?? ProcessMode.general,
|
||||
)
|
||||
|
||||
// General chunking settings
|
||||
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
|
||||
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH)
|
||||
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(MAXIMUM_CHUNK_TOKEN_LENGTH)
|
||||
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
|
||||
|
||||
// Pre-processing rules
|
||||
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
||||
const [defaultConfig, setDefaultConfig] = useState<Rules>()
|
||||
|
||||
// Parent-child config
|
||||
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
|
||||
|
||||
// Escaped segment identifier setter
|
||||
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
|
||||
if (value) {
|
||||
doSetSegmentIdentifier(escape(value))
|
||||
}
|
||||
else {
|
||||
doSetSegmentIdentifier(canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Rule toggle handler
|
||||
const toggleRule = useCallback((id: string) => {
|
||||
setRules(prev => prev.map(rule =>
|
||||
rule.id === id ? { ...rule, enabled: !rule.enabled } : rule,
|
||||
))
|
||||
}, [])
|
||||
|
||||
// Reset to defaults
|
||||
const resetToDefaults = useCallback(() => {
|
||||
if (defaultConfig) {
|
||||
setSegmentIdentifier(defaultConfig.segmentation.separator)
|
||||
setMaxChunkLength(defaultConfig.segmentation.max_tokens)
|
||||
setOverlap(defaultConfig.segmentation.chunk_overlap!)
|
||||
setRules(defaultConfig.pre_processing_rules)
|
||||
}
|
||||
setParentChildConfig(defaultParentChildConfig)
|
||||
}, [defaultConfig, setSegmentIdentifier])
|
||||
|
||||
// Apply config from document detail
|
||||
const applyConfigFromRules = useCallback((rulesConfig: Rules, isHierarchical: boolean) => {
|
||||
const separator = rulesConfig.segmentation.separator
|
||||
const max = rulesConfig.segmentation.max_tokens
|
||||
const chunkOverlap = rulesConfig.segmentation.chunk_overlap
|
||||
|
||||
setSegmentIdentifier(separator)
|
||||
setMaxChunkLength(max)
|
||||
setOverlap(chunkOverlap!)
|
||||
setRules(rulesConfig.pre_processing_rules)
|
||||
setDefaultConfig(rulesConfig)
|
||||
|
||||
if (isHierarchical) {
|
||||
setParentChildConfig({
|
||||
chunkForContext: rulesConfig.parent_mode || 'paragraph',
|
||||
parent: {
|
||||
delimiter: escape(rulesConfig.segmentation.separator),
|
||||
maxLength: rulesConfig.segmentation.max_tokens,
|
||||
},
|
||||
child: {
|
||||
delimiter: escape(rulesConfig.subchunk_segmentation!.separator),
|
||||
maxLength: rulesConfig.subchunk_segmentation!.max_tokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
}, [setSegmentIdentifier])
|
||||
|
||||
// Get process rule for API
|
||||
const getProcessRule = useCallback((docForm: ChunkingMode): ProcessRule => {
|
||||
if (docForm === ChunkingMode.parentChild) {
|
||||
return {
|
||||
rules: {
|
||||
pre_processing_rules: rules,
|
||||
segmentation: {
|
||||
separator: unescape(parentChildConfig.parent.delimiter),
|
||||
max_tokens: parentChildConfig.parent.maxLength,
|
||||
},
|
||||
parent_mode: parentChildConfig.chunkForContext,
|
||||
subchunk_segmentation: {
|
||||
separator: unescape(parentChildConfig.child.delimiter),
|
||||
max_tokens: parentChildConfig.child.maxLength,
|
||||
},
|
||||
},
|
||||
mode: 'hierarchical',
|
||||
} as ProcessRule
|
||||
}
|
||||
|
||||
return {
|
||||
rules: {
|
||||
pre_processing_rules: rules,
|
||||
segmentation: {
|
||||
separator: unescape(segmentIdentifier),
|
||||
max_tokens: maxChunkLength,
|
||||
chunk_overlap: overlap,
|
||||
},
|
||||
},
|
||||
mode: segmentationType,
|
||||
} as ProcessRule
|
||||
}, [rules, parentChildConfig, segmentIdentifier, maxChunkLength, overlap, segmentationType])
|
||||
|
||||
// Update parent config field
|
||||
const updateParentConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
|
||||
setParentChildConfig((prev) => {
|
||||
let newValue: string | number
|
||||
if (field === 'delimiter')
|
||||
newValue = value ? escape(value as string) : ''
|
||||
else
|
||||
newValue = value
|
||||
return {
|
||||
...prev,
|
||||
parent: { ...prev.parent, [field]: newValue },
|
||||
}
|
||||
})
|
||||
}, [])
|
||||
|
||||
// Update child config field
|
||||
const updateChildConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
|
||||
setParentChildConfig((prev) => {
|
||||
let newValue: string | number
|
||||
if (field === 'delimiter')
|
||||
newValue = value ? escape(value as string) : ''
|
||||
else
|
||||
newValue = value
|
||||
return {
|
||||
...prev,
|
||||
child: { ...prev.child, [field]: newValue },
|
||||
}
|
||||
})
|
||||
}, [])
|
||||
|
||||
// Set chunk for context mode
|
||||
const setChunkForContext = useCallback((mode: ParentMode) => {
|
||||
setParentChildConfig(prev => ({ ...prev, chunkForContext: mode }))
|
||||
}, [])
|
||||
|
||||
return {
|
||||
// General chunking state
|
||||
segmentationType,
|
||||
setSegmentationType,
|
||||
segmentIdentifier,
|
||||
setSegmentIdentifier,
|
||||
maxChunkLength,
|
||||
setMaxChunkLength,
|
||||
limitMaxChunkLength,
|
||||
setLimitMaxChunkLength,
|
||||
overlap,
|
||||
setOverlap,
|
||||
|
||||
// Rules
|
||||
rules,
|
||||
setRules,
|
||||
defaultConfig,
|
||||
setDefaultConfig,
|
||||
toggleRule,
|
||||
|
||||
// Parent-child config
|
||||
parentChildConfig,
|
||||
setParentChildConfig,
|
||||
updateParentConfig,
|
||||
updateChildConfig,
|
||||
setChunkForContext,
|
||||
|
||||
// Actions
|
||||
resetToDefaults,
|
||||
applyConfigFromRules,
|
||||
getProcessRule,
|
||||
}
|
||||
}
|
||||
|
||||
export type SegmentationState = ReturnType<typeof useSegmentationState>
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,28 +0,0 @@
|
||||
import type { IndexingType } from './hooks'
|
||||
import type { DataSourceProvider, NotionPage } from '@/models/common'
|
||||
import type { CrawlOptions, CrawlResultItem, createDocumentResponse, CustomFile, DataSourceType, FullDocumentDetail } from '@/models/datasets'
|
||||
import type { RETRIEVE_METHOD } from '@/types/app'
|
||||
|
||||
export type StepTwoProps = {
|
||||
isSetting?: boolean
|
||||
documentDetail?: FullDocumentDetail
|
||||
isAPIKeySet: boolean
|
||||
onSetting: () => void
|
||||
datasetId?: string
|
||||
indexingType?: IndexingType
|
||||
retrievalMethod?: string
|
||||
dataSourceType: DataSourceType
|
||||
files: CustomFile[]
|
||||
notionPages?: NotionPage[]
|
||||
notionCredentialId: string
|
||||
websitePages?: CrawlResultItem[]
|
||||
crawlOptions?: CrawlOptions
|
||||
websiteCrawlProvider?: DataSourceProvider
|
||||
websiteCrawlJobId?: string
|
||||
onStepChange?: (delta: number) => void
|
||||
updateIndexingTypeCache?: (type: string) => void
|
||||
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
|
||||
updateResultCache?: (res: createDocumentResponse) => void
|
||||
onSave?: () => void
|
||||
onCancel?: () => void
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import type { FC } from 'react'
|
||||
import {
|
||||
RiFileTextLine,
|
||||
RiFilmAiLine,
|
||||
RiHammerLine,
|
||||
RiImageCircleAiLine,
|
||||
RiVoiceAiFill,
|
||||
} from '@remixicon/react'
|
||||
@@ -38,17 +39,33 @@ const FeatureIcon: FC<FeatureIconProps> = ({
|
||||
// )
|
||||
// }
|
||||
|
||||
// if (feature === ModelFeatureEnum.toolCall) {
|
||||
// return (
|
||||
// <Tooltip
|
||||
// popupContent={t('common.modelProvider.featureSupported', { feature: ModelFeatureTextEnum.toolCall })}
|
||||
// >
|
||||
// <ModelBadge className={`mr-0.5 !px-0 w-[18px] justify-center text-gray-500 ${className}`}>
|
||||
// <MagicWand className='w-3 h-3' />
|
||||
// </ModelBadge>
|
||||
// </Tooltip>
|
||||
// )
|
||||
// }
|
||||
if (feature === ModelFeatureEnum.toolCall) {
|
||||
if (showFeaturesLabel) {
|
||||
return (
|
||||
<ModelBadge className={cn('gap-x-0.5', className)}>
|
||||
<RiHammerLine className="size-3" />
|
||||
<span>{ModelFeatureTextEnum.toolCall}</span>
|
||||
</ModelBadge>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
popupContent={t('modelProvider.featureSupported', { ns: 'common', feature: ModelFeatureTextEnum.toolCall })}
|
||||
>
|
||||
<div className="inline-block cursor-help">
|
||||
<ModelBadge
|
||||
className={cn(
|
||||
'w-[18px] justify-center !px-0',
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<RiHammerLine className="size-3" />
|
||||
</ModelBadge>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
// if (feature === ModelFeatureEnum.multiToolCall) {
|
||||
// return (
|
||||
|
||||
@@ -96,6 +96,14 @@ const PopupItem: FC<PopupItemProps> = ({
|
||||
<div className='text-text-tertiary system-xs-regular'>{currentProvider?.description?.[language] || currentProvider?.description?.en_US}</div>
|
||||
)} */}
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{
|
||||
modelItem.features?.includes(ModelFeatureEnum.toolCall) && (
|
||||
<FeatureIcon
|
||||
feature={ModelFeatureEnum.toolCall}
|
||||
showFeaturesLabel
|
||||
/>
|
||||
)
|
||||
}
|
||||
{modelItem.model_type && (
|
||||
<ModelBadge>
|
||||
{modelTypeFormat(modelItem.model_type)}
|
||||
@@ -118,7 +126,7 @@ const PopupItem: FC<PopupItemProps> = ({
|
||||
<div className="pt-2">
|
||||
<div className="system-2xs-medium-uppercase mb-1 text-text-tertiary">{t('model.capabilities', { ns: 'common' })}</div>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{modelItem.features?.map(feature => (
|
||||
{modelItem.features?.filter(feature => feature !== ModelFeatureEnum.toolCall).map(feature => (
|
||||
<FeatureIcon
|
||||
key={feature}
|
||||
feature={feature}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import type { ActivePluginType } from './constants'
|
||||
import type { PluginsSort, SearchParamsFromCollection } from './types'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { useQueryState } from 'nuqs'
|
||||
import { useCallback } from 'react'
|
||||
import { DEFAULT_SORT, PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
|
||||
import { marketplaceSearchParamsParsers } from './search-params'
|
||||
|
||||
const marketplaceSortAtom = atom<PluginsSort>(DEFAULT_SORT)
|
||||
export function useMarketplaceSort() {
|
||||
return useAtom(marketplaceSortAtom)
|
||||
}
|
||||
export function useMarketplaceSortValue() {
|
||||
return useAtomValue(marketplaceSortAtom)
|
||||
}
|
||||
export function useSetMarketplaceSort() {
|
||||
return useSetAtom(marketplaceSortAtom)
|
||||
}
|
||||
|
||||
/**
|
||||
* Preserve the state for marketplace
|
||||
*/
|
||||
export const preserveSearchStateInQueryAtom = atom<boolean>(false)
|
||||
|
||||
const searchPluginTextAtom = atom<string>('')
|
||||
const activePluginTypeAtom = atom<ActivePluginType>('all')
|
||||
const filterPluginTagsAtom = atom<string[]>([])
|
||||
|
||||
export function useSearchPluginText() {
|
||||
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
|
||||
const queryState = useQueryState('q', marketplaceSearchParamsParsers.q)
|
||||
const atomState = useAtom(searchPluginTextAtom)
|
||||
return preserveSearchStateInQuery ? queryState : atomState
|
||||
}
|
||||
export function useActivePluginType() {
|
||||
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
|
||||
const queryState = useQueryState('category', marketplaceSearchParamsParsers.category)
|
||||
const atomState = useAtom(activePluginTypeAtom)
|
||||
return preserveSearchStateInQuery ? queryState : atomState
|
||||
}
|
||||
export function useFilterPluginTags() {
|
||||
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
|
||||
const queryState = useQueryState('tags', marketplaceSearchParamsParsers.tags)
|
||||
const atomState = useAtom(filterPluginTagsAtom)
|
||||
return preserveSearchStateInQuery ? queryState : atomState
|
||||
}
|
||||
|
||||
/**
|
||||
* Not all categories have collections, so we need to
|
||||
* force the search mode for those categories.
|
||||
*/
|
||||
export const searchModeAtom = atom<true | null>(null)
|
||||
|
||||
export function useMarketplaceSearchMode() {
|
||||
const [searchPluginText] = useSearchPluginText()
|
||||
const [filterPluginTags] = useFilterPluginTags()
|
||||
const [activePluginType] = useActivePluginType()
|
||||
|
||||
const searchMode = useAtomValue(searchModeAtom)
|
||||
const isSearchMode = !!searchPluginText
|
||||
|| filterPluginTags.length > 0
|
||||
|| (searchMode ?? (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(activePluginType)))
|
||||
return isSearchMode
|
||||
}
|
||||
|
||||
export function useMarketplaceMoreClick() {
|
||||
const [,setQ] = useSearchPluginText()
|
||||
const setSort = useSetAtom(marketplaceSortAtom)
|
||||
const setSearchMode = useSetAtom(searchModeAtom)
|
||||
|
||||
return useCallback((searchParams?: SearchParamsFromCollection) => {
|
||||
if (!searchParams)
|
||||
return
|
||||
setQ(searchParams?.query || '')
|
||||
setSort({
|
||||
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
|
||||
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
|
||||
})
|
||||
setSearchMode(true)
|
||||
}, [setQ, setSort, setSearchMode])
|
||||
}
|
||||
@@ -1,30 +1,6 @@
|
||||
import { PluginCategoryEnum } from '../types'
|
||||
|
||||
export const DEFAULT_SORT = {
|
||||
sortBy: 'install_count',
|
||||
sortOrder: 'DESC',
|
||||
}
|
||||
|
||||
export const SCROLL_BOTTOM_THRESHOLD = 100
|
||||
|
||||
export const PLUGIN_TYPE_SEARCH_MAP = {
|
||||
all: 'all',
|
||||
model: PluginCategoryEnum.model,
|
||||
tool: PluginCategoryEnum.tool,
|
||||
agent: PluginCategoryEnum.agent,
|
||||
extension: PluginCategoryEnum.extension,
|
||||
datasource: PluginCategoryEnum.datasource,
|
||||
trigger: PluginCategoryEnum.trigger,
|
||||
bundle: 'bundle',
|
||||
} as const
|
||||
|
||||
type ValueOf<T> = T[keyof T]
|
||||
|
||||
export type ActivePluginType = ValueOf<typeof PLUGIN_TYPE_SEARCH_MAP>
|
||||
|
||||
export const PLUGIN_CATEGORY_WITH_COLLECTIONS = new Set<ActivePluginType>(
|
||||
[
|
||||
PLUGIN_TYPE_SEARCH_MAP.all,
|
||||
PLUGIN_TYPE_SEARCH_MAP.tool,
|
||||
],
|
||||
)
|
||||
|
||||
332
web/app/components/plugins/marketplace/context.tsx
Normal file
332
web/app/components/plugins/marketplace/context.tsx
Normal file
@@ -0,0 +1,332 @@
|
||||
'use client'
|
||||
|
||||
import type {
|
||||
ReactNode,
|
||||
} from 'react'
|
||||
import type { TagKey } from '../constants'
|
||||
import type { Plugin } from '../types'
|
||||
import type {
|
||||
MarketplaceCollection,
|
||||
PluginsSort,
|
||||
SearchParams,
|
||||
SearchParamsFromCollection,
|
||||
} from './types'
|
||||
import { debounce } from 'es-toolkit/compat'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import {
|
||||
createContext,
|
||||
useContextSelector,
|
||||
} from 'use-context-selector'
|
||||
import { useMarketplaceFilters } from '@/hooks/use-query-params'
|
||||
import { useInstalledPluginList } from '@/service/use-plugins'
|
||||
import {
|
||||
getValidCategoryKeys,
|
||||
getValidTagKeys,
|
||||
} from '../utils'
|
||||
import { DEFAULT_SORT } from './constants'
|
||||
import {
|
||||
useMarketplaceCollectionsAndPlugins,
|
||||
useMarketplaceContainerScroll,
|
||||
useMarketplacePlugins,
|
||||
} from './hooks'
|
||||
import { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch'
|
||||
import {
|
||||
getMarketplaceListCondition,
|
||||
getMarketplaceListFilterType,
|
||||
} from './utils'
|
||||
|
||||
export type MarketplaceContextValue = {
|
||||
searchPluginText: string
|
||||
handleSearchPluginTextChange: (text: string) => void
|
||||
filterPluginTags: string[]
|
||||
handleFilterPluginTagsChange: (tags: string[]) => void
|
||||
activePluginType: string
|
||||
handleActivePluginTypeChange: (type: string) => void
|
||||
page: number
|
||||
handlePageChange: () => void
|
||||
plugins?: Plugin[]
|
||||
pluginsTotal?: number
|
||||
resetPlugins: () => void
|
||||
sort: PluginsSort
|
||||
handleSortChange: (sort: PluginsSort) => void
|
||||
handleQueryPlugins: () => void
|
||||
handleMoreClick: (searchParams: SearchParamsFromCollection) => void
|
||||
marketplaceCollectionsFromClient?: MarketplaceCollection[]
|
||||
setMarketplaceCollectionsFromClient: (collections: MarketplaceCollection[]) => void
|
||||
marketplaceCollectionPluginsMapFromClient?: Record<string, Plugin[]>
|
||||
setMarketplaceCollectionPluginsMapFromClient: (map: Record<string, Plugin[]>) => void
|
||||
isLoading: boolean
|
||||
isSuccessCollections: boolean
|
||||
}
|
||||
|
||||
export const MarketplaceContext = createContext<MarketplaceContextValue>({
|
||||
searchPluginText: '',
|
||||
handleSearchPluginTextChange: noop,
|
||||
filterPluginTags: [],
|
||||
handleFilterPluginTagsChange: noop,
|
||||
activePluginType: 'all',
|
||||
handleActivePluginTypeChange: noop,
|
||||
page: 1,
|
||||
handlePageChange: noop,
|
||||
plugins: undefined,
|
||||
pluginsTotal: 0,
|
||||
resetPlugins: noop,
|
||||
sort: DEFAULT_SORT,
|
||||
handleSortChange: noop,
|
||||
handleQueryPlugins: noop,
|
||||
handleMoreClick: noop,
|
||||
marketplaceCollectionsFromClient: [],
|
||||
setMarketplaceCollectionsFromClient: noop,
|
||||
marketplaceCollectionPluginsMapFromClient: {},
|
||||
setMarketplaceCollectionPluginsMapFromClient: noop,
|
||||
isLoading: false,
|
||||
isSuccessCollections: false,
|
||||
})
|
||||
|
||||
type MarketplaceContextProviderProps = {
|
||||
children: ReactNode
|
||||
searchParams?: SearchParams
|
||||
shouldExclude?: boolean
|
||||
scrollContainerId?: string
|
||||
showSearchParams?: boolean
|
||||
}
|
||||
|
||||
export function useMarketplaceContext(selector: (value: MarketplaceContextValue) => any) {
|
||||
return useContextSelector(MarketplaceContext, selector)
|
||||
}
|
||||
|
||||
export const MarketplaceContextProvider = ({
|
||||
children,
|
||||
searchParams,
|
||||
shouldExclude,
|
||||
scrollContainerId,
|
||||
showSearchParams,
|
||||
}: MarketplaceContextProviderProps) => {
|
||||
// Use nuqs hook for URL-based filter state
|
||||
const [urlFilters, setUrlFilters] = useMarketplaceFilters()
|
||||
|
||||
const { data, isSuccess } = useInstalledPluginList(!shouldExclude)
|
||||
const exclude = useMemo(() => {
|
||||
if (shouldExclude)
|
||||
return data?.plugins.map(plugin => plugin.plugin_id)
|
||||
}, [data?.plugins, shouldExclude])
|
||||
|
||||
// Initialize from URL params (legacy support) or use nuqs state
|
||||
const queryFromSearchParams = searchParams?.q || urlFilters.q
|
||||
const tagsFromSearchParams = getValidTagKeys(urlFilters.tags as TagKey[])
|
||||
const hasValidTags = !!tagsFromSearchParams.length
|
||||
const hasValidCategory = getValidCategoryKeys(urlFilters.category)
|
||||
const categoryFromSearchParams = hasValidCategory || PLUGIN_TYPE_SEARCH_MAP.all
|
||||
|
||||
const [searchPluginText, setSearchPluginText] = useState(queryFromSearchParams)
|
||||
const searchPluginTextRef = useRef(searchPluginText)
|
||||
const [filterPluginTags, setFilterPluginTags] = useState<string[]>(tagsFromSearchParams)
|
||||
const filterPluginTagsRef = useRef(filterPluginTags)
|
||||
const [activePluginType, setActivePluginType] = useState(categoryFromSearchParams)
|
||||
const activePluginTypeRef = useRef(activePluginType)
|
||||
const [sort, setSort] = useState(DEFAULT_SORT)
|
||||
const sortRef = useRef(sort)
|
||||
const {
|
||||
marketplaceCollections: marketplaceCollectionsFromClient,
|
||||
setMarketplaceCollections: setMarketplaceCollectionsFromClient,
|
||||
marketplaceCollectionPluginsMap: marketplaceCollectionPluginsMapFromClient,
|
||||
setMarketplaceCollectionPluginsMap: setMarketplaceCollectionPluginsMapFromClient,
|
||||
queryMarketplaceCollectionsAndPlugins,
|
||||
isLoading,
|
||||
isSuccess: isSuccessCollections,
|
||||
} = useMarketplaceCollectionsAndPlugins()
|
||||
const {
|
||||
plugins,
|
||||
total: pluginsTotal,
|
||||
resetPlugins,
|
||||
queryPlugins,
|
||||
queryPluginsWithDebounced,
|
||||
cancelQueryPluginsWithDebounced,
|
||||
isLoading: isPluginsLoading,
|
||||
fetchNextPage: fetchNextPluginsPage,
|
||||
hasNextPage: hasNextPluginsPage,
|
||||
page: pluginsPage,
|
||||
} = useMarketplacePlugins()
|
||||
const page = Math.max(pluginsPage || 0, 1)
|
||||
|
||||
useEffect(() => {
|
||||
if (queryFromSearchParams || hasValidTags || hasValidCategory) {
|
||||
queryPlugins({
|
||||
query: queryFromSearchParams,
|
||||
category: hasValidCategory,
|
||||
tags: hasValidTags ? tagsFromSearchParams : [],
|
||||
sortBy: sortRef.current.sortBy,
|
||||
sortOrder: sortRef.current.sortOrder,
|
||||
type: getMarketplaceListFilterType(activePluginTypeRef.current),
|
||||
})
|
||||
}
|
||||
else {
|
||||
if (shouldExclude && isSuccess) {
|
||||
queryMarketplaceCollectionsAndPlugins({
|
||||
exclude,
|
||||
type: getMarketplaceListFilterType(activePluginTypeRef.current),
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [queryPlugins, queryMarketplaceCollectionsAndPlugins, isSuccess, exclude])
|
||||
|
||||
const handleQueryMarketplaceCollectionsAndPlugins = useCallback(() => {
|
||||
queryMarketplaceCollectionsAndPlugins({
|
||||
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
|
||||
condition: getMarketplaceListCondition(activePluginTypeRef.current),
|
||||
exclude,
|
||||
type: getMarketplaceListFilterType(activePluginTypeRef.current),
|
||||
})
|
||||
resetPlugins()
|
||||
}, [exclude, queryMarketplaceCollectionsAndPlugins, resetPlugins])
|
||||
|
||||
const applyUrlFilters = useCallback(() => {
|
||||
if (!showSearchParams)
|
||||
return
|
||||
const nextFilters = {
|
||||
q: searchPluginTextRef.current,
|
||||
category: activePluginTypeRef.current,
|
||||
tags: filterPluginTagsRef.current,
|
||||
}
|
||||
const categoryChanged = urlFilters.category !== nextFilters.category
|
||||
setUrlFilters(nextFilters, {
|
||||
history: categoryChanged ? 'push' : 'replace',
|
||||
})
|
||||
}, [setUrlFilters, showSearchParams, urlFilters.category])
|
||||
|
||||
const debouncedUpdateSearchParams = useMemo(() => debounce(() => {
|
||||
applyUrlFilters()
|
||||
}, 500), [applyUrlFilters])
|
||||
|
||||
const handleUpdateSearchParams = useCallback((debounced?: boolean) => {
|
||||
if (debounced) {
|
||||
debouncedUpdateSearchParams()
|
||||
}
|
||||
else {
|
||||
applyUrlFilters()
|
||||
}
|
||||
}, [applyUrlFilters, debouncedUpdateSearchParams])
|
||||
|
||||
const handleQueryPlugins = useCallback((debounced?: boolean) => {
|
||||
handleUpdateSearchParams(debounced)
|
||||
if (debounced) {
|
||||
queryPluginsWithDebounced({
|
||||
query: searchPluginTextRef.current,
|
||||
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
|
||||
tags: filterPluginTagsRef.current,
|
||||
sortBy: sortRef.current.sortBy,
|
||||
sortOrder: sortRef.current.sortOrder,
|
||||
exclude,
|
||||
type: getMarketplaceListFilterType(activePluginTypeRef.current),
|
||||
})
|
||||
}
|
||||
else {
|
||||
queryPlugins({
|
||||
query: searchPluginTextRef.current,
|
||||
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
|
||||
tags: filterPluginTagsRef.current,
|
||||
sortBy: sortRef.current.sortBy,
|
||||
sortOrder: sortRef.current.sortOrder,
|
||||
exclude,
|
||||
type: getMarketplaceListFilterType(activePluginTypeRef.current),
|
||||
})
|
||||
}
|
||||
}, [exclude, queryPluginsWithDebounced, queryPlugins, handleUpdateSearchParams])
|
||||
|
||||
const handleQuery = useCallback((debounced?: boolean) => {
|
||||
if (!searchPluginTextRef.current && !filterPluginTagsRef.current.length) {
|
||||
handleUpdateSearchParams(debounced)
|
||||
cancelQueryPluginsWithDebounced()
|
||||
handleQueryMarketplaceCollectionsAndPlugins()
|
||||
return
|
||||
}
|
||||
|
||||
handleQueryPlugins(debounced)
|
||||
}, [handleQueryMarketplaceCollectionsAndPlugins, handleQueryPlugins, cancelQueryPluginsWithDebounced, handleUpdateSearchParams])
|
||||
|
||||
const handleSearchPluginTextChange = useCallback((text: string) => {
|
||||
setSearchPluginText(text)
|
||||
searchPluginTextRef.current = text
|
||||
|
||||
handleQuery(true)
|
||||
}, [handleQuery])
|
||||
|
||||
const handleFilterPluginTagsChange = useCallback((tags: string[]) => {
|
||||
setFilterPluginTags(tags)
|
||||
filterPluginTagsRef.current = tags
|
||||
|
||||
handleQuery()
|
||||
}, [handleQuery])
|
||||
|
||||
const handleActivePluginTypeChange = useCallback((type: string) => {
|
||||
setActivePluginType(type)
|
||||
activePluginTypeRef.current = type
|
||||
|
||||
handleQuery()
|
||||
}, [handleQuery])
|
||||
|
||||
const handleSortChange = useCallback((sort: PluginsSort) => {
|
||||
setSort(sort)
|
||||
sortRef.current = sort
|
||||
|
||||
handleQueryPlugins()
|
||||
}, [handleQueryPlugins])
|
||||
|
||||
const handlePageChange = useCallback(() => {
|
||||
if (hasNextPluginsPage)
|
||||
fetchNextPluginsPage()
|
||||
}, [fetchNextPluginsPage, hasNextPluginsPage])
|
||||
|
||||
const handleMoreClick = useCallback((searchParams: SearchParamsFromCollection) => {
|
||||
setSearchPluginText(searchParams?.query || '')
|
||||
searchPluginTextRef.current = searchParams?.query || ''
|
||||
setSort({
|
||||
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
|
||||
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
|
||||
})
|
||||
sortRef.current = {
|
||||
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
|
||||
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
|
||||
}
|
||||
handleQueryPlugins()
|
||||
}, [handleQueryPlugins])
|
||||
|
||||
useMarketplaceContainerScroll(handlePageChange, scrollContainerId)
|
||||
|
||||
return (
|
||||
<MarketplaceContext.Provider
|
||||
value={{
|
||||
searchPluginText,
|
||||
handleSearchPluginTextChange,
|
||||
filterPluginTags,
|
||||
handleFilterPluginTagsChange,
|
||||
activePluginType,
|
||||
handleActivePluginTypeChange,
|
||||
page,
|
||||
handlePageChange,
|
||||
plugins,
|
||||
pluginsTotal,
|
||||
resetPlugins,
|
||||
sort,
|
||||
handleSortChange,
|
||||
handleQueryPlugins,
|
||||
handleMoreClick,
|
||||
marketplaceCollectionsFromClient,
|
||||
setMarketplaceCollectionsFromClient,
|
||||
marketplaceCollectionPluginsMapFromClient,
|
||||
setMarketplaceCollectionPluginsMapFromClient,
|
||||
isLoading: isLoading || isPluginsLoading,
|
||||
isSuccessCollections,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</MarketplaceContext.Provider>
|
||||
)
|
||||
}
|
||||
@@ -26,9 +26,6 @@ import {
|
||||
getMarketplacePluginsByCollectionId,
|
||||
} from './utils'
|
||||
|
||||
/**
|
||||
* @deprecated Use useMarketplaceCollectionsAndPlugins from query.ts instead
|
||||
*/
|
||||
export const useMarketplaceCollectionsAndPlugins = () => {
|
||||
const [queryParams, setQueryParams] = useState<CollectionsAndPluginsSearchParams>()
|
||||
const [marketplaceCollectionsOverride, setMarketplaceCollections] = useState<MarketplaceCollection[]>()
|
||||
@@ -92,9 +89,7 @@ export const useMarketplacePluginsByCollectionId = (
|
||||
isSuccess,
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @deprecated Use useMarketplacePlugins from query.ts instead
|
||||
*/
|
||||
|
||||
export const useMarketplacePlugins = () => {
|
||||
const queryClient = useQueryClient()
|
||||
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
'use client'
|
||||
|
||||
import { useHydrateAtoms } from 'jotai/utils'
|
||||
import { preserveSearchStateInQueryAtom } from './atoms'
|
||||
|
||||
export function HydrateMarketplaceAtoms({
|
||||
preserveSearchStateInQuery,
|
||||
children,
|
||||
}: {
|
||||
preserveSearchStateInQuery: boolean
|
||||
children: React.ReactNode
|
||||
}) {
|
||||
useHydrateAtoms([[preserveSearchStateInQueryAtom, preserveSearchStateInQuery]])
|
||||
return <>{children}</>
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
import type { SearchParams } from 'nuqs'
|
||||
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
|
||||
import { createLoader } from 'nuqs/server'
|
||||
import { getQueryClientServer } from '@/context/query-client-server'
|
||||
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
|
||||
import { marketplaceKeys } from './query'
|
||||
import { marketplaceSearchParamsParsers } from './search-params'
|
||||
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
|
||||
|
||||
// The server side logic should move to marketplace's codebase so that we can get rid of Next.js
|
||||
|
||||
async function getDehydratedState(searchParams?: Promise<SearchParams>) {
|
||||
if (!searchParams) {
|
||||
return
|
||||
}
|
||||
const loadSearchParams = createLoader(marketplaceSearchParamsParsers)
|
||||
const params = await loadSearchParams(searchParams)
|
||||
|
||||
if (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(params.category)) {
|
||||
return
|
||||
}
|
||||
|
||||
const queryClient = getQueryClientServer()
|
||||
|
||||
await queryClient.prefetchQuery({
|
||||
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
|
||||
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
|
||||
})
|
||||
return dehydrate(queryClient)
|
||||
}
|
||||
|
||||
export async function HydrateQueryClient({
|
||||
searchParams,
|
||||
children,
|
||||
}: {
|
||||
searchParams: Promise<SearchParams> | undefined
|
||||
children: React.ReactNode
|
||||
}) {
|
||||
const dehydratedState = await getDehydratedState(searchParams)
|
||||
return (
|
||||
<HydrationBoundary state={dehydratedState}>
|
||||
{children}
|
||||
</HydrationBoundary>
|
||||
)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,39 +1,55 @@
|
||||
import type { SearchParams } from 'nuqs'
|
||||
import type { MarketplaceCollection, SearchParams } from './types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import { TanstackQueryInitializer } from '@/context/query-client'
|
||||
import { MarketplaceContextProvider } from './context'
|
||||
import Description from './description'
|
||||
import { HydrateMarketplaceAtoms } from './hydration-client'
|
||||
import { HydrateQueryClient } from './hydration-server'
|
||||
import ListWrapper from './list/list-wrapper'
|
||||
import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper'
|
||||
import { getMarketplaceCollectionsAndPlugins } from './utils'
|
||||
|
||||
type MarketplaceProps = {
|
||||
showInstallButton?: boolean
|
||||
shouldExclude?: boolean
|
||||
searchParams?: SearchParams
|
||||
pluginTypeSwitchClassName?: string
|
||||
/**
|
||||
* Pass the search params from the request to prefetch data on the server
|
||||
* and preserve the search params in the URL.
|
||||
*/
|
||||
searchParams?: Promise<SearchParams>
|
||||
scrollContainerId?: string
|
||||
showSearchParams?: boolean
|
||||
}
|
||||
|
||||
const Marketplace = async ({
|
||||
showInstallButton = true,
|
||||
pluginTypeSwitchClassName,
|
||||
shouldExclude,
|
||||
searchParams,
|
||||
pluginTypeSwitchClassName,
|
||||
scrollContainerId,
|
||||
showSearchParams = true,
|
||||
}: MarketplaceProps) => {
|
||||
let marketplaceCollections: MarketplaceCollection[] = []
|
||||
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
|
||||
if (!shouldExclude) {
|
||||
const marketplaceCollectionsAndPluginsData = await getMarketplaceCollectionsAndPlugins()
|
||||
marketplaceCollections = marketplaceCollectionsAndPluginsData.marketplaceCollections
|
||||
marketplaceCollectionPluginsMap = marketplaceCollectionsAndPluginsData.marketplaceCollectionPluginsMap
|
||||
}
|
||||
|
||||
return (
|
||||
<TanstackQueryInitializer>
|
||||
<HydrateQueryClient searchParams={searchParams}>
|
||||
<HydrateMarketplaceAtoms preserveSearchStateInQuery={!!searchParams}>
|
||||
<Description />
|
||||
<StickySearchAndSwitchWrapper
|
||||
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
|
||||
/>
|
||||
<ListWrapper
|
||||
showInstallButton={showInstallButton}
|
||||
/>
|
||||
</HydrateMarketplaceAtoms>
|
||||
</HydrateQueryClient>
|
||||
<MarketplaceContextProvider
|
||||
searchParams={searchParams}
|
||||
shouldExclude={shouldExclude}
|
||||
scrollContainerId={scrollContainerId}
|
||||
showSearchParams={showSearchParams}
|
||||
>
|
||||
<Description />
|
||||
<StickySearchAndSwitchWrapper
|
||||
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
|
||||
showSearchParams={showSearchParams}
|
||||
/>
|
||||
<ListWrapper
|
||||
marketplaceCollections={marketplaceCollections}
|
||||
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap}
|
||||
showInstallButton={showInstallButton}
|
||||
/>
|
||||
</MarketplaceContextProvider>
|
||||
</TanstackQueryInitializer>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { MarketplaceCollection, SearchParamsFromCollection } from '../types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import List from './index'
|
||||
@@ -30,27 +30,23 @@ vi.mock('#i18n', () => ({
|
||||
useLocale: () => 'en-US',
|
||||
}))
|
||||
|
||||
// Mock marketplace state hooks with controllable values
|
||||
const { mockMarketplaceData, mockMoreClick } = vi.hoisted(() => {
|
||||
return {
|
||||
mockMarketplaceData: {
|
||||
plugins: undefined as Plugin[] | undefined,
|
||||
pluginsTotal: 0,
|
||||
marketplaceCollections: undefined as MarketplaceCollection[] | undefined,
|
||||
marketplaceCollectionPluginsMap: undefined as Record<string, Plugin[]> | undefined,
|
||||
isLoading: false,
|
||||
page: 1,
|
||||
},
|
||||
mockMoreClick: vi.fn(),
|
||||
}
|
||||
})
|
||||
// Mock useMarketplaceContext with controllable values
|
||||
const mockContextValues = {
|
||||
plugins: undefined as Plugin[] | undefined,
|
||||
pluginsTotal: 0,
|
||||
marketplaceCollectionsFromClient: undefined as MarketplaceCollection[] | undefined,
|
||||
marketplaceCollectionPluginsMapFromClient: undefined as Record<string, Plugin[]> | undefined,
|
||||
isLoading: false,
|
||||
isSuccessCollections: false,
|
||||
handleQueryPlugins: vi.fn(),
|
||||
searchPluginText: '',
|
||||
filterPluginTags: [] as string[],
|
||||
page: 1,
|
||||
handleMoreClick: vi.fn(),
|
||||
}
|
||||
|
||||
vi.mock('../state', () => ({
|
||||
useMarketplaceData: () => mockMarketplaceData,
|
||||
}))
|
||||
|
||||
vi.mock('../atoms', () => ({
|
||||
useMarketplaceMoreClick: () => mockMoreClick,
|
||||
vi.mock('../context', () => ({
|
||||
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
|
||||
}))
|
||||
|
||||
// Mock useLocale context
|
||||
@@ -582,7 +578,7 @@ describe('ListWithCollection', () => {
|
||||
// View More Button Tests
|
||||
// ================================
|
||||
describe('View More Button', () => {
|
||||
it('should render View More button when collection is searchable', () => {
|
||||
it('should render View More button when collection is searchable and onMoreClick is provided', () => {
|
||||
const collections = [createMockCollection({
|
||||
name: 'collection-0',
|
||||
searchable: true,
|
||||
@@ -591,12 +587,14 @@ describe('ListWithCollection', () => {
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
const onMoreClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ListWithCollection
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
onMoreClick={onMoreClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
@@ -611,24 +609,24 @@ describe('ListWithCollection', () => {
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
const onMoreClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ListWithCollection
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
onMoreClick={onMoreClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('View More')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call moreClick hook with search_params when View More is clicked', () => {
|
||||
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
|
||||
it('should not render View More button when onMoreClick is not provided', () => {
|
||||
const collections = [createMockCollection({
|
||||
name: 'collection-0',
|
||||
searchable: true,
|
||||
search_params: searchParams,
|
||||
})]
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
@@ -639,13 +637,38 @@ describe('ListWithCollection', () => {
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
onMoreClick={undefined}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('View More')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onMoreClick with search_params when View More is clicked', () => {
|
||||
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
|
||||
const collections = [createMockCollection({
|
||||
name: 'collection-0',
|
||||
searchable: true,
|
||||
search_params: searchParams,
|
||||
})]
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
const onMoreClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ListWithCollection
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
onMoreClick={onMoreClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('View More'))
|
||||
|
||||
expect(mockMoreClick).toHaveBeenCalledTimes(1)
|
||||
expect(mockMoreClick).toHaveBeenCalledWith(searchParams)
|
||||
expect(onMoreClick).toHaveBeenCalledTimes(1)
|
||||
expect(onMoreClick).toHaveBeenCalledWith(searchParams)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -779,15 +802,24 @@ describe('ListWithCollection', () => {
|
||||
// ListWrapper Component Tests
|
||||
// ================================
|
||||
describe('ListWrapper', () => {
|
||||
const defaultProps = {
|
||||
marketplaceCollections: [] as MarketplaceCollection[],
|
||||
marketplaceCollectionPluginsMap: {} as Record<string, Plugin[]>,
|
||||
showInstallButton: false,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
// Reset mock data
|
||||
mockMarketplaceData.plugins = undefined
|
||||
mockMarketplaceData.pluginsTotal = 0
|
||||
mockMarketplaceData.marketplaceCollections = undefined
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
|
||||
mockMarketplaceData.isLoading = false
|
||||
mockMarketplaceData.page = 1
|
||||
// Reset context values
|
||||
mockContextValues.plugins = undefined
|
||||
mockContextValues.pluginsTotal = 0
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
|
||||
mockContextValues.isLoading = false
|
||||
mockContextValues.isSuccessCollections = false
|
||||
mockContextValues.searchPluginText = ''
|
||||
mockContextValues.filterPluginTags = []
|
||||
mockContextValues.page = 1
|
||||
})
|
||||
|
||||
// ================================
|
||||
@@ -795,32 +827,32 @@ describe('ListWrapper', () => {
|
||||
// ================================
|
||||
describe('Rendering', () => {
|
||||
it('should render without crashing', () => {
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(document.body).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render with scrollbarGutter style', () => {
|
||||
const { container } = render(<ListWrapper />)
|
||||
const { container } = render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper).toHaveStyle({ scrollbarGutter: 'stable' })
|
||||
})
|
||||
|
||||
it('should render Loading component when isLoading is true and page is 1', () => {
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 1
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 1
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render Loading component when page > 1', () => {
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 2
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 2
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
|
||||
})
|
||||
@@ -831,26 +863,26 @@ describe('ListWrapper', () => {
|
||||
// ================================
|
||||
describe('Plugins Header', () => {
|
||||
it('should render plugins result count when plugins are present', () => {
|
||||
mockMarketplaceData.plugins = createMockPluginList(5)
|
||||
mockMarketplaceData.pluginsTotal = 5
|
||||
mockContextValues.plugins = createMockPluginList(5)
|
||||
mockContextValues.pluginsTotal = 5
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render SortDropdown when plugins are present', () => {
|
||||
mockMarketplaceData.plugins = createMockPluginList(1)
|
||||
mockContextValues.plugins = createMockPluginList(1)
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('sort-dropdown')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render plugins header when plugins is undefined', () => {
|
||||
mockMarketplaceData.plugins = undefined
|
||||
mockContextValues.plugins = undefined
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.queryByTestId('sort-dropdown')).not.toBeInTheDocument()
|
||||
})
|
||||
@@ -860,60 +892,197 @@ describe('ListWrapper', () => {
|
||||
// List Rendering Logic Tests
|
||||
// ================================
|
||||
describe('List Rendering Logic', () => {
|
||||
it('should render collections when not loading', () => {
|
||||
mockMarketplaceData.isLoading = false
|
||||
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = {
|
||||
it('should render List when not loading', () => {
|
||||
mockContextValues.isLoading = false
|
||||
const collections = createMockCollectionList(1)
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(
|
||||
<ListWrapper
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Collection 0')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render List when loading but page > 1', () => {
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 2
|
||||
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = {
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 2
|
||||
const collections = createMockCollectionList(1)
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(
|
||||
<ListWrapper
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Collection 0')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use client collections when available', () => {
|
||||
const serverCollections = createMockCollectionList(1)
|
||||
serverCollections[0].label = { 'en-US': 'Server Collection' }
|
||||
const clientCollections = createMockCollectionList(1)
|
||||
clientCollections[0].label = { 'en-US': 'Client Collection' }
|
||||
|
||||
const serverPluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
const clientPluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
|
||||
mockContextValues.marketplaceCollectionsFromClient = clientCollections
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = clientPluginsMap
|
||||
|
||||
render(
|
||||
<ListWrapper
|
||||
{...defaultProps}
|
||||
marketplaceCollections={serverCollections}
|
||||
marketplaceCollectionPluginsMap={serverPluginsMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Client Collection')).toBeInTheDocument()
|
||||
expect(screen.queryByText('Server Collection')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should use server collections when client collections are not available', () => {
|
||||
const serverCollections = createMockCollectionList(1)
|
||||
serverCollections[0].label = { 'en-US': 'Server Collection' }
|
||||
const serverPluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
|
||||
|
||||
render(
|
||||
<ListWrapper
|
||||
{...defaultProps}
|
||||
marketplaceCollections={serverCollections}
|
||||
marketplaceCollectionPluginsMap={serverPluginsMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Server Collection')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// ================================
|
||||
// Data Integration Tests
|
||||
// Context Integration Tests
|
||||
// ================================
|
||||
describe('Data Integration', () => {
|
||||
it('should pass plugins from state to List', () => {
|
||||
mockMarketplaceData.plugins = createMockPluginList(2)
|
||||
describe('Context Integration', () => {
|
||||
it('should pass plugins from context to List', () => {
|
||||
const plugins = createMockPluginList(2)
|
||||
mockContextValues.plugins = plugins
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByTestId('card-plugin-0')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('card-plugin-1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show View More button and call moreClick hook', () => {
|
||||
mockMarketplaceData.marketplaceCollections = [createMockCollection({
|
||||
it('should pass handleMoreClick from context to List', () => {
|
||||
const mockHandleMoreClick = vi.fn()
|
||||
mockContextValues.handleMoreClick = mockHandleMoreClick
|
||||
|
||||
const collections = [createMockCollection({
|
||||
name: 'collection-0',
|
||||
searchable: true,
|
||||
search_params: { query: 'test' },
|
||||
})]
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = {
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(
|
||||
<ListWrapper
|
||||
{...defaultProps}
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('View More'))
|
||||
|
||||
expect(mockMoreClick).toHaveBeenCalled()
|
||||
expect(mockHandleMoreClick).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// ================================
|
||||
// Effect Tests (handleQueryPlugins)
|
||||
// ================================
|
||||
describe('handleQueryPlugins Effect', () => {
|
||||
it('should call handleQueryPlugins when conditions are met', async () => {
|
||||
const mockHandleQueryPlugins = vi.fn()
|
||||
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
|
||||
mockContextValues.isSuccessCollections = true
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.searchPluginText = ''
|
||||
mockContextValues.filterPluginTags = []
|
||||
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleQueryPlugins).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call handleQueryPlugins when client collections exist', async () => {
|
||||
const mockHandleQueryPlugins = vi.fn()
|
||||
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
|
||||
mockContextValues.isSuccessCollections = true
|
||||
mockContextValues.marketplaceCollectionsFromClient = createMockCollectionList(1)
|
||||
mockContextValues.searchPluginText = ''
|
||||
mockContextValues.filterPluginTags = []
|
||||
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
// Give time for effect to run
|
||||
await waitFor(() => {
|
||||
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call handleQueryPlugins when search text exists', async () => {
|
||||
const mockHandleQueryPlugins = vi.fn()
|
||||
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
|
||||
mockContextValues.isSuccessCollections = true
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.searchPluginText = 'search text'
|
||||
mockContextValues.filterPluginTags = []
|
||||
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call handleQueryPlugins when filter tags exist', async () => {
|
||||
const mockHandleQueryPlugins = vi.fn()
|
||||
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
|
||||
mockContextValues.isSuccessCollections = true
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.searchPluginText = ''
|
||||
mockContextValues.filterPluginTags = ['tag1']
|
||||
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -921,32 +1090,32 @@ describe('ListWrapper', () => {
|
||||
// Edge Cases Tests
|
||||
// ================================
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty plugins array', () => {
|
||||
mockMarketplaceData.plugins = []
|
||||
mockMarketplaceData.pluginsTotal = 0
|
||||
it('should handle empty plugins array from context', () => {
|
||||
mockContextValues.plugins = []
|
||||
mockContextValues.pluginsTotal = 0
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle large pluginsTotal', () => {
|
||||
mockMarketplaceData.plugins = createMockPluginList(10)
|
||||
mockMarketplaceData.pluginsTotal = 10000
|
||||
mockContextValues.plugins = createMockPluginList(10)
|
||||
mockContextValues.pluginsTotal = 10000
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('10000 plugins found')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle both loading and has plugins', () => {
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 2
|
||||
mockMarketplaceData.plugins = createMockPluginList(5)
|
||||
mockMarketplaceData.pluginsTotal = 50
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 2
|
||||
mockContextValues.plugins = createMockPluginList(5)
|
||||
mockContextValues.pluginsTotal = 50
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(<ListWrapper {...defaultProps} />)
|
||||
|
||||
// Should show plugins header and list
|
||||
expect(screen.getByText('50 plugins found')).toBeInTheDocument()
|
||||
@@ -1259,72 +1428,106 @@ describe('CardWrapper (via List integration)', () => {
|
||||
describe('Combined Workflows', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockMarketplaceData.plugins = undefined
|
||||
mockMarketplaceData.pluginsTotal = 0
|
||||
mockMarketplaceData.isLoading = false
|
||||
mockMarketplaceData.page = 1
|
||||
mockMarketplaceData.marketplaceCollections = undefined
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
|
||||
mockContextValues.plugins = undefined
|
||||
mockContextValues.pluginsTotal = 0
|
||||
mockContextValues.isLoading = false
|
||||
mockContextValues.page = 1
|
||||
mockContextValues.marketplaceCollectionsFromClient = undefined
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
|
||||
})
|
||||
|
||||
it('should transition from loading to showing collections', async () => {
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 1
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 1
|
||||
|
||||
const { rerender } = render(<ListWrapper />)
|
||||
const { rerender } = render(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
|
||||
|
||||
// Simulate loading complete
|
||||
mockMarketplaceData.isLoading = false
|
||||
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = {
|
||||
mockContextValues.isLoading = false
|
||||
const collections = createMockCollectionList(1)
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
mockContextValues.marketplaceCollectionsFromClient = collections
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
|
||||
|
||||
rerender(<ListWrapper />)
|
||||
rerender(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Collection 0')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should transition from collections to search results', async () => {
|
||||
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
|
||||
mockMarketplaceData.marketplaceCollectionPluginsMap = {
|
||||
const collections = createMockCollectionList(1)
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
mockContextValues.marketplaceCollectionsFromClient = collections
|
||||
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
|
||||
|
||||
const { rerender } = render(<ListWrapper />)
|
||||
const { rerender } = render(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('Collection 0')).toBeInTheDocument()
|
||||
|
||||
// Simulate search results
|
||||
mockMarketplaceData.plugins = createMockPluginList(5)
|
||||
mockMarketplaceData.pluginsTotal = 5
|
||||
mockContextValues.plugins = createMockPluginList(5)
|
||||
mockContextValues.pluginsTotal = 5
|
||||
|
||||
rerender(<ListWrapper />)
|
||||
rerender(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByText('Collection 0')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle empty search results', () => {
|
||||
mockMarketplaceData.plugins = []
|
||||
mockMarketplaceData.pluginsTotal = 0
|
||||
mockContextValues.plugins = []
|
||||
mockContextValues.pluginsTotal = 0
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
|
||||
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should support pagination (page > 1)', () => {
|
||||
mockMarketplaceData.plugins = createMockPluginList(40)
|
||||
mockMarketplaceData.pluginsTotal = 80
|
||||
mockMarketplaceData.isLoading = true
|
||||
mockMarketplaceData.page = 2
|
||||
mockContextValues.plugins = createMockPluginList(40)
|
||||
mockContextValues.pluginsTotal = 80
|
||||
mockContextValues.isLoading = true
|
||||
mockContextValues.page = 2
|
||||
|
||||
render(<ListWrapper />)
|
||||
render(
|
||||
<ListWrapper
|
||||
marketplaceCollections={[]}
|
||||
marketplaceCollectionPluginsMap={{}}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Should show existing results while loading more
|
||||
expect(screen.getByText('80 plugins found')).toBeInTheDocument()
|
||||
@@ -1339,9 +1542,9 @@ describe('Combined Workflows', () => {
|
||||
describe('Accessibility', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockMarketplaceData.plugins = undefined
|
||||
mockMarketplaceData.isLoading = false
|
||||
mockMarketplaceData.page = 1
|
||||
mockContextValues.plugins = undefined
|
||||
mockContextValues.isLoading = false
|
||||
mockContextValues.page = 1
|
||||
})
|
||||
|
||||
it('should have semantic structure with collections', () => {
|
||||
@@ -1370,11 +1573,13 @@ describe('Accessibility', () => {
|
||||
const pluginsMap: Record<string, Plugin[]> = {
|
||||
'collection-0': createMockPluginList(1),
|
||||
}
|
||||
const onMoreClick = vi.fn()
|
||||
|
||||
render(
|
||||
<ListWithCollection
|
||||
marketplaceCollections={collections}
|
||||
marketplaceCollectionPluginsMap={pluginsMap}
|
||||
onMoreClick={onMoreClick}
|
||||
/>,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ type ListProps = {
|
||||
showInstallButton?: boolean
|
||||
cardContainerClassName?: string
|
||||
cardRender?: (plugin: Plugin) => React.JSX.Element | null
|
||||
onMoreClick?: () => void
|
||||
emptyClassName?: string
|
||||
}
|
||||
const List = ({
|
||||
@@ -22,6 +23,7 @@ const List = ({
|
||||
showInstallButton,
|
||||
cardContainerClassName,
|
||||
cardRender,
|
||||
onMoreClick,
|
||||
emptyClassName,
|
||||
}: ListProps) => {
|
||||
return (
|
||||
@@ -34,6 +36,7 @@ const List = ({
|
||||
showInstallButton={showInstallButton}
|
||||
cardContainerClassName={cardContainerClassName}
|
||||
cardRender={cardRender}
|
||||
onMoreClick={onMoreClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
'use client'
|
||||
|
||||
import type { MarketplaceCollection } from '../types'
|
||||
import type { SearchParamsFromCollection } from '@/app/components/plugins/marketplace/types'
|
||||
import type { Plugin } from '@/app/components/plugins/types'
|
||||
import { useLocale, useTranslation } from '#i18n'
|
||||
import { RiArrowRightSLine } from '@remixicon/react'
|
||||
import { getLanguage } from '@/i18n-config/language'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useMarketplaceMoreClick } from '../atoms'
|
||||
import CardWrapper from './card-wrapper'
|
||||
|
||||
type ListWithCollectionProps = {
|
||||
@@ -15,6 +15,7 @@ type ListWithCollectionProps = {
|
||||
showInstallButton?: boolean
|
||||
cardContainerClassName?: string
|
||||
cardRender?: (plugin: Plugin) => React.JSX.Element | null
|
||||
onMoreClick?: (searchParams?: SearchParamsFromCollection) => void
|
||||
}
|
||||
const ListWithCollection = ({
|
||||
marketplaceCollections,
|
||||
@@ -22,10 +23,10 @@ const ListWithCollection = ({
|
||||
showInstallButton,
|
||||
cardContainerClassName,
|
||||
cardRender,
|
||||
onMoreClick,
|
||||
}: ListWithCollectionProps) => {
|
||||
const { t } = useTranslation()
|
||||
const locale = useLocale()
|
||||
const onMoreClick = useMarketplaceMoreClick()
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -43,10 +44,10 @@ const ListWithCollection = ({
|
||||
<div className="system-xs-regular text-text-tertiary">{collection.description[getLanguage(locale)]}</div>
|
||||
</div>
|
||||
{
|
||||
collection.searchable && (
|
||||
collection.searchable && onMoreClick && (
|
||||
<div
|
||||
className="system-xs-medium flex cursor-pointer items-center text-text-accent "
|
||||
onClick={() => onMoreClick(collection.search_params)}
|
||||
onClick={() => onMoreClick?.(collection.search_params)}
|
||||
>
|
||||
{t('marketplace.viewMore', { ns: 'plugin' })}
|
||||
<RiArrowRightSLine className="h-4 w-4" />
|
||||
|
||||
@@ -1,26 +1,46 @@
|
||||
'use client'
|
||||
import type { Plugin } from '../../types'
|
||||
import type { MarketplaceCollection } from '../types'
|
||||
import { useTranslation } from '#i18n'
|
||||
import { useEffect } from 'react'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useMarketplaceContext } from '../context'
|
||||
import SortDropdown from '../sort-dropdown'
|
||||
import { useMarketplaceData } from '../state'
|
||||
import List from './index'
|
||||
|
||||
type ListWrapperProps = {
|
||||
marketplaceCollections: MarketplaceCollection[]
|
||||
marketplaceCollectionPluginsMap: Record<string, Plugin[]>
|
||||
showInstallButton?: boolean
|
||||
}
|
||||
const ListWrapper = ({
|
||||
marketplaceCollections,
|
||||
marketplaceCollectionPluginsMap,
|
||||
showInstallButton,
|
||||
}: ListWrapperProps) => {
|
||||
const { t } = useTranslation()
|
||||
const plugins = useMarketplaceContext(v => v.plugins)
|
||||
const pluginsTotal = useMarketplaceContext(v => v.pluginsTotal)
|
||||
const marketplaceCollectionsFromClient = useMarketplaceContext(v => v.marketplaceCollectionsFromClient)
|
||||
const marketplaceCollectionPluginsMapFromClient = useMarketplaceContext(v => v.marketplaceCollectionPluginsMapFromClient)
|
||||
const isLoading = useMarketplaceContext(v => v.isLoading)
|
||||
const isSuccessCollections = useMarketplaceContext(v => v.isSuccessCollections)
|
||||
const handleQueryPlugins = useMarketplaceContext(v => v.handleQueryPlugins)
|
||||
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
|
||||
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
|
||||
const page = useMarketplaceContext(v => v.page)
|
||||
const handleMoreClick = useMarketplaceContext(v => v.handleMoreClick)
|
||||
|
||||
const {
|
||||
plugins,
|
||||
pluginsTotal,
|
||||
marketplaceCollections,
|
||||
marketplaceCollectionPluginsMap,
|
||||
isLoading,
|
||||
page,
|
||||
} = useMarketplaceData()
|
||||
useEffect(() => {
|
||||
if (
|
||||
!marketplaceCollectionsFromClient?.length
|
||||
&& isSuccessCollections
|
||||
&& !searchPluginText
|
||||
&& !filterPluginTags.length
|
||||
) {
|
||||
handleQueryPlugins()
|
||||
}
|
||||
}, [handleQueryPlugins, marketplaceCollections, marketplaceCollectionsFromClient, isSuccessCollections, searchPluginText, filterPluginTags])
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -46,10 +66,11 @@ const ListWrapper = ({
|
||||
{
|
||||
(!isLoading || page > 1) && (
|
||||
<List
|
||||
marketplaceCollections={marketplaceCollections || []}
|
||||
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap || {}}
|
||||
marketplaceCollections={marketplaceCollectionsFromClient || marketplaceCollections}
|
||||
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMapFromClient || marketplaceCollectionPluginsMap}
|
||||
plugins={plugins}
|
||||
showInstallButton={showInstallButton}
|
||||
onMoreClick={handleMoreClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
'use client'
|
||||
import type { ActivePluginType } from './constants'
|
||||
import { useTranslation } from '#i18n'
|
||||
import {
|
||||
RiArchive2Line,
|
||||
@@ -9,27 +8,35 @@ import {
|
||||
RiPuzzle2Line,
|
||||
RiSpeakAiLine,
|
||||
} from '@remixicon/react'
|
||||
import { useSetAtom } from 'jotai'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { Trigger as TriggerIcon } from '@/app/components/base/icons/src/vender/plugin'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { searchModeAtom, useActivePluginType } from './atoms'
|
||||
import { PLUGIN_CATEGORY_WITH_COLLECTIONS, PLUGIN_TYPE_SEARCH_MAP } from './constants'
|
||||
import { PluginCategoryEnum } from '../types'
|
||||
import { useMarketplaceContext } from './context'
|
||||
|
||||
export const PLUGIN_TYPE_SEARCH_MAP = {
|
||||
all: 'all',
|
||||
model: PluginCategoryEnum.model,
|
||||
tool: PluginCategoryEnum.tool,
|
||||
agent: PluginCategoryEnum.agent,
|
||||
extension: PluginCategoryEnum.extension,
|
||||
datasource: PluginCategoryEnum.datasource,
|
||||
trigger: PluginCategoryEnum.trigger,
|
||||
bundle: 'bundle',
|
||||
}
|
||||
type PluginTypeSwitchProps = {
|
||||
className?: string
|
||||
showSearchParams?: boolean
|
||||
}
|
||||
const PluginTypeSwitch = ({
|
||||
className,
|
||||
showSearchParams,
|
||||
}: PluginTypeSwitchProps) => {
|
||||
const { t } = useTranslation()
|
||||
const [activePluginType, handleActivePluginTypeChange] = useActivePluginType()
|
||||
const setSearchMode = useSetAtom(searchModeAtom)
|
||||
const activePluginType = useMarketplaceContext(s => s.activePluginType)
|
||||
const handleActivePluginTypeChange = useMarketplaceContext(s => s.handleActivePluginTypeChange)
|
||||
|
||||
const options: Array<{
|
||||
value: ActivePluginType
|
||||
text: string
|
||||
icon: React.ReactNode | null
|
||||
}> = [
|
||||
const options = [
|
||||
{
|
||||
value: PLUGIN_TYPE_SEARCH_MAP.all,
|
||||
text: t('category.all', { ns: 'plugin' }),
|
||||
@@ -72,6 +79,23 @@ const PluginTypeSwitch = ({
|
||||
},
|
||||
]
|
||||
|
||||
const handlePopState = useCallback(() => {
|
||||
if (!showSearchParams)
|
||||
return
|
||||
// nuqs handles popstate automatically
|
||||
const url = new URL(window.location.href)
|
||||
const category = url.searchParams.get('category') || PLUGIN_TYPE_SEARCH_MAP.all
|
||||
handleActivePluginTypeChange(category)
|
||||
}, [showSearchParams, handleActivePluginTypeChange])
|
||||
|
||||
useEffect(() => {
|
||||
// nuqs manages popstate internally, but we keep this for URL sync
|
||||
window.addEventListener('popstate', handlePopState)
|
||||
return () => {
|
||||
window.removeEventListener('popstate', handlePopState)
|
||||
}
|
||||
}, [handlePopState])
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center justify-center space-x-2 bg-background-body py-3',
|
||||
@@ -88,9 +112,6 @@ const PluginTypeSwitch = ({
|
||||
)}
|
||||
onClick={() => {
|
||||
handleActivePluginTypeChange(option.value)
|
||||
if (PLUGIN_CATEGORY_WITH_COLLECTIONS.has(option.value)) {
|
||||
setSearchMode(null)
|
||||
}
|
||||
}}
|
||||
>
|
||||
{option.icon}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
|
||||
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
|
||||
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
|
||||
|
||||
// TODO: Avoid manual maintenance of query keys and better service management,
|
||||
// https://github.com/langgenius/dify/issues/30342
|
||||
|
||||
export const marketplaceKeys = {
|
||||
all: ['marketplace'] as const,
|
||||
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
|
||||
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
|
||||
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
|
||||
}
|
||||
|
||||
export function useMarketplaceCollectionsAndPlugins(
|
||||
collectionsParams: CollectionsAndPluginsSearchParams,
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: marketplaceKeys.collections(collectionsParams),
|
||||
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
|
||||
})
|
||||
}
|
||||
|
||||
export function useMarketplacePlugins(
|
||||
queryParams: PluginsSearchParams | undefined,
|
||||
) {
|
||||
return useInfiniteQuery({
|
||||
queryKey: marketplaceKeys.plugins(queryParams),
|
||||
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
|
||||
getNextPageParam: (lastPage) => {
|
||||
const nextPage = lastPage.page + 1
|
||||
const loaded = lastPage.page * lastPage.pageSize
|
||||
return loaded < (lastPage.total || 0) ? nextPage : undefined
|
||||
},
|
||||
initialPageParam: 1,
|
||||
enabled: queryParams !== undefined,
|
||||
})
|
||||
}
|
||||
@@ -26,19 +26,16 @@ vi.mock('#i18n', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock marketplace state hooks
|
||||
const { mockSearchPluginText, mockHandleSearchPluginTextChange, mockFilterPluginTags, mockHandleFilterPluginTagsChange } = vi.hoisted(() => {
|
||||
return {
|
||||
mockSearchPluginText: '',
|
||||
mockHandleSearchPluginTextChange: vi.fn(),
|
||||
mockFilterPluginTags: [] as string[],
|
||||
mockHandleFilterPluginTagsChange: vi.fn(),
|
||||
}
|
||||
})
|
||||
// Mock useMarketplaceContext
|
||||
const mockContextValues = {
|
||||
searchPluginText: '',
|
||||
handleSearchPluginTextChange: vi.fn(),
|
||||
filterPluginTags: [] as string[],
|
||||
handleFilterPluginTagsChange: vi.fn(),
|
||||
}
|
||||
|
||||
vi.mock('../atoms', () => ({
|
||||
useSearchPluginText: () => [mockSearchPluginText, mockHandleSearchPluginTextChange],
|
||||
useFilterPluginTags: () => [mockFilterPluginTags, mockHandleFilterPluginTagsChange],
|
||||
vi.mock('../context', () => ({
|
||||
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
|
||||
}))
|
||||
|
||||
// Mock useTags hook
|
||||
@@ -433,6 +430,9 @@ describe('SearchBoxWrapper', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false
|
||||
// Reset context values
|
||||
mockContextValues.searchPluginText = ''
|
||||
mockContextValues.filterPluginTags = []
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
@@ -456,14 +456,28 @@ describe('SearchBoxWrapper', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Hook Integration', () => {
|
||||
describe('Context Integration', () => {
|
||||
it('should use searchPluginText from context', () => {
|
||||
mockContextValues.searchPluginText = 'context search'
|
||||
render(<SearchBoxWrapper />)
|
||||
|
||||
expect(screen.getByDisplayValue('context search')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call handleSearchPluginTextChange when search changes', () => {
|
||||
render(<SearchBoxWrapper />)
|
||||
|
||||
const input = screen.getByRole('textbox')
|
||||
fireEvent.change(input, { target: { value: 'new search' } })
|
||||
|
||||
expect(mockHandleSearchPluginTextChange).toHaveBeenCalledWith('new search')
|
||||
expect(mockContextValues.handleSearchPluginTextChange).toHaveBeenCalledWith('new search')
|
||||
})
|
||||
|
||||
it('should use filterPluginTags from context', () => {
|
||||
mockContextValues.filterPluginTags = ['agent', 'rag']
|
||||
render(<SearchBoxWrapper />)
|
||||
|
||||
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
'use client'
|
||||
|
||||
import { useTranslation } from '#i18n'
|
||||
import { useFilterPluginTags, useSearchPluginText } from '../atoms'
|
||||
import { useMarketplaceContext } from '../context'
|
||||
import SearchBox from './index'
|
||||
|
||||
const SearchBoxWrapper = () => {
|
||||
const { t } = useTranslation()
|
||||
const [searchPluginText, handleSearchPluginTextChange] = useSearchPluginText()
|
||||
const [filterPluginTags, handleFilterPluginTagsChange] = useFilterPluginTags()
|
||||
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
|
||||
const handleSearchPluginTextChange = useMarketplaceContext(v => v.handleSearchPluginTextChange)
|
||||
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
|
||||
const handleFilterPluginTagsChange = useMarketplaceContext(v => v.handleFilterPluginTagsChange)
|
||||
|
||||
return (
|
||||
<SearchBox
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import type { ActivePluginType } from './constants'
|
||||
import { parseAsArrayOf, parseAsString, parseAsStringEnum } from 'nuqs/server'
|
||||
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
|
||||
|
||||
export const marketplaceSearchParamsParsers = {
|
||||
category: parseAsStringEnum<ActivePluginType>(Object.values(PLUGIN_TYPE_SEARCH_MAP) as ActivePluginType[]).withDefault('all').withOptions({ history: 'replace', clearOnDefault: false }),
|
||||
q: parseAsString.withDefault('').withOptions({ history: 'replace' }),
|
||||
tags: parseAsArrayOf(parseAsString).withDefault([]).withOptions({ history: 'replace' }),
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { MarketplaceContextValue } from '../context'
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
@@ -27,12 +28,18 @@ vi.mock('#i18n', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock marketplace atoms with controllable values
|
||||
let mockSort: { sortBy: string, sortOrder: string } = { sortBy: 'install_count', sortOrder: 'DESC' }
|
||||
// Mock marketplace context with controllable values
|
||||
let mockSort = { sortBy: 'install_count', sortOrder: 'DESC' }
|
||||
const mockHandleSortChange = vi.fn()
|
||||
|
||||
vi.mock('../atoms', () => ({
|
||||
useMarketplaceSort: () => [mockSort, mockHandleSortChange],
|
||||
vi.mock('../context', () => ({
|
||||
useMarketplaceContext: (selector: (value: MarketplaceContextValue) => unknown) => {
|
||||
const contextValue = {
|
||||
sort: mockSort,
|
||||
handleSortChange: mockHandleSortChange,
|
||||
} as unknown as MarketplaceContextValue
|
||||
return selector(contextValue)
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock portal component with controllable open state
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user