Compare commits

..

1 Commits

152 changed files with 5048 additions and 18090 deletions

View File

@@ -1,4 +1,4 @@
name: Deploy Agent Dev
name: Deploy Trigger Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/agent-dev"
- "deploy/trigger-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
host: ${{ secrets.TRIGGER_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

1
.gitignore vendored
View File

@@ -209,7 +209,6 @@ api/.vscode
.history
.idea/
web/migration/
# pnpm
/.pnpm-store

2
.nvmrc
View File

@@ -1 +1 @@
24
22.11.0

View File

@@ -70,6 +70,13 @@ class ActivateCheckApi(Resource):
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
# Check workspace permission
if tenant:
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None

View File

@@ -107,6 +107,12 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(inviter.current_tenant.id)
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL

View File

@@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@@ -28,6 +29,7 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@@ -288,3 +290,31 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/current/permission")
class WorkspacePermissionApi(Resource):
"""Get workspace permissions for the current workspace."""
@setup_required
@login_required
@account_initialization_required
@only_edition_enterprise
def get(self):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")
# Get workspace permissions from enterprise service
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
return {
"workspace_id": permission.workspace_id,
"allow_member_invite": permission.allow_member_invite,
"allow_owner_transfer": permission.allow_owner_transfer,
}, 200

View File

@@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
from libs.workspace_permission import check_workspace_owner_transfer_permission
# otherwise, return 403
abort(403)
_, current_tenant_id = current_account_with_tenant()
# Check both billing/plan level and workspace policy level permissions
check_workspace_owner_transfer_permission(current_tenant_id)
return view(*args, **kwargs)
return decorated

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -120,7 +121,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
json_schema: str | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -134,11 +135,17 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
def validate_json_schema(cls, schema: str | None) -> str | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@@ -26,6 +26,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@@ -358,6 +358,25 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],

View File

@@ -1,3 +1,4 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -75,24 +76,12 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
return user_inputs
@@ -189,8 +178,12 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@@ -1,434 +0,0 @@
# Memory Module
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
## Overview
The memory module contains two types of memory implementations:
1. **TokenBufferMemory** - Conversation-level memory (existing)
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Memory Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ TokenBufferMemory │ │
│ │ Scope: Conversation │ │
│ │ Storage: Database (Message table) │ │
│ │ Key: conversation_id │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ NodeTokenBufferMemory │ │
│ │ Scope: Node within Conversation │ │
│ │ Storage: Object Storage (JSON file) │ │
│ │ Key: (app_id, conversation_id, node_id) │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
```
---
## TokenBufferMemory (Existing)
### Purpose
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
### Key Features
- **Conversation-scoped**: All messages within a conversation are candidates
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
- **Token-limited**: Truncates history to fit within `max_token_limit`
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
### Data Flow
```
Message Table TokenBufferMemory LLM
│ │ │
│ SELECT * FROM messages │ │
│ WHERE conversation_id = ? │ │
│ ORDER BY created_at DESC │ │
├─────────────────────────────────▶│ │
│ │ │
│ extract_thread_messages() │
│ │ │
│ build_prompt_message_with_files() │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage]
│ ├───────────────────────▶│
│ │ │
```
### Thread Extraction
When a user regenerates a response, a new thread is created:
```
Message A (user)
└── Message A' (assistant)
└── Message B (user)
└── Message B' (assistant)
└── Message A'' (assistant, regenerated) ← New thread
└── Message C (user)
└── Message C' (assistant)
```
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
### Usage
```python
from core.memory.token_buffer_memory import TokenBufferMemory
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
```
---
## NodeTokenBufferMemory (To Be Implemented)
### Purpose
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
### Use Cases
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
3. **Specialized Agents**: Each agent node maintains its own dialogue history
### Design Decisions
#### Storage: Object Storage for Messages (No New Database Table)
| Aspect | Database | Object Storage |
| ------------------------- | -------------------- | ------------------ |
| Cost | High | Low |
| Query Flexibility | High | Low |
| Schema Changes | Migration required | None |
| Consistency with existing | ConversationVariable | File uploads, logs |
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
**What is stored in Object Storage:**
- Message content (text)
- Message metadata (role, token_count, created_at)
- File references (upload_file_id, tool_file_id, etc.)
- Thread relationships (message_id, parent_message_id)
**What still requires Database queries:**
- File reconstruction: When reading node memory, file references are used to query
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
complete `File` objects with storage_key, mime_type, etc.
**Why this hybrid approach:**
- No database migration required (no new tables)
- Message data may be large, object storage is cost-effective
- File metadata is already in database, no need to duplicate
- Aligns with existing storage patterns (file uploads, logs)
#### Storage Key Format
```
node_memory/{app_id}/{conversation_id}/{node_id}.json
```
#### Data Structure
```json
{
"version": 1,
"messages": [
{
"message_id": "msg-001",
"parent_message_id": null,
"role": "user",
"content": "Analyze this image",
"files": [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": "file-uuid-123",
"belongs_to": "user"
}
],
"token_count": 15,
"created_at": "2026-01-07T10:00:00Z"
},
{
"message_id": "msg-002",
"parent_message_id": "msg-001",
"role": "assistant",
"content": "This is a landscape image...",
"files": [],
"token_count": 50,
"created_at": "2026-01-07T10:00:01Z"
}
]
}
```
### Thread Support
Node memory also supports thread extraction (for regeneration scenarios):
```python
def _extract_thread(
self,
messages: list[NodeMemoryMessage],
current_message_id: str
) -> list[NodeMemoryMessage]:
"""
Extract messages belonging to the thread of current_message_id.
Similar to extract_thread_messages() in TokenBufferMemory.
"""
...
```
### File Handling
Files are stored as references (not full metadata):
```python
class NodeMemoryFile(BaseModel):
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None # for local_file
tool_file_id: str | None # for tool_file
url: str | None # for remote_url
belongs_to: str # user / assistant
```
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
### API Design
```python
class NodeTokenBufferMemory:
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param model_instance: Model instance for token counting
"""
...
def add_messages(
self,
message_id: str,
parent_message_id: str | None,
user_content: str,
user_files: Sequence[File],
assistant_content: str,
assistant_files: Sequence[File],
) -> None:
"""
Append a dialogue turn (user + assistant) to node memory.
Call this after LLM node execution completes.
:param message_id: Current message ID (from Message table)
:param parent_message_id: Parent message ID (for thread tracking)
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
...
def get_history_prompt_messages(
self,
current_message_id: str,
tenant_id: str,
max_token_limit: int = 2000,
file_upload_config: FileUploadConfig | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
:param current_message_id: Current message ID (for thread extraction)
:param tenant_id: Tenant ID (for file reconstruction)
:param max_token_limit: Maximum tokens for history
:param file_upload_config: File upload configuration
:return: Sequence of PromptMessage for LLM context
"""
...
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
...
def clear(self) -> None:
"""
Clear all messages in this node's memory.
"""
...
```
### Data Flow
```
Object Storage NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ storage.load(key) │ │
│◀─────────────────────────────────┤ │
│ │ │
│ JSON data │ │
├─────────────────────────────────▶│ │
│ │ │
│ _extract_thread() │
│ │ │
│ _rebuild_files() via file_factory │
│ │ │
│ _build_prompt_messages() │
│ │ │
│ _truncate_by_tokens() │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
│ │◀── LLM execution complete │
│ │ │
│ │◀── add_messages() │
│ │ │
│ storage.save(key, data) │ │
│◀─────────────────────────────────┤ │
│ │ │
```
### Integration with LLM Node
```python
# In LLM Node execution
# 1. Fetch memory based on mode
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
memory = fetch_node_memory(
variable_pool=variable_pool,
app_id=app_id,
node_id=self.node_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
# Conversation-level memory (existing behavior)
memory = fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
else:
memory = None
# 2. Get history for context
if memory:
if isinstance(memory, NodeTokenBufferMemory):
history = memory.get_history_prompt_messages(
current_message_id=current_message_id,
tenant_id=tenant_id,
max_token_limit=max_token_limit,
)
else: # TokenBufferMemory
history = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
)
prompt_messages = [*history, *current_messages]
else:
prompt_messages = current_messages
# 3. Call LLM
response = model_instance.invoke(prompt_messages)
# 4. Append to node memory (only for NodeTokenBufferMemory)
if isinstance(memory, NodeTokenBufferMemory):
memory.add_messages(
message_id=message_id,
parent_message_id=parent_message_id,
user_content=user_input,
user_files=user_files,
assistant_content=response.content,
assistant_files=response_files,
)
memory.flush()
```
### Configuration
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
```python
class MemoryMode(StrEnum):
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
class MemoryConfig(BaseModel):
# Existing fields
role_prefix: RolePrefix | None = None
window: MemoryWindowConfig | None = None
query_prompt_template: str | None = None
# Memory mode (new)
mode: MemoryMode = MemoryMode.CONVERSATION
```
**Mode Behavior:**
| Mode | Memory Class | Scope | Availability |
| -------------- | --------------------- | ------------------------ | ------------- |
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
> fall back to no memory or raise a configuration error.
---
## Comparison
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | Object Storage (JSON) |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via file references) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
---
## Future Considerations
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
2. **Concurrency**: Consider Redis lock for concurrent node executions
3. **Compression**: Compress large memory files to reduce storage costs
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory

View File

@@ -1,15 +0,0 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeMemoryData,
NodeMemoryFile,
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeMemoryData",
"NodeMemoryFile",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

View File

@@ -1,83 +0,0 @@
"""
Base memory interfaces and types.
This module defines the common protocol for memory implementations.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
class BaseMemory(ABC):
"""
Abstract base class for memory implementations.
Provides a common interface for both conversation-level and node-level memory.
"""
@abstractmethod
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Sequence of PromptMessage for LLM context
"""
pass
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt as formatted text.
:param human_prefix: Prefix for human messages
:param ai_prefix: Prefix for assistant messages
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Formatted history text
"""
from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -1,353 +0,0 @@
"""
Node-level Token Buffer Memory for Chatflow.
This module provides node-scoped memory within a conversation.
Each LLM node in a workflow can maintain its own independent conversation history.
Note: This is only available in Chatflow (advanced-chat mode) because it requires
both conversation_id and node_id.
Design:
- Storage is indexed by workflow_run_id (each execution stores one turn)
- Thread tracking leverages Message table's parent_message_id structure
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
"""
import logging
from collections.abc import Sequence
from pydantic import BaseModel
from sqlalchemy import select
from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message
logger = logging.getLogger(__name__)
class NodeMemoryFile(BaseModel):
"""File reference stored in node memory."""
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None = None
tool_file_id: str | None = None
url: str | None = None
class NodeMemoryTurn(BaseModel):
"""A single dialogue turn (user + assistant) in node memory."""
user_content: str = ""
user_files: list[NodeMemoryFile] = []
assistant_content: str = ""
assistant_files: list[NodeMemoryFile] = []
class NodeMemoryData(BaseModel):
"""Root data structure for node memory storage."""
version: int = 1
# Key: workflow_run_id, Value: dialogue turn
turns: dict[str, NodeMemoryTurn] = {}
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history, stored in object storage.
Key design: Thread tracking is delegated to Message table's parent_message_id.
Storage is indexed by workflow_run_id for easy filtering.
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
"""
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param tenant_id: Tenant ID for file reconstruction
:param model_instance: Model instance for token counting
"""
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
self._data: NodeMemoryData | None = None
self._dirty = False
def _load(self) -> NodeMemoryData:
"""Load data from object storage."""
if self._data is not None:
return self._data
try:
raw = storage.load_once(self._storage_key)
self._data = NodeMemoryData.model_validate_json(raw)
except Exception:
# File not found or parse error, start fresh
self._data = NodeMemoryData()
return self._data
def _save(self) -> None:
"""Save data to object storage."""
if self._data is not None:
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
self._dirty = False
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
"""Convert File object to NodeMemoryFile reference."""
return NodeMemoryFile(
type=file.type.value if hasattr(file.type, "value") else str(file.type),
transfer_method=(
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
),
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
)
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
"""Convert NodeMemoryFile to mapping for file_factory."""
mapping: dict = {
"type": memory_file.type,
"transfer_method": memory_file.transfer_method,
}
if memory_file.upload_file_id:
mapping["upload_file_id"] = memory_file.upload_file_id
if memory_file.tool_file_id:
mapping["tool_file_id"] = memory_file.tool_file_id
if memory_file.url:
mapping["url"] = memory_file.url
return mapping
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
"""Rebuild File objects from NodeMemoryFile references."""
if not memory_files:
return []
from factories import file_factory
files = []
for mf in memory_files:
try:
mapping = self._memory_file_to_mapping(mf)
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
files.append(file)
except Exception as e:
logger.warning("Failed to rebuild file from memory: %s", e)
continue
return files
def _build_prompt_message(
self,
role: str,
content: str,
files: list[File],
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
) -> PromptMessage:
"""Build PromptMessage from content and files."""
from core.file import file_manager
if not files:
if role == "user":
return UserPromptMessage(content=content)
else:
return AssistantPromptMessage(content=content)
# Build multimodal content
prompt_contents: list = []
for file in files:
try:
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
prompt_contents.append(prompt_content)
except Exception as e:
logger.warning("Failed to convert file to prompt content: %s", e)
continue
prompt_contents.append(TextPromptMessageContent(data=content))
if role == "user":
return UserPromptMessage(content=prompt_contents)
else:
return AssistantPromptMessage(content=prompt_contents)
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
# Query messages for this conversation
stmt = (
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
)
messages = db.session.scalars(stmt.limit(500)).all()
if not messages:
return []
# Extract thread messages using existing logic
thread_messages = extract_thread_messages(messages)
# For newly created message, its answer is temporarily empty, skip it
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
workflow_run_ids = []
for msg in reversed(thread_messages):
if msg.workflow_run_id:
workflow_run_ids.append(msg.workflow_run_id)
return workflow_run_ids
def add_messages(
self,
workflow_run_id: str,
user_content: str,
user_files: Sequence[File] | None = None,
assistant_content: str = "",
assistant_files: Sequence[File] | None = None,
) -> None:
"""
Add a dialogue turn to node memory.
Call this after LLM node execution completes.
:param workflow_run_id: Current workflow execution ID
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
data = self._load()
# Convert files to memory file references
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
# Store the turn indexed by workflow_run_id
data.turns[workflow_run_id] = NodeMemoryTurn(
user_content=user_content,
user_files=user_memory_files,
assistant_content=assistant_content,
assistant_files=assistant_memory_files,
)
self._dirty = True
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
Thread tracking is handled by querying Message table's parent_message_id structure.
:param max_token_limit: Maximum tokens for history
:param message_limit: unused, for interface compatibility
:return: Sequence of PromptMessage for LLM context
"""
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
_ = message_limit
detail = ImagePromptMessageContent.DETAIL.HIGH
data = self._load()
if not data.turns:
return []
# Get workflow_run_ids for current thread from Message table
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Build prompt messages in thread order
prompt_messages: list[PromptMessage] = []
for wf_run_id in thread_workflow_run_ids:
turn = data.turns.get(wf_run_id)
if not turn:
# This workflow execution didn't have node memory stored
continue
# Build user message
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
user_msg = self._build_prompt_message(
role="user",
content=turn.user_content,
files=user_files,
detail=detail,
)
prompt_messages.append(user_msg)
# Build assistant message
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
assistant_msg = self._build_prompt_message(
role="assistant",
content=turn.assistant_content,
files=assistant_files,
detail=detail,
)
prompt_messages.append(assistant_msg)
if not prompt_messages:
return []
# Truncate by token limit
try:
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
while current_tokens > max_token_limit and len(prompt_messages) > 1:
prompt_messages.pop(0)
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
except Exception as e:
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
if self._dirty:
self._save()
def clear(self) -> None:
"""Clear all messages in this node's memory."""
self._data = NodeMemoryData()
self._save()
def exists(self) -> bool:
"""Check if node memory exists in storage."""
return storage.exists(self._storage_key)

View File

@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory(BaseMemory):
class TokenBufferMemory:
def __init__(
self,
conversation: Conversation,
@@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
@@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@@ -1,4 +1,3 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
@@ -6,13 +5,6 @@ from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageRole
class MemoryMode(StrEnum):
"""Memory mode for LLM nodes."""
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class ChatModelMessage(BaseModel):
"""
Chat Message.
@@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,6 @@ class NodeType(StrEnum):
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
GROUP = "group"
@property
def is_trigger_node(self) -> bool:

View File

@@ -307,14 +307,7 @@ class Graph:
if not node_configs:
raise ValueError("Graph must have at least one node")
# Filter out UI-only node types:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config for node_config in node_configs
if node_config.get("type", "") != "custom-note"
and node_config.get("data", {}).get("type", "") != "group"
]
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@@ -125,11 +125,6 @@ class EventHandler:
Args:
event: The node started event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_started(event)
return
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
@@ -169,11 +164,6 @@ class EventHandler:
Args:
event: The node succeeded event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_success(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
@@ -236,11 +226,6 @@ class EventHandler:
Args:
event: The node failed event
"""
# Check if this is a virtual node (extraction node)
if self._is_virtual_node(event.node_id):
self._handle_virtual_node_failed(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
@@ -360,57 +345,3 @@ class EventHandler:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)
def _is_virtual_node(self, node_id: str) -> bool:
"""
Check if node_id represents a virtual sub-node.
Virtual nodes have IDs in the format: {parent_node_id}.{local_id}
We check if the part before '.' exists in graph nodes.
"""
if "." in node_id:
parent_id = node_id.rsplit(".", 1)[0]
return parent_id in self._graph.nodes
return False
def _handle_virtual_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle virtual node started event.
Virtual nodes don't need full execution tracking, just collect the event.
"""
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_success(self, event: NodeRunSucceededEvent) -> None:
"""
Handle virtual node success event.
Virtual nodes (extraction nodes) need special handling:
- Store outputs in variable pool (for reference by other nodes)
- Accumulate token usage
- Collect the event for logging
- Do NOT process edges or enqueue next nodes (parent node handles that)
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
def _handle_virtual_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle virtual node failed event.
Virtual nodes (extraction nodes) failures are collected for logging,
but the parent node is responsible for handling the error.
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Collect the event for logging
self._event_collector.collect(event)

View File

@@ -20,12 +20,6 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_type: str = ""
provider_id: str = ""
# Virtual node fields for extraction
is_virtual: bool = False
parent_node_id: str | None = None
extraction_source: str | None = None # e.g., "llm1.context"
extraction_prompt: str | None = None
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields

View File

@@ -1,13 +1,5 @@
from .entities import (
BaseIterationNodeData,
BaseIterationState,
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
VirtualNodeConfig,
)
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
from .virtual_node_executor import VirtualNodeExecutionError, VirtualNodeExecutor
__all__ = [
"BaseIterationNodeData",
@@ -16,7 +8,4 @@ __all__ = [
"BaseLoopState",
"BaseNodeData",
"LLMUsageTrackingMixin",
"VirtualNodeConfig",
"VirtualNodeExecutionError",
"VirtualNodeExecutor",
]

View File

@@ -167,24 +167,6 @@ class DefaultValue(BaseModel):
return self
class VirtualNodeConfig(BaseModel):
"""Configuration for a virtual sub-node embedded within a parent node."""
# Local ID within parent node (e.g., "ext_1")
# Will be converted to global ID: "{parent_id}.{id}"
id: str
# Node type (e.g., "llm", "code", "tool")
type: str
# Full node data configuration
data: dict[str, Any] = {}
def get_global_id(self, parent_node_id: str) -> str:
"""Get the global node ID by combining parent ID and local ID."""
return f"{parent_node_id}.{self.id}"
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
@@ -193,9 +175,6 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Virtual sub-nodes that execute before the main node
virtual_nodes: list[VirtualNodeConfig] = []
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:

View File

@@ -229,7 +229,6 @@ class Node(Generic[NodeDataT]):
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
self._virtual_node_outputs: dict[str, Any] = {} # Outputs from virtual sub-nodes
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
@@ -271,52 +270,10 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _execute_virtual_nodes(self) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual sub-nodes defined in node configuration.
Virtual nodes are complete node definitions that execute before the main node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool (via event handling)
- Supports retry via parent node's retry config
Returns:
dict mapping local_id -> outputs dict
"""
from .virtual_node_executor import VirtualNodeExecutor
virtual_nodes = self.node_data.virtual_nodes
if not virtual_nodes:
return {}
executor = VirtualNodeExecutor(
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
parent_node_id=self._node_id,
parent_retry_config=self.retry_config,
)
return (yield from executor.execute_virtual_nodes(virtual_nodes))
@property
def virtual_node_outputs(self) -> dict[str, Any]:
"""
Get the outputs from virtual sub-nodes.
Returns:
dict mapping local_id -> outputs dict
"""
return self._virtual_node_outputs
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute virtual sub-nodes before main node execution
self._virtual_node_outputs = yield from self._execute_virtual_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=execution_id,

View File

@@ -1,213 +0,0 @@
"""
Virtual Node Executor for running embedded sub-nodes within a parent node.
This module handles the execution of virtual nodes defined in a parent node's
`virtual_nodes` configuration. Virtual nodes are complete node definitions
that execute before the parent node.
Example configuration:
virtual_nodes:
- id: ext_1
type: llm
data:
model: {...}
prompt_template: [...]
"""
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Any
from uuid import uuid4
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunFailedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from libs.datetime_utils import naive_utc_now
from .entities import RetryConfig, VirtualNodeConfig
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class VirtualNodeExecutionError(Exception):
"""Error during virtual node execution"""
def __init__(self, node_id: str, original_error: Exception):
self.node_id = node_id
self.original_error = original_error
super().__init__(f"Virtual node {node_id} execution failed: {original_error}")
class VirtualNodeExecutor:
"""
Executes virtual sub-nodes embedded within a parent node.
Virtual nodes are complete node definitions that execute before the parent node.
Each virtual node:
- Has its own global ID: "{parent_id}.{local_id}"
- Generates standard node events
- Stores outputs in the variable pool
- Supports retry via parent node's retry config
"""
def __init__(
self,
*,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
parent_node_id: str,
parent_retry_config: RetryConfig | None = None,
):
self._graph_init_params = graph_init_params
self._graph_runtime_state = graph_runtime_state
self._parent_node_id = parent_node_id
self._parent_retry_config = parent_retry_config or RetryConfig()
def execute_virtual_nodes(
self,
virtual_nodes: list[VirtualNodeConfig],
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute all virtual nodes in order.
Args:
virtual_nodes: List of virtual node configurations
Yields:
Node events from each virtual node execution
Returns:
dict mapping local_id -> outputs dict
"""
results: dict[str, Any] = {}
for vnode_config in virtual_nodes:
global_id = vnode_config.get_global_id(self._parent_node_id)
# Execute with retry
outputs = yield from self._execute_with_retry(vnode_config, global_id)
results[vnode_config.id] = outputs
return results
def _execute_with_retry(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute virtual node with retry support.
"""
retry_config = self._parent_retry_config
last_error: Exception | None = None
for attempt in range(retry_config.max_retries + 1):
try:
return (yield from self._execute_single_node(vnode_config, global_id))
except Exception as e:
last_error = e
if attempt < retry_config.max_retries:
# Yield retry event
yield NodeRunRetryEvent(
id=str(uuid4()),
node_id=global_id,
node_type=self._get_node_type(vnode_config.type),
node_title=vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
start_at=naive_utc_now(),
error=str(e),
retry_index=attempt + 1,
)
time.sleep(retry_config.retry_interval_seconds)
continue
raise VirtualNodeExecutionError(global_id, e) from e
raise last_error or VirtualNodeExecutionError(global_id, Exception("Unknown error"))
def _execute_single_node(
self,
vnode_config: VirtualNodeConfig,
global_id: str,
) -> Generator[GraphNodeEventBase, None, dict[str, Any]]:
"""
Execute a single virtual node by instantiating and running it.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
# Build node config
node_config: dict[str, Any] = {
"id": global_id,
"data": {
**vnode_config.data,
"title": vnode_config.data.get("title", f"Virtual: {vnode_config.id}"),
},
}
# Get the node class for this type
node_type = self._get_node_type(vnode_config.type)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
node_version = str(vnode_config.data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
raise ValueError(f"No class found for node type: {node_type}")
# Instantiate the node
node = node_cls(
id=global_id,
config=node_config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self._graph_runtime_state,
)
# Run and collect events
outputs: dict[str, Any] = {}
for event in node.run():
# Mark event as coming from virtual node
self._mark_event_as_virtual(event, vnode_config)
yield event
if isinstance(event, NodeRunSucceededEvent):
outputs = event.node_run_result.outputs or {}
elif isinstance(event, NodeRunFailedEvent):
raise Exception(event.error or "Virtual node execution failed")
return outputs
def _mark_event_as_virtual(
self,
event: GraphNodeEventBase,
vnode_config: VirtualNodeConfig,
) -> None:
"""Mark event as coming from a virtual node."""
if isinstance(event, NodeRunStartedEvent):
event.is_virtual = True
event.parent_node_id = self._parent_node_id
def _get_node_type(self, type_str: str) -> NodeType:
"""Convert type string to NodeType enum."""
type_mapping = {
"llm": NodeType.LLM,
"code": NodeType.CODE,
"tool": NodeType.TOOL,
"if-else": NodeType.IF_ELSE,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"template-transform": NodeType.TEMPLATE_TRANSFORM,
"variable-assigner": NodeType.VARIABLE_ASSIGNER,
"http-request": NodeType.HTTP_REQUEST,
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
}
return type_mapping.get(type_str, NodeType.LLM)

View File

@@ -8,13 +8,12 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
@@ -87,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool,
app_id: str,
tenant_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
node_id: str = "",
) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
:param variable_pool: Variable pool containing system variables
:param app_id: Application ID
:param tenant_id: Tenant ID
:param node_data_memory: Memory configuration
:param model_instance: Model instance for token counting
:param node_id: Node ID in the workflow (required for node mode)
:return: Memory instance or None if not applicable
"""
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# Get conversation_id from variable pool (required for both modes in Chatflow)
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if node_data_memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
if not node_id:
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return NodeTokenBufferMemory(
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
tenant_id=tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):

View File

@@ -16,8 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
@@ -209,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
@@ -304,41 +301,12 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Write to Node Memory if in node memory mode
if isinstance(memory, NodeTokenBufferMemory):
# Get workflow_run_id as the key for this execution
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
workflow_run_id = workflow_run_id_var.value if isinstance(workflow_run_id_var, StringSegment) else ""
if workflow_run_id:
# Resolve the query template to get actual user content
# query may be a template like "{{#sys.query#}}" or "{{#node_id.output#}}"
actual_query = variable_pool.convert_template(query or "").text
# Get user files from sys.files
user_files_var = variable_pool.get(["sys", SystemVariableKey.FILES])
user_files: list[File] = []
if isinstance(user_files_var, ArrayFileSegment):
user_files = list(user_files_var.value)
elif isinstance(user_files_var, FileSegment):
user_files = [user_files_var.value]
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=actual_query,
user_files=user_files,
assistant_content=clean_text,
assistant_files=self._file_outputs,
)
memory.flush()
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@@ -598,22 +566,6 @@ class LLMNode(Node[LLMNodeData]):
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
@staticmethod
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
model_mode: str,
) -> list[dict[str, Any]]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
"""
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode, prompt_messages=context_messages
)
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@@ -826,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
@@ -1385,7 +1337,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
@@ -1402,7 +1354,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -42,22 +43,25 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
Draft7Validator(schema).validate(value)
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value

View File

@@ -89,20 +89,18 @@ class ToolNode(Node[ToolNodeData]):
)
return
# get parameters (use virtual_node_outputs from base class)
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
virtual_node_outputs=self.virtual_node_outputs,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
virtual_node_outputs=self.virtual_node_outputs,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -178,7 +176,6 @@ class ToolNode(Node[ToolNodeData]):
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
virtual_node_outputs: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@@ -187,17 +184,12 @@ class ToolNode(Node[ToolNodeData]):
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
for_log (bool): Whether to generate parameters for logging.
virtual_node_outputs (dict[str, Any] | None): Outputs from virtual sub-nodes.
Maps local_id -> outputs dict. Virtual node outputs are also in variable_pool
with global IDs like "{parent_id}.{local_id}".
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
virtual_node_outputs = virtual_node_outputs or {}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
@@ -207,25 +199,14 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
# Check if this references a virtual node output (local ID like [ext_1, text])
selector = tool_input.value
if len(selector) >= 2 and selector[0] in virtual_node_outputs:
# Reference to virtual node output
local_id = selector[0]
var_name = selector[1]
outputs = virtual_node_outputs.get(local_id, {})
parameter_value = outputs.get(var_name)
else:
# Normal variable reference
variable = variable_pool.get(selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
continue
parameter_value = variable.value
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
template = str(tool_input.value)
segment_group = variable_pool.convert_template(template)
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")

View File

@@ -0,0 +1,74 @@
"""
Workspace permission helper functions.
These helpers check both billing/plan level and workspace-specific policy level permissions.
Checks are performed at two levels:
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
"""
import logging
from werkzeug.exceptions import Forbidden
from configs import dify_config
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
def check_workspace_member_invite_permission(workspace_id: str) -> None:
"""
Check if workspace allows member invitations at both billing and policy levels.
Checks performed:
1. Billing/plan level - For future expansion (currently no plan-level restriction)
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits member invitations
"""
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_member_invite:
raise Forbidden("Workspace policy prohibits member invitations")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
"""
Check if workspace allows owner transfer at both billing and policy levels.
Checks performed:
1. Billing/plan level - SANDBOX plan blocks owner transfer
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
"""
features = FeatureService.get_features(workspace_id)
if not features.is_allow_transfer_workspace:
raise Forbidden("Your current plan does not allow workspace ownership transfer")
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_owner_transfer:
raise Forbidden("Workspace policy prohibits ownership transfer")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)

View File

@@ -1364,6 +1364,11 @@ class RegisterService:
raise ValueError("Inviter is required")
"""Invite new member"""
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
account = session.query(Account).filter_by(email=email).first()

View File

@@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
)
class WorkspacePermission(BaseModel):
workspace_id: str = Field(
description="The ID of the workspace.",
alias="workspaceId",
)
allow_member_invite: bool = Field(
description="Whether to allow members to invite new members to the workspace.",
default=False,
alias="allowMemberInvite",
)
allow_owner_transfer: bool = Field(
description="Whether to allow owners to transfer ownership of the workspace.",
default=False,
alias="allowOwnerTransfer",
)
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -44,6 +61,16 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WorkspacePermissionService:
@classmethod
def get_permission(cls, workspace_id: str):
if not workspace_id:
raise ValueError("workspace_id must be provided.")
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
if not data or "permission" not in data:
raise ValueError("No data found.")
return WorkspacePermission.model_validate(data["permission"])
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):

