mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 06:17:04 +00:00
Compare commits
18 Commits
test/workf
...
yanli/phas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
710ac3b90a | ||
|
|
8548498f25 | ||
|
|
d014f0b91a | ||
|
|
cc5aac268a | ||
|
|
4c1d27431b | ||
|
|
9a86f280eb | ||
|
|
c5920fb28a | ||
|
|
a87b928079 | ||
|
|
93f9546353 | ||
|
|
2f81d5dfdf | ||
|
|
7639d8e43f | ||
|
|
1dce81c604 | ||
|
|
f874ca183e | ||
|
|
0d805e624e | ||
|
|
61196180b8 | ||
|
|
79433b0091 | ||
|
|
c4aeaa35d4 | ||
|
|
9f0d79b8b0 |
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -47,7 +47,6 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -522,8 +521,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
workflow = _refresh_model(session=session, model=workflow)
|
||||
message = _refresh_model(session=session, model=message)
|
||||
assert message is not None
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
@@ -690,11 +690,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
@overload
|
||||
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
@overload
|
||||
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
|
||||
|
||||
|
||||
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
|
||||
if session is not None:
|
||||
detached_model = session.get(type(model), model.id)
|
||||
assert detached_model is not None
|
||||
return detached_model
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
|
||||
detached_model = refresh_session.get(type(model), model.id)
|
||||
assert detached_model is not None
|
||||
return detached_model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_full_response(stream_response)
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_simple_response(stream_response)
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ class BaseAppGenerator:
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
assert isinstance(account, Account)
|
||||
debug_account = account
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
@@ -240,7 +241,7 @@ class BaseAppGenerator:
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
user=account,
|
||||
user=debug_account,
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
@@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
generated_conversation_id = str(conversation.id)
|
||||
generated_message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
conversation_id=generated_conversation_id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
message_id=generated_message_id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
conversation_id=generated_conversation_id,
|
||||
message_id=generated_message_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
conversation_id = str(conversation.id)
|
||||
message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
conversation_id=conversation_id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Protocol, TypeAlias
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GraphConfigObject: TypeAlias = dict[str, object]
|
||||
GraphConfigMapping: TypeAlias = Mapping[str, object]
|
||||
|
||||
|
||||
class SingleNodeRunEntity(Protocol):
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
@@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: GraphConfigMapping,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
single_iteration_run: SingleNodeRunEntity | None = None,
|
||||
single_loop_run: SingleNodeRunEntity | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
@staticmethod
|
||||
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
|
||||
nodes = graph_config.get("nodes")
|
||||
edges = graph_config.get("edges")
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
if not isinstance(edges, list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
validated_nodes: list[GraphConfigMapping] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
raise ValueError("nodes in workflow graph must be mappings")
|
||||
validated_nodes.append(node)
|
||||
|
||||
validated_edges: list[GraphConfigMapping] = []
|
||||
for edge in edges:
|
||||
if not isinstance(edge, Mapping):
|
||||
raise ValueError("edges in workflow graph must be mappings")
|
||||
validated_edges.append(edge)
|
||||
|
||||
return validated_nodes, validated_edges
|
||||
|
||||
@staticmethod
|
||||
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
|
||||
if node_config is None:
|
||||
return None
|
||||
node_data = node_config.get("data")
|
||||
if not isinstance(node_data, Mapping):
|
||||
return None
|
||||
start_node_id = node_data.get("start_node_id")
|
||||
return start_node_id if isinstance(start_node_id, str) else None
|
||||
|
||||
@classmethod
|
||||
def _build_single_node_graph_config(
|
||||
cls,
|
||||
*,
|
||||
graph_config: GraphConfigMapping,
|
||||
node_id: str,
|
||||
node_type_filter_key: str,
|
||||
) -> tuple[GraphConfigObject, NodeConfigDict]:
|
||||
node_configs, edge_configs = cls._get_graph_items(graph_config)
|
||||
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
|
||||
start_node_id = cls._extract_start_node_id(main_node_config)
|
||||
|
||||
filtered_node_configs = [
|
||||
dict(node)
|
||||
for node in node_configs
|
||||
if node.get("id") == node_id
|
||||
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
if not filtered_node_configs:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
filtered_node_ids = {
|
||||
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
|
||||
}
|
||||
filtered_edge_configs = [
|
||||
dict(edge)
|
||||
for edge in edge_configs
|
||||
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
|
||||
]
|
||||
|
||||
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
|
||||
if target_node_config is None:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
return (
|
||||
{
|
||||
"nodes": filtered_node_configs,
|
||||
"edges": filtered_edge_configs,
|
||||
},
|
||||
NodeConfigDictAdapter.validate_python(target_node_config),
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
user_inputs: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
@@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
graph_config, target_node_config = self._build_single_node_graph_config(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_type_filter_key=node_type_filter_key,
|
||||
)
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
|
||||
@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@@ -1045,9 +1045,10 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "constant":
|
||||
parameter_value = tool_input.value
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
AgentInputConstantValue: TypeAlias = (
|
||||
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
|
||||
)
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
|
||||
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
@@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: AgentInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> AgentInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "variable":
|
||||
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type in {"mixed", "constant"}:
|
||||
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown agent input type: {input_type}")
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
@@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
JsonObject: TypeAlias = dict[str, object]
|
||||
JsonObjectList: TypeAlias = list[JsonObject]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
@@ -39,12 +48,12 @@ class AgentRuntimeSupport:
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@@ -54,9 +63,10 @@ class AgentRuntimeSupport:
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
@@ -79,60 +89,38 @@ class AgentRuntimeSupport:
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
value = self._normalize_tool_payloads(
|
||||
strategy=strategy,
|
||||
tools=tool_payloads,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
provider_type = self._coerce_tool_provider_type(tool.get("type"))
|
||||
setting_params = self._coerce_json_object(tool.get("settings")) or {}
|
||||
parameters = self._coerce_json_object(tool.get("parameters")) or {}
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
|
||||
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
|
||||
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_name=tool_name,
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
extra = self._coerce_json_object(tool.get("extra")) or {}
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
@@ -145,8 +133,9 @@ class AgentRuntimeSupport:
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
description_override = self._coerce_optional_string(extra.get("description"))
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
description_override or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
@@ -167,13 +156,13 @@ class AgentRuntimeSupport:
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"credential_id": credential_id,
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
value = _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
@@ -199,17 +188,27 @@ class AgentRuntimeSupport:
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
tools = parameters.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
return credentials
|
||||
|
||||
for raw_tool in tools:
|
||||
tool = self._coerce_json_object(raw_tool)
|
||||
if tool is None:
|
||||
continue
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
if credential_id is None:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = credential_id
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
@@ -232,14 +231,14 @@ class AgentRuntimeSupport:
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
provider=str(value.get("provider", "")),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_name = str(value.get("model", ""))
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
@@ -249,7 +248,7 @@ class AgentRuntimeSupport:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model_type=ModelType(str(value.get("model_type", ""))),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
@@ -268,9 +267,88 @@ class AgentRuntimeSupport:
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
tools: JsonObjectList,
|
||||
) -> JsonObjectList:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _normalize_tool_payloads(
|
||||
self,
|
||||
*,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: JsonObjectList,
|
||||
variable_pool: VariablePool,
|
||||
) -> JsonObjectList:
|
||||
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
|
||||
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
|
||||
for tool in normalized_tools:
|
||||
tool.pop("schemas", None)
|
||||
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
|
||||
tool["settings"] = self._resolve_tool_settings(tool)
|
||||
return normalized_tools
|
||||
|
||||
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
|
||||
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
|
||||
if parameter_configs is None:
|
||||
raw_parameters = self._coerce_json_object(tool.get("parameters"))
|
||||
return raw_parameters or {}
|
||||
|
||||
resolved_parameters: JsonObject = {}
|
||||
for key, parameter_config in parameter_configs.items():
|
||||
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
|
||||
value_param = self._coerce_json_object(parameter_config.get("value"))
|
||||
if value_param and value_param.get("type") == "variable":
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
resolved_parameters[key] = variable.value
|
||||
else:
|
||||
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
resolved_parameters[key] = None
|
||||
|
||||
return resolved_parameters
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
|
||||
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
|
||||
if settings is None:
|
||||
return {}
|
||||
return {key: setting.get("value") for key, setting in settings.items()}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json_object(value: object) -> JsonObject | None:
|
||||
try:
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
except ValidationError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_optional_string(value: object) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
|
||||
if isinstance(value, ToolProviderType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return ToolProviderType(value)
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@classmethod
|
||||
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
coerced: dict[str, JsonObject] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
return None
|
||||
json_object = cls._coerce_json_object(item)
|
||||
if json_object is None:
|
||||
return None
|
||||
coerced[key] = json_object
|
||||
return coerced
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@@ -32,6 +32,13 @@ from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SpecialValueScalar: TypeAlias = str | int | float | bool | None
|
||||
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
|
||||
SerializedSpecialValue: TypeAlias = (
|
||||
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
|
||||
)
|
||||
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
@@ -276,10 +283,10 @@ class WorkflowEntry:
|
||||
@staticmethod
|
||||
def _create_single_node_graph(
|
||||
node_id: str,
|
||||
node_data: dict[str, Any],
|
||||
node_data: Mapping[str, object],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> dict[str, Any]:
|
||||
) -> SingleNodeGraphConfig:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
@@ -289,14 +296,14 @@ class WorkflowEntry:
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config = {
|
||||
node_config: dict[str, object] = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": node_data,
|
||||
"data": dict(node_data),
|
||||
}
|
||||
start_node_config = {
|
||||
start_node_config: dict[str, object] = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
@@ -321,7 +328,12 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
cls,
|
||||
node_data: Mapping[str, object],
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_inputs: Mapping[str, object],
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
@@ -339,6 +351,8 @@ class WorkflowEntry:
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = node_data.get("type", "")
|
||||
if not isinstance(node_type, str):
|
||||
raise ValueError("Node type must be a string")
|
||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
@@ -369,7 +383,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@@ -405,30 +419,34 @@ class WorkflowEntry:
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
raise TypeError("handle_special_values expects a mapping input")
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any):
|
||||
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
if isinstance(value, Mapping):
|
||||
res: dict[str, SerializedSpecialValue] = {}
|
||||
for k, v in value.items():
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
res_list: list[SerializedSpecialValue] = []
|
||||
for item in value:
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
return dict(value.to_dict())
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
else:
|
||||
return
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
|
||||
@@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_executions = graph_execution.node_executions
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
@@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self.execution_id
|
||||
yield event
|
||||
yield event.model_copy(update={"id": self.execution_id})
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
|
||||
@@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
it = iter(doc.element.body)
|
||||
doc_body = getattr(doc.element, "body", None)
|
||||
if doc_body is None:
|
||||
raise TextExtractionError("DOCX body not found")
|
||||
it = iter(doc_body)
|
||||
part = next(it, None)
|
||||
i = 0
|
||||
while part is not None:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Literal, NotRequired
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class StructuredOutputConfig(TypedDict):
|
||||
schema: Mapping[str, object]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
completion_params: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
@@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
def convert_none_configs(cls, v: object):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
@@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
def convert_none_jinja2_variables(cls, v: object):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
@@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
structured_output: StructuredOutputConfig | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
@@ -90,11 +97,30 @@ class LLMNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
def convert_none_prompt_config(cls, v: object):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@field_validator("structured_output", mode="before")
|
||||
@classmethod
|
||||
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
|
||||
if not isinstance(v, Mapping):
|
||||
return v
|
||||
|
||||
schema = v.get("schema")
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
normalized: StructuredOutputConfig = {"schema": schema}
|
||||
name = v.get("name")
|
||||
description = v.get("description")
|
||||
if isinstance(name, str):
|
||||
normalized["name"] = name
|
||||
if isinstance(description, str):
|
||||
normalized["description"] = description
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
||||
@@ -9,6 +9,7 @@ import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@@ -74,6 +75,7 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
StructuredOutputConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -88,6 +90,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
|
||||
class LLMNode(Node[LLMNodeData]):
|
||||
@@ -354,7 +357,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
structured_output: StructuredOutputConfig | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@@ -367,8 +370,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
if structured_output is None:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
structured_output=structured_output,
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@@ -920,6 +925,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
structured_output = (
|
||||
dict(invoke_result.structured_output)
|
||||
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
|
||||
else None
|
||||
)
|
||||
|
||||
event = ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
@@ -928,7 +939,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if enabled
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
structured_output=structured_output,
|
||||
)
|
||||
if request_latency is not None:
|
||||
event.usage.latency = round(request_latency, 3)
|
||||
@@ -962,27 +973,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
structured_output: StructuredOutputConfig,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
dict[str, object]: The structured output schema
|
||||
"""
|
||||
if not structured_output:
|
||||
schema = structured_output.get("schema")
|
||||
if not schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
if not isinstance(schema, dict):
|
||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||
return schema
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(schema)
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, TypeAlias, cast
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
@@ -9,6 +12,12 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
@@ -29,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
return seg_type
|
||||
|
||||
|
||||
def _validate_loop_value(value: object) -> LoopValue:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return cast(LoopValue, value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_validate_loop_value(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
normalized: dict[str, LoopValue] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop values only support string object keys")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
|
||||
|
||||
|
||||
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("Loop outputs must be an object")
|
||||
|
||||
normalized: LoopValueMapping = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop output keys must be strings")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
@@ -37,7 +76,29 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Any | list[str] | None = None
|
||||
value: LoopValue | VariableSelector | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
||||
value_type = validation_info.data.get("value_type")
|
||||
if value_type == "variable":
|
||||
if value is None:
|
||||
raise ValueError("Variable loop inputs require a selector")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if value_type == "constant":
|
||||
return _validate_loop_value(value)
|
||||
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.value_type != "variable":
|
||||
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
def require_constant_value(self) -> LoopValue:
|
||||
if self.value_type != "constant":
|
||||
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||
return _validate_loop_value(self.value)
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@@ -46,14 +107,14 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
outputs: LoopValueMapping = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||
if value is None:
|
||||
return {}
|
||||
return v
|
||||
return _validate_loop_value_mapping(value)
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@@ -77,8 +138,8 @@ class LoopState(BaseLoopState):
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
outputs: list[LoopValue] = Field(default_factory=list)
|
||||
current_output: LoopValue | None = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@@ -87,7 +148,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Any:
|
||||
def get_last_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@@ -95,7 +156,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any:
|
||||
def get_current_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
|
||||
)
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
inputs: dict[str, object] = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
@@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
loop_variable_selectors: dict[str, list[str]] = {}
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
|
||||
"variable": lambda var: (
|
||||
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
|
||||
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
|
||||
if var.value is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
@@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||
|
||||
@@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
single_loop_variable: dict[str, LoopValue] = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
@@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
variable_mapping = {}
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
node_configs: dict[str, Mapping[str, object]] = {}
|
||||
if isinstance(raw_nodes, list):
|
||||
for raw_node in raw_nodes:
|
||||
if not isinstance(raw_node, dict):
|
||||
continue
|
||||
raw_node_id = raw_node.get("id")
|
||||
if isinstance(raw_node_id, str):
|
||||
node_configs[raw_node_id] = raw_node
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
sub_node_data = sub_node_config.get("data")
|
||||
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
selector = loop_variable.require_variable_selector()
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
@@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
@@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
loop_node_ids: set[str] = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
if not isinstance(raw_nodes, list):
|
||||
return loop_node_ids
|
||||
|
||||
for node in raw_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
node_data = node.get("data")
|
||||
if not isinstance(node_data, dict):
|
||||
continue
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
@@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
@@ -389,11 +406,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
# New typed payloads may already provide native lists, while legacy
|
||||
# configs still serialize array constants as JSON strings.
|
||||
if isinstance(original_value, str):
|
||||
value = json.loads(original_value) if original_value else []
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
value = original_value
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -6,6 +6,7 @@ from pydantic import (
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value) -> str:
|
||||
def validate_name(cls, value: object) -> str:
|
||||
if not value:
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
@@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
|
||||
return element_type
|
||||
|
||||
|
||||
class JsonSchemaArrayItems(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterJsonSchemaProperty(TypedDict, total=False):
|
||||
description: str
|
||||
type: str
|
||||
items: JsonSchemaArrayItems
|
||||
enum: list[str]
|
||||
|
||||
|
||||
class ParameterJsonSchema(TypedDict):
|
||||
type: Literal["object"]
|
||||
properties: dict[str, ParameterJsonSchemaProperty]
|
||||
required: list[str]
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
@@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
def set_reasoning_mode(cls, v: object) -> str:
|
||||
return str(v) if v else "function_call"
|
||||
|
||||
def get_parameter_json_schema(self):
|
||||
def get_parameter_json_schema(self) -> ParameterJsonSchema:
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
|
||||
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
@@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
parameter_schema["type"] = parameter.type.value
|
||||
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
@@ -5,6 +5,8 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -63,6 +65,7 @@ from .prompts import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
@@ -70,7 +73,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
def extract_json(text: str) -> str | None:
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
@@ -392,10 +395,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
# generate tool
|
||||
parameter_schema = node_data.get_parameter_json_schema()
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
parameters={
|
||||
"type": parameter_schema["type"],
|
||||
"properties": dict(parameter_schema["properties"]),
|
||||
"required": list(parameter_schema["required"]),
|
||||
},
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
@@ -602,19 +610,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result: dict[str, Any] = {}
|
||||
transformed_result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
if isinstance(param_value, (bool, int, float, str)):
|
||||
numeric_value: bool | int | float | str = param_value
|
||||
transformed = self._transform_number(numeric_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
@@ -661,7 +671,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@@ -672,11 +682,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@@ -690,16 +700,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
|
||||
@@ -1,12 +1,66 @@
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Literal, TypeAlias, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class WorkflowToolInputValue(TypedDict):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
|
||||
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
|
||||
|
||||
|
||||
class ToolInputPayload(BaseModel):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
|
||||
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return cast(ToolConfigurationEntry, value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
|
||||
|
||||
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@@ -14,52 +68,29 @@ class ToolEntity(BaseModel):
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
tool_configurations: ToolConfigurations
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
raise TypeError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
normalized: ToolConfigurations = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("tool_configurations keys must be strings")
|
||||
normalized[key] = _validate_tool_configuration_entry(item)
|
||||
return normalized
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = BuiltinNodeTypes.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
|
||||
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
|
||||
return typ
|
||||
class ToolInput(ToolInputPayload):
|
||||
pass
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
@@ -69,7 +100,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
def filter_none_tool_inputs(cls, value: object) -> object:
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
@@ -80,8 +111,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
def _has_valid_value(tool_input: object) -> bool:
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
if isinstance(tool_input, ToolNodeData.ToolInput):
|
||||
return tool_input.value is not None
|
||||
return False
|
||||
|
||||
@@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
@@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
variable_selector = input.require_variable_selector()
|
||||
selector_key = ".".join(variable_selector)
|
||||
result[f"#{selector_key}#"] = variable_selector
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from dify_graph.variables import SegmentType, VariableBase
|
||||
from dify_graph.variables import Segment, SegmentType, VariableBase
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
@@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
income_value: Segment
|
||||
updated_variable: VariableBase
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
@@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
@property
|
||||
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
|
||||
"""Return node execution state keyed by node id for resume support."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
@@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeExecutionProtocol(Protocol):
|
||||
"""Structural interface for per-node execution state used during resume."""
|
||||
|
||||
execution_id: str | None
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
|
||||
@@ -13,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
@@ -108,35 +93,6 @@ core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
|
||||
|
||||
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
|
||||
refreshed = _refresh_model(session=None, model=source_model)
|
||||
|
||||
assert refreshed is detached_model
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -33,8 +33,8 @@ def _make_graph_state():
|
||||
],
|
||||
)
|
||||
def test_run_uses_single_node_execution_branch(
|
||||
single_iteration_run: Any,
|
||||
single_loop_run: Any,
|
||||
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
|
||||
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
|
||||
) -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
@@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "other-node",
|
||||
"target": "loop-node",
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
original_graph_dict = deepcopy(workflow.graph_dict)
|
||||
|
||||
_, _, graph_runtime_state = _make_graph_state()
|
||||
seen_configs: list[object] = []
|
||||
@@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
|
||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
|
||||
):
|
||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
@@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
assert workflow.graph_dict == original_graph_dict
|
||||
graph_config = graph_init.call_args.kwargs["graph_config"]
|
||||
assert graph_config is not workflow.graph_dict
|
||||
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
|
||||
assert graph_config["edges"] == []
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.nodes.loop.entities import LoopNodeData
|
||||
from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
|
||||
from dify_graph.nodes.loop.loop_node import LoopNode
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
||||
@@ -50,3 +53,21 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
|
||||
)
|
||||
|
||||
assert seen_configs == [child_node_config]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("var_type", "original_value", "expected_value"),
|
||||
[
|
||||
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
|
||||
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
|
||||
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
|
||||
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
|
||||
],
|
||||
)
|
||||
def test_get_segment_for_constant_accepts_native_array_values(
|
||||
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
|
||||
) -> None:
|
||||
segment = LoopNode._get_segment_for_constant(var_type, original_value)
|
||||
|
||||
assert segment.value_type == var_type
|
||||
assert segment.value == expected_value
|
||||
|
||||
@@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type'
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ALL_PLANS } from '@/app/components/billing/config'
|
||||
import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item'
|
||||
@@ -21,7 +22,6 @@ let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockFetchSubscriptionUrls = vi.fn()
|
||||
const mockInvoices = vi.fn()
|
||||
const mockOpenAsyncWindow = vi.fn()
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
// ─── Context mocks ───────────────────────────────────────────────────────────
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@@ -49,10 +49,6 @@ vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => mockOpenAsyncWindow,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
@@ -82,12 +78,15 @@ const renderCloudPlanItem = ({
|
||||
canPay = true,
|
||||
}: RenderCloudPlanItemOptions = {}) => {
|
||||
return render(
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>,
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
|
||||
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
|
||||
@@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
// Should not proceed with payment
|
||||
expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled()
|
||||
|
||||
@@ -10,12 +10,12 @@
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config'
|
||||
import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item'
|
||||
import { SelfHostedPlan } from '@/app/components/billing/type'
|
||||
|
||||
let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
const originalLocation = window.location
|
||||
let assignedHref = ''
|
||||
@@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({
|
||||
AwsMarketplaceDark: () => <span data-testid="icon-aws-dark" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({
|
||||
default: ({ plan }: { plan: string }) => (
|
||||
<div data-testid={`self-hosted-list-${plan}`}>Features</div>
|
||||
@@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record<string, unknown> = {}) => {
|
||||
}
|
||||
}
|
||||
|
||||
const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<SelfHostedPlanItem plan={plan} />
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Self-Hosted Plan Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
|
||||
// Mock window.location with minimal getter/setter (Location props are non-enumerable)
|
||||
@@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
// ─── 1. Plan Rendering ──────────────────────────────────────────────────
|
||||
describe('Plan rendering', () => {
|
||||
it('should render community plan with name and description', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render premium plan with cloud provider icons', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('icon-azure')).toBeInTheDocument()
|
||||
@@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
})
|
||||
|
||||
it('should render enterprise plan without cloud provider icons', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show price tip for community (free) plan', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show price tip for premium plan', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render features list for each plan', () => {
|
||||
const { unmount: unmount1 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument()
|
||||
unmount1()
|
||||
|
||||
const { unmount: unmount2 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument()
|
||||
unmount2()
|
||||
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show AWS marketplace icon for premium plan button', () => {
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument()
|
||||
})
|
||||
@@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
describe('Navigation flow', () => {
|
||||
it('should redirect to GitHub when clicking community plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to AWS Marketplace when clicking premium plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to Typeform when clicking enterprise plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks community button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
// Should NOT redirect
|
||||
expect(assignedHref).toBe('')
|
||||
@@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks premium button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
@@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks enterprise button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
|
||||
@@ -13,7 +13,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
@@ -91,9 +91,9 @@ export default function OAuthAuthorize() {
|
||||
globalThis.location.href = url.toString()
|
||||
}
|
||||
catch (err: any) {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
|
||||
title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,10 +102,10 @@ export default function OAuthAuthorize() {
|
||||
const invalidParams = !client_id || !redirect_uri
|
||||
if ((invalidParams || isError) && !hasNotifiedRef.current) {
|
||||
hasNotifiedRef.current = true
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
duration: 0,
|
||||
title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
timeout: 0,
|
||||
})
|
||||
}
|
||||
}, [client_id, redirect_uri, isError])
|
||||
|
||||
@@ -39,8 +39,8 @@ vi.mock('../app-card', () => ({
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
default: () => <div data-testid="create-from-template-modal" />,
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: vi.fn() },
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: { add: vi.fn() },
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
|
||||
@@ -12,7 +12,7 @@ import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import CreateAppModal from '@/app/components/explore/create-app-modal'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
@@ -137,9 +137,9 @@ const Apps = ({
|
||||
})
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('newApp.appCreated', { ns: 'app' }),
|
||||
title: t('newApp.appCreated', { ns: 'app' }),
|
||||
})
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
@@ -149,7 +149,7 @@ const Apps = ({
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push)
|
||||
}
|
||||
catch {
|
||||
Toast.notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) })
|
||||
toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* @deprecated Use `@/app/components/base/ui/toast` instead.
|
||||
* This module will be removed after migration is complete.
|
||||
* See: https://github.com/langgenius/dify/issues/32811
|
||||
*/
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import { createContext, useContext } from 'use-context-selector'
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export type IToastProps = {
|
||||
type?: 'success' | 'error' | 'warning' | 'info'
|
||||
size?: 'md' | 'sm'
|
||||
@@ -19,5 +26,8 @@ type IToastContext = {
|
||||
close: () => void
|
||||
}
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const ToastContext = createContext<IToastContext>({} as IToastContext)
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const useToastContext = () => useContext(ToastContext)
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* @deprecated Use `@/app/components/base/ui/toast` instead.
|
||||
* This component will be removed after migration is complete.
|
||||
* See: https://github.com/langgenius/dify/issues/32811
|
||||
*/
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import type { IToastProps } from './context'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
@@ -12,6 +19,7 @@ import { ToastContext, useToastContext } from './context'
|
||||
export type ToastHandle = {
|
||||
clear?: VoidFunction
|
||||
}
|
||||
|
||||
const Toast = ({
|
||||
type = 'info',
|
||||
size = 'md',
|
||||
@@ -74,6 +82,7 @@ const Toast = ({
|
||||
)
|
||||
}
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const ToastProvider = ({
|
||||
children,
|
||||
}: {
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../../config'
|
||||
import { Plan } from '../../../../type'
|
||||
import { PlanRange } from '../../../plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '../index'
|
||||
|
||||
vi.mock('../../../../../base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
@@ -47,11 +41,19 @@ const mockUseAppContext = useAppContext as Mock
|
||||
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
|
||||
const mockBillingInvoices = consoleClient.billing.invoices as Mock
|
||||
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
|
||||
const mockToastNotify = Toast.notify as Mock
|
||||
|
||||
let assignedHref = ''
|
||||
const originalLocation = window.location
|
||||
|
||||
const renderWithToastHost = (ui: React.ReactNode) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
{ui}
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
configurable: true,
|
||||
@@ -68,6 +70,7 @@ beforeAll(() => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toast.close()
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
|
||||
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
|
||||
mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' })
|
||||
@@ -163,7 +166,7 @@ describe('CloudPlanItem', () => {
|
||||
it('should show toast when non-manager tries to buy a plan', () => {
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
|
||||
|
||||
render(
|
||||
renderWithToastHost(
|
||||
<CloudPlanItem
|
||||
plan={Plan.professional}
|
||||
currentPlan={Plan.sandbox}
|
||||
@@ -173,10 +176,7 @@ describe('CloudPlanItem', () => {
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }))
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: 'billing.buyPermissionDeniedTip',
|
||||
}))
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockBillingInvoices).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import type { BasicPlan } from '../../../type'
|
||||
import * as React from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../config'
|
||||
import { Plan } from '../../../type'
|
||||
import { Professional, Sandbox, Team } from '../../assets'
|
||||
@@ -66,10 +66,9 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
return
|
||||
|
||||
if (!isCurrentWorkspaceManager) {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
className: 'z-[1001]',
|
||||
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -83,7 +82,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
throw new Error('Failed to open billing page')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
Toast.notify({ type: 'error', message: err.message || String(err) })
|
||||
toast.add({ type: 'error', title: err.message || String(err) })
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -111,34 +110,34 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
{
|
||||
isMostPopularPlan && (
|
||||
<div className="flex items-center justify-center bg-saas-dify-blue-static px-1.5 py-1">
|
||||
<span className="system-2xs-semibold-uppercase text-text-primary-on-surface">
|
||||
<span className="text-text-primary-on-surface system-2xs-semibold-uppercase">
|
||||
{t('plansCommon.mostPopular', { ns: 'billing' })}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className="system-sm-regular text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
<div className="text-text-secondary system-sm-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* Price */}
|
||||
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
|
||||
{isFreePlan && (
|
||||
<span className="title-4xl-semi-bold text-text-primary">{t('plansCommon.free', { ns: 'billing' })}</span>
|
||||
<span className="text-text-primary title-4xl-semi-bold">{t('plansCommon.free', { ns: 'billing' })}</span>
|
||||
)}
|
||||
{!isFreePlan && (
|
||||
<>
|
||||
{isYear && (
|
||||
<span className="title-4xl-semi-bold text-text-quaternary line-through">
|
||||
<span className="text-text-quaternary line-through title-4xl-semi-bold">
|
||||
$
|
||||
{planInfo.price * 12}
|
||||
</span>
|
||||
)}
|
||||
<span className="title-4xl-semi-bold text-text-primary">
|
||||
<span className="text-text-primary title-4xl-semi-bold">
|
||||
$
|
||||
{isYear ? planInfo.price * 10 : planInfo.price}
|
||||
</span>
|
||||
<span className="system-md-regular pb-0.5 text-text-tertiary">
|
||||
<span className="pb-0.5 text-text-tertiary system-md-regular">
|
||||
{t('plansCommon.priceTip', { ns: 'billing' })}
|
||||
{t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })}
|
||||
</span>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import Toast from '../../../../../base/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../../config'
|
||||
import { SelfHostedPlan } from '../../../../type'
|
||||
import SelfHostedPlanItem from '../index'
|
||||
@@ -16,12 +16,6 @@ vi.mock('../list', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../../../../../base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
@@ -35,11 +29,19 @@ vi.mock('../../../assets', () => ({
|
||||
}))
|
||||
|
||||
const mockUseAppContext = useAppContext as Mock
|
||||
const mockToastNotify = Toast.notify as Mock
|
||||
|
||||
let assignedHref = ''
|
||||
const originalLocation = window.location
|
||||
|
||||
const renderWithToastHost = (ui: React.ReactNode) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
{ui}
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
configurable: true,
|
||||
@@ -56,6 +58,7 @@ beforeAll(() => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toast.close()
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
|
||||
assignedHref = ''
|
||||
})
|
||||
@@ -90,13 +93,10 @@ describe('SelfHostedPlanItem', () => {
|
||||
it('should show toast when non-manager tries to proceed', () => {
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
|
||||
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
renderWithToastHost(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
fireEvent.click(screen.getByRole('button', { name: /billing\.plans\.premium\.btnText/ }))
|
||||
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: 'billing.buyPermissionDeniedTip',
|
||||
}))
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should redirect to community url when community plan button clicked', () => {
|
||||
|
||||
@@ -4,9 +4,9 @@ import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Azure, GoogleCloud } from '@/app/components/base/icons/src/public/billing'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../config'
|
||||
import { SelfHostedPlan } from '../../../type'
|
||||
import { Community, Enterprise, EnterpriseNoise, Premium, PremiumNoise } from '../../assets'
|
||||
@@ -56,10 +56,9 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
const handleGetPayUrl = useCallback(() => {
|
||||
// Only workspace manager can buy plan
|
||||
if (!isCurrentWorkspaceManager) {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
className: 'z-[1001]',
|
||||
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -82,18 +81,18 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
{/* Noise Effect */}
|
||||
{STYLE_MAP[plan].noise}
|
||||
<div className="flex flex-col px-5 py-4">
|
||||
<div className=" flex flex-col gap-y-6 px-1 pt-10">
|
||||
<div className="flex flex-col gap-y-6 px-1 pt-10">
|
||||
{STYLE_MAP[plan].icon}
|
||||
<div className="flex min-h-[104px] flex-col gap-y-2">
|
||||
<div className="text-[30px] font-medium leading-[1.2] text-text-primary">{t(`${i18nPrefix}.name`, { ns: 'billing' })}</div>
|
||||
<div className="system-md-regular line-clamp-2 text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
<div className="line-clamp-2 text-text-secondary system-md-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* Price */}
|
||||
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
|
||||
<div className="title-4xl-semi-bold shrink-0 text-text-primary">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
|
||||
<div className="shrink-0 text-text-primary title-4xl-semi-bold">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
|
||||
{!isFreePlan && (
|
||||
<span className="system-md-regular pb-0.5 text-text-tertiary">
|
||||
<span className="pb-0.5 text-text-tertiary system-md-regular">
|
||||
{t(`${i18nPrefix}.priceTip`, { ns: 'billing' })}
|
||||
</span>
|
||||
)}
|
||||
@@ -114,7 +113,7 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
<GoogleCloud />
|
||||
</div>
|
||||
</div>
|
||||
<span className="system-xs-regular text-text-tertiary">
|
||||
<span className="text-text-tertiary system-xs-regular">
|
||||
{t('plans.premium.comingSoon', { ns: 'billing' })}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type * as React from 'react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { IndexingType } from '../../../create/step-two'
|
||||
|
||||
@@ -13,14 +13,7 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('use-context-selector', async (importOriginal) => {
|
||||
const actual = await importOriginal() as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ notify: mockNotify }),
|
||||
}
|
||||
})
|
||||
const toastAddSpy = vi.spyOn(toast, 'add')
|
||||
|
||||
// Mock dataset detail context
|
||||
let mockIndexingTechnique = IndexingType.QUALIFIED
|
||||
@@ -51,11 +44,6 @@ vi.mock('@/service/knowledge/use-segment', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock app store
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: () => ({ appSidebarExpand: 'expand' }),
|
||||
}))
|
||||
|
||||
vi.mock('../completed/common/action-buttons', () => ({
|
||||
default: ({ handleCancel, handleSave, loading, actionType }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string }) => (
|
||||
<div data-testid="action-buttons">
|
||||
@@ -139,6 +127,8 @@ vi.mock('@/app/components/datasets/common/image-uploader/image-uploader-in-chunk
|
||||
describe('NewSegmentModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useRealTimers()
|
||||
toast.close()
|
||||
mockFullScreen = false
|
||||
mockIndexingTechnique = IndexingType.QUALIFIED
|
||||
})
|
||||
@@ -258,7 +248,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -272,7 +262,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -287,7 +277,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -337,7 +327,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
}),
|
||||
@@ -430,10 +420,9 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('CustomButton in success notification', () => {
|
||||
it('should call viewNewlyAddedChunk when custom button is clicked', async () => {
|
||||
describe('Action button in success notification', () => {
|
||||
it('should call viewNewlyAddedChunk when the toast action is clicked', async () => {
|
||||
const mockViewNewlyAddedChunk = vi.fn()
|
||||
mockNotify.mockImplementation(() => {})
|
||||
|
||||
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
|
||||
options.onSuccess()
|
||||
@@ -442,37 +431,25 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
|
||||
render(
|
||||
<NewSegmentModal
|
||||
{...defaultProps}
|
||||
docForm={ChunkingMode.text}
|
||||
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
|
||||
/>,
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<NewSegmentModal
|
||||
{...defaultProps}
|
||||
docForm={ChunkingMode.text}
|
||||
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
|
||||
// Enter content and save
|
||||
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
customComponent: expect.anything(),
|
||||
}),
|
||||
)
|
||||
expect(mockViewNewlyAddedChunk).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Extract customComponent from the notify call args
|
||||
const notifyCallArgs = mockNotify.mock.calls[0][0] as { customComponent?: React.ReactElement }
|
||||
expect(notifyCallArgs.customComponent).toBeDefined()
|
||||
const customComponent = notifyCallArgs.customComponent!
|
||||
const { container: btnContainer } = render(customComponent)
|
||||
const viewButton = btnContainer.querySelector('.system-xs-semibold.text-text-accent') as HTMLElement
|
||||
expect(viewButton).toBeInTheDocument()
|
||||
fireEvent.click(viewButton)
|
||||
|
||||
// Assert that viewNewlyAddedChunk was called via the onClick handler (lines 66-67)
|
||||
expect(mockViewNewlyAddedChunk).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -599,9 +576,8 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('onSave delayed call', () => {
|
||||
it('should call onSave after timeout in success handler', async () => {
|
||||
vi.useFakeTimers()
|
||||
describe('onSave after success', () => {
|
||||
it('should call onSave immediately after save succeeds', async () => {
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
|
||||
options.onSuccess()
|
||||
@@ -611,15 +587,12 @@ describe('NewSegmentModal', () => {
|
||||
|
||||
render(<NewSegmentModal {...defaultProps} onSave={mockOnSave} docForm={ChunkingMode.text} />)
|
||||
|
||||
// Enter content and save
|
||||
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
// Fast-forward timer
|
||||
vi.advanceTimersByTime(3000)
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalled()
|
||||
vi.useRealTimers()
|
||||
await waitFor(() => {
|
||||
expect(mockOnSave).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
|
||||
import NewChildSegmentModal from '../new-child-segment'
|
||||
|
||||
@@ -10,14 +11,7 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('use-context-selector', async (importOriginal) => {
|
||||
const actual = await importOriginal() as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ notify: mockNotify }),
|
||||
}
|
||||
})
|
||||
const toastAddSpy = vi.spyOn(toast, 'add')
|
||||
|
||||
// Mock document context
|
||||
let mockParentMode = 'paragraph'
|
||||
@@ -48,11 +42,6 @@ vi.mock('@/service/knowledge/use-segment', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock app store
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: () => ({ appSidebarExpand: 'expand' }),
|
||||
}))
|
||||
|
||||
vi.mock('../common/action-buttons', () => ({
|
||||
default: ({ handleCancel, handleSave, loading, actionType, isChildChunk }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string, isChildChunk?: boolean }) => (
|
||||
<div data-testid="action-buttons">
|
||||
@@ -103,6 +92,8 @@ vi.mock('../common/segment-index-tag', () => ({
|
||||
describe('NewChildSegmentModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useRealTimers()
|
||||
toast.close()
|
||||
mockFullScreen = false
|
||||
mockParentMode = 'paragraph'
|
||||
})
|
||||
@@ -198,7 +189,7 @@ describe('NewChildSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -253,7 +244,7 @@ describe('NewChildSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
}),
|
||||
@@ -374,35 +365,62 @@ describe('NewChildSegmentModal', () => {
|
||||
|
||||
// View newly added chunk
|
||||
describe('View Newly Added Chunk', () => {
|
||||
it('should show custom button in full-doc mode after save', async () => {
|
||||
it('should call viewNewlyAddedChildChunk when the toast action is clicked', async () => {
|
||||
mockParentMode = 'full-doc'
|
||||
const mockViewNewlyAddedChildChunk = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
options.onSuccess({ data: { id: 'new-child-id' } })
|
||||
options.onSettled()
|
||||
return Promise.resolve()
|
||||
})
|
||||
|
||||
render(<NewChildSegmentModal {...defaultProps} />)
|
||||
render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<NewChildSegmentModal
|
||||
{...defaultProps}
|
||||
viewNewlyAddedChildChunk={mockViewNewlyAddedChildChunk}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
|
||||
// Enter valid content
|
||||
fireEvent.change(screen.getByTestId('content-input'), {
|
||||
target: { value: 'Valid content' },
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
// Assert - success notification with custom component
|
||||
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
customComponent: expect.anything(),
|
||||
}),
|
||||
)
|
||||
expect(mockViewNewlyAddedChildChunk).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should not show custom button in paragraph mode after save', async () => {
|
||||
it('should call onSave immediately in full-doc mode after save succeeds', async () => {
|
||||
mockParentMode = 'full-doc'
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
options.onSuccess({ data: { id: 'new-child-id' } })
|
||||
options.onSettled()
|
||||
return Promise.resolve()
|
||||
})
|
||||
|
||||
render(<NewChildSegmentModal {...defaultProps} onSave={mockOnSave} />)
|
||||
|
||||
fireEvent.change(screen.getByTestId('content-input'), {
|
||||
target: { value: 'Valid content' },
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnSave).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onSave with the new child chunk in paragraph mode', async () => {
|
||||
mockParentMode = 'paragraph'
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import type { FC } from 'react'
|
||||
import type { ChildChunkDetail, SegmentUpdater } from '@/models/datasets'
|
||||
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
|
||||
import { memo, useMemo, useRef, useState } from 'react'
|
||||
import { memo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { useParams } from '@/next/navigation'
|
||||
import { useAddChildSegment } from '@/service/knowledge/use-segment'
|
||||
@@ -35,39 +32,15 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
viewNewlyAddedChildChunk,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [content, setContent] = useState('')
|
||||
const { datasetId, documentId } = useParams<{ datasetId: string, documentId: string }>()
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [addAnother, setAddAnother] = useState(true)
|
||||
const fullScreen = useSegmentListContext(s => s.fullScreen)
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const { appSidebarExpand } = useAppStore(useShallow(state => ({
|
||||
appSidebarExpand: state.appSidebarExpand,
|
||||
})))
|
||||
const parentMode = useDocumentContext(s => s.parentMode)
|
||||
|
||||
const refreshTimer = useRef<any>(null)
|
||||
|
||||
const isFullDocMode = useMemo(() => {
|
||||
return parentMode === 'full-doc'
|
||||
}, [parentMode])
|
||||
|
||||
const CustomButton = (
|
||||
<>
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChildChunk?.()
|
||||
}}
|
||||
>
|
||||
{t('operation.view', { ns: 'common' })}
|
||||
</button>
|
||||
</>
|
||||
)
|
||||
const isFullDocMode = parentMode === 'full-doc'
|
||||
|
||||
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
|
||||
if (actionType === 'esc' || !addAnother)
|
||||
@@ -80,26 +53,27 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
const params: SegmentUpdater = { content: '' }
|
||||
|
||||
if (!content.trim())
|
||||
return notify({ type: 'error', message: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
|
||||
return toast.add({ type: 'error', title: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
|
||||
|
||||
params.content = content
|
||||
|
||||
setLoading(true)
|
||||
await addChildSegment({ datasetId, documentId, segmentId: chunkId, body: params }, {
|
||||
onSuccess(res) {
|
||||
notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
|
||||
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
|
||||
!top-auto !right-auto !mb-[52px] !ml-11`,
|
||||
customComponent: isFullDocMode && CustomButton,
|
||||
title: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
|
||||
actionProps: isFullDocMode
|
||||
? {
|
||||
children: t('operation.view', { ns: 'common' }),
|
||||
onClick: viewNewlyAddedChildChunk,
|
||||
}
|
||||
: undefined,
|
||||
})
|
||||
handleCancel('add')
|
||||
setContent('')
|
||||
if (isFullDocMode) {
|
||||
refreshTimer.current = setTimeout(() => {
|
||||
onSave()
|
||||
}, 3000)
|
||||
onSave()
|
||||
}
|
||||
else {
|
||||
onSave(res.data)
|
||||
@@ -111,10 +85,8 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
})
|
||||
}
|
||||
|
||||
const wordCountText = useMemo(() => {
|
||||
const count = content.length
|
||||
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
}, [content.length])
|
||||
const count = content.length
|
||||
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
|
||||
@@ -2,13 +2,10 @@ import type { FC } from 'react'
|
||||
import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types'
|
||||
import type { SegmentUpdater } from '@/models/datasets'
|
||||
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { memo, useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
@@ -39,7 +36,6 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
viewNewlyAddedChunk,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [question, setQuestion] = useState('')
|
||||
const [answer, setAnswer] = useState('')
|
||||
const [attachments, setAttachments] = useState<FileEntity[]>([])
|
||||
@@ -50,27 +46,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
const fullScreen = useSegmentListContext(s => s.fullScreen)
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
|
||||
const { appSidebarExpand } = useAppStore(useShallow(state => ({
|
||||
appSidebarExpand: state.appSidebarExpand,
|
||||
})))
|
||||
const [imageUploaderKey, setImageUploaderKey] = useState(Date.now())
|
||||
const refreshTimer = useRef<any>(null)
|
||||
|
||||
const CustomButton = useMemo(() => (
|
||||
<>
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChunk()
|
||||
}}
|
||||
>
|
||||
{t('operation.view', { ns: 'common' })}
|
||||
</button>
|
||||
</>
|
||||
), [viewNewlyAddedChunk, t])
|
||||
const [imageUploaderKey, setImageUploaderKey] = useState(() => Date.now())
|
||||
|
||||
const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
|
||||
if (actionType === 'esc' || !addAnother)
|
||||
@@ -87,15 +63,15 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
const params: SegmentUpdater = { content: '', attachment_ids: [] }
|
||||
if (docForm === ChunkingMode.qa) {
|
||||
if (!question.trim()) {
|
||||
return notify({
|
||||
return toast.add({
|
||||
type: 'error',
|
||||
message: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
|
||||
title: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
if (!answer.trim()) {
|
||||
return notify({
|
||||
return toast.add({
|
||||
type: 'error',
|
||||
message: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
|
||||
title: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -104,9 +80,9 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
}
|
||||
else {
|
||||
if (!question.trim()) {
|
||||
return notify({
|
||||
return toast.add({
|
||||
type: 'error',
|
||||
message: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
|
||||
title: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -122,12 +98,13 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
setLoading(true)
|
||||
await addSegment({ datasetId, documentId, body: params }, {
|
||||
onSuccess() {
|
||||
notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
|
||||
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
|
||||
!top-auto !right-auto !mb-[52px] !ml-11`,
|
||||
customComponent: CustomButton,
|
||||
title: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
|
||||
actionProps: {
|
||||
children: t('operation.view', { ns: 'common' }),
|
||||
onClick: viewNewlyAddedChunk,
|
||||
},
|
||||
})
|
||||
handleCancel('add')
|
||||
setQuestion('')
|
||||
@@ -135,20 +112,16 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
setAttachments([])
|
||||
setImageUploaderKey(Date.now())
|
||||
setKeywords([])
|
||||
refreshTimer.current = setTimeout(() => {
|
||||
onSave()
|
||||
}, 3000)
|
||||
onSave()
|
||||
},
|
||||
onSettled() {
|
||||
setLoading(false)
|
||||
},
|
||||
})
|
||||
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
|
||||
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, t, handleCancel, onSave, viewNewlyAddedChunk])
|
||||
|
||||
const wordCountText = useMemo(() => {
|
||||
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
}, [question.length, answer.length, docForm, t])
|
||||
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
|
||||
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`,
|
||||
}))
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
const mockNotify = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockNotify,
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock modal context
|
||||
@@ -164,7 +164,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
// Verify success notification
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'External Knowledge Base Connected Successfully',
|
||||
title: 'External Knowledge Base Connected Successfully',
|
||||
})
|
||||
|
||||
// Verify navigation back
|
||||
@@ -206,7 +206,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'Failed to connect External Knowledge Base',
|
||||
title: 'Failed to connect External Knowledge Base',
|
||||
})
|
||||
})
|
||||
|
||||
@@ -228,7 +228,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'Failed to connect External Knowledge Base',
|
||||
title: 'Failed to connect External Knowledge Base',
|
||||
})
|
||||
})
|
||||
|
||||
@@ -274,7 +274,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'External Knowledge Base Connected Successfully',
|
||||
title: 'External Knowledge Base Connected Successfully',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,13 +4,12 @@ import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { createExternalKnowledgeBase } from '@/service/datasets'
|
||||
|
||||
const ExternalKnowledgeBaseConnector = () => {
|
||||
const { notify } = useToastContext()
|
||||
const [loading, setLoading] = useState(false)
|
||||
const router = useRouter()
|
||||
|
||||
@@ -19,7 +18,7 @@ const ExternalKnowledgeBaseConnector = () => {
|
||||
setLoading(true)
|
||||
const result = await createExternalKnowledgeBase({ body: formValue })
|
||||
if (result && result.id) {
|
||||
notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' })
|
||||
toast.add({ type: 'success', title: 'External Knowledge Base Connected Successfully' })
|
||||
trackEvent('create_external_knowledge_base', {
|
||||
provider: formValue.provider,
|
||||
name: formValue.name,
|
||||
@@ -30,7 +29,7 @@ const ExternalKnowledgeBaseConnector = () => {
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error creating external knowledge base:', error)
|
||||
notify({ type: 'error', message: 'Failed to connect External Knowledge Base' })
|
||||
toast.add({ type: 'error', title: 'Failed to connect External Knowledge Base' })
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
@@ -43,10 +43,10 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockNotify,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
@@ -150,7 +150,7 @@ describe('SystemModel', () => {
|
||||
expect(mockUpdateDefaultModel).toHaveBeenCalledTimes(1)
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'Modified successfully',
|
||||
title: 'Modified successfully',
|
||||
})
|
||||
expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5)
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(5)
|
||||
|
||||
@@ -6,13 +6,13 @@ import type {
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
DialogTitle,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
@@ -64,7 +64,6 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
isLoading,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const { textGenerationModelList } = useProviderContext()
|
||||
const updateModelList = useUpdateModelList()
|
||||
@@ -124,7 +123,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
},
|
||||
})
|
||||
if (res.result === 'success') {
|
||||
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
toast.add({ type: 'success', title: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
setOpen(false)
|
||||
|
||||
const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts]
|
||||
|
||||
@@ -4,7 +4,7 @@ import { DeleteConfirm } from '../delete-confirm'
|
||||
|
||||
const mockRefetch = vi.fn()
|
||||
const mockDelete = vi.fn()
|
||||
const mockToast = vi.fn()
|
||||
const mockToastAdd = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('../use-subscription-list', () => ({
|
||||
useSubscriptionList: () => ({ refetch: mockRefetch }),
|
||||
@@ -14,9 +14,9 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
useDeleteTriggerSubscription: () => ({ mutate: mockDelete, isPending: false }),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (args: { type: string, message: string }) => mockToast(args),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockToastAdd,
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -42,7 +42,7 @@ describe('DeleteConfirm', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
|
||||
|
||||
expect(mockDelete).not.toHaveBeenCalled()
|
||||
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
|
||||
})
|
||||
|
||||
it('should allow deletion after matching input name', () => {
|
||||
@@ -87,6 +87,6 @@ describe('DeleteConfirm', () => {
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
|
||||
|
||||
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', message: 'network error' }))
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', title: 'network error' }))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useDeleteTriggerSubscription } from '@/service/use-triggers'
|
||||
import { useSubscriptionList } from './use-subscription-list'
|
||||
|
||||
@@ -23,58 +31,74 @@ export const DeleteConfirm = (props: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const [inputName, setInputName] = useState('')
|
||||
|
||||
const handleOpenChange = (open: boolean) => {
|
||||
if (isDeleting)
|
||||
return
|
||||
|
||||
if (!open)
|
||||
onClose(false)
|
||||
}
|
||||
|
||||
const onConfirm = () => {
|
||||
if (workflowsInUse > 0 && inputName !== currentName) {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
|
||||
// temporarily
|
||||
className: 'z-[10000001]',
|
||||
title: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
|
||||
})
|
||||
return
|
||||
}
|
||||
deleteSubscription(currentId, {
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
|
||||
className: 'z-[10000001]',
|
||||
title: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
|
||||
})
|
||||
refetch?.()
|
||||
onClose(true)
|
||||
},
|
||||
onError: (error: any) => {
|
||||
Toast.notify({
|
||||
onError: (error: unknown) => {
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: error?.message || t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
|
||||
className: 'z-[10000001]',
|
||||
title: error instanceof Error ? error.message : t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<Confirm
|
||||
title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
|
||||
confirmText={t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
|
||||
content={workflowsInUse > 0
|
||||
? (
|
||||
<>
|
||||
{t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })}
|
||||
<div className="system-sm-medium mb-2 mt-6 text-text-secondary">{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}</div>
|
||||
<AlertDialog open={isShow} onOpenChange={handleOpenChange}>
|
||||
<AlertDialogContent backdropProps={{ forceRender: true }}>
|
||||
<div className="flex flex-col gap-2 px-6 pb-4 pt-6">
|
||||
<AlertDialogTitle title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })} className="w-full truncate text-text-primary title-2xl-semi-bold">
|
||||
{t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
|
||||
</AlertDialogTitle>
|
||||
<AlertDialogDescription className="w-full whitespace-pre-wrap break-words text-text-tertiary system-md-regular">
|
||||
{workflowsInUse > 0
|
||||
? t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })
|
||||
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
|
||||
</AlertDialogDescription>
|
||||
{workflowsInUse > 0 && (
|
||||
<div className="mt-6">
|
||||
<div className="mb-2 text-text-secondary system-sm-medium">
|
||||
{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}
|
||||
</div>
|
||||
<Input
|
||||
value={inputName}
|
||||
onChange={e => setInputName(e.target.value)}
|
||||
placeholder={t(`${tPrefix}.confirmInputPlaceholder`, { ns: 'pluginTrigger', name: currentName })}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
|
||||
isShow={isShow}
|
||||
isLoading={isDeleting}
|
||||
isDisabled={isDeleting}
|
||||
onConfirm={onConfirm}
|
||||
onCancel={() => onClose(false)}
|
||||
maskClosable={false}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton disabled={isDeleting}>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton loading={isDeleting} disabled={isDeleting} onClick={onConfirm}>
|
||||
{t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { OutputVar } from '../../../code/types'
|
||||
import type { ToastHandle } from '@/app/components/base/toast'
|
||||
import type { VarType } from '@/app/components/workflow/types'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
|
||||
import RemoveButton from '../remove-button'
|
||||
import VarTypePicker from './var-type-picker'
|
||||
@@ -30,7 +29,6 @@ const OutputVarList: FC<Props> = ({
|
||||
onRemove,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [toastHandler, setToastHandler] = useState<ToastHandle>()
|
||||
|
||||
const list = outputKeyOrders.map((key) => {
|
||||
return {
|
||||
@@ -42,20 +40,17 @@ const OutputVarList: FC<Props> = ({
|
||||
const { run: validateVarInput } = useDebounceFn((existingVariables: typeof list, newKey: string) => {
|
||||
const result = checkKeys([newKey], true)
|
||||
if (!result.isValid) {
|
||||
setToastHandler(Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
}))
|
||||
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (existingVariables.some(key => key.variable?.trim() === newKey.trim())) {
|
||||
setToastHandler(Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
}))
|
||||
}
|
||||
else {
|
||||
toastHandler?.clear?.()
|
||||
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
})
|
||||
}
|
||||
}, { wait: 500 })
|
||||
|
||||
@@ -66,7 +61,6 @@ const OutputVarList: FC<Props> = ({
|
||||
replaceSpaceWithUnderscoreInVarNameInput(e.target)
|
||||
const newKey = e.target.value
|
||||
|
||||
toastHandler?.clear?.()
|
||||
validateVarInput(list.toSpliced(index, 1), newKey)
|
||||
|
||||
const newOutputs = produce(outputs, (draft) => {
|
||||
@@ -75,7 +69,7 @@ const OutputVarList: FC<Props> = ({
|
||||
})
|
||||
onChange(newOutputs, index, newKey)
|
||||
}
|
||||
}, [list, onChange, outputs, outputKeyOrders, validateVarInput])
|
||||
}, [list, onChange, outputs, validateVarInput])
|
||||
|
||||
const handleVarTypeChange = useCallback((index: number) => {
|
||||
return (value: string) => {
|
||||
@@ -85,7 +79,7 @@ const OutputVarList: FC<Props> = ({
|
||||
})
|
||||
onChange(newOutputs)
|
||||
}
|
||||
}, [list, onChange, outputs, outputKeyOrders])
|
||||
}, [list, onChange, outputs])
|
||||
|
||||
const handleVarRemove = useCallback((index: number) => {
|
||||
return () => {
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { ToastHandle } from '@/app/components/base/toast'
|
||||
import type { ValueSelector, Var, Variable } from '@/app/components/workflow/types'
|
||||
import { RiDraggable } from '@remixicon/react'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ReactSortable } from 'react-sortablejs'
|
||||
import { v4 as uuid4 } from 'uuid'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
|
||||
@@ -42,7 +41,6 @@ const VarList: FC<Props> = ({
|
||||
isSupportFileVar = true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [toastHandle, setToastHandle] = useState<ToastHandle>()
|
||||
|
||||
const listWithIds = useMemo(() => list.map((item) => {
|
||||
const id = uuid4()
|
||||
@@ -55,20 +53,17 @@ const VarList: FC<Props> = ({
|
||||
const { run: validateVarInput } = useDebounceFn((list: Variable[], newKey: string) => {
|
||||
const result = checkKeys([newKey], true)
|
||||
if (!result.isValid) {
|
||||
setToastHandle(Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
}))
|
||||
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (list.some(item => item.variable?.trim() === newKey.trim())) {
|
||||
setToastHandle(Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
}))
|
||||
}
|
||||
else {
|
||||
toastHandle?.clear?.()
|
||||
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
})
|
||||
}
|
||||
}, { wait: 500 })
|
||||
|
||||
@@ -78,7 +73,6 @@ const VarList: FC<Props> = ({
|
||||
|
||||
const newKey = e.target.value
|
||||
|
||||
toastHandle?.clear?.()
|
||||
validateVarInput(list.toSpliced(index, 1), newKey)
|
||||
|
||||
onVarNameChange?.(list[index].variable, newKey)
|
||||
|
||||
@@ -7,7 +7,7 @@ import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import VersionInfoModal from '@/app/components/app/app-publisher/version-info-modal'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useSelector as useAppContextSelector } from '@/context/app-context'
|
||||
import { useDeleteWorkflow, useInvalidAllLastRun, useResetWorkflowVersionHistory, useUpdateWorkflow, useWorkflowVersionHistory } from '@/service/use-workflow'
|
||||
import { useDSL, useNodesSyncDraft, useWorkflowRun } from '../../hooks'
|
||||
@@ -118,9 +118,9 @@ export const VersionHistoryPanel = ({
|
||||
break
|
||||
case VersionHistoryContextMenuOptions.copyId:
|
||||
copy(item.id)
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
break
|
||||
case VersionHistoryContextMenuOptions.exportDSL:
|
||||
@@ -152,17 +152,17 @@ export const VersionHistoryPanel = ({
|
||||
workflowStore.setState({ backupDraft: undefined })
|
||||
handleSyncWorkflowDraft(true, false, {
|
||||
onSuccess: () => {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
deleteAllInspectVars()
|
||||
invalidAllLastRun()
|
||||
},
|
||||
onError: () => {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
@@ -177,18 +177,18 @@ export const VersionHistoryPanel = ({
|
||||
await deleteWorkflow(deleteVersionUrl?.(id) || '', {
|
||||
onSuccess: () => {
|
||||
setDeleteConfirmOpen(false)
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
resetWorkflowVersionHistory()
|
||||
deleteAllInspectVars()
|
||||
invalidAllLastRun()
|
||||
},
|
||||
onError: () => {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
@@ -207,16 +207,16 @@ export const VersionHistoryPanel = ({
|
||||
}, {
|
||||
onSuccess: () => {
|
||||
setEditModalOpen(false)
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'success',
|
||||
message: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
resetWorkflowVersionHistory()
|
||||
},
|
||||
onError: () => {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
|
||||
title: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { emailRegex } from '@/config'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import Link from '@/next/link'
|
||||
@@ -35,18 +35,18 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis
|
||||
|
||||
const handleEmailPasswordLogin = async () => {
|
||||
if (!email) {
|
||||
Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) })
|
||||
toast.add({ type: 'error', title: t('error.emailEmpty', { ns: 'login' }) })
|
||||
return
|
||||
}
|
||||
if (!emailRegex.test(email)) {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('error.emailInValid', { ns: 'login' }),
|
||||
title: t('error.emailInValid', { ns: 'login' }),
|
||||
})
|
||||
return
|
||||
}
|
||||
if (!password?.trim()) {
|
||||
Toast.notify({ type: 'error', message: t('error.passwordEmpty', { ns: 'login' }) })
|
||||
toast.add({ type: 'error', title: t('error.passwordEmpty', { ns: 'login' }) })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -83,17 +83,17 @@ export default function MailAndPasswordAuth({ isInvite, isEmailSetup, allowRegis
|
||||
}
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: res.data,
|
||||
title: res.data,
|
||||
})
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
if ((error as ResponseError).code === 'authentication_failed') {
|
||||
Toast.notify({
|
||||
toast.add({
|
||||
type: 'error',
|
||||
message: t('error.invalidEmailOrPassword', { ns: 'login' }),
|
||||
title: t('error.invalidEmailOrPassword', { ns: 'login' }),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useQueryClient } from '@tanstack/react-query'
|
||||
import dayjs from 'dayjs'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { setZendeskConversationFields } from '@/app/components/base/zendesk/utils'
|
||||
import { defaultPlan } from '@/app/components/billing/config'
|
||||
import { parseCurrentPlan } from '@/app/components/billing/utils'
|
||||
@@ -132,13 +132,11 @@ export const ProviderContextProvider = ({
|
||||
if (anthropic && anthropic.system_configuration.current_quota_type === CurrentSystemQuotaTypeEnum.trial) {
|
||||
const quota = anthropic.system_configuration.quota_configurations.find(item => item.quota_type === anthropic.system_configuration.current_quota_type)
|
||||
if (quota && quota.is_valid && quota.quota_used < quota.quota_limit) {
|
||||
Toast.notify({
|
||||
localStorage.setItem('anthropic_quota_notice', 'true')
|
||||
toast.add({
|
||||
type: 'info',
|
||||
message: t('provider.anthropicHosted.trialQuotaTip', { ns: 'common' }),
|
||||
duration: 60000,
|
||||
onClose: () => {
|
||||
localStorage.setItem('anthropic_quota_notice', 'true')
|
||||
},
|
||||
title: t('provider.anthropicHosted.trialQuotaTip', { ns: 'common' }),
|
||||
timeout: 60000,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,9 +335,6 @@
|
||||
}
|
||||
},
|
||||
"app/account/oauth/authorize/page.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
@@ -1127,9 +1124,6 @@
|
||||
}
|
||||
},
|
||||
"app/components/app/create-app-dialog/app-list/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 5
|
||||
}
|
||||
@@ -2924,14 +2918,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/billing/pricing/plans/cloud-plan-item/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 6
|
||||
}
|
||||
},
|
||||
"app/components/billing/pricing/plans/cloud-plan-item/list/item/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
@@ -2947,17 +2933,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/billing/pricing/plans/self-hosted-plan-item/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 4
|
||||
},
|
||||
"tailwindcss/no-unnecessary-whitespace": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/billing/pricing/plans/self-hosted-plan-item/list/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
@@ -3786,14 +3761,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/detail/completed/new-child-segment.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/detail/completed/segment-card/chunk-content.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
@@ -3862,14 +3829,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/detail/new-segment.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/documents/detail/segment-add/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
@@ -3930,11 +3889,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/external-knowledge-base/connector/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 1
|
||||
@@ -4859,11 +4813,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/header/account-setting/model-provider-page/system-model-selector/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
@@ -5394,17 +5343,6 @@
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/subscription-list/delete-confirm.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
@@ -7105,11 +7043,6 @@
|
||||
"count": 5
|
||||
}
|
||||
},
|
||||
"app/components/workflow/nodes/_base/components/variable/output-var-list.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/workflow/nodes/_base/components/variable/utils.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 32
|
||||
@@ -7123,11 +7056,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/workflow/nodes/_base/components/variable/var-list.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 2
|
||||
@@ -8877,9 +8805,6 @@
|
||||
}
|
||||
},
|
||||
"app/components/workflow/panel/version-history-panel/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
}
|
||||
@@ -9450,9 +9375,6 @@
|
||||
}
|
||||
},
|
||||
"app/signin/components/mail-and-password-auth.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
@@ -9564,9 +9486,6 @@
|
||||
}
|
||||
},
|
||||
"context/provider-context-provider.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
@@ -9752,9 +9671,6 @@
|
||||
}
|
||||
},
|
||||
"service/fetch.ts": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"regexp/no-unused-capturing-group": {
|
||||
"count": 1
|
||||
},
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { base } from './fetch'
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { AfterResponseHook, BeforeRequestHook, Hooks } from 'ky'
|
||||
import type { IOtherOptions } from './base'
|
||||
import Cookies from 'js-cookie'
|
||||
import ky, { HTTPError } from 'ky'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { API_PREFIX, APP_VERSION, CSRF_COOKIE_NAME, CSRF_HEADER_NAME, IS_MARKETPLACE, MARKETPLACE_API_PREFIX, PASSPORT_HEADER_NAME, PUBLIC_API_PREFIX, WEB_APP_SHARE_CODE_HEADER_NAME } from '@/config'
|
||||
import { getWebAppAccessToken, getWebAppPassport } from './webapp-auth'
|
||||
|
||||
@@ -48,7 +48,7 @@ const afterResponseErrorCode = (otherOptions: IOtherOptions): AfterResponseHook
|
||||
const shouldNotifyError = response.status !== 401 && errorData && !otherOptions.silent
|
||||
|
||||
if (shouldNotifyError)
|
||||
Toast.notify({ type: 'error', message: errorData.message })
|
||||
toast.add({ type: 'error', title: errorData.message })
|
||||
|
||||
if (response.status === 403 && errorData?.code === 'already_setup')
|
||||
globalThis.location.href = `${globalThis.location.origin}/signin`
|
||||
|
||||
Reference in New Issue
Block a user