mirror of
https://github.com/langgenius/dify.git
synced 2026-01-12 01:12:01 +00:00
Compare commits
8 Commits
1.3.0
...
feat/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9bae7aafd | ||
|
|
48be8fb6cc | ||
|
|
dd02a9ac9d | ||
|
|
b203139356 | ||
|
|
c479fcf251 | ||
|
|
d7c3e54eaa | ||
|
|
d5fe50e471 | ||
|
|
205535c8e9 |
@@ -6,7 +6,7 @@
|
||||
|
||||
本指南和 Dify 一样在不断完善中。如果有任何滞后于项目实际情况的地方,恳请谅解,我们也欢迎任何改进建议。
|
||||
|
||||
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。社区同时也遵循[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。同时也请遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 开始之前
|
||||
|
||||
|
||||
@@ -80,8 +80,6 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
|
||||
0
api/core/memory/__init__.py
Normal file
0
api/core/memory/__init__.py
Normal file
64
api/core/memory/base_memory.py
Normal file
64
api/core/memory/base_memory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
|
||||
class BaseMemory:
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
200
api/core/memory/model_context_memory.py
Normal file
200
api/core/memory/model_context_memory.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.entities.advanced_prompt_entities import LLMMemoryType
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun
|
||||
|
||||
|
||||
class ModelContextMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, node_id: str, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.node_id = node_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
thread_messages = list(reversed(self._fetch_thread_messages(message_limit)))
|
||||
if not thread_messages:
|
||||
return []
|
||||
# Get all required workflow_run_ids
|
||||
workflow_run_ids = [msg.workflow_run_id for msg in thread_messages]
|
||||
|
||||
# Batch query all related WorkflowNodeExecution records
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.workflow_run_id.in_(workflow_run_ids),
|
||||
WorkflowNodeExecution.node_id == self.node_id,
|
||||
WorkflowNodeExecution.status.in_(
|
||||
[WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION]
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create mapping from workflow_run_id to node_execution
|
||||
node_execution_map = {ne.workflow_run_id: ne for ne in node_executions}
|
||||
|
||||
# Get the last node_execution
|
||||
last_node_execution = node_execution_map.get(thread_messages[-1].workflow_run_id)
|
||||
prompt_messages = self._get_prompt_messages_in_process_data(last_node_execution)
|
||||
|
||||
# Batch query all message-related files
|
||||
message_ids = [msg.id for msg in thread_messages]
|
||||
all_files = db.session.query(MessageFile).filter(MessageFile.message_id.in_(message_ids)).all()
|
||||
|
||||
# Create mapping from message_id to files
|
||||
files_map = {}
|
||||
for file in all_files:
|
||||
if file.message_id not in files_map:
|
||||
files_map[file.message_id] = []
|
||||
files_map[file.message_id].append(file)
|
||||
|
||||
for message in thread_messages:
|
||||
files = files_map.get(message.id, [])
|
||||
node_execution = node_execution_map.get(message.workflow_run_id)
|
||||
if node_execution and files:
|
||||
file_objs, detail = self._handle_file(message, files)
|
||||
if file_objs:
|
||||
outputs = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
if not outputs:
|
||||
continue
|
||||
if outputs not in [prompt.content for prompt in prompt_messages]:
|
||||
continue
|
||||
outputs_index = [prompt.content for prompt in prompt_messages].index(outputs)
|
||||
prompt_index = outputs_index - 1
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
content = cast(str, prompt_messages[prompt_index].content)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=content))
|
||||
for file in file_objs:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
prompt_messages[prompt_index].content = prompt_message_contents
|
||||
return prompt_messages
|
||||
|
||||
def _get_prompt_messages_in_process_data(
|
||||
self,
|
||||
node_execution: WorkflowNodeExecution,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Get prompt messages in process data.
|
||||
:param node_execution: node execution
|
||||
:return: prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if not node_execution.process_data:
|
||||
return []
|
||||
|
||||
try:
|
||||
process_data = json.loads(node_execution.process_data)
|
||||
if process_data.get("memory_type", "") != LLMMemoryType.INDEPENDENT:
|
||||
return []
|
||||
prompts = process_data.get("prompts", [])
|
||||
for prompt in prompts:
|
||||
prompt_content = prompt.get("text", "")
|
||||
if prompt.get("role", "") == "user":
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_content))
|
||||
elif prompt.get("role", "") == "assistant":
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt_content))
|
||||
output = node_execution.outputs_dict.get("text", "") if node_execution.outputs_dict else ""
|
||||
prompt_messages.append(AssistantPromptMessage(content=output))
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return prompt_messages
|
||||
|
||||
def _fetch_thread_messages(self, message_limit: int | None = None) -> list[Message]:
|
||||
"""
|
||||
Fetch thread messages.
|
||||
:param message_limit: message limit
|
||||
:return: thread messages
|
||||
"""
|
||||
query = (
|
||||
db.session.query(
|
||||
Message.id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message.created_at,
|
||||
Message.workflow_run_id,
|
||||
Message.parent_message_id,
|
||||
Message.answer_tokens,
|
||||
)
|
||||
.filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
)
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
message_limit = min(message_limit, 500)
|
||||
else:
|
||||
message_limit = 500
|
||||
|
||||
messages = query.limit(message_limit).all()
|
||||
|
||||
# fetch the thread messages
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
if not thread_messages:
|
||||
return []
|
||||
return thread_messages
|
||||
|
||||
def _handle_file(self, message: Message, files: list[MessageFile]):
|
||||
"""
|
||||
Handle file for memory.
|
||||
:param message: message
|
||||
:param files: files
|
||||
:return: file objects and detail
|
||||
"""
|
||||
file_extra_config = None
|
||||
if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
else:
|
||||
if message.workflow_run_id:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
app_record = self.conversation.app
|
||||
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
return file_objs, detail
|
||||
@@ -3,12 +3,12 @@ from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base_memory import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -20,7 +20,7 @@ from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
@@ -129,44 +129,3 @@ class TokenBufferMemory:
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -24,6 +25,11 @@ class CompletionModelPromptTemplate(BaseModel):
|
||||
edition_type: Optional[Literal["basic", "jinja2"]] = None
|
||||
|
||||
|
||||
class LLMMemoryType(str, Enum):
|
||||
INDEPENDENT = "independent"
|
||||
GLOBAL = "global"
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""
|
||||
Memory Config.
|
||||
@@ -48,3 +54,4 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: Optional[str] = None
|
||||
type: LLMMemoryType = LLMMemoryType.GLOBAL
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from constants import UUID_NIL
|
||||
from models.model import Message
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[Any]):
|
||||
def extract_thread_messages(messages: list[Message]) -> list[Message]:
|
||||
thread_messages = []
|
||||
next_message = None
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.model_context_memory import ModelContextMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
@@ -39,7 +40,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, LLMMemoryType, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
@@ -190,6 +191,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
),
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"memory_type": self.node_data.memory.type if self.node_data.memory else None,
|
||||
}
|
||||
|
||||
# handle invoke result
|
||||
@@ -553,10 +555,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
def _fetch_memory(
|
||||
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
) -> Optional[TokenBufferMemory | ModelContextMemory]:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
@@ -575,7 +576,15 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
memory = (
|
||||
TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
if node_data_memory.type == LLMMemoryType.GLOBAL
|
||||
else ModelContextMemory(
|
||||
conversation=conversation,
|
||||
node_id=self.node_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
@@ -585,7 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -1201,7 +1210,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -1218,7 +1227,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: TokenBufferMemory | ModelContextMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
|
||||
@@ -3,8 +3,8 @@ import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
@@ -13,9 +13,6 @@ from services.plugin.plugin_service import PluginService
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import Workflow
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from enum import Enum, StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Self, Union
|
||||
from uuid import uuid4
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, func
|
||||
@@ -608,6 +606,17 @@ class WorkflowNodeExecution(Base):
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_node_status_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
"status",
|
||||
),
|
||||
db.Index(
|
||||
"workflow_node_execution_run_status_idx",
|
||||
"workflow_run_id",
|
||||
"status",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
|
||||
@@ -48,6 +48,7 @@ const OPTION_MAP = {
|
||||
: ''},
|
||||
systemVariables: {
|
||||
// user_id: 'YOU CAN DEFINE USER ID HERE',
|
||||
// conversation_id: 'YOU CAN DEFINE CONVERSATION ID HERE, IT MUST BE A VALID UUID',
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -39,6 +39,7 @@ export type EmbeddedChatbotContextValue = {
|
||||
chatShouldReloadKey: string
|
||||
isMobile: boolean
|
||||
isInstalledApp: boolean
|
||||
allowResetChat: boolean
|
||||
appId?: string
|
||||
handleFeedback: (messageId: string, feedback: Feedback) => void
|
||||
currentChatInstanceRef: RefObject<{ handleStop: () => void }>
|
||||
@@ -67,6 +68,7 @@ export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>
|
||||
chatShouldReloadKey: '',
|
||||
isMobile: false,
|
||||
isInstalledApp: false,
|
||||
allowResetChat: true,
|
||||
handleFeedback: noop,
|
||||
currentChatInstanceRef: { current: { handleStop: noop } },
|
||||
clearChatList: false,
|
||||
|
||||
@@ -16,6 +16,7 @@ import cn from '@/utils/classnames'
|
||||
|
||||
export type IHeaderProps = {
|
||||
isMobile?: boolean
|
||||
allowResetChat?: boolean
|
||||
customerIcon?: React.ReactNode
|
||||
title: string
|
||||
theme?: Theme
|
||||
@@ -23,6 +24,7 @@ export type IHeaderProps = {
|
||||
}
|
||||
const Header: FC<IHeaderProps> = ({
|
||||
isMobile,
|
||||
allowResetChat,
|
||||
customerIcon,
|
||||
title,
|
||||
theme,
|
||||
@@ -57,7 +59,7 @@ const Header: FC<IHeaderProps> = ({
|
||||
{currentConversationId && (
|
||||
<Divider type='vertical' className='h-3.5' />
|
||||
)}
|
||||
{currentConversationId && (
|
||||
{currentConversationId && allowResetChat && (
|
||||
<Tooltip
|
||||
popupContent={t('share.chat.resetChat')}
|
||||
>
|
||||
@@ -89,7 +91,7 @@ const Header: FC<IHeaderProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex items-center gap-1'>
|
||||
{currentConversationId && (
|
||||
{currentConversationId && allowResetChat && (
|
||||
<Tooltip
|
||||
popupContent={t('share.chat.resetChat')}
|
||||
>
|
||||
|
||||
@@ -73,9 +73,11 @@ export const useEmbeddedChatbot = () => {
|
||||
const appId = useMemo(() => appData?.app_id, [appData])
|
||||
|
||||
const [userId, setUserId] = useState<string>()
|
||||
const [conversationId, setConversationId] = useState<string>()
|
||||
useEffect(() => {
|
||||
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
|
||||
getProcessedSystemVariablesFromUrlParams().then(({ user_id, conversation_id }) => {
|
||||
setUserId(user_id)
|
||||
setConversationId(conversation_id)
|
||||
})
|
||||
}, [])
|
||||
|
||||
@@ -109,7 +111,9 @@ export const useEmbeddedChatbot = () => {
|
||||
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
|
||||
defaultValue: {},
|
||||
})
|
||||
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
|
||||
const allowResetChat = !conversationId
|
||||
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || conversationId || '',
|
||||
[appId, conversationIdInfo, userId, conversationId])
|
||||
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
|
||||
if (appId) {
|
||||
let prevValue = conversationIdInfo?.[appId || '']
|
||||
@@ -362,6 +366,7 @@ export const useEmbeddedChatbot = () => {
|
||||
appInfoError,
|
||||
appInfoLoading,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
appId,
|
||||
currentConversationId,
|
||||
currentConversationItem,
|
||||
|
||||
@@ -25,6 +25,7 @@ import cn from '@/utils/classnames'
|
||||
const Chatbot = () => {
|
||||
const {
|
||||
isMobile,
|
||||
allowResetChat,
|
||||
appInfoError,
|
||||
appInfoLoading,
|
||||
appData,
|
||||
@@ -90,6 +91,7 @@ const Chatbot = () => {
|
||||
>
|
||||
<Header
|
||||
isMobile={isMobile}
|
||||
allowResetChat={allowResetChat}
|
||||
title={site?.title || ''}
|
||||
customerIcon={isDify() ? difyIcon : ''}
|
||||
theme={themeBuilder?.theme}
|
||||
@@ -153,6 +155,7 @@ const EmbeddedChatbotWrapper = () => {
|
||||
handleNewConversationCompleted,
|
||||
chatShouldReloadKey,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
appId,
|
||||
handleFeedback,
|
||||
currentChatInstanceRef,
|
||||
@@ -187,6 +190,7 @@ const EmbeddedChatbotWrapper = () => {
|
||||
chatShouldReloadKey,
|
||||
isMobile,
|
||||
isInstalledApp,
|
||||
allowResetChat,
|
||||
appId,
|
||||
handleFeedback,
|
||||
currentChatInstanceRef,
|
||||
|
||||
@@ -87,7 +87,7 @@ const Doc = ({ appDetail }: IDocProps) => {
|
||||
<div className={`fixed right-8 top-32 z-10 transition-all ${isTocExpanded ? 'w-64' : 'w-10'}`}>
|
||||
{isTocExpanded
|
||||
? (
|
||||
<nav className="toc w-full rounded-lg bg-components-panel-bg p-4 shadow-md">
|
||||
<nav className="toc max-h-[calc(100vh-150px)] w-full overflow-y-auto rounded-lg bg-components-panel-bg p-4 shadow-md">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<h3 className="text-lg font-semibold text-text-primary">{t('appApi.develop.toc')}</h3>
|
||||
<button
|
||||
|
||||
@@ -4,12 +4,13 @@ import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import produce from 'immer'
|
||||
import type { Memory } from '../../../types'
|
||||
import { MemoryRole } from '../../../types'
|
||||
import { LLMMemoryType, MemoryRole } from '../../../types'
|
||||
import cn from '@/utils/classnames'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Slider from '@/app/components/base/slider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.common.memory'
|
||||
const WINDOW_SIZE_MIN = 1
|
||||
@@ -54,6 +55,7 @@ type Props = {
|
||||
const MEMORY_DEFAULT: Memory = {
|
||||
window: { enabled: false, size: WINDOW_SIZE_DEFAULT },
|
||||
query_prompt_template: '{{#sys.query#}}',
|
||||
type: LLMMemoryType.GLOBAL,
|
||||
}
|
||||
|
||||
const MemoryConfig: FC<Props> = ({
|
||||
@@ -178,6 +180,24 @@ const MemoryConfig: FC<Props> = ({
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<SimpleSelect
|
||||
items={[{
|
||||
value: LLMMemoryType.INDEPENDENT,
|
||||
name: 'Individual memory',
|
||||
}, {
|
||||
value: LLMMemoryType.GLOBAL,
|
||||
name: 'Global memory',
|
||||
}]}
|
||||
onSelect={(value) => {
|
||||
const newPayload = produce(payload || MEMORY_DEFAULT, (draft) => {
|
||||
draft.type = value.value as LLMMemoryType
|
||||
})
|
||||
onChange(newPayload)
|
||||
}}
|
||||
defaultValue={payload.type}
|
||||
/>
|
||||
</div>
|
||||
{canSetRoleName && (
|
||||
<div className='mt-4'>
|
||||
<div className='text-xs font-medium uppercase leading-6 text-text-tertiary'>{t(`${i18nPrefix}.conversationRoleName`)}</div>
|
||||
|
||||
@@ -234,6 +234,11 @@ export type RolePrefix = {
|
||||
assistant: string
|
||||
}
|
||||
|
||||
export enum LLMMemoryType {
|
||||
INDEPENDENT = 'independent',
|
||||
GLOBAL = 'global',
|
||||
}
|
||||
|
||||
export type Memory = {
|
||||
role_prefix?: RolePrefix
|
||||
window: {
|
||||
@@ -241,6 +246,7 @@ export type Memory = {
|
||||
size: number | string | null
|
||||
}
|
||||
query_prompt_template: string
|
||||
type: LLMMemoryType
|
||||
}
|
||||
|
||||
export enum VarType {
|
||||
|
||||
@@ -54,6 +54,9 @@ const LocaleLayout = async ({
|
||||
data-public-indexing-max-segmentation-tokens-length={process.env.NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH}
|
||||
data-public-loop-node-max-count={process.env.NEXT_PUBLIC_LOOP_NODE_MAX_COUNT}
|
||||
data-public-max-iterations-num={process.env.NEXT_PUBLIC_MAX_ITERATIONS_NUM}
|
||||
data-public-enable-website-jinareader={process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER}
|
||||
data-public-enable-website-firecrawl={process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL}
|
||||
data-public-enable-website-watercrawl={process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL}
|
||||
>
|
||||
<BrowserInitor>
|
||||
<SentryInitor>
|
||||
|
||||
@@ -306,12 +306,12 @@ export const MAX_ITERATIONS_NUM = maxIterationsNum
|
||||
|
||||
export const ENABLE_WEBSITE_JINAREADER = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER !== undefined
|
||||
? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER === 'true'
|
||||
: true
|
||||
: globalThis.document?.body?.getAttribute('data-public-enable-website-jinareader') === 'true' || true
|
||||
|
||||
export const ENABLE_WEBSITE_FIRECRAWL = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL !== undefined
|
||||
? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL === 'true'
|
||||
: true
|
||||
: globalThis.document?.body?.getAttribute('data-public-enable-website-firecrawl') === 'true' || true
|
||||
|
||||
export const ENABLE_WEBSITE_WATERCRAWL = process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL !== undefined
|
||||
? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL === 'true'
|
||||
: true
|
||||
: globalThis.document?.body?.getAttribute('data-public-enable-website-watercrawl') === 'true' || true
|
||||
|
||||
Reference in New Issue
Block a user