View File

@@ -1,266 +0,0 @@
app:
description: Test for variable extraction feature
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: pav-test-extraction
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- local_file
- remote_url
enabled: false
image:
enabled: false
number_limits: 3
transfer_methods:
- local_file
- remote_url
number_limits: 3
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1767773675796-llm
source: '1767773675796'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
isInIteration: false
isInLoop: false
sourceType: llm
targetType: tool
id: llm-source-1767773709491-target
source: llm
sourceHandle: source
target: '1767773709491'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: answer
id: tool-source-answer-target
source: '1767773709491'
sourceHandle: source
target: answer
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1767773675796'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
query_prompt_template: ''
role_prefix:
assistant: ''
user: ''
window:
enabled: true
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
prompt_template:
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
role: system
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
请与用户进行对话,了解他们的搜索需求。
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
'
selected: false
title: LLM
type: llm
vision:
enabled: false
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: used for searching
ja_JP: used for searching
pt_BR: used for searching
zh_Hans: 用于搜索网页内容
label:
en_US: Query string
ja_JP: Query string
pt_BR: Query string
zh_Hans: 查询语句
llm_description: key words for searching
max: null
min: null
name: query
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
params:
query: ''
plugin_id: langgenius/google
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
provider_id: langgenius/google/google
provider_name: langgenius/google/google
provider_type: builtin
selected: false
title: GoogleSearch
tool_configurations: {}
tool_description: A tool for performing a Google SERP search and extracting
snippets and webpages.Input should be a search query.
tool_label: GoogleSearch
tool_name: google_search
tool_node_version: '2'
tool_parameters:
query:
type: variable
value:
- ext_1
- text
type: tool
virtual_nodes:
- data:
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
context:
enabled: false
prompt_template:
- role: user
text: '{{#llm.context#}}'
- role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
title: 提取搜索关键词
id: ext_1
type: llm
height: 52
id: '1767773709491'
position:
x: 682
y: 282
positionAbsolute:
x: 682
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '搜索结果:
{{#1767773709491.text#}}
'
selected: false
title: Answer
type: answer
height: 103
id: answer
position:
x: 984
y: 282
positionAbsolute:
x: 984
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: 151
y: 141.5
zoom: 1
rag_pipeline_variables: []

View File

@@ -0,0 +1,390 @@
"""
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
specifically testing the ANSWER node message_replace logic.
"""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
from core.app.entities.queue_entities import QueueNodeSucceededEvent
from core.workflow.enums import NodeType
from models import EndUser
from models.model import AppMode
class TestAnswerNodeMessageReplace:
"""Test cases for ANSWER node message_replace event logic."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=AdvancedChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
entity.workflow_run_id = "test-workflow-run-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.ADVANCED_CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
entity.query = "test query"
entity.files = []
entity.extras = {}
entity.trace_manager = None
entity.inputs = {}
entity.invoke_from = "debugger"
return entity
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow."""
workflow = Mock()
workflow.id = "test-workflow-id"
workflow.features_dict = {}
return workflow
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock()
manager.listen.return_value = []
manager.graph_runtime_state = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "advanced_chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.query = "test query"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_user(self):
"""Create a mock end user."""
user = MagicMock(spec=EndUser)
user.id = "test-user-id"
user.session_id = "test-session-id"
return user
@pytest.fixture
def mock_draft_var_saver_factory(self):
"""Create a mock draft variable saver factory."""
return Mock()
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_workflow,
mock_queue_manager,
mock_conversation,
mock_message,
mock_user,
mock_draft_var_saver_factory,
):
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
workflow=mock_workflow,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
user=mock_user,
stream=True,
dialogue_count=1,
draft_var_saver_factory=mock_draft_var_saver_factory,
)
# Initialize workflow run id to avoid validation errors
pipeline._workflow_run_id = "test-workflow-run-id"
# Mock the message cycle manager methods we need to track
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
return pipeline
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
"""
Test that when an ANSWER node's final output differs from accumulated answer,
a message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "initial answer"
# Create ANSWER node succeeded event with different final output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "updated final answer"},
)
# Mock the workflow response converter to avoid extra processing
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
responses = list(pipeline._handle_node_succeeded_event(event))
# Assert
assert pipeline._task_state.answer == "updated final answer"
# Verify message_replace was called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="updated final answer", reason="variable_update"
)
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's final output is the same as accumulated answer,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "same answer"
# Create ANSWER node succeeded event with same output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "same answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "same answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's output is None or missing 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with None output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": None},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node has empty outputs dict,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with empty outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
"""
Test that when an ANSWER node's outputs don't contain 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event without 'answer' key in outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"other_key": "some value"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
"""
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Test with LLM node
llm_event = QueueNodeSucceededEvent(
node_execution_id="test-llm-execution-id",
node_id="test-llm-node",
node_type=NodeType.LLM,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(llm_event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_end_node_does_not_send_message_replace(self, pipeline):
"""
Test that END nodes don't trigger message_replace events even with 'answer' output.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create END node succeeded event with answer output
event = QueueNodeSucceededEvent(
node_execution_id="test-end-execution-id",
node_id="test-end-node",
node_type=NodeType.END,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
"""
Test that when an ANSWER node's final output is numeric,
it gets converted to string properly.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "text answer"
# Create ANSWER node succeeded event with numeric output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": 12345},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should be converted to string
assert pipeline._task_state.answer == "12345"
# Verify message_replace was called with string
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="12345", reason="variable_update"
)
def test_answer_node_files_are_recorded(self, pipeline):
"""
Test that ANSWER nodes properly record files from outputs.
"""
# Arrange
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with files
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={
"answer": "same answer",
"files": [
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
],
},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: files should be recorded
assert len(pipeline._recorded_files) == 1
assert pipeline._recorded_files[0] == event.outputs["files"][0]

View File

@@ -1,77 +0,0 @@
"""
Unit tests for virtual node configuration.
"""
from core.workflow.nodes.base.entities import VirtualNodeConfig
class TestVirtualNodeConfig:
"""Tests for VirtualNodeConfig entity."""
def test_create_basic_config(self):
"""Test creating a basic virtual node config."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={
"title": "Extract keywords",
"model": {"provider": "openai", "name": "gpt-4o-mini"},
},
)
assert config.id == "ext_1"
assert config.type == "llm"
assert config.data["title"] == "Extract keywords"
def test_get_global_id(self):
"""Test generating global ID from parent ID."""
config = VirtualNodeConfig(
id="ext_1",
type="llm",
data={},
)
global_id = config.get_global_id("tool1")
assert global_id == "tool1.ext_1"
def test_get_global_id_with_different_parents(self):
"""Test global ID generation with different parent IDs."""
config = VirtualNodeConfig(id="sub_node", type="code", data={})
assert config.get_global_id("parent1") == "parent1.sub_node"
assert config.get_global_id("node_123") == "node_123.sub_node"
def test_empty_data(self):
"""Test virtual node config with empty data."""
config = VirtualNodeConfig(
id="test",
type="tool",
)
assert config.id == "test"
assert config.type == "tool"
assert config.data == {}
def test_complex_data(self):
"""Test virtual node config with complex data."""
config = VirtualNodeConfig(
id="llm_1",
type="llm",
data={
"title": "Generate summary",
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {"temperature": 0.7, "max_tokens": 500},
},
"prompt_template": [
{"role": "user", "text": "{{#llm1.context#}}"},
{"role": "user", "text": "Please summarize the conversation"},
],
},
)
assert config.data["model"]["provider"] == "openai"
assert len(config.data["prompt_template"]) == 2

View File

@@ -58,8 +58,6 @@ def test_json_object_valid_schema():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -70,7 +68,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
node = make_start_node(user_inputs, variables)
result = node._run()
@@ -89,8 +87,6 @@ def test_json_object_invalid_json_string():
"required": ["age", "name"],
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -101,12 +97,12 @@ def test_json_object_invalid_json_string():
)
]
# Providing a string instead of an object should raise a type error
# Missing closing brace makes this invalid JSON
user_inputs = {"profile": '{"age": 20, "name": "Tom"'}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"):
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
node._run()
@@ -122,8 +118,6 @@ def test_json_object_does_not_match_schema():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -135,7 +129,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
node = make_start_node(user_inputs, variables)
@@ -155,8 +149,6 @@ def test_json_object_missing_required_schema_field():
}
)
schema = json.loads(schema)
variables = [
VariableEntity(
variable="profile",
@@ -168,7 +160,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": {"age": 20}}
user_inputs = {"profile": json.dumps({"age": 20})}
node = make_start_node(user_inputs, variables)

View File

@@ -0,0 +1,142 @@
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.workspace_permission import (
check_workspace_member_invite_permission,
check_workspace_owner_transfer_permission,
)
class TestWorkspacePermissionHelper:
"""Test workspace permission helper functions."""
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.EnterpriseService")
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
"""Community edition should always allow invitations without calling any service."""
mock_config.ENTERPRISE_ENABLED = False
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
# EnterpriseService should NOT be called in community edition
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
"""Community edition should check billing plan but not call enterprise service."""
mock_config.ENTERPRISE_ENABLED = False
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True
mock_feature_service.get_features.return_value = mock_features
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should block invitations when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = False
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should allow invitations when workspace policy is True."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = True
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
mock_feature_service.get_features.return_value = mock_features
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
# Enterprise service should NOT be called since billing plan already blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
"""Enterprise edition should block transfer when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = False # Workspace policy blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_allows_transfer_when_both_enabled(
self, mock_feature_service, mock_config, mock_enterprise_service
):
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = True # Workspace policy allows
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.logger")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
"""On enterprise service error, should fail-open (allow) and log error."""
mock_config.ENTERPRISE_ENABLED = True
# Simulate enterprise service error
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
# Should not raise (fail-open)
check_workspace_member_invite_permission("test-workspace-id")
# Should log the error
mock_logger.exception.assert_called_once()
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)

View File

@@ -2,11 +2,11 @@ import Marketplace from '@/app/components/plugins/marketplace'
import PluginPage from '@/app/components/plugins/plugin-page'
import PluginsPanel from '@/app/components/plugins/plugin-page/plugins-panel'
const PluginList = () => {
const PluginList = async () => {
return (
<PluginPage
plugins={<PluginsPanel />}
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" />}
marketplace={<Marketplace pluginTypeSwitchClassName="top-[60px]" showSearchParams={false} />}
/>
)
}

View File

@@ -26,7 +26,6 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps'
import { useInvalidateAppList } from '@/service/use-apps'
import { fetchWorkflowDraft } from '@/service/workflow'
import { AppModeEnum } from '@/types/app'
import { getRedirection } from '@/utils/app-redirection'
@@ -67,7 +66,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
const { onPlanInfoChanged } = useProviderContext()
const appDetail = useAppStore(state => state.appDetail)
const setAppDetail = useAppStore(state => state.setAppDetail)
const invalidateAppList = useInvalidateAppList()
const [open, setOpen] = useState(openState)
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
@@ -193,7 +191,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
try {
await deleteApp(appDetail.id)
notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) })
invalidateAppList()
onPlanInfoChanged()
setAppDetail()
replace('/apps')
@@ -205,7 +202,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
})
}
setShowConfirmDelete(false)
}, [appDetail, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t])
}, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t])
const { isCurrentWorkspaceEditor } = useAppContext()

View File

@@ -83,7 +83,7 @@ const ConfigModal: FC<IConfigModalProps> = ({
if (!isJsonObject || !tempPayload.json_schema)
return ''
try {
return tempPayload.json_schema
return JSON.stringify(JSON.parse(tempPayload.json_schema), null, 2)
}
catch {
return ''

View File

@@ -10,7 +10,6 @@ const mockReplace = vi.fn()
const mockRouter = { replace: mockReplace }
vi.mock('next/navigation', () => ({
useRouter: () => mockRouter,
useSearchParams: () => new URLSearchParams(''),
}))
// Mock app context

View File

@@ -12,7 +12,6 @@ import { useDebounceFn } from 'ahooks'
import dynamic from 'next/dynamic'
import {
useRouter,
useSearchParams,
} from 'next/navigation'
import { parseAsString, useQueryState } from 'nuqs'
import { useCallback, useEffect, useRef, useState } from 'react'
@@ -37,16 +36,6 @@ import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import NewAppCard from './new-app-card'
// Define valid tabs at module scope to avoid re-creation on each render and stale closures
const validTabs = new Set<string | AppModeEnum>([
'all',
AppModeEnum.WORKFLOW,
AppModeEnum.ADVANCED_CHAT,
AppModeEnum.CHAT,
AppModeEnum.AGENT_CHAT,
AppModeEnum.COMPLETION,
])
const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), {
ssr: false,
})
@@ -58,41 +47,12 @@ const List = () => {
const { t } = useTranslation()
const { systemFeatures } = useGlobalPublicStore()
const router = useRouter()
const searchParams = useSearchParams()
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext()
const showTagManagementModal = useTagStore(s => s.showTagManagementModal)
const [activeTab, setActiveTab] = useQueryState(
'category',
parseAsString.withDefault('all').withOptions({ history: 'push' }),
)
// valid tabs for apps list; anything else should fallback to 'all'
// 1) Normalize legacy/incorrect query params like ?mode=discover -> ?category=all
useEffect(() => {
// avoid running on server
if (typeof window === 'undefined')
return
const mode = searchParams.get('mode')
if (!mode)
return
const url = new URL(window.location.href)
url.searchParams.delete('mode')
if (validTabs.has(mode)) {
// migrate to category key
url.searchParams.set('category', mode)
}
else {
url.searchParams.set('category', 'all')
}
router.replace(url.pathname + url.search)
}, [router, searchParams])
// 2) If category has an invalid value (e.g., 'discover'), reset to 'all'
useEffect(() => {
if (!validTabs.has(activeTab))
setActiveTab('all')
}, [activeTab, setActiveTab])
const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState()
const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe)
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)

View File

@@ -37,7 +37,7 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
return
}
if (inputValue == null)
if (!inputValue)
return
if (item.type === InputVarType.singleFile) {
@@ -52,20 +52,6 @@ export const getProcessedInputs = (inputs: Record<string, any>, inputsForm: Inpu
else
processedInputs[item.variable] = getProcessedFiles(inputValue)
}
else if (item.type === InputVarType.jsonObject) {
// Prefer sending an object if the user entered valid JSON; otherwise keep the raw string.
try {
const v = typeof inputValue === 'string' ? JSON.parse(inputValue) : inputValue
if (v && typeof v === 'object' && !Array.isArray(v))
processedInputs[item.variable] = v
else
processedInputs[item.variable] = inputValue
}
catch {
// keep original string; backend will parse/validate
processedInputs[item.variable] = inputValue
}
}
})
return processedInputs

View File

@@ -5,7 +5,6 @@ import type {
} from 'lexical'
import type { FC } from 'react'
import type {
AgentBlockType,
ContextBlockType,
CurrentBlockType,
ErrorMessageBlockType,
@@ -104,7 +103,6 @@ export type PromptEditorProps = {
currentBlock?: CurrentBlockType
errorMessageBlock?: ErrorMessageBlockType
lastRunBlock?: LastRunBlockType
agentBlock?: AgentBlockType
isSupportFileVar?: boolean
}
@@ -130,7 +128,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
currentBlock,
errorMessageBlock,
lastRunBlock,
agentBlock,
isSupportFileVar,
}) => {
const { eventEmitter } = useEventEmitterContextContext()
@@ -142,7 +139,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
{
replace: TextNode,
with: (node: TextNode) => new CustomTextNode(node.__text),
withKlass: CustomTextNode,
},
ContextBlockNode,
HistoryBlockNode,
@@ -216,22 +212,6 @@ const PromptEditor: FC<PromptEditorProps> = ({
lastRunBlock={lastRunBlock}
isSupportFileVar={isSupportFileVar}
/>
{(!agentBlock || agentBlock.show) && (
<ComponentPickerBlock
triggerString="@"
contextBlock={contextBlock}
historyBlock={historyBlock}
queryBlock={queryBlock}
variableBlock={variableBlock}
externalToolBlock={externalToolBlock}
workflowVariableBlock={workflowVariableBlock}
currentBlock={currentBlock}
errorMessageBlock={errorMessageBlock}
lastRunBlock={lastRunBlock}
agentBlock={agentBlock}
isSupportFileVar={isSupportFileVar}
/>
)}
<ComponentPickerBlock
triggerString="{"
contextBlock={contextBlock}

View File

@@ -1,7 +1,6 @@
import type { MenuRenderFn } from '@lexical/react/LexicalTypeaheadMenuPlugin'
import type { TextNode } from 'lexical'
import type {
AgentBlockType,
ContextBlockType,
CurrentBlockType,
ErrorMessageBlockType,
@@ -21,11 +20,7 @@ import {
} from '@floating-ui/react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
import {
$getRoot,
$insertNodes,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import { KEY_ESCAPE_COMMAND } from 'lexical'
import {
Fragment,
memo,
@@ -34,9 +29,7 @@ import {
} from 'react'
import ReactDOM from 'react-dom'
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
import AgentNodeList from '@/app/components/workflow/nodes/_base/components/agent-node-list'
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
import { BlockEnum } from '@/app/components/workflow/types'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useBasicTypeaheadTriggerMatch } from '../../hooks'
import { $splitNodeContainingQuery } from '../../utils'
@@ -45,7 +38,6 @@ import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../error-message-block'
import { INSERT_LAST_RUN_BLOCK_COMMAND } from '../last-run-block'
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '../variable-block'
import { INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND } from '../workflow-variable-block'
import { $createWorkflowVariableBlockNode } from '../workflow-variable-block/node'
import { useOptions } from './hooks'
type ComponentPickerProps = {
@@ -59,7 +51,6 @@ type ComponentPickerProps = {
currentBlock?: CurrentBlockType
errorMessageBlock?: ErrorMessageBlockType
lastRunBlock?: LastRunBlockType
agentBlock?: AgentBlockType
isSupportFileVar?: boolean
}
const ComponentPicker = ({
@@ -73,7 +64,6 @@ const ComponentPicker = ({
currentBlock,
errorMessageBlock,
lastRunBlock,
agentBlock,
isSupportFileVar,
}: ComponentPickerProps) => {
const { eventEmitter } = useEventEmitterContextContext()
@@ -161,41 +151,12 @@ const ComponentPicker = ({
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
}, [editor])
const handleSelectAgent = useCallback((agent: { id: string, title: string }) => {
editor.update(() => {
const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!)
if (needRemove)
needRemove.remove()
const root = $getRoot()
const firstChild = root.getFirstChild()
if (firstChild) {
const selection = firstChild.selectStart()
if (selection) {
const workflowVariableBlockNode = $createWorkflowVariableBlockNode([agent.id, 'text'], {}, undefined)
$insertNodes([workflowVariableBlockNode])
}
}
})
agentBlock?.onSelect?.(agent)
handleClose()
}, [editor, checkForTriggerMatch, triggerString, agentBlock, handleClose])
const isAgentTrigger = triggerString === '@' && agentBlock?.show
const agentNodes = agentBlock?.agentNodes || []
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
anchorElementRef,
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) => {
if (isAgentTrigger) {
if (!(anchorElementRef.current && agentNodes.length))
return null
}
else {
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
}
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
setTimeout(() => {
if (anchorElementRef.current)
@@ -206,6 +167,9 @@ const ComponentPicker = ({
<>
{
ReactDOM.createPortal(
// The `LexicalMenu` will try to calculate the position of the floating menu based on the first child.
// Since we use floating ui, we need to wrap it with a div to prevent the position calculation being affected.
// See https://github.com/facebook/lexical/blob/ac97dfa9e14a73ea2d6934ff566282d7f758e8bb/packages/lexical-react/src/shared/LexicalMenu.ts#L493
<div className="h-0 w-0">
<div
className="w-[260px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
@@ -215,73 +179,56 @@ const ComponentPicker = ({
}}
ref={refs.setFloating}
>
{isAgentTrigger
? (
<AgentNodeList
nodes={agentNodes.map(node => ({
...node,
type: BlockEnum.Agent,
}))}
onSelect={handleSelectAgent}
{
workflowVariableBlock?.show && (
<div className="p-1">
<VarReferenceVars
searchBoxClassName="mt-1"
vars={workflowVariableOptions}
onChange={(variables: string[]) => {
handleSelectWorkflowVariable(variables)
}}
maxHeightClass="max-h-[34vh]"
isSupportFileVar={isSupportFileVar}
onClose={handleClose}
onBlur={handleClose}
maxHeightClass="max-h-[34vh]"
showManageInputField={workflowVariableBlock.showManageInputField}
onManageInputField={workflowVariableBlock.onManageInputField}
autoFocus={false}
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
/>
)
: (
<>
</div>
)
}
{
workflowVariableBlock?.show && !!options.length && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
<div>
{
options.map((option, index) => (
<Fragment key={option.key}>
{
workflowVariableBlock?.show && (
<div className="p-1">
<VarReferenceVars
searchBoxClassName="mt-1"
vars={workflowVariableOptions}
onChange={(variables: string[]) => {
handleSelectWorkflowVariable(variables)
}}
maxHeightClass="max-h-[34vh]"
isSupportFileVar={isSupportFileVar}
onClose={handleClose}
onBlur={handleClose}
showManageInputField={workflowVariableBlock.showManageInputField}
onManageInputField={workflowVariableBlock.onManageInputField}
autoFocus={false}
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
/>
</div>
)
}
{
workflowVariableBlock?.show && !!options.length && (
// Divider
index !== 0 && options.at(index - 1)?.group !== option.group && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
<div>
{
options.map((option, index) => (
<Fragment key={option.key}>
{
index !== 0 && options.at(index - 1)?.group !== option.group && (
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
)
}
{option.renderMenuOption({
queryString,
isSelected: selectedIndex === index,
onSelect: () => {
selectOptionAndCleanUp(option)
},
onSetHighlight: () => {
setHighlightedIndex(index)
},
})}
</Fragment>
))
}
</div>
</>
)}
{option.renderMenuOption({
queryString,
isSelected: selectedIndex === index,
onSelect: () => {
selectOptionAndCleanUp(option)
},
onSetHighlight: () => {
setHighlightedIndex(index)
},
})}
</Fragment>
))
}
</div>
</div>
</div>,
anchorElementRef.current,
@@ -289,7 +236,7 @@ const ComponentPicker = ({
}
</>
)
}, [isAgentTrigger, agentNodes, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, handleSelectAgent, handleClose, workflowVariableOptions, isSupportFileVar, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
return (
<LexicalTypeaheadMenuPlugin

View File

@@ -21,7 +21,6 @@ import {
VariableLabelInEditor,
} from '@/app/components/workflow/nodes/_base/components/variable/variable-label'
import { Type } from '@/app/components/workflow/nodes/llm/types'
import { BlockEnum } from '@/app/components/workflow/types'
import { isExceptionVariable } from '@/app/components/workflow/utils'
import { useSelectOrDelete } from '../../hooks'
import {
@@ -67,7 +66,6 @@ const WorkflowVariableBlockComponent = ({
)()
const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState<WorkflowNodesMap>(workflowNodesMap)
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
const isAgentVariable = node?.type === BlockEnum.Agent
const isException = isExceptionVariable(varName, node?.type)
const variableValid = useMemo(() => {
@@ -136,9 +134,6 @@ const WorkflowVariableBlockComponent = ({
})
}, [node, reactflow, store])
if (isAgentVariable)
return <span className="hidden" ref={ref} />
const Item = (
<VariableLabelInEditor
nodeType={node?.type}

View File

@@ -73,17 +73,6 @@ export type WorkflowVariableBlockType = {
onManageInputField?: () => void
}
export type AgentNode = {
id: string
title: string
}
export type AgentBlockType = {
show?: boolean
agentNodes?: AgentNode[]
onSelect?: (agent: AgentNode) => void
}
export type MenuTextMatch = {
leadOffset: number
matchingString: string

View File

@@ -1,29 +1,47 @@
import type { FC } from 'react'
import type { FullDocumentDetail } from '@/models/datasets'
import type { RETRIEVE_METHOD } from '@/types/app'
import type {
DataSourceInfo,
FullDocumentDetail,
IndexingStatusResponse,
LegacyDataSourceInfo,
ProcessRuleResponse,
} from '@/models/datasets'
import {
RiArrowRightLine,
RiCheckboxCircleFill,
RiErrorWarningFill,
RiLoader2Fill,
RiTerminalBoxLine,
} from '@remixicon/react'
import Image from 'next/image'
import Link from 'next/link'
import { useRouter } from 'next/navigation'
import { useMemo } from 'react'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider'
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
import NotionIcon from '@/app/components/base/notion-icon'
import Tooltip from '@/app/components/base/tooltip'
import PriorityLabel from '@/app/components/billing/priority-label'
import { Plan } from '@/app/components/billing/type'
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
import { useProviderContext } from '@/context/provider-context'
import { useDatasetApiAccessUrl } from '@/hooks/use-api-access-url'
import { DataSourceType, ProcessMode } from '@/models/datasets'
import { fetchIndexingStatusBatch as doFetchIndexingStatus } from '@/service/datasets'
import { useProcessRule } from '@/service/knowledge/use-dataset'
import { useInvalidDocumentList } from '@/service/knowledge/use-document'
import IndexingProgressItem from './indexing-progress-item'
import RuleDetail from './rule-detail'
import UpgradeBanner from './upgrade-banner'
import { useIndexingStatusPolling } from './use-indexing-status-polling'
import { createDocumentLookup } from './utils'
import { RETRIEVE_METHOD } from '@/types/app'
import { sleep } from '@/utils'
import { cn } from '@/utils/classnames'
import DocumentFileIcon from '../../common/document-file-icon'
import { indexMethodIcon, retrievalIcon } from '../icons'
import { IndexingType } from '../step-two'
type EmbeddingProcessProps = {
type Props = {
datasetId: string
batchId: string
documents?: FullDocumentDetail[]
@@ -31,121 +49,333 @@ type EmbeddingProcessProps = {
retrievalMethod?: RETRIEVE_METHOD
}
// Status header component
const StatusHeader: FC<{ isEmbedding: boolean, isCompleted: boolean }> = ({
isEmbedding,
isCompleted,
}) => {
const RuleDetail: FC<{
sourceData?: ProcessRuleResponse
indexingType?: string
retrievalMethod?: RETRIEVE_METHOD
}> = ({ sourceData, indexingType, retrievalMethod }) => {
const { t } = useTranslation()
const segmentationRuleMap = {
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
}
const getRuleName = (key: string) => {
if (key === 'remove_extra_spaces')
return t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' })
if (key === 'remove_urls_emails')
return t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' })
if (key === 'remove_stopwords')
return t('stepTwo.removeStopwords', { ns: 'datasetCreation' })
}
const isNumber = (value: unknown) => {
return typeof value === 'number'
}
const getValue = useCallback((field: string) => {
let value: string | number | undefined = '-'
const maxTokens = isNumber(sourceData?.rules?.segmentation?.max_tokens)
? sourceData.rules.segmentation.max_tokens
: value
const childMaxTokens = isNumber(sourceData?.rules?.subchunk_segmentation?.max_tokens)
? sourceData.rules.subchunk_segmentation.max_tokens
: value
switch (field) {
case 'mode':
value = !sourceData?.mode
? value
: sourceData.mode === ProcessMode.general
? (t('embedding.custom', { ns: 'datasetDocuments' }) as string)
: `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${sourceData?.rules?.parent_mode === 'paragraph'
? t('parentMode.paragraph', { ns: 'dataset' })
: t('parentMode.fullDoc', { ns: 'dataset' })}`
break
case 'segmentLength':
value = !sourceData?.mode
? value
: sourceData.mode === ProcessMode.general
? maxTokens
: `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
break
default:
value = !sourceData?.mode
? value
: sourceData?.rules?.pre_processing_rules?.filter(rule =>
rule.enabled).map(rule => getRuleName(rule.id)).join(',')
break
}
return value
}, [sourceData])
return (
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
{isEmbedding && (
<>
<RiLoader2Fill className="size-4 animate-spin" />
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
</>
)}
{isCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
<div className="flex flex-col gap-1">
{Object.keys(segmentationRuleMap).map((field) => {
return (
<FieldInfo
key={field}
label={segmentationRuleMap[field as keyof typeof segmentationRuleMap]}
displayedValue={String(getValue(field))}
/>
)
})}
<FieldInfo
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
displayedValue={t(`stepTwo.${indexingType === IndexingType.ECONOMICAL ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' }) as string}
valueIcon={(
<Image
className="size-4"
src={
indexingType === IndexingType.ECONOMICAL
? indexMethodIcon.economical
: indexMethodIcon.high_quality
}
alt=""
/>
)}
/>
<FieldInfo
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
// displayedValue={t(`datasetSettings.form.retrievalSetting.${retrievalMethod}`) as string}
displayedValue={t(`retrieval.${indexingType === IndexingType.ECONOMICAL ? 'keyword_search' : retrievalMethod ?? 'semantic_search'}.title`, { ns: 'dataset' })}
valueIcon={(
<Image
className="size-4"
src={
retrievalMethod === RETRIEVE_METHOD.fullText
? retrievalIcon.fullText
: retrievalMethod === RETRIEVE_METHOD.hybrid
? retrievalIcon.hybrid
: retrievalIcon.vector
}
alt=""
/>
)}
/>
</div>
)
}
// Action buttons component
const ActionButtons: FC<{
apiReferenceUrl: string
onNavToDocuments: () => void
}> = ({ apiReferenceUrl, onNavToDocuments }) => {
const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], indexingType, retrievalMethod }) => {
const { t } = useTranslation()
return (
<div className="mt-6 flex items-center gap-x-2 py-2">
<Link href={apiReferenceUrl} target="_blank" rel="noopener noreferrer">
<Button className="w-fit gap-x-0.5 px-3">
<RiTerminalBoxLine className="size-4" />
<span className="px-0.5">Access the API</span>
</Button>
</Link>
<Button
className="w-fit gap-x-0.5 px-3"
variant="primary"
onClick={onNavToDocuments}
>
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
</Button>
</div>
)
}
const EmbeddingProcess: FC<EmbeddingProcessProps> = ({
datasetId,
batchId,
documents = [],
indexingType,
retrievalMethod,
}) => {
const { enableBilling, plan } = useProviderContext()
const getFirstDocument = documents[0]
const [indexingStatusBatchDetail, setIndexingStatusDetail] = useState<IndexingStatusResponse[]>([])
const fetchIndexingStatus = async () => {
const status = await doFetchIndexingStatus({ datasetId, batchId })
setIndexingStatusDetail(status.data)
return status.data
}
const [isStopQuery, setIsStopQuery] = useState(false)
const isStopQueryRef = useRef(isStopQuery)
useEffect(() => {
isStopQueryRef.current = isStopQuery
}, [isStopQuery])
const stopQueryStatus = () => {
setIsStopQuery(true)
}
const startQueryStatus = async () => {
if (isStopQueryRef.current)
return
try {
const indexingStatusBatchDetail = await fetchIndexingStatus()
const isCompleted = indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail.indexing_status))
if (isCompleted) {
stopQueryStatus()
return
}
await sleep(2500)
await startQueryStatus()
}
catch {
await sleep(2500)
await startQueryStatus()
}
}
useEffect(() => {
setIsStopQuery(false)
startQueryStatus()
return () => {
stopQueryStatus()
}
}, [])
// get rule
const { data: ruleDetail } = useProcessRule(getFirstDocument?.id)
const router = useRouter()
const invalidDocumentList = useInvalidDocumentList()
const apiReferenceUrl = useDatasetApiAccessUrl()
// Polling hook for indexing status
const { statusList, isEmbedding, isEmbeddingCompleted } = useIndexingStatusPolling({
datasetId,
batchId,
})
// Get process rule for the first document
const firstDocumentId = documents[0]?.id
const { data: ruleDetail } = useProcessRule(firstDocumentId)
// Document lookup utilities - memoized for performance
const documentLookup = useMemo(
() => createDocumentLookup(documents),
[documents],
)
const handleNavToDocuments = () => {
const navToDocumentList = () => {
invalidDocumentList()
router.push(`/datasets/${datasetId}/documents`)
}
const apiReferenceUrl = useDatasetApiAccessUrl()
const showUpgradeBanner = enableBilling && plan.type !== Plan.team
const isEmbedding = useMemo(() => {
return indexingStatusBatchDetail.some(indexingStatusDetail => ['indexing', 'splitting', 'parsing', 'cleaning'].includes(indexingStatusDetail?.indexing_status || ''))
}, [indexingStatusBatchDetail])
const isEmbeddingCompleted = useMemo(() => {
return indexingStatusBatchDetail.every(indexingStatusDetail => ['completed', 'error', 'paused'].includes(indexingStatusDetail?.indexing_status || ''))
}, [indexingStatusBatchDetail])
const getSourceName = (id: string) => {
const doc = documents.find(document => document.id === id)
return doc?.name
}
const getFileType = (name?: string) => name?.split('.').pop() || 'txt'
const getSourcePercent = (detail: IndexingStatusResponse) => {
const completedCount = detail.completed_segments || 0
const totalCount = detail.total_segments || 0
if (totalCount === 0)
return 0
const percent = Math.round(completedCount * 100 / totalCount)
return percent > 100 ? 100 : percent
}
const getSourceType = (id: string) => {
const doc = documents.find(document => document.id === id)
return doc?.data_source_type as DataSourceType
}
const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
}
const getIcon = (id: string) => {
const doc = documents.find(document => document.id === id)
const info = doc?.data_source_info
if (info && isLegacyDataSourceInfo(info))
return info.notion_page_icon
return undefined
}
const isSourceEmbedding = (detail: IndexingStatusResponse) =>
['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'].includes(detail.indexing_status || '')
return (
<>
<div className="flex flex-col gap-y-3">
<StatusHeader isEmbedding={isEmbedding} isCompleted={isEmbeddingCompleted} />
{showUpgradeBanner && <UpgradeBanner />}
<div className="system-md-semibold-uppercase flex items-center gap-x-1 text-text-secondary">
{isEmbedding && (
<>
<RiLoader2Fill className="size-4 animate-spin" />
<span>{t('embedding.processing', { ns: 'datasetDocuments' })}</span>
</>
)}
{isEmbeddingCompleted && t('embedding.completed', { ns: 'datasetDocuments' })}
</div>
{
enableBilling && plan.type !== Plan.team && (
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
<ZapFast className="h-4 w-4 text-[#FB6514]" />
</div>
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
</div>
<UpgradeBtn loc="knowledge-speed-up" />
</div>
)
}
<div className="flex flex-col gap-0.5 pb-2">
{statusList.map(detail => (
<IndexingProgressItem
key={detail.id}
detail={detail}
name={documentLookup.getName(detail.id)}
sourceType={documentLookup.getSourceType(detail.id)}
notionIcon={documentLookup.getNotionIcon(detail.id)}
enableBilling={enableBilling}
/>
{indexingStatusBatchDetail.map(indexingStatusDetail => (
<div
key={indexingStatusDetail.id}
className={cn(
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
indexingStatusDetail.indexing_status === 'error' && 'bg-state-destructive-hover-alt',
)}
>
{isSourceEmbedding(indexingStatusDetail) && (
<div
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
style={{ width: `${getSourcePercent(indexingStatusDetail)}%` }}
/>
)}
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
{getSourceType(indexingStatusDetail.id) === DataSourceType.FILE && (
<DocumentFileIcon
size="sm"
className="shrink-0"
name={getSourceName(indexingStatusDetail.id)}
extension={getFileType(getSourceName(indexingStatusDetail.id))}
/>
)}
{getSourceType(indexingStatusDetail.id) === DataSourceType.NOTION && (
<NotionIcon
className="shrink-0"
type="page"
src={getIcon(indexingStatusDetail.id)}
/>
)}
<div className="flex w-0 grow items-center gap-1" title={getSourceName(indexingStatusDetail.id)}>
<div className="system-xs-medium truncate text-text-secondary">
{getSourceName(indexingStatusDetail.id)}
</div>
{
enableBilling && (
<PriorityLabel className="ml-0" />
)
}
</div>
{isSourceEmbedding(indexingStatusDetail) && (
<div className="shrink-0 text-xs text-text-secondary">{`${getSourcePercent(indexingStatusDetail)}%`}</div>
)}
{indexingStatusDetail.indexing_status === 'error' && (
<Tooltip
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
offset={4}
popupContent={indexingStatusDetail.error}
>
<span>
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
</span>
</Tooltip>
)}
{indexingStatusDetail.indexing_status === 'completed' && (
<RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
)}
</div>
</div>
))}
</div>
<Divider type="horizontal" className="my-0 bg-divider-subtle" />
<RuleDetail
sourceData={ruleDetail}
indexingType={indexingType}
retrievalMethod={retrievalMethod}
/>
</div>
<ActionButtons
apiReferenceUrl={apiReferenceUrl}
onNavToDocuments={handleNavToDocuments}
/>
<div className="mt-6 flex items-center gap-x-2 py-2">
<Link
href={apiReferenceUrl}
target="_blank"
rel="noopener noreferrer"
>
<Button
className="w-fit gap-x-0.5 px-3"
>
<RiTerminalBoxLine className="size-4" />
<span className="px-0.5">Access the API</span>
</Button>
</Link>
<Button
className="w-fit gap-x-0.5 px-3"
variant="primary"
onClick={navToDocumentList}
>
<span className="px-0.5">{t('stepThree.navTo', { ns: 'datasetCreation' })}</span>
<RiArrowRightLine className="size-4 stroke-current stroke-1" />
</Button>
</div>
</>
)
}

View File

@@ -1,120 +0,0 @@
import type { FC } from 'react'
import type { IndexingStatusResponse } from '@/models/datasets'
import {
RiCheckboxCircleFill,
RiErrorWarningFill,
} from '@remixicon/react'
import NotionIcon from '@/app/components/base/notion-icon'
import Tooltip from '@/app/components/base/tooltip'
import PriorityLabel from '@/app/components/billing/priority-label'
import { DataSourceType } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import DocumentFileIcon from '../../common/document-file-icon'
import { getFileType, getSourcePercent, isSourceEmbedding } from './utils'
type IndexingProgressItemProps = {
detail: IndexingStatusResponse
name?: string
sourceType?: DataSourceType
notionIcon?: string
enableBilling?: boolean
}
// Status icon component for completed/error states
const StatusIcon: FC<{ status: string, error?: string }> = ({ status, error }) => {
if (status === 'completed')
return <RiCheckboxCircleFill className="size-4 shrink-0 text-text-success" />
if (status === 'error') {
return (
<Tooltip
popupClassName="px-4 py-[14px] max-w-60 body-xs-regular text-text-secondary border-[0.5px] border-components-panel-border rounded-xl"
offset={4}
popupContent={error}
>
<span>
<RiErrorWarningFill className="size-4 shrink-0 text-text-destructive" />
</span>
</Tooltip>
)
}
return null
}
// Source type icon component
const SourceTypeIcon: FC<{
sourceType?: DataSourceType
name?: string
notionIcon?: string
}> = ({ sourceType, name, notionIcon }) => {
if (sourceType === DataSourceType.FILE) {
return (
<DocumentFileIcon
size="sm"
className="shrink-0"
name={name}
extension={getFileType(name)}
/>
)
}
if (sourceType === DataSourceType.NOTION) {
return (
<NotionIcon
className="shrink-0"
type="page"
src={notionIcon}
/>
)
}
return null
}
const IndexingProgressItem: FC<IndexingProgressItemProps> = ({
detail,
name,
sourceType,
notionIcon,
enableBilling,
}) => {
const isEmbedding = isSourceEmbedding(detail)
const percent = getSourcePercent(detail)
const isError = detail.indexing_status === 'error'
return (
<div
className={cn(
'relative h-[26px] overflow-hidden rounded-md bg-components-progress-bar-bg',
isError && 'bg-state-destructive-hover-alt',
)}
>
{isEmbedding && (
<div
className="absolute left-0 top-0 h-full min-w-0.5 border-r-[2px] border-r-components-progress-bar-progress-highlight bg-components-progress-bar-progress"
style={{ width: `${percent}%` }}
/>
)}
<div className="z-[1] flex h-full items-center gap-1 pl-[6px] pr-2">
<SourceTypeIcon
sourceType={sourceType}
name={name}
notionIcon={notionIcon}
/>
<div className="flex w-0 grow items-center gap-1" title={name}>
<div className="system-xs-medium truncate text-text-secondary">
{name}
</div>
{enableBilling && <PriorityLabel className="ml-0" />}
</div>
{isEmbedding && (
<div className="shrink-0 text-xs text-text-secondary">{`${percent}%`}</div>
)}
<StatusIcon status={detail.indexing_status} error={detail.error} />
</div>
</div>
)
}
export default IndexingProgressItem

View File

@@ -1,133 +0,0 @@
import type { FC } from 'react'
import type { ProcessRuleResponse } from '@/models/datasets'
import Image from 'next/image'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
import { ProcessMode } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import { indexMethodIcon, retrievalIcon } from '../icons'
import { IndexingType } from '../step-two'
type RuleDetailProps = {
sourceData?: ProcessRuleResponse
indexingType?: string
retrievalMethod?: RETRIEVE_METHOD
}
// Lookup table for pre-processing rule names
const PRE_PROCESSING_RULE_KEYS = {
remove_extra_spaces: 'stepTwo.removeExtraSpaces',
remove_urls_emails: 'stepTwo.removeUrlEmails',
remove_stopwords: 'stepTwo.removeStopwords',
} as const
// Lookup table for retrieval method icons
const RETRIEVAL_ICON_MAP: Partial<Record<RETRIEVE_METHOD, string>> = {
[RETRIEVE_METHOD.fullText]: retrievalIcon.fullText,
[RETRIEVE_METHOD.hybrid]: retrievalIcon.hybrid,
[RETRIEVE_METHOD.semantic]: retrievalIcon.vector,
[RETRIEVE_METHOD.invertedIndex]: retrievalIcon.fullText,
[RETRIEVE_METHOD.keywordSearch]: retrievalIcon.fullText,
}
const isNumber = (value: unknown): value is number => typeof value === 'number'
const RuleDetail: FC<RuleDetailProps> = ({ sourceData, indexingType, retrievalMethod }) => {
const { t } = useTranslation()
const segmentationRuleLabels = {
mode: t('embedding.mode', { ns: 'datasetDocuments' }),
segmentLength: t('embedding.segmentLength', { ns: 'datasetDocuments' }),
textCleaning: t('embedding.textCleaning', { ns: 'datasetDocuments' }),
}
const getRuleName = useCallback((key: string): string | undefined => {
const translationKey = PRE_PROCESSING_RULE_KEYS[key as keyof typeof PRE_PROCESSING_RULE_KEYS]
return translationKey ? t(translationKey, { ns: 'datasetCreation' }) : undefined
}, [t])
const getModeValue = useCallback((): string => {
if (!sourceData?.mode)
return '-'
if (sourceData.mode === ProcessMode.general)
return t('embedding.custom', { ns: 'datasetDocuments' })
const parentModeLabel = sourceData.rules?.parent_mode === 'paragraph'
? t('parentMode.paragraph', { ns: 'dataset' })
: t('parentMode.fullDoc', { ns: 'dataset' })
return `${t('embedding.hierarchical', { ns: 'datasetDocuments' })} · ${parentModeLabel}`
}, [sourceData, t])
const getSegmentLengthValue = useCallback((): string | number => {
if (!sourceData?.mode)
return '-'
const maxTokens = isNumber(sourceData.rules?.segmentation?.max_tokens)
? sourceData.rules.segmentation.max_tokens
: '-'
if (sourceData.mode === ProcessMode.general)
return maxTokens
const childMaxTokens = isNumber(sourceData.rules?.subchunk_segmentation?.max_tokens)
? sourceData.rules.subchunk_segmentation.max_tokens
: '-'
return `${t('embedding.parentMaxTokens', { ns: 'datasetDocuments' })} ${maxTokens}; ${t('embedding.childMaxTokens', { ns: 'datasetDocuments' })} ${childMaxTokens}`
}, [sourceData, t])
const getTextCleaningValue = useCallback((): string => {
if (!sourceData?.mode)
return '-'
const enabledRules = sourceData.rules?.pre_processing_rules?.filter(rule => rule.enabled) || []
const ruleNames = enabledRules
.map((rule) => {
const name = getRuleName(rule.id)
return typeof name === 'string' ? name : ''
})
.filter(name => name)
return ruleNames.length > 0 ? ruleNames.join(',') : '-'
}, [sourceData, getRuleName])
const fieldValueGetters: Record<string, () => string | number> = {
mode: getModeValue,
segmentLength: getSegmentLengthValue,
textCleaning: getTextCleaningValue,
}
const isEconomical = indexingType === IndexingType.ECONOMICAL
const indexMethodIconSrc = isEconomical ? indexMethodIcon.economical : indexMethodIcon.high_quality
const indexModeLabel = t(`stepTwo.${isEconomical ? 'economical' : 'qualified'}`, { ns: 'datasetCreation' })
const effectiveRetrievalMethod = isEconomical ? 'keyword_search' : (retrievalMethod ?? 'semantic_search')
const retrievalLabel = t(`retrieval.${effectiveRetrievalMethod}.title`, { ns: 'dataset' })
const retrievalIconSrc = RETRIEVAL_ICON_MAP[retrievalMethod as keyof typeof RETRIEVAL_ICON_MAP] ?? retrievalIcon.vector
return (
<div className="flex flex-col gap-1">
{Object.keys(segmentationRuleLabels).map(field => (
<FieldInfo
key={field}
label={segmentationRuleLabels[field as keyof typeof segmentationRuleLabels]}
displayedValue={String(fieldValueGetters[field]())}
/>
))}
<FieldInfo
label={t('stepTwo.indexMode', { ns: 'datasetCreation' })}
displayedValue={indexModeLabel}
valueIcon={<Image className="size-4" src={indexMethodIconSrc} alt="" />}
/>
<FieldInfo
label={t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
displayedValue={retrievalLabel}
valueIcon={<Image className="size-4" src={retrievalIconSrc} alt="" />}
/>
</div>
)
}
export default RuleDetail

View File

@@ -1,22 +0,0 @@
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general'
import UpgradeBtn from '@/app/components/billing/upgrade-btn'
const UpgradeBanner: FC = () => {
const { t } = useTranslation()
return (
<div className="flex h-14 items-center rounded-xl border-[0.5px] border-black/5 bg-white p-3 shadow-md">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg bg-[#FFF6ED]">
<ZapFast className="h-4 w-4 text-[#FB6514]" />
</div>
<div className="mx-3 grow text-[13px] font-medium text-gray-700">
{t('plansCommon.documentProcessingPriorityUpgrade', { ns: 'billing' })}
</div>
<UpgradeBtn loc="knowledge-speed-up" />
</div>
)
}
export default UpgradeBanner

View File

@@ -1,90 +0,0 @@
import type { IndexingStatusResponse } from '@/models/datasets'
import { useEffect, useRef, useState } from 'react'
import { fetchIndexingStatusBatch } from '@/service/datasets'
const POLLING_INTERVAL = 2500
const COMPLETED_STATUSES = ['completed', 'error', 'paused'] as const
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
type IndexingStatusPollingParams = {
datasetId: string
batchId: string
}
type IndexingStatusPollingResult = {
statusList: IndexingStatusResponse[]
isEmbedding: boolean
isEmbeddingCompleted: boolean
}
const isStatusCompleted = (status: string): boolean =>
COMPLETED_STATUSES.includes(status as typeof COMPLETED_STATUSES[number])
const isAllCompleted = (statusList: IndexingStatusResponse[]): boolean =>
statusList.every(item => isStatusCompleted(item.indexing_status))
/**
* Custom hook for polling indexing status with automatic stop on completion.
* Handles the polling lifecycle and provides derived states for UI rendering.
*/
export const useIndexingStatusPolling = ({
datasetId,
batchId,
}: IndexingStatusPollingParams): IndexingStatusPollingResult => {
const [statusList, setStatusList] = useState<IndexingStatusResponse[]>([])
const isStopPollingRef = useRef(false)
useEffect(() => {
// Reset polling state on mount
isStopPollingRef.current = false
let timeoutId: ReturnType<typeof setTimeout> | null = null
const fetchStatus = async (): Promise<IndexingStatusResponse[]> => {
const response = await fetchIndexingStatusBatch({ datasetId, batchId })
setStatusList(response.data)
return response.data
}
const poll = async (): Promise<void> => {
if (isStopPollingRef.current)
return
try {
const data = await fetchStatus()
if (isAllCompleted(data)) {
isStopPollingRef.current = true
return
}
}
catch {
// Continue polling on error
}
if (!isStopPollingRef.current) {
timeoutId = setTimeout(() => {
poll()
}, POLLING_INTERVAL)
}
}
poll()
return () => {
isStopPollingRef.current = true
if (timeoutId)
clearTimeout(timeoutId)
}
}, [datasetId, batchId])
const isEmbedding = statusList.some(item =>
EMBEDDING_STATUSES.includes(item?.indexing_status as typeof EMBEDDING_STATUSES[number]),
)
const isEmbeddingCompleted = statusList.length > 0 && isAllCompleted(statusList)
return {
statusList,
isEmbedding,
isEmbeddingCompleted,
}
}

View File

@@ -1,64 +0,0 @@
import type {
DataSourceInfo,
DataSourceType,
FullDocumentDetail,
IndexingStatusResponse,
LegacyDataSourceInfo,
} from '@/models/datasets'
const EMBEDDING_STATUSES = ['indexing', 'splitting', 'parsing', 'cleaning', 'waiting'] as const
/**
* Type guard for legacy data source info with upload_file property
*/
export const isLegacyDataSourceInfo = (info: DataSourceInfo): info is LegacyDataSourceInfo => {
return info != null && typeof (info as LegacyDataSourceInfo).upload_file === 'object'
}
/**
* Check if a status indicates the source is being embedded
*/
export const isSourceEmbedding = (detail: IndexingStatusResponse): boolean =>
EMBEDDING_STATUSES.includes(detail.indexing_status as typeof EMBEDDING_STATUSES[number])
/**
* Calculate the progress percentage for a document
*/
export const getSourcePercent = (detail: IndexingStatusResponse): number => {
const completedCount = detail.completed_segments || 0
const totalCount = detail.total_segments || 0
if (totalCount === 0)
return 0
const percent = Math.round(completedCount * 100 / totalCount)
return Math.min(percent, 100)
}
/**
* Get file extension from filename, defaults to 'txt'
*/
export const getFileType = (name?: string): string =>
name?.split('.').pop() || 'txt'
/**
* Document lookup utilities - provides document info by ID from a list
*/
export const createDocumentLookup = (documents: FullDocumentDetail[]) => {
const documentMap = new Map(documents.map(doc => [doc.id, doc]))
return {
getDocument: (id: string) => documentMap.get(id),
getName: (id: string) => documentMap.get(id)?.name,
getSourceType: (id: string) => documentMap.get(id)?.data_source_type as DataSourceType | undefined,
getNotionIcon: (id: string) => {
const info = documentMap.get(id)?.data_source_info
if (info && isLegacyDataSourceInfo(info))
return info.notion_page_icon
return undefined
},
}
}

View File

@@ -1,199 +0,0 @@
'use client'
import type { FC } from 'react'
import type { PreProcessingRule } from '@/models/datasets'
import {
RiAlertFill,
RiSearchEyeLine,
} from '@remixicon/react'
import Image from 'next/image'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import SettingCog from '../../assets/setting-gear-mod.svg'
import s from '../index.module.css'
import LanguageSelect from '../language-select'
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import { OptionCard } from './option-card'
type TextLabelProps = {
children: React.ReactNode
}
const TextLabel: FC<TextLabelProps> = ({ children }) => {
return <label className="system-sm-semibold text-text-secondary">{children}</label>
}
type GeneralChunkingOptionsProps = {
// State
segmentIdentifier: string
maxChunkLength: number
overlap: number
rules: PreProcessingRule[]
currentDocForm: ChunkingMode
docLanguage: string
// Flags
isActive: boolean
isInUpload: boolean
isNotUploadInEmptyDataset: boolean
hasCurrentDatasetDocForm: boolean
// Actions
onSegmentIdentifierChange: (value: string) => void
onMaxChunkLengthChange: (value: number) => void
onOverlapChange: (value: number) => void
onRuleToggle: (id: string) => void
onDocFormChange: (form: ChunkingMode) => void
onDocLanguageChange: (lang: string) => void
onPreview: () => void
onReset: () => void
// Locale
locale: string
}
export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
segmentIdentifier,
maxChunkLength,
overlap,
rules,
currentDocForm,
docLanguage,
isActive,
isInUpload,
isNotUploadInEmptyDataset,
hasCurrentDatasetDocForm,
onSegmentIdentifierChange,
onMaxChunkLengthChange,
onOverlapChange,
onRuleToggle,
onDocFormChange,
onDocLanguageChange,
onPreview,
onReset,
locale,
}) => {
const { t } = useTranslation()
const getRuleName = (key: string): string => {
const ruleNameMap: Record<string, string> = {
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
}
return ruleNameMap[key] ?? key
}
return (
<OptionCard
className="mb-2 bg-background-section"
title={t('stepTwo.general', { ns: 'datasetCreation' })}
icon={<Image width={20} height={20} src={SettingCog} alt={t('stepTwo.general', { ns: 'datasetCreation' })} />}
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
description={t('stepTwo.generalTip', { ns: 'datasetCreation' })}
isActive={isActive}
onSwitched={() => onDocFormChange(ChunkingMode.text)}
actions={(
<>
<Button variant="secondary-accent" onClick={onPreview}>
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
</Button>
<Button variant="ghost" onClick={onReset}>
{t('stepTwo.reset', { ns: 'datasetCreation' })}
</Button>
</>
)}
noHighlight={isInUpload && isNotUploadInEmptyDataset}
>
<div className="flex flex-col gap-y-4">
<div className="flex gap-3">
<DelimiterInput
value={segmentIdentifier}
onChange={e => onSegmentIdentifierChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={maxChunkLength}
onChange={onMaxChunkLengthChange}
/>
<OverlapInput
unit="characters"
value={overlap}
min={1}
onChange={onOverlapChange}
/>
</div>
<div className="flex w-full flex-col">
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1">
{rules.map(rule => (
<div
key={rule.id}
className={s.ruleItem}
onClick={() => onRuleToggle(rule.id)}
>
<Checkbox checked={rule.enabled} />
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{getRuleName(rule.id)}
</label>
</div>
))}
{IS_CE_EDITION && (
<>
<Divider type="horizontal" className="my-4 bg-divider-subtle" />
<div className="flex items-center py-0.5">
<div
className="flex items-center"
onClick={() => {
if (hasCurrentDatasetDocForm)
return
if (currentDocForm === ChunkingMode.qa)
onDocFormChange(ChunkingMode.text)
else
onDocFormChange(ChunkingMode.qa)
}}
>
<Checkbox
checked={currentDocForm === ChunkingMode.qa}
disabled={hasCurrentDatasetDocForm}
/>
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{t('stepTwo.useQALanguage', { ns: 'datasetCreation' })}
</label>
</div>
<LanguageSelect
currentLanguage={docLanguage || locale}
onSelect={onDocLanguageChange}
disabled={currentDocForm !== ChunkingMode.qa}
/>
<Tooltip popupContent={t('stepTwo.QATip', { ns: 'datasetCreation' })} />
</div>
{currentDocForm === ChunkingMode.qa && (
<div
style={{
background: 'linear-gradient(92deg, rgba(247, 144, 9, 0.1) 0%, rgba(255, 255, 255, 0.00) 100%)',
}}
className="mt-2 flex h-10 items-center gap-2 rounded-xl border border-components-panel-border px-3 text-xs shadow-xs backdrop-blur-[5px]"
>
<RiAlertFill className="size-4 text-text-warning-secondary" />
<span className="system-xs-medium text-text-primary">
{t('stepTwo.QATip', { ns: 'datasetCreation' })}
</span>
</div>
)}
</>
)}
</div>
</div>
</div>
</OptionCard>
)
}

View File

@@ -1,5 +0,0 @@
export { GeneralChunkingOptions } from './general-chunking-options'
export { IndexingModeSection } from './indexing-mode-section'
export { ParentChildOptions } from './parent-child-options'
export { PreviewPanel } from './preview-panel'
export { StepTwoFooter } from './step-two-footer'

View File

@@ -1,253 +0,0 @@
'use client'
import type { FC } from 'react'
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import Image from 'next/image'
import Link from 'next/link'
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import Button from '@/app/components/base/button'
import CustomDialog from '@/app/components/base/dialog'
import Divider from '@/app/components/base/divider'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { useDocLink } from '@/context/i18n'
import { ChunkingMode } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import { indexMethodIcon } from '../../icons'
import { IndexingType } from '../hooks'
import s from '../index.module.css'
import { OptionCard } from './option-card'
type IndexingModeSectionProps = {
// State
indexType: IndexingType
hasSetIndexType: boolean
docForm: ChunkingMode
embeddingModel: DefaultModel
embeddingModelList?: Model[]
retrievalConfig: RetrievalConfig
showMultiModalTip: boolean
// Flags
isModelAndRetrievalConfigDisabled: boolean
datasetId?: string
// Modal state
isQAConfirmDialogOpen: boolean
// Actions
onIndexTypeChange: (type: IndexingType) => void
onEmbeddingModelChange: (model: DefaultModel) => void
onRetrievalConfigChange: (config: RetrievalConfig) => void
onQAConfirmDialogClose: () => void
onQAConfirmDialogConfirm: () => void
}
export const IndexingModeSection: FC<IndexingModeSectionProps> = ({
indexType,
hasSetIndexType,
docForm,
embeddingModel,
embeddingModelList,
retrievalConfig,
showMultiModalTip,
isModelAndRetrievalConfigDisabled,
datasetId,
isQAConfirmDialogOpen,
onIndexTypeChange,
onEmbeddingModelChange,
onRetrievalConfigChange,
onQAConfirmDialogClose,
onQAConfirmDialogConfirm,
}) => {
const { t } = useTranslation()
const docLink = useDocLink()
const getIndexingTechnique = () => indexType
return (
<>
{/* Index Mode */}
<div className="system-md-semibold mb-1 text-text-secondary">
{t('stepTwo.indexMode', { ns: 'datasetCreation' })}
</div>
<div className="flex items-center gap-2">
{/* Qualified option */}
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.QUALIFIED)) && (
<OptionCard
className="flex-1 self-stretch"
title={(
<div className="flex items-center">
{t('stepTwo.qualified', { ns: 'datasetCreation' })}
<Badge
className={cn(
'ml-1 h-[18px]',
(!hasSetIndexType && indexType === IndexingType.QUALIFIED)
? 'border-text-accent-secondary text-text-accent-secondary'
: '',
)}
uppercase
>
{t('stepTwo.recommend', { ns: 'datasetCreation' })}
</Badge>
<span className="ml-auto">
{!hasSetIndexType && <span className={cn(s.radio)} />}
</span>
</div>
)}
description={t('stepTwo.qualifiedTip', { ns: 'datasetCreation' })}
icon={<Image src={indexMethodIcon.high_quality} alt="" />}
isActive={!hasSetIndexType && indexType === IndexingType.QUALIFIED}
disabled={hasSetIndexType}
onSwitched={() => onIndexTypeChange(IndexingType.QUALIFIED)}
/>
)}
{/* Economical option */}
{(!hasSetIndexType || (hasSetIndexType && indexType === IndexingType.ECONOMICAL)) && (
<>
<CustomDialog show={isQAConfirmDialogOpen} onClose={onQAConfirmDialogClose} className="w-[432px]">
<header className="mb-4 pt-6">
<h2 className="text-lg font-semibold text-text-primary">
{t('stepTwo.qaSwitchHighQualityTipTitle', { ns: 'datasetCreation' })}
</h2>
<p className="mt-2 text-sm font-normal text-text-secondary">
{t('stepTwo.qaSwitchHighQualityTipContent', { ns: 'datasetCreation' })}
</p>
</header>
<div className="flex gap-2 pb-6">
<Button className="ml-auto" onClick={onQAConfirmDialogClose}>
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
</Button>
<Button variant="primary" onClick={onQAConfirmDialogConfirm}>
{t('stepTwo.switch', { ns: 'datasetCreation' })}
</Button>
</div>
</CustomDialog>
<Tooltip
popupContent={(
<div className="rounded-lg border-components-panel-border bg-components-tooltip-bg p-3 text-xs font-medium text-text-secondary shadow-lg">
{docForm === ChunkingMode.qa
? t('stepTwo.notAvailableForQA', { ns: 'datasetCreation' })
: t('stepTwo.notAvailableForParentChild', { ns: 'datasetCreation' })}
</div>
)}
noDecoration
position="top"
asChild={false}
triggerClassName="flex-1 self-stretch"
>
<OptionCard
className="h-full"
title={t('stepTwo.economical', { ns: 'datasetCreation' })}
description={t('stepTwo.economicalTip', { ns: 'datasetCreation' })}
icon={<Image src={indexMethodIcon.economical} alt="" />}
isActive={!hasSetIndexType && indexType === IndexingType.ECONOMICAL}
disabled={hasSetIndexType || docForm !== ChunkingMode.text}
onSwitched={() => onIndexTypeChange(IndexingType.ECONOMICAL)}
/>
</Tooltip>
</>
)}
</div>
{/* High quality tip */}
{!hasSetIndexType && indexType === IndexingType.QUALIFIED && (
<div className="mt-2 flex h-10 items-center gap-x-0.5 overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-2 shadow-xs backdrop-blur-[5px]">
<div className="absolute bottom-0 left-0 right-0 top-0 bg-dataset-warning-message-bg opacity-40"></div>
<div className="p-1">
<AlertTriangle className="size-4 text-text-warning-secondary" />
</div>
<span className="system-xs-medium text-text-primary">
{t('stepTwo.highQualityTip', { ns: 'datasetCreation' })}
</span>
</div>
)}
{/* Economical index setting tip */}
{hasSetIndexType && indexType === IndexingType.ECONOMICAL && (
<div className="system-xs-medium mt-2 text-text-tertiary">
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
</Link>
</div>
)}
{/* Embedding model */}
{indexType === IndexingType.QUALIFIED && (
<div className="mt-5">
<div className={cn('system-md-semibold mb-1 text-text-secondary', datasetId && 'flex items-center justify-between')}>
{t('form.embeddingModel', { ns: 'datasetSettings' })}
</div>
<ModelSelector
readonly={isModelAndRetrievalConfigDisabled}
triggerClassName={isModelAndRetrievalConfigDisabled ? 'opacity-50' : ''}
defaultModel={embeddingModel}
modelList={embeddingModelList ?? []}
onSelect={onEmbeddingModelChange}
/>
{isModelAndRetrievalConfigDisabled && (
<div className="system-xs-medium mt-2 text-text-tertiary">
{t('stepTwo.indexSettingTip', { ns: 'datasetCreation' })}
<Link className="text-text-accent" href={`/datasets/${datasetId}/settings`}>
{t('stepTwo.datasetSettingLink', { ns: 'datasetCreation' })}
</Link>
</div>
)}
</div>
)}
<Divider className="my-5" />
{/* Retrieval Method Config */}
<div>
{!isModelAndRetrievalConfigDisabled
? (
<div className="mb-1">
<div className="system-md-semibold mb-0.5 text-text-secondary">
{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}
</div>
<div className="body-xs-regular text-text-tertiary">
<a
target="_blank"
rel="noopener noreferrer"
href={docLink('/guides/knowledge-base/create-knowledge-and-upload-documents')}
className="text-text-accent"
>
{t('form.retrievalSetting.learnMore', { ns: 'datasetSettings' })}
</a>
{t('form.retrievalSetting.longDescription', { ns: 'datasetSettings' })}
</div>
</div>
)
: (
<div className={cn('system-md-semibold mb-0.5 text-text-secondary', 'flex items-center justify-between')}>
<div>{t('form.retrievalSetting.title', { ns: 'datasetSettings' })}</div>
</div>
)}
<div>
{getIndexingTechnique() === IndexingType.QUALIFIED
? (
<RetrievalMethodConfig
disabled={isModelAndRetrievalConfigDisabled}
value={retrievalConfig}
onChange={onRetrievalConfigChange}
showMultiModalTip={showMultiModalTip}
/>
)
: (
<EconomicalRetrievalMethodConfig
disabled={isModelAndRetrievalConfigDisabled}
value={retrievalConfig}
onChange={onRetrievalConfigChange}
/>
)}
</div>
</div>
</>
)
}

View File

@@ -1,191 +0,0 @@
'use client'
import type { FC } from 'react'
import type { ParentChildConfig } from '../hooks'
import type { ParentMode, PreProcessingRule } from '@/models/datasets'
import { RiSearchEyeLine } from '@remixicon/react'
import Image from 'next/image'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Checkbox from '@/app/components/base/checkbox'
import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card'
import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg'
import BlueEffect from '../../assets/option-card-effect-blue.svg'
import s from '../index.module.css'
import { DelimiterInput, MaxLengthInput } from './inputs'
import { OptionCard } from './option-card'
type TextLabelProps = {
children: React.ReactNode
}
const TextLabel: FC<TextLabelProps> = ({ children }) => {
return <label className="system-sm-semibold text-text-secondary">{children}</label>
}
type ParentChildOptionsProps = {
// State
parentChildConfig: ParentChildConfig
rules: PreProcessingRule[]
currentDocForm: ChunkingMode
// Flags
isActive: boolean
isInUpload: boolean
isNotUploadInEmptyDataset: boolean
// Actions
onDocFormChange: (form: ChunkingMode) => void
onChunkForContextChange: (mode: ParentMode) => void
onParentDelimiterChange: (value: string) => void
onParentMaxLengthChange: (value: number) => void
onChildDelimiterChange: (value: string) => void
onChildMaxLengthChange: (value: number) => void
onRuleToggle: (id: string) => void
onPreview: () => void
onReset: () => void
}
export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
parentChildConfig,
rules,
currentDocForm: _currentDocForm,
isActive,
isInUpload,
isNotUploadInEmptyDataset,
onDocFormChange,
onChunkForContextChange,
onParentDelimiterChange,
onParentMaxLengthChange,
onChildDelimiterChange,
onChildMaxLengthChange,
onRuleToggle,
onPreview,
onReset,
}) => {
const { t } = useTranslation()
const getRuleName = (key: string): string => {
const ruleNameMap: Record<string, string> = {
remove_extra_spaces: t('stepTwo.removeExtraSpaces', { ns: 'datasetCreation' }),
remove_urls_emails: t('stepTwo.removeUrlEmails', { ns: 'datasetCreation' }),
remove_stopwords: t('stepTwo.removeStopwords', { ns: 'datasetCreation' }),
}
return ruleNameMap[key] ?? key
}
return (
<OptionCard
title={t('stepTwo.parentChild', { ns: 'datasetCreation' })}
icon={<ParentChildChunk className="h-[20px] w-[20px]" />}
effectImg={BlueEffect.src}
className="text-util-colors-blue-light-blue-light-500"
activeHeaderClassName="bg-dataset-option-card-blue-gradient"
description={t('stepTwo.parentChildTip', { ns: 'datasetCreation' })}
isActive={isActive}
onSwitched={() => onDocFormChange(ChunkingMode.parentChild)}
actions={(
<>
<Button variant="secondary-accent" onClick={onPreview}>
<RiSearchEyeLine className="mr-0.5 h-4 w-4" />
{t('stepTwo.previewChunk', { ns: 'datasetCreation' })}
</Button>
<Button variant="ghost" onClick={onReset}>
{t('stepTwo.reset', { ns: 'datasetCreation' })}
</Button>
</>
)}
noHighlight={isInUpload && isNotUploadInEmptyDataset}
>
<div className="flex flex-col gap-4">
{/* Parent chunk for context */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.parentChunkForContext', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<RadioCard
className="mt-1"
icon={<Image src={Note} alt="" />}
title={t('stepTwo.paragraph', { ns: 'datasetCreation' })}
description={t('stepTwo.paragraphTip', { ns: 'datasetCreation' })}
isChosen={parentChildConfig.chunkForContext === 'paragraph'}
onChosen={() => onChunkForContextChange('paragraph')}
chosenConfig={(
<div className="flex gap-3">
<DelimiterInput
value={parentChildConfig.parent.delimiter}
tooltip={t('stepTwo.parentChildDelimiterTip', { ns: 'datasetCreation' })!}
onChange={e => onParentDelimiterChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={parentChildConfig.parent.maxLength}
onChange={onParentMaxLengthChange}
/>
</div>
)}
/>
<RadioCard
className="mt-2"
icon={<Image src={FileList} alt="" />}
title={t('stepTwo.fullDoc', { ns: 'datasetCreation' })}
description={t('stepTwo.fullDocTip', { ns: 'datasetCreation' })}
onChosen={() => onChunkForContextChange('full-doc')}
isChosen={parentChildConfig.chunkForContext === 'full-doc'}
/>
</div>
{/* Child chunk for retrieval */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.childChunkForRetrieval', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1 flex gap-3">
<DelimiterInput
value={parentChildConfig.child.delimiter}
tooltip={t('stepTwo.parentChildChunkDelimiterTip', { ns: 'datasetCreation' })!}
onChange={e => onChildDelimiterChange(e.target.value)}
/>
<MaxLengthInput
unit="characters"
value={parentChildConfig.child.maxLength}
onChange={onChildMaxLengthChange}
/>
</div>
</div>
{/* Rules */}
<div>
<div className="flex items-center gap-x-2">
<div className="inline-flex shrink-0">
<TextLabel>{t('stepTwo.rules', { ns: 'datasetCreation' })}</TextLabel>
</div>
<Divider className="grow" bgStyle="gradient" />
</div>
<div className="mt-1">
{rules.map(rule => (
<div
key={rule.id}
className={s.ruleItem}
onClick={() => onRuleToggle(rule.id)}
>
<Checkbox checked={rule.enabled} />
<label className="system-sm-regular ml-2 cursor-pointer text-text-secondary">
{getRuleName(rule.id)}
</label>
</div>
))}
</div>
</div>
</div>
</OptionCard>
)
}

View File

@@ -1,171 +0,0 @@
'use client'
import type { FC } from 'react'
import type { ParentChildConfig } from '../hooks'
import type { DataSourceType, FileIndexingEstimateResponse } from '@/models/datasets'
import { RiSearchEyeLine } from '@remixicon/react'
import { noop } from 'es-toolkit/function'
import { useTranslation } from 'react-i18next'
import Badge from '@/app/components/base/badge'
import FloatRightContainer from '@/app/components/base/float-right-container'
import { SkeletonContainer, SkeletonPoint, SkeletonRectangle, SkeletonRow } from '@/app/components/base/skeleton'
import { FULL_DOC_PREVIEW_LENGTH } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import { cn } from '@/utils/classnames'
import { ChunkContainer, QAPreview } from '../../../chunk'
import PreviewDocumentPicker from '../../../common/document-picker/preview-document-picker'
import { PreviewSlice } from '../../../formatted-text/flavours/preview-slice'
import { FormattedText } from '../../../formatted-text/formatted'
import PreviewContainer from '../../../preview/container'
import { PreviewHeader } from '../../../preview/header'
type PreviewPanelProps = {
// State
isMobile: boolean
dataSourceType: DataSourceType
currentDocForm: ChunkingMode
estimate?: FileIndexingEstimateResponse
parentChildConfig: ParentChildConfig
isSetting?: boolean
// Picker
pickerFiles: Array<{ id: string, name: string, extension: string }>
pickerValue: { id: string, name: string, extension: string }
// Mutation state
isIdle: boolean
isPending: boolean
// Actions
onPickerChange: (selected: { id: string, name: string }) => void
}
export const PreviewPanel: FC<PreviewPanelProps> = ({
isMobile,
dataSourceType: _dataSourceType,
currentDocForm,
estimate,
parentChildConfig,
isSetting,
pickerFiles,
pickerValue,
isIdle,
isPending,
onPickerChange,
}) => {
const { t } = useTranslation()
return (
<FloatRightContainer isMobile={isMobile} isOpen={true} onClose={noop} footer={null}>
<PreviewContainer
header={(
<PreviewHeader title={t('stepTwo.preview', { ns: 'datasetCreation' })}>
<div className="flex items-center gap-1">
<PreviewDocumentPicker
files={pickerFiles as Array<Required<{ id: string, name: string, extension: string }>>}
onChange={onPickerChange}
value={isSetting ? pickerFiles[0] : pickerValue}
/>
{currentDocForm !== ChunkingMode.qa && (
<Badge
text={t('stepTwo.previewChunkCount', {
ns: 'datasetCreation',
count: estimate?.total_segments || 0,
}) as string}
/>
)}
</div>
</PreviewHeader>
)}
className={cn('relative flex h-full w-1/2 shrink-0 p-4 pr-0', isMobile && 'w-full max-w-[524px]')}
mainClassName="space-y-6"
>
{/* QA Preview */}
{currentDocForm === ChunkingMode.qa && estimate?.qa_preview && (
estimate.qa_preview.map((item, index) => (
<ChunkContainer
key={item.question}
label={`Chunk-${index + 1}`}
characterCount={item.question.length + item.answer.length}
>
<QAPreview qa={item} />
</ChunkContainer>
))
)}
{/* Text Preview */}
{currentDocForm === ChunkingMode.text && estimate?.preview && (
estimate.preview.map((item, index) => (
<ChunkContainer
key={item.content}
label={`Chunk-${index + 1}`}
characterCount={item.content.length}
>
{item.content}
</ChunkContainer>
))
)}
{/* Parent-Child Preview */}
{currentDocForm === ChunkingMode.parentChild && estimate?.preview && (
estimate.preview.map((item, index) => {
const indexForLabel = index + 1
const childChunks = parentChildConfig.chunkForContext === 'full-doc'
? item.child_chunks.slice(0, FULL_DOC_PREVIEW_LENGTH)
: item.child_chunks
return (
<ChunkContainer
key={item.content}
label={`Chunk-${indexForLabel}`}
characterCount={item.content.length}
>
<FormattedText>
{childChunks.map((child, childIndex) => {
const childIndexForLabel = childIndex + 1
return (
<PreviewSlice
key={`C-${childIndexForLabel}-${child}`}
label={`C-${childIndexForLabel}`}
text={child}
tooltip={`Child-chunk-${childIndexForLabel} · ${child.length} Characters`}
labelInnerClassName="text-[10px] font-semibold align-bottom leading-7"
dividerClassName="leading-7"
/>
)
})}
</FormattedText>
</ChunkContainer>
)
})
)}
{/* Idle State */}
{isIdle && (
<div className="flex h-full w-full items-center justify-center">
<div className="flex flex-col items-center justify-center gap-3">
<RiSearchEyeLine className="size-10 text-text-empty-state-icon" />
<p className="text-sm text-text-tertiary">
{t('stepTwo.previewChunkTip', { ns: 'datasetCreation' })}
</p>
</div>
</div>
)}
{/* Loading State */}
{isPending && (
<div className="space-y-6">
{Array.from({ length: 10 }, (_, i) => (
<SkeletonContainer key={i}>
<SkeletonRow>
<SkeletonRectangle className="w-20" />
<SkeletonPoint />
<SkeletonRectangle className="w-24" />
</SkeletonRow>
<SkeletonRectangle className="w-full" />
<SkeletonRectangle className="w-full" />
<SkeletonRectangle className="w-[422px]" />
</SkeletonContainer>
))}
</div>
)}
</PreviewContainer>
</FloatRightContainer>
)
}

View File

@@ -1,58 +0,0 @@
'use client'
import type { FC } from 'react'
import { RiArrowLeftLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
type StepTwoFooterProps = {
isSetting?: boolean
isCreating: boolean
onPrevious: () => void
onCreate: () => void
onCancel?: () => void
}
export const StepTwoFooter: FC<StepTwoFooterProps> = ({
isSetting,
isCreating,
onPrevious,
onCreate,
onCancel,
}) => {
const { t } = useTranslation()
if (!isSetting) {
return (
<div className="mt-8 flex items-center py-2">
<Button onClick={onPrevious}>
<RiArrowLeftLine className="mr-1 h-4 w-4" />
{t('stepTwo.previousStep', { ns: 'datasetCreation' })}
</Button>
<Button
className="ml-auto"
loading={isCreating}
variant="primary"
onClick={onCreate}
>
{t('stepTwo.nextStep', { ns: 'datasetCreation' })}
</Button>
</div>
)
}
return (
<div className="mt-8 flex items-center py-2">
<Button
loading={isCreating}
variant="primary"
onClick={onCreate}
>
{t('stepTwo.save', { ns: 'datasetCreation' })}
</Button>
<Button className="ml-2" onClick={onCancel}>
{t('stepTwo.cancel', { ns: 'datasetCreation' })}
</Button>
</div>
)
}

View File

@@ -1,14 +0,0 @@
export { useDocumentCreation } from './use-document-creation'
export type { DocumentCreation, ValidationParams } from './use-document-creation'
export { IndexingType, useIndexingConfig } from './use-indexing-config'
export type { IndexingConfig } from './use-indexing-config'
export { useIndexingEstimate } from './use-indexing-estimate'
export type { IndexingEstimate } from './use-indexing-estimate'
export { usePreviewState } from './use-preview-state'
export type { PreviewState } from './use-preview-state'
export { DEFAULT_MAXIMUM_CHUNK_LENGTH, DEFAULT_OVERLAP, DEFAULT_SEGMENT_IDENTIFIER, defaultParentChildConfig, MAXIMUM_CHUNK_TOKEN_LENGTH, useSegmentationState } from './use-segmentation-state'
export type { ParentChildConfig, SegmentationState } from './use-segmentation-state'

View File

@@ -1,279 +0,0 @@
import type { DefaultModel, Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { NotionPage } from '@/models/common'
import type {
ChunkingMode,
CrawlOptions,
CrawlResultItem,
CreateDocumentReq,
createDocumentResponse,
CustomFile,
FullDocumentDetail,
ProcessRule,
} from '@/models/datasets'
import type { RetrievalConfig, RETRIEVE_METHOD } from '@/types/app'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Toast from '@/app/components/base/toast'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { DataSourceProvider } from '@/models/common'
import {
DataSourceType,
} from '@/models/datasets'
import { getNotionInfo, getWebsiteInfo, useCreateDocument, useCreateFirstDocument } from '@/service/knowledge/use-create-dataset'
import { useInvalidDatasetList } from '@/service/knowledge/use-dataset'
import { IndexingType } from './use-indexing-config'
import { MAXIMUM_CHUNK_TOKEN_LENGTH } from './use-segmentation-state'
export type UseDocumentCreationOptions = {
datasetId?: string
isSetting?: boolean
documentDetail?: FullDocumentDetail
dataSourceType: DataSourceType
files: CustomFile[]
notionPages: NotionPage[]
notionCredentialId: string
websitePages: CrawlResultItem[]
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
// Callbacks
onStepChange?: (delta: number) => void
updateIndexingTypeCache?: (type: string) => void
updateResultCache?: (res: createDocumentResponse) => void
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
onSave?: () => void
mutateDatasetRes?: () => void
}
export type ValidationParams = {
segmentationType: string
maxChunkLength: number
limitMaxChunkLength: number
overlap: number
indexType: IndexingType
embeddingModel: DefaultModel
rerankModelList: Model[]
retrievalConfig: RetrievalConfig
}
export const useDocumentCreation = (options: UseDocumentCreationOptions) => {
const { t } = useTranslation()
const {
datasetId,
isSetting,
documentDetail,
dataSourceType,
files,
notionPages,
notionCredentialId,
websitePages,
crawlOptions,
websiteCrawlProvider = DataSourceProvider.jinaReader,
websiteCrawlJobId = '',
onStepChange,
updateIndexingTypeCache,
updateResultCache,
updateRetrievalMethodCache,
onSave,
mutateDatasetRes,
} = options
const createFirstDocumentMutation = useCreateFirstDocument()
const createDocumentMutation = useCreateDocument(datasetId!)
const invalidDatasetList = useInvalidDatasetList()
const isCreating = createFirstDocumentMutation.isPending || createDocumentMutation.isPending
// Validate creation params
const validateParams = useCallback((params: ValidationParams): boolean => {
const {
segmentationType,
maxChunkLength,
limitMaxChunkLength,
overlap,
indexType,
embeddingModel,
rerankModelList,
retrievalConfig,
} = params
if (segmentationType === 'general' && overlap > maxChunkLength) {
Toast.notify({ type: 'error', message: t('stepTwo.overlapCheck', { ns: 'datasetCreation' }) })
return false
}
if (segmentationType === 'general' && maxChunkLength > limitMaxChunkLength) {
Toast.notify({
type: 'error',
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: limitMaxChunkLength }),
})
return false
}
if (!isSetting) {
if (indexType === IndexingType.QUALIFIED && (!embeddingModel.model || !embeddingModel.provider)) {
Toast.notify({
type: 'error',
message: t('datasetConfig.embeddingModelRequired', { ns: 'appDebug' }),
})
return false
}
if (!isReRankModelSelected({
rerankModelList,
retrievalConfig,
indexMethod: indexType,
})) {
Toast.notify({ type: 'error', message: t('datasetConfig.rerankModelRequired', { ns: 'appDebug' }) })
return false
}
}
return true
}, [t, isSetting])
// Build creation params
const buildCreationParams = useCallback((
currentDocForm: ChunkingMode,
docLanguage: string,
processRule: ProcessRule,
retrievalConfig: RetrievalConfig,
embeddingModel: DefaultModel,
indexingTechnique: string,
): CreateDocumentReq | null => {
if (isSetting) {
return {
original_document_id: documentDetail?.id,
doc_form: currentDocForm,
doc_language: docLanguage,
process_rule: processRule,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
indexing_technique: indexingTechnique,
} as CreateDocumentReq
}
const params: CreateDocumentReq = {
data_source: {
type: dataSourceType,
info_list: {
data_source_type: dataSourceType,
},
},
indexing_technique: indexingTechnique,
process_rule: processRule,
doc_form: currentDocForm,
doc_language: docLanguage,
retrieval_model: retrievalConfig,
embedding_model: embeddingModel.model,
embedding_model_provider: embeddingModel.provider,
} as CreateDocumentReq
// Add data source specific info
if (dataSourceType === DataSourceType.FILE) {
params.data_source!.info_list.file_info_list = {
file_ids: files.map(file => file.id || '').filter(Boolean),
}
}
if (dataSourceType === DataSourceType.NOTION)
params.data_source!.info_list.notion_info_list = getNotionInfo(notionPages, notionCredentialId)
if (dataSourceType === DataSourceType.WEB) {
params.data_source!.info_list.website_info_list = getWebsiteInfo({
websiteCrawlProvider,
websiteCrawlJobId,
websitePages,
crawlOptions,
})
}
return params
}, [
isSetting,
documentDetail,
dataSourceType,
files,
notionPages,
notionCredentialId,
websitePages,
websiteCrawlProvider,
websiteCrawlJobId,
crawlOptions,
])
// Execute creation
const executeCreation = useCallback(async (
params: CreateDocumentReq,
indexType: IndexingType,
retrievalConfig: RetrievalConfig,
) => {
if (!datasetId) {
await createFirstDocumentMutation.mutateAsync(params, {
onSuccess(data) {
updateIndexingTypeCache?.(indexType)
updateResultCache?.(data)
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
},
})
}
else {
await createDocumentMutation.mutateAsync(params, {
onSuccess(data) {
updateIndexingTypeCache?.(indexType)
updateResultCache?.(data)
updateRetrievalMethodCache?.(retrievalConfig.search_method as RETRIEVE_METHOD)
},
})
}
mutateDatasetRes?.()
invalidDatasetList()
trackEvent('create_datasets', {
data_source_type: dataSourceType,
indexing_technique: indexType,
})
onStepChange?.(+1)
if (isSetting)
onSave?.()
}, [
datasetId,
createFirstDocumentMutation,
createDocumentMutation,
updateIndexingTypeCache,
updateResultCache,
updateRetrievalMethodCache,
mutateDatasetRes,
invalidDatasetList,
dataSourceType,
onStepChange,
isSetting,
onSave,
])
// Validate preview params
const validatePreviewParams = useCallback((maxChunkLength: number): boolean => {
if (maxChunkLength > MAXIMUM_CHUNK_TOKEN_LENGTH) {
Toast.notify({
type: 'error',
message: t('stepTwo.maxLengthCheck', { ns: 'datasetCreation', limit: MAXIMUM_CHUNK_TOKEN_LENGTH }),
})
return false
}
return true
}, [t])
return {
isCreating,
validateParams,
buildCreationParams,
executeCreation,
validatePreviewParams,
}
}
export type DocumentCreation = ReturnType<typeof useDocumentCreation>

View File

@@ -1,143 +0,0 @@
import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import { useEffect, useMemo, useState } from 'react'
import { checkShowMultiModalTip } from '@/app/components/datasets/settings/utils'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { RETRIEVE_METHOD } from '@/types/app'
export enum IndexingType {
QUALIFIED = 'high_quality',
ECONOMICAL = 'economy',
}
const DEFAULT_RETRIEVAL_CONFIG: RetrievalConfig = {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
}
export type UseIndexingConfigOptions = {
initialIndexType?: IndexingType
initialEmbeddingModel?: DefaultModel
initialRetrievalConfig?: RetrievalConfig
isAPIKeySet: boolean
hasSetIndexType: boolean
}
export const useIndexingConfig = (options: UseIndexingConfigOptions) => {
const {
initialIndexType,
initialEmbeddingModel,
initialRetrievalConfig,
isAPIKeySet,
hasSetIndexType,
} = options
// Rerank model
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
// Embedding model list
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding)
// Index type state
const [indexType, setIndexType] = useState<IndexingType>(() => {
if (initialIndexType)
return initialIndexType
return isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL
})
// Embedding model state
const [embeddingModel, setEmbeddingModel] = useState<DefaultModel>(
initialEmbeddingModel ?? {
provider: defaultEmbeddingModel?.provider.provider || '',
model: defaultEmbeddingModel?.model || '',
},
)
// Retrieval config state
const [retrievalConfig, setRetrievalConfig] = useState<RetrievalConfig>(
initialRetrievalConfig ?? DEFAULT_RETRIEVAL_CONFIG,
)
// Sync retrieval config with rerank model when available
useEffect(() => {
if (initialRetrievalConfig)
return
setRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: !!isRerankDefaultModelValid,
reranking_model: {
reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider.provider ?? '' : '',
reranking_model_name: isRerankDefaultModelValid ? rerankDefaultModel?.model ?? '' : '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
})
}, [rerankDefaultModel, isRerankDefaultModelValid, initialRetrievalConfig])
// Sync index type with props
useEffect(() => {
if (initialIndexType)
setIndexType(initialIndexType)
else
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, initialIndexType])
// Show multimodal tip
const showMultiModalTip = useMemo(() => {
return checkShowMultiModalTip({
embeddingModel,
rerankingEnable: retrievalConfig.reranking_enable,
rerankModel: {
rerankingProviderName: retrievalConfig.reranking_model.reranking_provider_name,
rerankingModelName: retrievalConfig.reranking_model.reranking_model_name,
},
indexMethod: indexType,
embeddingModelList,
rerankModelList,
})
}, [embeddingModel, retrievalConfig, indexType, embeddingModelList, rerankModelList])
// Get effective indexing technique
const getIndexingTechnique = () => initialIndexType || indexType
return {
// Index type
indexType,
setIndexType,
hasSetIndexType,
getIndexingTechnique,
// Embedding model
embeddingModel,
setEmbeddingModel,
embeddingModelList,
defaultEmbeddingModel,
// Retrieval config
retrievalConfig,
setRetrievalConfig,
rerankModelList,
rerankDefaultModel,
isRerankDefaultModelValid,
// Computed
showMultiModalTip,
}
}
export type IndexingConfig = ReturnType<typeof useIndexingConfig>

View File

@@ -1,123 +0,0 @@
import type { IndexingType } from './use-indexing-config'
import type { NotionPage } from '@/models/common'
import type { ChunkingMode, CrawlOptions, CrawlResultItem, CustomFile, ProcessRule } from '@/models/datasets'
import { useCallback } from 'react'
import { DataSourceProvider } from '@/models/common'
import { DataSourceType } from '@/models/datasets'
import {
useFetchFileIndexingEstimateForFile,
useFetchFileIndexingEstimateForNotion,
useFetchFileIndexingEstimateForWeb,
} from '@/service/knowledge/use-create-dataset'
export type UseIndexingEstimateOptions = {
dataSourceType: DataSourceType
datasetId?: string
// Document settings
currentDocForm: ChunkingMode
docLanguage: string
// File data source
files: CustomFile[]
previewFileName?: string
// Notion data source
previewNotionPage: NotionPage
notionCredentialId: string
// Website data source
previewWebsitePage: CrawlResultItem
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
// Processing
indexingTechnique: IndexingType
processRule: ProcessRule
}
export const useIndexingEstimate = (options: UseIndexingEstimateOptions) => {
const {
dataSourceType,
datasetId,
currentDocForm,
docLanguage,
files,
previewFileName,
previewNotionPage,
notionCredentialId,
previewWebsitePage,
crawlOptions,
websiteCrawlProvider,
websiteCrawlJobId,
indexingTechnique,
processRule,
} = options
// File indexing estimate
const fileQuery = useFetchFileIndexingEstimateForFile({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.FILE,
files: previewFileName
? [files.find(file => file.name === previewFileName)!]
: files,
indexingTechnique,
processRule,
dataset_id: datasetId!,
})
// Notion indexing estimate
const notionQuery = useFetchFileIndexingEstimateForNotion({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.NOTION,
notionPages: [previewNotionPage],
indexingTechnique,
processRule,
dataset_id: datasetId || '',
credential_id: notionCredentialId,
})
// Website indexing estimate
const websiteQuery = useFetchFileIndexingEstimateForWeb({
docForm: currentDocForm,
docLanguage,
dataSourceType: DataSourceType.WEB,
websitePages: [previewWebsitePage],
crawlOptions,
websiteCrawlProvider: websiteCrawlProvider ?? DataSourceProvider.jinaReader,
websiteCrawlJobId: websiteCrawlJobId ?? '',
indexingTechnique,
processRule,
dataset_id: datasetId || '',
})
// Get current mutation based on data source type
const getCurrentMutation = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
return fileQuery
if (dataSourceType === DataSourceType.NOTION)
return notionQuery
return websiteQuery
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
const currentMutation = getCurrentMutation()
// Trigger estimate fetch
const fetchEstimate = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
fileQuery.mutate()
else if (dataSourceType === DataSourceType.NOTION)
notionQuery.mutate()
else
websiteQuery.mutate()
}, [dataSourceType, fileQuery, notionQuery, websiteQuery])
return {
currentMutation,
estimate: currentMutation.data,
isIdle: currentMutation.isIdle,
isPending: currentMutation.isPending,
fetchEstimate,
reset: currentMutation.reset,
}
}
export type IndexingEstimate = ReturnType<typeof useIndexingEstimate>

View File

@@ -1,127 +0,0 @@
import type { NotionPage } from '@/models/common'
import type { CrawlResultItem, CustomFile, DocumentItem, FullDocumentDetail } from '@/models/datasets'
import { useCallback, useState } from 'react'
import { DataSourceType } from '@/models/datasets'
export type UsePreviewStateOptions = {
dataSourceType: DataSourceType
files: CustomFile[]
notionPages: NotionPage[]
websitePages: CrawlResultItem[]
documentDetail?: FullDocumentDetail
datasetId?: string
}
export const usePreviewState = (options: UsePreviewStateOptions) => {
const {
dataSourceType,
files,
notionPages,
websitePages,
documentDetail,
datasetId,
} = options
// File preview state
const [previewFile, setPreviewFile] = useState<DocumentItem>(
(datasetId && documentDetail)
? documentDetail.file
: files[0],
)
// Notion page preview state
const [previewNotionPage, setPreviewNotionPage] = useState<NotionPage>(
(datasetId && documentDetail)
? documentDetail.notion_page
: notionPages[0],
)
// Website page preview state
const [previewWebsitePage, setPreviewWebsitePage] = useState<CrawlResultItem>(
(datasetId && documentDetail)
? documentDetail.website_page
: websitePages[0],
)
// Get preview items for document picker based on data source type
const getPreviewPickerItems = useCallback(() => {
if (dataSourceType === DataSourceType.FILE) {
return files as Array<Required<CustomFile>>
}
if (dataSourceType === DataSourceType.NOTION) {
return notionPages.map(page => ({
id: page.page_id,
name: page.page_name,
extension: 'md',
}))
}
if (dataSourceType === DataSourceType.WEB) {
return websitePages.map(page => ({
id: page.source_url,
name: page.title,
extension: 'md',
}))
}
return []
}, [dataSourceType, files, notionPages, websitePages])
// Get current preview value for picker
const getPreviewPickerValue = useCallback(() => {
if (dataSourceType === DataSourceType.FILE) {
return previewFile as Required<CustomFile>
}
if (dataSourceType === DataSourceType.NOTION) {
return {
id: previewNotionPage?.page_id || '',
name: previewNotionPage?.page_name || '',
extension: 'md',
}
}
if (dataSourceType === DataSourceType.WEB) {
return {
id: previewWebsitePage?.source_url || '',
name: previewWebsitePage?.title || '',
extension: 'md',
}
}
return { id: '', name: '', extension: '' }
}, [dataSourceType, previewFile, previewNotionPage, previewWebsitePage])
// Handle preview change
const handlePreviewChange = useCallback((selected: { id: string, name: string }) => {
if (dataSourceType === DataSourceType.FILE) {
setPreviewFile(selected as DocumentItem)
}
else if (dataSourceType === DataSourceType.NOTION) {
const selectedPage = notionPages.find(page => page.page_id === selected.id)
if (selectedPage)
setPreviewNotionPage(selectedPage)
}
else if (dataSourceType === DataSourceType.WEB) {
const selectedPage = websitePages.find(page => page.source_url === selected.id)
if (selectedPage)
setPreviewWebsitePage(selectedPage)
}
}, [dataSourceType, notionPages, websitePages])
return {
// File preview
previewFile,
setPreviewFile,
// Notion preview
previewNotionPage,
setPreviewNotionPage,
// Website preview
previewWebsitePage,
setPreviewWebsitePage,
// Picker helpers
getPreviewPickerItems,
getPreviewPickerValue,
handlePreviewChange,
}
}
export type PreviewState = ReturnType<typeof usePreviewState>

View File

@@ -1,222 +0,0 @@
import type { ParentMode, PreProcessingRule, ProcessRule, Rules } from '@/models/datasets'
import { useCallback, useState } from 'react'
import { ChunkingMode, ProcessMode } from '@/models/datasets'
import escape from './escape'
import unescape from './unescape'
// Constants
export const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
export const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024
export const DEFAULT_OVERLAP = 50
export const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(
globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000',
10,
)
export type ParentChildConfig = {
chunkForContext: ParentMode
parent: {
delimiter: string
maxLength: number
}
child: {
delimiter: string
maxLength: number
}
}
export const defaultParentChildConfig: ParentChildConfig = {
chunkForContext: 'paragraph',
parent: {
delimiter: '\\n\\n',
maxLength: 1024,
},
child: {
delimiter: '\\n',
maxLength: 512,
},
}
export type UseSegmentationStateOptions = {
initialSegmentationType?: ProcessMode
}
export const useSegmentationState = (options: UseSegmentationStateOptions = {}) => {
const { initialSegmentationType } = options
// Segmentation type (general or parent-child)
const [segmentationType, setSegmentationType] = useState<ProcessMode>(
initialSegmentationType ?? ProcessMode.general,
)
// General chunking settings
const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER)
const [maxChunkLength, setMaxChunkLength] = useState(DEFAULT_MAXIMUM_CHUNK_LENGTH)
const [limitMaxChunkLength, setLimitMaxChunkLength] = useState(MAXIMUM_CHUNK_TOKEN_LENGTH)
const [overlap, setOverlap] = useState(DEFAULT_OVERLAP)
// Pre-processing rules
const [rules, setRules] = useState<PreProcessingRule[]>([])
const [defaultConfig, setDefaultConfig] = useState<Rules>()
// Parent-child config
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
// Escaped segment identifier setter
const setSegmentIdentifier = useCallback((value: string, canEmpty?: boolean) => {
if (value) {
doSetSegmentIdentifier(escape(value))
}
else {
doSetSegmentIdentifier(canEmpty ? '' : DEFAULT_SEGMENT_IDENTIFIER)
}
}, [])
// Rule toggle handler
const toggleRule = useCallback((id: string) => {
setRules(prev => prev.map(rule =>
rule.id === id ? { ...rule, enabled: !rule.enabled } : rule,
))
}, [])
// Reset to defaults
const resetToDefaults = useCallback(() => {
if (defaultConfig) {
setSegmentIdentifier(defaultConfig.segmentation.separator)
setMaxChunkLength(defaultConfig.segmentation.max_tokens)
setOverlap(defaultConfig.segmentation.chunk_overlap!)
setRules(defaultConfig.pre_processing_rules)
}
setParentChildConfig(defaultParentChildConfig)
}, [defaultConfig, setSegmentIdentifier])
// Apply config from document detail
const applyConfigFromRules = useCallback((rulesConfig: Rules, isHierarchical: boolean) => {
const separator = rulesConfig.segmentation.separator
const max = rulesConfig.segmentation.max_tokens
const chunkOverlap = rulesConfig.segmentation.chunk_overlap
setSegmentIdentifier(separator)
setMaxChunkLength(max)
setOverlap(chunkOverlap!)
setRules(rulesConfig.pre_processing_rules)
setDefaultConfig(rulesConfig)
if (isHierarchical) {
setParentChildConfig({
chunkForContext: rulesConfig.parent_mode || 'paragraph',
parent: {
delimiter: escape(rulesConfig.segmentation.separator),
maxLength: rulesConfig.segmentation.max_tokens,
},
child: {
delimiter: escape(rulesConfig.subchunk_segmentation!.separator),
maxLength: rulesConfig.subchunk_segmentation!.max_tokens,
},
})
}
}, [setSegmentIdentifier])
// Get process rule for API
const getProcessRule = useCallback((docForm: ChunkingMode): ProcessRule => {
if (docForm === ChunkingMode.parentChild) {
return {
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(parentChildConfig.parent.delimiter),
max_tokens: parentChildConfig.parent.maxLength,
},
parent_mode: parentChildConfig.chunkForContext,
subchunk_segmentation: {
separator: unescape(parentChildConfig.child.delimiter),
max_tokens: parentChildConfig.child.maxLength,
},
},
mode: 'hierarchical',
} as ProcessRule
}
return {
rules: {
pre_processing_rules: rules,
segmentation: {
separator: unescape(segmentIdentifier),
max_tokens: maxChunkLength,
chunk_overlap: overlap,
},
},
mode: segmentationType,
} as ProcessRule
}, [rules, parentChildConfig, segmentIdentifier, maxChunkLength, overlap, segmentationType])
// Update parent config field
const updateParentConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
setParentChildConfig((prev) => {
let newValue: string | number
if (field === 'delimiter')
newValue = value ? escape(value as string) : ''
else
newValue = value
return {
...prev,
parent: { ...prev.parent, [field]: newValue },
}
})
}, [])
// Update child config field
const updateChildConfig = useCallback((field: 'delimiter' | 'maxLength', value: string | number) => {
setParentChildConfig((prev) => {
let newValue: string | number
if (field === 'delimiter')
newValue = value ? escape(value as string) : ''
else
newValue = value
return {
...prev,
child: { ...prev.child, [field]: newValue },
}
})
}, [])
// Set chunk for context mode
const setChunkForContext = useCallback((mode: ParentMode) => {
setParentChildConfig(prev => ({ ...prev, chunkForContext: mode }))
}, [])
return {
// General chunking state
segmentationType,
setSegmentationType,
segmentIdentifier,
setSegmentIdentifier,
maxChunkLength,
setMaxChunkLength,
limitMaxChunkLength,
setLimitMaxChunkLength,
overlap,
setOverlap,
// Rules
rules,
setRules,
defaultConfig,
setDefaultConfig,
toggleRule,
// Parent-child config
parentChildConfig,
setParentChildConfig,
updateParentConfig,
updateChildConfig,
setChunkForContext,
// Actions
resetToDefaults,
applyConfigFromRules,
getProcessRule,
}
}
export type SegmentationState = ReturnType<typeof useSegmentationState>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,28 +0,0 @@
import type { IndexingType } from './hooks'
import type { DataSourceProvider, NotionPage } from '@/models/common'
import type { CrawlOptions, CrawlResultItem, createDocumentResponse, CustomFile, DataSourceType, FullDocumentDetail } from '@/models/datasets'
import type { RETRIEVE_METHOD } from '@/types/app'
export type StepTwoProps = {
isSetting?: boolean
documentDetail?: FullDocumentDetail
isAPIKeySet: boolean
onSetting: () => void
datasetId?: string
indexingType?: IndexingType
retrievalMethod?: string
dataSourceType: DataSourceType
files: CustomFile[]
notionPages?: NotionPage[]
notionCredentialId: string
websitePages?: CrawlResultItem[]
crawlOptions?: CrawlOptions
websiteCrawlProvider?: DataSourceProvider
websiteCrawlJobId?: string
onStepChange?: (delta: number) => void
updateIndexingTypeCache?: (type: string) => void
updateRetrievalMethodCache?: (method: RETRIEVE_METHOD | '') => void
updateResultCache?: (res: createDocumentResponse) => void
onSave?: () => void
onCancel?: () => void
}

View File

@@ -1,81 +0,0 @@
import type { ActivePluginType } from './constants'
import type { PluginsSort, SearchParamsFromCollection } from './types'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useQueryState } from 'nuqs'
import { useCallback } from 'react'
import { DEFAULT_SORT, PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceSearchParamsParsers } from './search-params'
const marketplaceSortAtom = atom<PluginsSort>(DEFAULT_SORT)
export function useMarketplaceSort() {
return useAtom(marketplaceSortAtom)
}
export function useMarketplaceSortValue() {
return useAtomValue(marketplaceSortAtom)
}
export function useSetMarketplaceSort() {
return useSetAtom(marketplaceSortAtom)
}
/**
* Preserve the state for marketplace
*/
export const preserveSearchStateInQueryAtom = atom<boolean>(false)
const searchPluginTextAtom = atom<string>('')
const activePluginTypeAtom = atom<ActivePluginType>('all')
const filterPluginTagsAtom = atom<string[]>([])
export function useSearchPluginText() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('q', marketplaceSearchParamsParsers.q)
const atomState = useAtom(searchPluginTextAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
export function useActivePluginType() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('category', marketplaceSearchParamsParsers.category)
const atomState = useAtom(activePluginTypeAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
export function useFilterPluginTags() {
const preserveSearchStateInQuery = useAtomValue(preserveSearchStateInQueryAtom)
const queryState = useQueryState('tags', marketplaceSearchParamsParsers.tags)
const atomState = useAtom(filterPluginTagsAtom)
return preserveSearchStateInQuery ? queryState : atomState
}
/**
* Not all categories have collections, so we need to
* force the search mode for those categories.
*/
export const searchModeAtom = atom<true | null>(null)
export function useMarketplaceSearchMode() {
const [searchPluginText] = useSearchPluginText()
const [filterPluginTags] = useFilterPluginTags()
const [activePluginType] = useActivePluginType()
const searchMode = useAtomValue(searchModeAtom)
const isSearchMode = !!searchPluginText
|| filterPluginTags.length > 0
|| (searchMode ?? (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(activePluginType)))
return isSearchMode
}
export function useMarketplaceMoreClick() {
const [,setQ] = useSearchPluginText()
const setSort = useSetAtom(marketplaceSortAtom)
const setSearchMode = useSetAtom(searchModeAtom)
return useCallback((searchParams?: SearchParamsFromCollection) => {
if (!searchParams)
return
setQ(searchParams?.query || '')
setSort({
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
})
setSearchMode(true)
}, [setQ, setSort, setSearchMode])
}

View File

@@ -1,30 +1,6 @@
import { PluginCategoryEnum } from '../types'
export const DEFAULT_SORT = {
sortBy: 'install_count',
sortOrder: 'DESC',
}
export const SCROLL_BOTTOM_THRESHOLD = 100
export const PLUGIN_TYPE_SEARCH_MAP = {
all: 'all',
model: PluginCategoryEnum.model,
tool: PluginCategoryEnum.tool,
agent: PluginCategoryEnum.agent,
extension: PluginCategoryEnum.extension,
datasource: PluginCategoryEnum.datasource,
trigger: PluginCategoryEnum.trigger,
bundle: 'bundle',
} as const
type ValueOf<T> = T[keyof T]
export type ActivePluginType = ValueOf<typeof PLUGIN_TYPE_SEARCH_MAP>
export const PLUGIN_CATEGORY_WITH_COLLECTIONS = new Set<ActivePluginType>(
[
PLUGIN_TYPE_SEARCH_MAP.all,
PLUGIN_TYPE_SEARCH_MAP.tool,
],
)

View File

@@ -0,0 +1,332 @@
'use client'
import type {
ReactNode,
} from 'react'
import type { TagKey } from '../constants'
import type { Plugin } from '../types'
import type {
MarketplaceCollection,
PluginsSort,
SearchParams,
SearchParamsFromCollection,
} from './types'
import { debounce } from 'es-toolkit/compat'
import { noop } from 'es-toolkit/function'
import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import {
createContext,
useContextSelector,
} from 'use-context-selector'
import { useMarketplaceFilters } from '@/hooks/use-query-params'
import { useInstalledPluginList } from '@/service/use-plugins'
import {
getValidCategoryKeys,
getValidTagKeys,
} from '../utils'
import { DEFAULT_SORT } from './constants'
import {
useMarketplaceCollectionsAndPlugins,
useMarketplaceContainerScroll,
useMarketplacePlugins,
} from './hooks'
import { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch'
import {
getMarketplaceListCondition,
getMarketplaceListFilterType,
} from './utils'
export type MarketplaceContextValue = {
searchPluginText: string
handleSearchPluginTextChange: (text: string) => void
filterPluginTags: string[]
handleFilterPluginTagsChange: (tags: string[]) => void
activePluginType: string
handleActivePluginTypeChange: (type: string) => void
page: number
handlePageChange: () => void
plugins?: Plugin[]
pluginsTotal?: number
resetPlugins: () => void
sort: PluginsSort
handleSortChange: (sort: PluginsSort) => void
handleQueryPlugins: () => void
handleMoreClick: (searchParams: SearchParamsFromCollection) => void
marketplaceCollectionsFromClient?: MarketplaceCollection[]
setMarketplaceCollectionsFromClient: (collections: MarketplaceCollection[]) => void
marketplaceCollectionPluginsMapFromClient?: Record<string, Plugin[]>
setMarketplaceCollectionPluginsMapFromClient: (map: Record<string, Plugin[]>) => void
isLoading: boolean
isSuccessCollections: boolean
}
export const MarketplaceContext = createContext<MarketplaceContextValue>({
searchPluginText: '',
handleSearchPluginTextChange: noop,
filterPluginTags: [],
handleFilterPluginTagsChange: noop,
activePluginType: 'all',
handleActivePluginTypeChange: noop,
page: 1,
handlePageChange: noop,
plugins: undefined,
pluginsTotal: 0,
resetPlugins: noop,
sort: DEFAULT_SORT,
handleSortChange: noop,
handleQueryPlugins: noop,
handleMoreClick: noop,
marketplaceCollectionsFromClient: [],
setMarketplaceCollectionsFromClient: noop,
marketplaceCollectionPluginsMapFromClient: {},
setMarketplaceCollectionPluginsMapFromClient: noop,
isLoading: false,
isSuccessCollections: false,
})
type MarketplaceContextProviderProps = {
children: ReactNode
searchParams?: SearchParams
shouldExclude?: boolean
scrollContainerId?: string
showSearchParams?: boolean
}
export function useMarketplaceContext(selector: (value: MarketplaceContextValue) => any) {
return useContextSelector(MarketplaceContext, selector)
}
export const MarketplaceContextProvider = ({
children,
searchParams,
shouldExclude,
scrollContainerId,
showSearchParams,
}: MarketplaceContextProviderProps) => {
// Use nuqs hook for URL-based filter state
const [urlFilters, setUrlFilters] = useMarketplaceFilters()
const { data, isSuccess } = useInstalledPluginList(!shouldExclude)
const exclude = useMemo(() => {
if (shouldExclude)
return data?.plugins.map(plugin => plugin.plugin_id)
}, [data?.plugins, shouldExclude])
// Initialize from URL params (legacy support) or use nuqs state
const queryFromSearchParams = searchParams?.q || urlFilters.q
const tagsFromSearchParams = getValidTagKeys(urlFilters.tags as TagKey[])
const hasValidTags = !!tagsFromSearchParams.length
const hasValidCategory = getValidCategoryKeys(urlFilters.category)
const categoryFromSearchParams = hasValidCategory || PLUGIN_TYPE_SEARCH_MAP.all
const [searchPluginText, setSearchPluginText] = useState(queryFromSearchParams)
const searchPluginTextRef = useRef(searchPluginText)
const [filterPluginTags, setFilterPluginTags] = useState<string[]>(tagsFromSearchParams)
const filterPluginTagsRef = useRef(filterPluginTags)
const [activePluginType, setActivePluginType] = useState(categoryFromSearchParams)
const activePluginTypeRef = useRef(activePluginType)
const [sort, setSort] = useState(DEFAULT_SORT)
const sortRef = useRef(sort)
const {
marketplaceCollections: marketplaceCollectionsFromClient,
setMarketplaceCollections: setMarketplaceCollectionsFromClient,
marketplaceCollectionPluginsMap: marketplaceCollectionPluginsMapFromClient,
setMarketplaceCollectionPluginsMap: setMarketplaceCollectionPluginsMapFromClient,
queryMarketplaceCollectionsAndPlugins,
isLoading,
isSuccess: isSuccessCollections,
} = useMarketplaceCollectionsAndPlugins()
const {
plugins,
total: pluginsTotal,
resetPlugins,
queryPlugins,
queryPluginsWithDebounced,
cancelQueryPluginsWithDebounced,
isLoading: isPluginsLoading,
fetchNextPage: fetchNextPluginsPage,
hasNextPage: hasNextPluginsPage,
page: pluginsPage,
} = useMarketplacePlugins()
const page = Math.max(pluginsPage || 0, 1)
useEffect(() => {
if (queryFromSearchParams || hasValidTags || hasValidCategory) {
queryPlugins({
query: queryFromSearchParams,
category: hasValidCategory,
tags: hasValidTags ? tagsFromSearchParams : [],
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
else {
if (shouldExclude && isSuccess) {
queryMarketplaceCollectionsAndPlugins({
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
}
}, [queryPlugins, queryMarketplaceCollectionsAndPlugins, isSuccess, exclude])
const handleQueryMarketplaceCollectionsAndPlugins = useCallback(() => {
queryMarketplaceCollectionsAndPlugins({
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
condition: getMarketplaceListCondition(activePluginTypeRef.current),
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
resetPlugins()
}, [exclude, queryMarketplaceCollectionsAndPlugins, resetPlugins])
const applyUrlFilters = useCallback(() => {
if (!showSearchParams)
return
const nextFilters = {
q: searchPluginTextRef.current,
category: activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
}
const categoryChanged = urlFilters.category !== nextFilters.category
setUrlFilters(nextFilters, {
history: categoryChanged ? 'push' : 'replace',
})
}, [setUrlFilters, showSearchParams, urlFilters.category])
const debouncedUpdateSearchParams = useMemo(() => debounce(() => {
applyUrlFilters()
}, 500), [applyUrlFilters])
const handleUpdateSearchParams = useCallback((debounced?: boolean) => {
if (debounced) {
debouncedUpdateSearchParams()
}
else {
applyUrlFilters()
}
}, [applyUrlFilters, debouncedUpdateSearchParams])
const handleQueryPlugins = useCallback((debounced?: boolean) => {
handleUpdateSearchParams(debounced)
if (debounced) {
queryPluginsWithDebounced({
query: searchPluginTextRef.current,
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
else {
queryPlugins({
query: searchPluginTextRef.current,
category: activePluginTypeRef.current === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginTypeRef.current,
tags: filterPluginTagsRef.current,
sortBy: sortRef.current.sortBy,
sortOrder: sortRef.current.sortOrder,
exclude,
type: getMarketplaceListFilterType(activePluginTypeRef.current),
})
}
}, [exclude, queryPluginsWithDebounced, queryPlugins, handleUpdateSearchParams])
const handleQuery = useCallback((debounced?: boolean) => {
if (!searchPluginTextRef.current && !filterPluginTagsRef.current.length) {
handleUpdateSearchParams(debounced)
cancelQueryPluginsWithDebounced()
handleQueryMarketplaceCollectionsAndPlugins()
return
}
handleQueryPlugins(debounced)
}, [handleQueryMarketplaceCollectionsAndPlugins, handleQueryPlugins, cancelQueryPluginsWithDebounced, handleUpdateSearchParams])
const handleSearchPluginTextChange = useCallback((text: string) => {
setSearchPluginText(text)
searchPluginTextRef.current = text
handleQuery(true)
}, [handleQuery])
const handleFilterPluginTagsChange = useCallback((tags: string[]) => {
setFilterPluginTags(tags)
filterPluginTagsRef.current = tags
handleQuery()
}, [handleQuery])
const handleActivePluginTypeChange = useCallback((type: string) => {
setActivePluginType(type)
activePluginTypeRef.current = type
handleQuery()
}, [handleQuery])
const handleSortChange = useCallback((sort: PluginsSort) => {
setSort(sort)
sortRef.current = sort
handleQueryPlugins()
}, [handleQueryPlugins])
const handlePageChange = useCallback(() => {
if (hasNextPluginsPage)
fetchNextPluginsPage()
}, [fetchNextPluginsPage, hasNextPluginsPage])
const handleMoreClick = useCallback((searchParams: SearchParamsFromCollection) => {
setSearchPluginText(searchParams?.query || '')
searchPluginTextRef.current = searchParams?.query || ''
setSort({
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
})
sortRef.current = {
sortBy: searchParams?.sort_by || DEFAULT_SORT.sortBy,
sortOrder: searchParams?.sort_order || DEFAULT_SORT.sortOrder,
}
handleQueryPlugins()
}, [handleQueryPlugins])
useMarketplaceContainerScroll(handlePageChange, scrollContainerId)
return (
<MarketplaceContext.Provider
value={{
searchPluginText,
handleSearchPluginTextChange,
filterPluginTags,
handleFilterPluginTagsChange,
activePluginType,
handleActivePluginTypeChange,
page,
handlePageChange,
plugins,
pluginsTotal,
resetPlugins,
sort,
handleSortChange,
handleQueryPlugins,
handleMoreClick,
marketplaceCollectionsFromClient,
setMarketplaceCollectionsFromClient,
marketplaceCollectionPluginsMapFromClient,
setMarketplaceCollectionPluginsMapFromClient,
isLoading: isLoading || isPluginsLoading,
isSuccessCollections,
}}
>
{children}
</MarketplaceContext.Provider>
)
}

View File

@@ -26,9 +26,6 @@ import {
getMarketplacePluginsByCollectionId,
} from './utils'
/**
* @deprecated Use useMarketplaceCollectionsAndPlugins from query.ts instead
*/
export const useMarketplaceCollectionsAndPlugins = () => {
const [queryParams, setQueryParams] = useState<CollectionsAndPluginsSearchParams>()
const [marketplaceCollectionsOverride, setMarketplaceCollections] = useState<MarketplaceCollection[]>()
@@ -92,9 +89,7 @@ export const useMarketplacePluginsByCollectionId = (
isSuccess,
}
}
/**
* @deprecated Use useMarketplacePlugins from query.ts instead
*/
export const useMarketplacePlugins = () => {
const queryClient = useQueryClient()
const [queryParams, setQueryParams] = useState<PluginsSearchParams>()

View File

@@ -1,15 +0,0 @@
'use client'
import { useHydrateAtoms } from 'jotai/utils'
import { preserveSearchStateInQueryAtom } from './atoms'
export function HydrateMarketplaceAtoms({
preserveSearchStateInQuery,
children,
}: {
preserveSearchStateInQuery: boolean
children: React.ReactNode
}) {
useHydrateAtoms([[preserveSearchStateInQueryAtom, preserveSearchStateInQuery]])
return <>{children}</>
}

View File

@@ -1,45 +0,0 @@
import type { SearchParams } from 'nuqs'
import { dehydrate, HydrationBoundary } from '@tanstack/react-query'
import { createLoader } from 'nuqs/server'
import { getQueryClientServer } from '@/context/query-client-server'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS } from './constants'
import { marketplaceKeys } from './query'
import { marketplaceSearchParamsParsers } from './search-params'
import { getCollectionsParams, getMarketplaceCollectionsAndPlugins } from './utils'
// The server side logic should move to marketplace's codebase so that we can get rid of Next.js
async function getDehydratedState(searchParams?: Promise<SearchParams>) {
if (!searchParams) {
return
}
const loadSearchParams = createLoader(marketplaceSearchParamsParsers)
const params = await loadSearchParams(searchParams)
if (!PLUGIN_CATEGORY_WITH_COLLECTIONS.has(params.category)) {
return
}
const queryClient = getQueryClientServer()
await queryClient.prefetchQuery({
queryKey: marketplaceKeys.collections(getCollectionsParams(params.category)),
queryFn: () => getMarketplaceCollectionsAndPlugins(getCollectionsParams(params.category)),
})
return dehydrate(queryClient)
}
export async function HydrateQueryClient({
searchParams,
children,
}: {
searchParams: Promise<SearchParams> | undefined
children: React.ReactNode
}) {
const dehydratedState = await getDehydratedState(searchParams)
return (
<HydrationBoundary state={dehydratedState}>
{children}
</HydrationBoundary>
)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,39 +1,55 @@
import type { SearchParams } from 'nuqs'
import type { MarketplaceCollection, SearchParams } from './types'
import type { Plugin } from '@/app/components/plugins/types'
import { TanstackQueryInitializer } from '@/context/query-client'
import { MarketplaceContextProvider } from './context'
import Description from './description'
import { HydrateMarketplaceAtoms } from './hydration-client'
import { HydrateQueryClient } from './hydration-server'
import ListWrapper from './list/list-wrapper'
import StickySearchAndSwitchWrapper from './sticky-search-and-switch-wrapper'
import { getMarketplaceCollectionsAndPlugins } from './utils'
type MarketplaceProps = {
showInstallButton?: boolean
shouldExclude?: boolean
searchParams?: SearchParams
pluginTypeSwitchClassName?: string
/**
* Pass the search params from the request to prefetch data on the server
* and preserve the search params in the URL.
*/
searchParams?: Promise<SearchParams>
scrollContainerId?: string
showSearchParams?: boolean
}
const Marketplace = async ({
showInstallButton = true,
pluginTypeSwitchClassName,
shouldExclude,
searchParams,
pluginTypeSwitchClassName,
scrollContainerId,
showSearchParams = true,
}: MarketplaceProps) => {
let marketplaceCollections: MarketplaceCollection[] = []
let marketplaceCollectionPluginsMap: Record<string, Plugin[]> = {}
if (!shouldExclude) {
const marketplaceCollectionsAndPluginsData = await getMarketplaceCollectionsAndPlugins()
marketplaceCollections = marketplaceCollectionsAndPluginsData.marketplaceCollections
marketplaceCollectionPluginsMap = marketplaceCollectionsAndPluginsData.marketplaceCollectionPluginsMap
}
return (
<TanstackQueryInitializer>
<HydrateQueryClient searchParams={searchParams}>
<HydrateMarketplaceAtoms preserveSearchStateInQuery={!!searchParams}>
<Description />
<StickySearchAndSwitchWrapper
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
/>
<ListWrapper
showInstallButton={showInstallButton}
/>
</HydrateMarketplaceAtoms>
</HydrateQueryClient>
<MarketplaceContextProvider
searchParams={searchParams}
shouldExclude={shouldExclude}
scrollContainerId={scrollContainerId}
showSearchParams={showSearchParams}
>
<Description />
<StickySearchAndSwitchWrapper
pluginTypeSwitchClassName={pluginTypeSwitchClassName}
showSearchParams={showSearchParams}
/>
<ListWrapper
marketplaceCollections={marketplaceCollections}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap}
showInstallButton={showInstallButton}
/>
</MarketplaceContextProvider>
</TanstackQueryInitializer>
)
}

View File

@@ -1,6 +1,6 @@
import type { MarketplaceCollection, SearchParamsFromCollection } from '../types'
import type { Plugin } from '@/app/components/plugins/types'
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import List from './index'
@@ -30,27 +30,23 @@ vi.mock('#i18n', () => ({
useLocale: () => 'en-US',
}))
// Mock marketplace state hooks with controllable values
const { mockMarketplaceData, mockMoreClick } = vi.hoisted(() => {
return {
mockMarketplaceData: {
plugins: undefined as Plugin[] | undefined,
pluginsTotal: 0,
marketplaceCollections: undefined as MarketplaceCollection[] | undefined,
marketplaceCollectionPluginsMap: undefined as Record<string, Plugin[]> | undefined,
isLoading: false,
page: 1,
},
mockMoreClick: vi.fn(),
}
})
// Mock useMarketplaceContext with controllable values
const mockContextValues = {
plugins: undefined as Plugin[] | undefined,
pluginsTotal: 0,
marketplaceCollectionsFromClient: undefined as MarketplaceCollection[] | undefined,
marketplaceCollectionPluginsMapFromClient: undefined as Record<string, Plugin[]> | undefined,
isLoading: false,
isSuccessCollections: false,
handleQueryPlugins: vi.fn(),
searchPluginText: '',
filterPluginTags: [] as string[],
page: 1,
handleMoreClick: vi.fn(),
}
vi.mock('../state', () => ({
useMarketplaceData: () => mockMarketplaceData,
}))
vi.mock('../atoms', () => ({
useMarketplaceMoreClick: () => mockMoreClick,
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
}))
// Mock useLocale context
@@ -582,7 +578,7 @@ describe('ListWithCollection', () => {
// View More Button Tests
// ================================
describe('View More Button', () => {
it('should render View More button when collection is searchable', () => {
it('should render View More button when collection is searchable and onMoreClick is provided', () => {
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
@@ -591,12 +587,14 @@ describe('ListWithCollection', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
@@ -611,24 +609,24 @@ describe('ListWithCollection', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
expect(screen.queryByText('View More')).not.toBeInTheDocument()
})
it('should call moreClick hook with search_params when View More is clicked', () => {
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
it('should not render View More button when onMoreClick is not provided', () => {
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: searchParams,
})]
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
@@ -639,13 +637,38 @@ describe('ListWithCollection', () => {
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={undefined}
/>,
)
expect(screen.queryByText('View More')).not.toBeInTheDocument()
})
it('should call onMoreClick with search_params when View More is clicked', () => {
const searchParams: SearchParamsFromCollection = { query: 'test-query', sort_by: 'install_count' }
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: searchParams,
})]
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)
fireEvent.click(screen.getByText('View More'))
expect(mockMoreClick).toHaveBeenCalledTimes(1)
expect(mockMoreClick).toHaveBeenCalledWith(searchParams)
expect(onMoreClick).toHaveBeenCalledTimes(1)
expect(onMoreClick).toHaveBeenCalledWith(searchParams)
})
})
@@ -779,15 +802,24 @@ describe('ListWithCollection', () => {
// ListWrapper Component Tests
// ================================
describe('ListWrapper', () => {
const defaultProps = {
marketplaceCollections: [] as MarketplaceCollection[],
marketplaceCollectionPluginsMap: {} as Record<string, Plugin[]>,
showInstallButton: false,
}
beforeEach(() => {
vi.clearAllMocks()
// Reset mock data
mockMarketplaceData.plugins = undefined
mockMarketplaceData.pluginsTotal = 0
mockMarketplaceData.marketplaceCollections = undefined
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
// Reset context values
mockContextValues.plugins = undefined
mockContextValues.pluginsTotal = 0
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
mockContextValues.isLoading = false
mockContextValues.isSuccessCollections = false
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
mockContextValues.page = 1
})
// ================================
@@ -795,32 +827,32 @@ describe('ListWrapper', () => {
// ================================
describe('Rendering', () => {
it('should render without crashing', () => {
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(document.body).toBeInTheDocument()
})
it('should render with scrollbarGutter style', () => {
const { container } = render(<ListWrapper />)
const { container } = render(<ListWrapper {...defaultProps} />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveStyle({ scrollbarGutter: 'stable' })
})
it('should render Loading component when isLoading is true and page is 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 1
mockContextValues.isLoading = true
mockContextValues.page = 1
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
})
it('should not render Loading component when page > 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockContextValues.isLoading = true
mockContextValues.page = 2
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
})
@@ -831,26 +863,26 @@ describe('ListWrapper', () => {
// ================================
describe('Plugins Header', () => {
it('should render plugins result count when plugins are present', () => {
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 5
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 5
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
})
it('should render SortDropdown when plugins are present', () => {
mockMarketplaceData.plugins = createMockPluginList(1)
mockContextValues.plugins = createMockPluginList(1)
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('sort-dropdown')).toBeInTheDocument()
})
it('should not render plugins header when plugins is undefined', () => {
mockMarketplaceData.plugins = undefined
mockContextValues.plugins = undefined
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.queryByTestId('sort-dropdown')).not.toBeInTheDocument()
})
@@ -860,60 +892,197 @@ describe('ListWrapper', () => {
// List Rendering Logic Tests
// ================================
describe('List Rendering Logic', () => {
it('should render collections when not loading', () => {
mockMarketplaceData.isLoading = false
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
it('should render List when not loading', () => {
mockContextValues.isLoading = false
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should render List when loading but page > 1', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
mockContextValues.isLoading = true
mockContextValues.page = 2
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should use client collections when available', () => {
const serverCollections = createMockCollectionList(1)
serverCollections[0].label = { 'en-US': 'Server Collection' }
const clientCollections = createMockCollectionList(1)
clientCollections[0].label = { 'en-US': 'Client Collection' }
const serverPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const clientPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = clientCollections
mockContextValues.marketplaceCollectionPluginsMapFromClient = clientPluginsMap
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={serverCollections}
marketplaceCollectionPluginsMap={serverPluginsMap}
/>,
)
expect(screen.getByText('Client Collection')).toBeInTheDocument()
expect(screen.queryByText('Server Collection')).not.toBeInTheDocument()
})
it('should use server collections when client collections are not available', () => {
const serverCollections = createMockCollectionList(1)
serverCollections[0].label = { 'en-US': 'Server Collection' }
const serverPluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={serverCollections}
marketplaceCollectionPluginsMap={serverPluginsMap}
/>,
)
expect(screen.getByText('Server Collection')).toBeInTheDocument()
})
})
// ================================
// Data Integration Tests
// Context Integration Tests
// ================================
describe('Data Integration', () => {
it('should pass plugins from state to List', () => {
mockMarketplaceData.plugins = createMockPluginList(2)
describe('Context Integration', () => {
it('should pass plugins from context to List', () => {
const plugins = createMockPluginList(2)
mockContextValues.plugins = plugins
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByTestId('card-plugin-0')).toBeInTheDocument()
expect(screen.getByTestId('card-plugin-1')).toBeInTheDocument()
})
it('should show View More button and call moreClick hook', () => {
mockMarketplaceData.marketplaceCollections = [createMockCollection({
it('should pass handleMoreClick from context to List', () => {
const mockHandleMoreClick = vi.fn()
mockContextValues.handleMoreClick = mockHandleMoreClick
const collections = [createMockCollection({
name: 'collection-0',
searchable: true,
search_params: { query: 'test' },
})]
mockMarketplaceData.marketplaceCollectionPluginsMap = {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
render(<ListWrapper />)
render(
<ListWrapper
{...defaultProps}
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
/>,
)
fireEvent.click(screen.getByText('View More'))
expect(mockMoreClick).toHaveBeenCalled()
expect(mockHandleMoreClick).toHaveBeenCalled()
})
})
// ================================
// Effect Tests (handleQueryPlugins)
// ================================
describe('handleQueryPlugins Effect', () => {
it('should call handleQueryPlugins when conditions are met', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when client collections exist', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = createMockCollectionList(1)
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
// Give time for effect to run
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when search text exists', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = 'search text'
mockContextValues.filterPluginTags = []
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
it('should not call handleQueryPlugins when filter tags exist', async () => {
const mockHandleQueryPlugins = vi.fn()
mockContextValues.handleQueryPlugins = mockHandleQueryPlugins
mockContextValues.isSuccessCollections = true
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = ['tag1']
render(<ListWrapper {...defaultProps} />)
await waitFor(() => {
expect(mockHandleQueryPlugins).not.toHaveBeenCalled()
})
})
})
@@ -921,32 +1090,32 @@ describe('ListWrapper', () => {
// Edge Cases Tests
// ================================
describe('Edge Cases', () => {
it('should handle empty plugins array', () => {
mockMarketplaceData.plugins = []
mockMarketplaceData.pluginsTotal = 0
it('should handle empty plugins array from context', () => {
mockContextValues.plugins = []
mockContextValues.pluginsTotal = 0
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
})
it('should handle large pluginsTotal', () => {
mockMarketplaceData.plugins = createMockPluginList(10)
mockMarketplaceData.pluginsTotal = 10000
mockContextValues.plugins = createMockPluginList(10)
mockContextValues.pluginsTotal = 10000
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
expect(screen.getByText('10000 plugins found')).toBeInTheDocument()
})
it('should handle both loading and has plugins', () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 50
mockContextValues.isLoading = true
mockContextValues.page = 2
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 50
render(<ListWrapper />)
render(<ListWrapper {...defaultProps} />)
// Should show plugins header and list
expect(screen.getByText('50 plugins found')).toBeInTheDocument()
@@ -1259,72 +1428,106 @@ describe('CardWrapper (via List integration)', () => {
describe('Combined Workflows', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMarketplaceData.plugins = undefined
mockMarketplaceData.pluginsTotal = 0
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
mockMarketplaceData.marketplaceCollections = undefined
mockMarketplaceData.marketplaceCollectionPluginsMap = undefined
mockContextValues.plugins = undefined
mockContextValues.pluginsTotal = 0
mockContextValues.isLoading = false
mockContextValues.page = 1
mockContextValues.marketplaceCollectionsFromClient = undefined
mockContextValues.marketplaceCollectionPluginsMapFromClient = undefined
})
it('should transition from loading to showing collections', async () => {
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 1
mockContextValues.isLoading = true
mockContextValues.page = 1
const { rerender } = render(<ListWrapper />)
const { rerender } = render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByTestId('loading-component')).toBeInTheDocument()
// Simulate loading complete
mockMarketplaceData.isLoading = false
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
mockContextValues.isLoading = false
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = collections
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
rerender(<ListWrapper />)
rerender(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.queryByTestId('loading-component')).not.toBeInTheDocument()
expect(screen.getByText('Collection 0')).toBeInTheDocument()
})
it('should transition from collections to search results', async () => {
mockMarketplaceData.marketplaceCollections = createMockCollectionList(1)
mockMarketplaceData.marketplaceCollectionPluginsMap = {
const collections = createMockCollectionList(1)
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
mockContextValues.marketplaceCollectionsFromClient = collections
mockContextValues.marketplaceCollectionPluginsMapFromClient = pluginsMap
const { rerender } = render(<ListWrapper />)
const { rerender } = render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByText('Collection 0')).toBeInTheDocument()
// Simulate search results
mockMarketplaceData.plugins = createMockPluginList(5)
mockMarketplaceData.pluginsTotal = 5
mockContextValues.plugins = createMockPluginList(5)
mockContextValues.pluginsTotal = 5
rerender(<ListWrapper />)
rerender(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.queryByText('Collection 0')).not.toBeInTheDocument()
expect(screen.getByText('5 plugins found')).toBeInTheDocument()
})
it('should handle empty search results', () => {
mockMarketplaceData.plugins = []
mockMarketplaceData.pluginsTotal = 0
mockContextValues.plugins = []
mockContextValues.pluginsTotal = 0
render(<ListWrapper />)
render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
expect(screen.getByTestId('empty-component')).toBeInTheDocument()
expect(screen.getByText('0 plugins found')).toBeInTheDocument()
})
it('should support pagination (page > 1)', () => {
mockMarketplaceData.plugins = createMockPluginList(40)
mockMarketplaceData.pluginsTotal = 80
mockMarketplaceData.isLoading = true
mockMarketplaceData.page = 2
mockContextValues.plugins = createMockPluginList(40)
mockContextValues.pluginsTotal = 80
mockContextValues.isLoading = true
mockContextValues.page = 2
render(<ListWrapper />)
render(
<ListWrapper
marketplaceCollections={[]}
marketplaceCollectionPluginsMap={{}}
/>,
)
// Should show existing results while loading more
expect(screen.getByText('80 plugins found')).toBeInTheDocument()
@@ -1339,9 +1542,9 @@ describe('Combined Workflows', () => {
describe('Accessibility', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMarketplaceData.plugins = undefined
mockMarketplaceData.isLoading = false
mockMarketplaceData.page = 1
mockContextValues.plugins = undefined
mockContextValues.isLoading = false
mockContextValues.page = 1
})
it('should have semantic structure with collections', () => {
@@ -1370,11 +1573,13 @@ describe('Accessibility', () => {
const pluginsMap: Record<string, Plugin[]> = {
'collection-0': createMockPluginList(1),
}
const onMoreClick = vi.fn()
render(
<ListWithCollection
marketplaceCollections={collections}
marketplaceCollectionPluginsMap={pluginsMap}
onMoreClick={onMoreClick}
/>,
)

View File

@@ -13,6 +13,7 @@ type ListProps = {
showInstallButton?: boolean
cardContainerClassName?: string
cardRender?: (plugin: Plugin) => React.JSX.Element | null
onMoreClick?: () => void
emptyClassName?: string
}
const List = ({
@@ -22,6 +23,7 @@ const List = ({
showInstallButton,
cardContainerClassName,
cardRender,
onMoreClick,
emptyClassName,
}: ListProps) => {
return (
@@ -34,6 +36,7 @@ const List = ({
showInstallButton={showInstallButton}
cardContainerClassName={cardContainerClassName}
cardRender={cardRender}
onMoreClick={onMoreClick}
/>
)
}

View File

@@ -1,12 +1,12 @@
'use client'
import type { MarketplaceCollection } from '../types'
import type { SearchParamsFromCollection } from '@/app/components/plugins/marketplace/types'
import type { Plugin } from '@/app/components/plugins/types'
import { useLocale, useTranslation } from '#i18n'
import { RiArrowRightSLine } from '@remixicon/react'
import { getLanguage } from '@/i18n-config/language'
import { cn } from '@/utils/classnames'
import { useMarketplaceMoreClick } from '../atoms'
import CardWrapper from './card-wrapper'
type ListWithCollectionProps = {
@@ -15,6 +15,7 @@ type ListWithCollectionProps = {
showInstallButton?: boolean
cardContainerClassName?: string
cardRender?: (plugin: Plugin) => React.JSX.Element | null
onMoreClick?: (searchParams?: SearchParamsFromCollection) => void
}
const ListWithCollection = ({
marketplaceCollections,
@@ -22,10 +23,10 @@ const ListWithCollection = ({
showInstallButton,
cardContainerClassName,
cardRender,
onMoreClick,
}: ListWithCollectionProps) => {
const { t } = useTranslation()
const locale = useLocale()
const onMoreClick = useMarketplaceMoreClick()
return (
<>
@@ -43,10 +44,10 @@ const ListWithCollection = ({
<div className="system-xs-regular text-text-tertiary">{collection.description[getLanguage(locale)]}</div>
</div>
{
collection.searchable && (
collection.searchable && onMoreClick && (
<div
className="system-xs-medium flex cursor-pointer items-center text-text-accent "
onClick={() => onMoreClick(collection.search_params)}
onClick={() => onMoreClick?.(collection.search_params)}
>
{t('marketplace.viewMore', { ns: 'plugin' })}
<RiArrowRightSLine className="h-4 w-4" />

View File

@@ -1,26 +1,46 @@
'use client'
import type { Plugin } from '../../types'
import type { MarketplaceCollection } from '../types'
import { useTranslation } from '#i18n'
import { useEffect } from 'react'
import Loading from '@/app/components/base/loading'
import { useMarketplaceContext } from '../context'
import SortDropdown from '../sort-dropdown'
import { useMarketplaceData } from '../state'
import List from './index'
type ListWrapperProps = {
marketplaceCollections: MarketplaceCollection[]
marketplaceCollectionPluginsMap: Record<string, Plugin[]>
showInstallButton?: boolean
}
const ListWrapper = ({
marketplaceCollections,
marketplaceCollectionPluginsMap,
showInstallButton,
}: ListWrapperProps) => {
const { t } = useTranslation()
const plugins = useMarketplaceContext(v => v.plugins)
const pluginsTotal = useMarketplaceContext(v => v.pluginsTotal)
const marketplaceCollectionsFromClient = useMarketplaceContext(v => v.marketplaceCollectionsFromClient)
const marketplaceCollectionPluginsMapFromClient = useMarketplaceContext(v => v.marketplaceCollectionPluginsMapFromClient)
const isLoading = useMarketplaceContext(v => v.isLoading)
const isSuccessCollections = useMarketplaceContext(v => v.isSuccessCollections)
const handleQueryPlugins = useMarketplaceContext(v => v.handleQueryPlugins)
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
const page = useMarketplaceContext(v => v.page)
const handleMoreClick = useMarketplaceContext(v => v.handleMoreClick)
const {
plugins,
pluginsTotal,
marketplaceCollections,
marketplaceCollectionPluginsMap,
isLoading,
page,
} = useMarketplaceData()
useEffect(() => {
if (
!marketplaceCollectionsFromClient?.length
&& isSuccessCollections
&& !searchPluginText
&& !filterPluginTags.length
) {
handleQueryPlugins()
}
}, [handleQueryPlugins, marketplaceCollections, marketplaceCollectionsFromClient, isSuccessCollections, searchPluginText, filterPluginTags])
return (
<div
@@ -46,10 +66,11 @@ const ListWrapper = ({
{
(!isLoading || page > 1) && (
<List
marketplaceCollections={marketplaceCollections || []}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMap || {}}
marketplaceCollections={marketplaceCollectionsFromClient || marketplaceCollections}
marketplaceCollectionPluginsMap={marketplaceCollectionPluginsMapFromClient || marketplaceCollectionPluginsMap}
plugins={plugins}
showInstallButton={showInstallButton}
onMoreClick={handleMoreClick}
/>
)
}

View File

@@ -1,5 +1,4 @@
'use client'
import type { ActivePluginType } from './constants'
import { useTranslation } from '#i18n'
import {
RiArchive2Line,
@@ -9,27 +8,35 @@ import {
RiPuzzle2Line,
RiSpeakAiLine,
} from '@remixicon/react'
import { useSetAtom } from 'jotai'
import { useCallback, useEffect } from 'react'
import { Trigger as TriggerIcon } from '@/app/components/base/icons/src/vender/plugin'
import { cn } from '@/utils/classnames'
import { searchModeAtom, useActivePluginType } from './atoms'
import { PLUGIN_CATEGORY_WITH_COLLECTIONS, PLUGIN_TYPE_SEARCH_MAP } from './constants'
import { PluginCategoryEnum } from '../types'
import { useMarketplaceContext } from './context'
export const PLUGIN_TYPE_SEARCH_MAP = {
all: 'all',
model: PluginCategoryEnum.model,
tool: PluginCategoryEnum.tool,
agent: PluginCategoryEnum.agent,
extension: PluginCategoryEnum.extension,
datasource: PluginCategoryEnum.datasource,
trigger: PluginCategoryEnum.trigger,
bundle: 'bundle',
}
type PluginTypeSwitchProps = {
className?: string
showSearchParams?: boolean
}
const PluginTypeSwitch = ({
className,
showSearchParams,
}: PluginTypeSwitchProps) => {
const { t } = useTranslation()
const [activePluginType, handleActivePluginTypeChange] = useActivePluginType()
const setSearchMode = useSetAtom(searchModeAtom)
const activePluginType = useMarketplaceContext(s => s.activePluginType)
const handleActivePluginTypeChange = useMarketplaceContext(s => s.handleActivePluginTypeChange)
const options: Array<{
value: ActivePluginType
text: string
icon: React.ReactNode | null
}> = [
const options = [
{
value: PLUGIN_TYPE_SEARCH_MAP.all,
text: t('category.all', { ns: 'plugin' }),
@@ -72,6 +79,23 @@ const PluginTypeSwitch = ({
},
]
const handlePopState = useCallback(() => {
if (!showSearchParams)
return
// nuqs handles popstate automatically
const url = new URL(window.location.href)
const category = url.searchParams.get('category') || PLUGIN_TYPE_SEARCH_MAP.all
handleActivePluginTypeChange(category)
}, [showSearchParams, handleActivePluginTypeChange])
useEffect(() => {
// nuqs manages popstate internally, but we keep this for URL sync
window.addEventListener('popstate', handlePopState)
return () => {
window.removeEventListener('popstate', handlePopState)
}
}, [handlePopState])
return (
<div className={cn(
'flex shrink-0 items-center justify-center space-x-2 bg-background-body py-3',
@@ -88,9 +112,6 @@ const PluginTypeSwitch = ({
)}
onClick={() => {
handleActivePluginTypeChange(option.value)
if (PLUGIN_CATEGORY_WITH_COLLECTIONS.has(option.value)) {
setSearchMode(null)
}
}}
>
{option.icon}

View File

@@ -1,38 +0,0 @@
import type { CollectionsAndPluginsSearchParams, PluginsSearchParams } from './types'
import { useInfiniteQuery, useQuery } from '@tanstack/react-query'
import { getMarketplaceCollectionsAndPlugins, getMarketplacePlugins } from './utils'
// TODO: Avoid manual maintenance of query keys and better service management,
// https://github.com/langgenius/dify/issues/30342
export const marketplaceKeys = {
all: ['marketplace'] as const,
collections: (params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collections', params] as const,
collectionPlugins: (collectionId: string, params?: CollectionsAndPluginsSearchParams) => [...marketplaceKeys.all, 'collectionPlugins', collectionId, params] as const,
plugins: (params?: PluginsSearchParams) => [...marketplaceKeys.all, 'plugins', params] as const,
}
export function useMarketplaceCollectionsAndPlugins(
collectionsParams: CollectionsAndPluginsSearchParams,
) {
return useQuery({
queryKey: marketplaceKeys.collections(collectionsParams),
queryFn: ({ signal }) => getMarketplaceCollectionsAndPlugins(collectionsParams, { signal }),
})
}
export function useMarketplacePlugins(
queryParams: PluginsSearchParams | undefined,
) {
return useInfiniteQuery({
queryKey: marketplaceKeys.plugins(queryParams),
queryFn: ({ pageParam = 1, signal }) => getMarketplacePlugins(queryParams, pageParam, signal),
getNextPageParam: (lastPage) => {
const nextPage = lastPage.page + 1
const loaded = lastPage.page * lastPage.pageSize
return loaded < (lastPage.total || 0) ? nextPage : undefined
},
initialPageParam: 1,
enabled: queryParams !== undefined,
})
}

View File

@@ -26,19 +26,16 @@ vi.mock('#i18n', () => ({
}),
}))
// Mock marketplace state hooks
const { mockSearchPluginText, mockHandleSearchPluginTextChange, mockFilterPluginTags, mockHandleFilterPluginTagsChange } = vi.hoisted(() => {
return {
mockSearchPluginText: '',
mockHandleSearchPluginTextChange: vi.fn(),
mockFilterPluginTags: [] as string[],
mockHandleFilterPluginTagsChange: vi.fn(),
}
})
// Mock useMarketplaceContext
const mockContextValues = {
searchPluginText: '',
handleSearchPluginTextChange: vi.fn(),
filterPluginTags: [] as string[],
handleFilterPluginTagsChange: vi.fn(),
}
vi.mock('../atoms', () => ({
useSearchPluginText: () => [mockSearchPluginText, mockHandleSearchPluginTextChange],
useFilterPluginTags: () => [mockFilterPluginTags, mockHandleFilterPluginTagsChange],
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (v: typeof mockContextValues) => unknown) => selector(mockContextValues),
}))
// Mock useTags hook
@@ -433,6 +430,9 @@ describe('SearchBoxWrapper', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPortalOpenState = false
// Reset context values
mockContextValues.searchPluginText = ''
mockContextValues.filterPluginTags = []
})
describe('Rendering', () => {
@@ -456,14 +456,28 @@ describe('SearchBoxWrapper', () => {
})
})
describe('Hook Integration', () => {
describe('Context Integration', () => {
it('should use searchPluginText from context', () => {
mockContextValues.searchPluginText = 'context search'
render(<SearchBoxWrapper />)
expect(screen.getByDisplayValue('context search')).toBeInTheDocument()
})
it('should call handleSearchPluginTextChange when search changes', () => {
render(<SearchBoxWrapper />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: 'new search' } })
expect(mockHandleSearchPluginTextChange).toHaveBeenCalledWith('new search')
expect(mockContextValues.handleSearchPluginTextChange).toHaveBeenCalledWith('new search')
})
it('should use filterPluginTags from context', () => {
mockContextValues.filterPluginTags = ['agent', 'rag']
render(<SearchBoxWrapper />)
expect(screen.getByTestId('portal-elem')).toBeInTheDocument()
})
})

View File

@@ -1,13 +1,15 @@
'use client'
import { useTranslation } from '#i18n'
import { useFilterPluginTags, useSearchPluginText } from '../atoms'
import { useMarketplaceContext } from '../context'
import SearchBox from './index'
const SearchBoxWrapper = () => {
const { t } = useTranslation()
const [searchPluginText, handleSearchPluginTextChange] = useSearchPluginText()
const [filterPluginTags, handleFilterPluginTagsChange] = useFilterPluginTags()
const searchPluginText = useMarketplaceContext(v => v.searchPluginText)
const handleSearchPluginTextChange = useMarketplaceContext(v => v.handleSearchPluginTextChange)
const filterPluginTags = useMarketplaceContext(v => v.filterPluginTags)
const handleFilterPluginTagsChange = useMarketplaceContext(v => v.handleFilterPluginTagsChange)
return (
<SearchBox

View File

@@ -1,9 +0,0 @@
import type { ActivePluginType } from './constants'
import { parseAsArrayOf, parseAsString, parseAsStringEnum } from 'nuqs/server'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
export const marketplaceSearchParamsParsers = {
category: parseAsStringEnum<ActivePluginType>(Object.values(PLUGIN_TYPE_SEARCH_MAP) as ActivePluginType[]).withDefault('all').withOptions({ history: 'replace', clearOnDefault: false }),
q: parseAsString.withDefault('').withOptions({ history: 'replace' }),
tags: parseAsArrayOf(parseAsString).withDefault([]).withOptions({ history: 'replace' }),
}

View File

@@ -1,3 +1,4 @@
import type { MarketplaceContextValue } from '../context'
import { fireEvent, render, screen, within } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
@@ -27,12 +28,18 @@ vi.mock('#i18n', () => ({
}),
}))
// Mock marketplace atoms with controllable values
let mockSort: { sortBy: string, sortOrder: string } = { sortBy: 'install_count', sortOrder: 'DESC' }
// Mock marketplace context with controllable values
let mockSort = { sortBy: 'install_count', sortOrder: 'DESC' }
const mockHandleSortChange = vi.fn()
vi.mock('../atoms', () => ({
useMarketplaceSort: () => [mockSort, mockHandleSortChange],
vi.mock('../context', () => ({
useMarketplaceContext: (selector: (value: MarketplaceContextValue) => unknown) => {
const contextValue = {
sort: mockSort,
handleSortChange: mockHandleSortChange,
} as unknown as MarketplaceContextValue
return selector(contextValue)
},
}))
// Mock portal component with controllable open state

View File

@@ -10,7 +10,7 @@ import {
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { useMarketplaceSort } from '../atoms'
import { useMarketplaceContext } from '../context'
const SortDropdown = () => {
const { t } = useTranslation()
@@ -36,7 +36,8 @@ const SortDropdown = () => {
text: t('marketplace.sortOption.firstReleased', { ns: 'plugin' }),
},
]
const [sort, handleSortChange] = useMarketplaceSort()
const sort = useMarketplaceContext(v => v.sort)
const handleSortChange = useMarketplaceContext(v => v.handleSortChange)
const [open, setOpen] = useState(false)
const selectedOption = options.find(option => option.value === sort.sortBy && option.order === sort.sortOrder) ?? options[0]

View File

@@ -1,54 +0,0 @@
import type { PluginsSearchParams } from './types'
import { useDebounce } from 'ahooks'
import { useCallback, useMemo } from 'react'
import { useActivePluginType, useFilterPluginTags, useMarketplaceSearchMode, useMarketplaceSortValue, useSearchPluginText } from './atoms'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
import { useMarketplaceContainerScroll } from './hooks'
import { useMarketplaceCollectionsAndPlugins, useMarketplacePlugins } from './query'
import { getCollectionsParams, getMarketplaceListFilterType } from './utils'
export function useMarketplaceData() {
const [searchPluginTextOriginal] = useSearchPluginText()
const searchPluginText = useDebounce(searchPluginTextOriginal, { wait: 500 })
const [filterPluginTags] = useFilterPluginTags()
const [activePluginType] = useActivePluginType()
const collectionsQuery = useMarketplaceCollectionsAndPlugins(
getCollectionsParams(activePluginType),
)
const sort = useMarketplaceSortValue()
const isSearchMode = useMarketplaceSearchMode()
const queryParams = useMemo((): PluginsSearchParams | undefined => {
if (!isSearchMode)
return undefined
return {
query: searchPluginText,
category: activePluginType === PLUGIN_TYPE_SEARCH_MAP.all ? undefined : activePluginType,
tags: filterPluginTags,
sortBy: sort.sortBy,
sortOrder: sort.sortOrder,
type: getMarketplaceListFilterType(activePluginType),
}
}, [isSearchMode, searchPluginText, activePluginType, filterPluginTags, sort])
const pluginsQuery = useMarketplacePlugins(queryParams)
const { hasNextPage, fetchNextPage, isFetching } = pluginsQuery
const handlePageChange = useCallback(() => {
if (hasNextPage && !isFetching)
fetchNextPage()
}, [fetchNextPage, hasNextPage, isFetching])
// Scroll pagination
useMarketplaceContainerScroll(handlePageChange)
return {
marketplaceCollections: collectionsQuery.data?.marketplaceCollections,
marketplaceCollectionPluginsMap: collectionsQuery.data?.marketplaceCollectionPluginsMap,
plugins: pluginsQuery.data?.pages.flatMap(page => page.plugins),
pluginsTotal: pluginsQuery.data?.pages[0]?.total,
page: pluginsQuery.data?.pages.length || 1,
isLoading: collectionsQuery.isLoading || pluginsQuery.isLoading,
}
}

View File

@@ -6,10 +6,12 @@ import SearchBoxWrapper from './search-box/search-box-wrapper'
type StickySearchAndSwitchWrapperProps = {
pluginTypeSwitchClassName?: string
showSearchParams?: boolean
}
const StickySearchAndSwitchWrapper = ({
pluginTypeSwitchClassName,
showSearchParams,
}: StickySearchAndSwitchWrapperProps) => {
const hasCustomTopClass = pluginTypeSwitchClassName?.includes('top-')
@@ -22,7 +24,9 @@ const StickySearchAndSwitchWrapper = ({
)}
>
<SearchBoxWrapper />
<PluginTypeSwitch />
<PluginTypeSwitch
showSearchParams={showSearchParams}
/>
</div>
)
}

View File

@@ -1,19 +1,16 @@
import type { ActivePluginType } from './constants'
import type {
CollectionsAndPluginsSearchParams,
MarketplaceCollection,
PluginsSearchParams,
} from '@/app/components/plugins/marketplace/types'
import type { Plugin, PluginsFromMarketplaceResponse } from '@/app/components/plugins/types'
import type { Plugin } from '@/app/components/plugins/types'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import {
APP_VERSION,
IS_MARKETPLACE,
MARKETPLACE_API_PREFIX,
} from '@/config'
import { postMarketplace } from '@/service/base'
import { getMarketplaceUrl } from '@/utils/var'
import { PLUGIN_TYPE_SEARCH_MAP } from './constants'
import { PLUGIN_TYPE_SEARCH_MAP } from './plugin-type-switch'
type MarketplaceFetchOptions = {
signal?: AbortSignal
@@ -29,13 +26,12 @@ export const getPluginIconInMarketplace = (plugin: Plugin) => {
return `${MARKETPLACE_API_PREFIX}/plugins/${plugin.org}/${plugin.name}/icon`
}
export const getFormattedPlugin = (bundle: Plugin): Plugin => {
export const getFormattedPlugin = (bundle: any) => {
if (bundle.type === 'bundle') {
return {
...bundle,
icon: getPluginIconInMarketplace(bundle),
brief: bundle.description,
// @ts-expect-error I do not have enough information
label: bundle.labels,
}
}
@@ -133,64 +129,6 @@ export const getMarketplaceCollectionsAndPlugins = async (
}
}
export const getMarketplacePlugins = async (
queryParams: PluginsSearchParams | undefined,
pageParam: number,
signal?: AbortSignal,
) => {
if (!queryParams) {
return {
plugins: [] as Plugin[],
total: 0,
page: 1,
pageSize: 40,
}
}
const {
query,
sortBy,
sortOrder,
category,
tags,
type,
pageSize = 40,
} = queryParams
const pluginOrBundle = type === 'bundle' ? 'bundles' : 'plugins'
try {
const res = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>(`/${pluginOrBundle}/search/advanced`, {
body: {
page: pageParam,
page_size: pageSize,
query,
sort_by: sortBy,
sort_order: sortOrder,
category: category !== 'all' ? category : '',
tags,
type,
},
signal,
})
const resPlugins = res.data.bundles || res.data.plugins || []
return {
plugins: resPlugins.map(plugin => getFormattedPlugin(plugin)),
total: res.data.total,
page: pageParam,
pageSize,
}
}
catch {
return {
plugins: [],
total: 0,
page: pageParam,
pageSize,
}
}
}
export const getMarketplaceListCondition = (pluginType: string) => {
if ([PluginCategoryEnum.tool, PluginCategoryEnum.agent, PluginCategoryEnum.model, PluginCategoryEnum.datasource, PluginCategoryEnum.trigger].includes(pluginType as PluginCategoryEnum))
return `category=${pluginType}`
@@ -204,7 +142,7 @@ export const getMarketplaceListCondition = (pluginType: string) => {
return ''
}
export const getMarketplaceListFilterType = (category: ActivePluginType) => {
export const getMarketplaceListFilterType = (category: string) => {
if (category === PLUGIN_TYPE_SEARCH_MAP.all)
return undefined
@@ -213,14 +151,3 @@ export const getMarketplaceListFilterType = (category: ActivePluginType) => {
return 'plugin'
}
export function getCollectionsParams(category: ActivePluginType): CollectionsAndPluginsSearchParams {
if (category === PLUGIN_TYPE_SEARCH_MAP.all) {
return {}
}
return {
category,
condition: getMarketplaceListCondition(category),
type: getMarketplaceListFilterType(category),
}
}

View File

@@ -27,7 +27,7 @@ import { cn } from '@/utils/classnames'
import { PLUGIN_PAGE_TABS_MAP } from '../hooks'
import InstallFromLocalPackage from '../install-plugin/install-from-local-package'
import InstallFromMarketplace from '../install-plugin/install-from-marketplace'
import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/constants'
import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/plugin-type-switch'
import {
PluginPageContextProvider,
usePluginPageContext,

View File

@@ -262,7 +262,7 @@ vi.mock('@/app/components/base/icons/src/vender/other', () => ({
}))
// Mock PLUGIN_TYPE_SEARCH_MAP
vi.mock('../../marketplace/constants', () => ({
vi.mock('../../marketplace/plugin-type-switch', () => ({
PLUGIN_TYPE_SEARCH_MAP: {
all: 'all',
model: 'model',

View File

@@ -1,6 +1,5 @@
'use client'
import type { FC } from 'react'
import type { ActivePluginType } from '../../marketplace/constants'
import * as React from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -13,7 +12,7 @@ import {
import SearchBox from '@/app/components/plugins/marketplace/search-box'
import { useInstalledPluginList } from '@/service/use-plugins'
import { cn } from '@/utils/classnames'
import { PLUGIN_TYPE_SEARCH_MAP } from '../../marketplace/constants'
import { PLUGIN_TYPE_SEARCH_MAP } from '../../marketplace/plugin-type-switch'
import { PluginSource } from '../../types'
import NoDataPlaceholder from './no-data-placeholder'
import ToolItem from './tool-item'
@@ -74,7 +73,7 @@ const ToolPicker: FC<Props> = ({
},
]
const [pluginType, setPluginType] = useState<ActivePluginType>(PLUGIN_TYPE_SEARCH_MAP.all)
const [pluginType, setPluginType] = useState(PLUGIN_TYPE_SEARCH_MAP.all)
const [query, setQuery] = useState('')
const [tags, setTags] = useState<string[]>([])
const { data, isLoading } = useInstalledPluginList()

View File

@@ -195,7 +195,7 @@ const RunOnce: FC<IRunOnceProps> = ({
noWrapper
className="bg h-[80px] overflow-y-auto rounded-[10px] bg-components-input-bg-normal p-1"
placeholder={
<div className="whitespace-pre">{typeof item.json_schema === 'string' ? item.json_schema : JSON.stringify(item.json_schema || '', null, 2)}</div>
<div className="whitespace-pre">{item.json_schema}</div>
}
/>
)}

Some files were not shown because too many files have changed in this diff Show More