Compare commits

..

11 Commits

112 changed files with 1438 additions and 4719 deletions

View File

@@ -134,7 +134,6 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel):
email: EmailStr
language: str | None = None
phase: str | None = None
token: str | None = None
@@ -548,13 +547,17 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.token is not None:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@@ -574,7 +577,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
phase=send_phase,
)
return {"result": "success", "data": token}
@@ -609,12 +612,26 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
phase_transitions: dict[str, str] = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@@ -644,13 +661,22 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -47,6 +47,7 @@ from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -521,9 +522,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session=session, model=workflow)
message = _refresh_model(session=session, model=message)
assert message is not None
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
@@ -690,21 +690,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e
@overload
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
_T = TypeVar("_T", bound=Base)
@overload
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
if session is not None:
detached_model = session.get(type(model), model.id)
assert detached_model is not None
return detached_model
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
detached_model = refresh_session.get(type(model), model.id)
assert detached_model is not None
return detached_model
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator, Mapping
from typing import Any
from collections.abc import Generator, Mapping
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -16,26 +16,24 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
stream_response = response
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(stream_response)
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
stream_response = response
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(stream_response)
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
return _generate_simple_response()
@@ -52,14 +50,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
raise NotImplementedError

View File

@@ -224,7 +224,6 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory(
session: Session,
@@ -241,7 +240,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=debug_account,
user=account,
)
else:

View File

@@ -166,19 +166,15 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
assert conversation is not None
assert message is not None
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=generated_conversation_id,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=generated_message_id,
message_id=message.id,
)
# new thread with request context
@@ -188,8 +184,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=generated_conversation_id,
message_id=generated_message_id,
conversation_id=conversation.id,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -149,8 +149,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -314,19 +312,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation_id,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message_id,
message_id=message.id,
)
# new thread with request context
@@ -336,7 +330,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message_id,
message_id=message.id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator, Iterator
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Iterator[AppStreamResponse]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,17 +1,13 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Protocol, TypeAlias
from typing import Any, cast
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -40,7 +36,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@@ -79,14 +75,6 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner:
def __init__(
@@ -110,7 +98,7 @@ class WorkflowBasedAppRunner:
def _init_graph(
self,
graph_config: GraphConfigMapping,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
@@ -166,8 +154,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: SingleNodeRunEntity | None = None,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@@ -220,88 +208,11 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: Mapping[str, object],
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@@ -325,14 +236,41 @@ class WorkflowBasedAppRunner:
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
graph_config, target_node_config = self._build_single_node_graph_config(
graph_config=graph_config,
node_id=node_id,
node_type_filter_key=node_type_filter_key,
)
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
@@ -361,6 +299,18 @@ class WorkflowBasedAppRunner:
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)

View File

@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping[str, object]
inputs: Mapping
single_iteration_run: SingleIterationRunEntity | None = None
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping[str, object]
inputs: Mapping
single_loop_run: SingleLoopRunEntity | None = None
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: Mapping[str, object]
inputs: dict
single_iteration_run: SingleIterationRunEntity | None = None
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: Mapping[str, object]
inputs: dict
single_loop_run: SingleLoopRunEntity | None = None

View File

@@ -5,7 +5,6 @@ This module provides integration with Weaviate vector database for storing and r
document embeddings used in retrieval-augmented generation workflows.
"""
import atexit
import datetime
import json
import logging
@@ -38,32 +37,6 @@ _weaviate_client: weaviate.WeaviateClient | None = None
_weaviate_client_lock = threading.Lock()
def _shutdown_weaviate_client() -> None:
"""
Best-effort shutdown hook to close the module-level Weaviate client.
This is registered with atexit so that HTTP/gRPC resources are released
when the Python interpreter exits.
"""
global _weaviate_client
# Ensure thread-safety when accessing the shared client instance
with _weaviate_client_lock:
client = _weaviate_client
_weaviate_client = None
if client is not None:
try:
client.close()
except Exception:
# Best-effort cleanup; log at debug level and ignore errors.
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
# Register the shutdown hook once per process.
atexit.register(_shutdown_weaviate_client)
class WeaviateConfig(BaseModel):
"""
Configuration model for Weaviate connection settings.
@@ -112,6 +85,18 @@ class WeaviateVector(BaseVector):
self._client = self._init_client(config)
self._attributes = attributes
def __del__(self):
"""
Destructor to properly close the Weaviate client connection.
Prevents connection leaks and resource warnings.
"""
if hasattr(self, "_client") and self._client is not None:
try:
self._client.close()
except Exception as e:
# Ignore errors during cleanup as object is being destroyed
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
"""
Initializes and returns a connected Weaviate client.

View File

@@ -1045,10 +1045,9 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {variable_selector} does not exist")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type == "constant":
parameter_value = tool_input.value

View File

@@ -1,24 +1,13 @@
from __future__ import annotations
from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from typing import Any, Literal, Union
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
@@ -32,20 +21,8 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput]

View File

@@ -1,17 +1,16 @@
from __future__ import annotations
import json
from collections.abc import Mapping, Sequence
from typing import TypeAlias
from collections.abc import Sequence
from typing import Any, cast
from packaging.version import Version
from pydantic import TypeAdapter, ValidationError
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
@@ -29,14 +28,6 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport:
def build_parameters(
@@ -48,12 +39,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: InvokeFrom,
invoke_from: Any,
for_log: bool = False,
) -> dict[str, object]:
) -> dict[str, Any]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, object] = {}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -63,10 +54,9 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
try:
@@ -89,38 +79,60 @@ class AgentRuntimeSupport:
value = parameter_value
if parameter.type == "array[tools]":
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = self._normalize_tool_payloads(
strategy=strategy,
tools=tool_payloads,
variable_pool=variable_pool,
)
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = self._coerce_json_object(tool.get("parameters")) or {}
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity(
provider_id=provider_id,
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool_name,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=plugin_unique_identifier,
credential_id=credential_id,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = self._coerce_json_object(tool.get("extra")) or {}
extra = tool.get("extra", {})
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
@@ -133,9 +145,8 @@ class AgentRuntimeSupport:
runtime_variable_pool,
)
if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = (
description_override or tool_runtime.entity.description.llm
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
@@ -156,13 +167,13 @@ class AgentRuntimeSupport:
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": credential_id,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = _JSON_OBJECT_ADAPTER.validate_python(value)
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
@@ -188,27 +199,17 @@ class AgentRuntimeSupport:
return result
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
for tool in parameters.get("tools", []):
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
return credentials
def fetch_memory(
@@ -231,14 +232,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=str(value.get("provider", "")),
provider=value.get("provider", ""),
model_type=ModelType.LLM,
)
model_name = str(value.get("model", ""))
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
@@ -248,7 +249,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(str(value.get("model_type", ""))),
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@@ -267,88 +268,9 @@ class AgentRuntimeSupport:
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
) -> JsonObjectList:
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import Any, TypeAlias, cast
from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -32,13 +32,6 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder:
@staticmethod
@@ -283,10 +276,10 @@ class WorkflowEntry:
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: Mapping[str, object],
node_data: dict[str, Any],
node_width: int = 114,
node_height: int = 514,
) -> SingleNodeGraphConfig:
) -> dict[str, Any]:
"""
Create a minimal graph structure for testing a single node in isolation.
@@ -296,14 +289,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config: dict[str, object] = {
node_config = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": dict(node_data),
"data": node_data,
}
start_node_config: dict[str, object] = {
start_node_config = {
"id": "start",
"width": node_width,
"height": node_height,
@@ -328,12 +321,7 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@@ -351,8 +339,6 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
@@ -383,7 +369,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -419,34 +405,30 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
def _handle_special_values(value: Any):
if value is None:
return value
if isinstance(value, Mapping):
res: dict[str, SerializedSpecialValue] = {}
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list: list[SerializedSpecialValue] = []
res_list = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return dict(value.to_dict())
return value.to_dict()
return value
@classmethod

View File

@@ -112,8 +112,6 @@ def _get_encoded_string(f: File, /) -> str:
data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return base64.b64encode(data).decode("utf-8")

View File

@@ -133,8 +133,6 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason)

View File

@@ -336,7 +336,12 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
node_executions = graph_execution.node_executions
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
@@ -390,7 +395,8 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
yield event.model_copy(update={"id": self.execution_id})
event.id = self.execution_id
yield event
else:
yield event
except Exception as e:

View File

@@ -443,10 +443,7 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
it = iter(doc.element.body)
part = next(it, None)
i = 0
while part is not None:

View File

@@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Literal, NotRequired
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -11,17 +10,11 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, object] = Field(default_factory=dict)
completion_params: dict[str, Any] = Field(default_factory=dict)
class ContextConfig(BaseModel):
@@ -40,7 +33,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: object):
def convert_none_configs(cls, v: Any):
if v is None:
return VisionConfigOptions()
return v
@@ -51,7 +44,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: object):
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
@@ -74,7 +67,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: StructuredOutputConfig | None = None
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
@@ -97,30 +90,11 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: object):
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v
@field_validator("structured_output", mode="before")
@classmethod
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
if not isinstance(v, Mapping):
return v
schema = v.get("schema")
if schema is None:
return None
normalized: StructuredOutputConfig = {"schema": schema}
name = v.get("name")
description = v.get("description")
if isinstance(name, str):
normalized["name"] = name
if isinstance(description, str):
normalized["description"] = description
return normalized
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -9,7 +9,6 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from pydantic import TypeAdapter
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
@@ -75,7 +74,6 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
StructuredOutputConfig,
)
from .exc import (
InvalidContextStructureError,
@@ -90,7 +88,6 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
class LLMNode(Node[LLMNodeData]):
@@ -357,7 +354,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: StructuredOutputConfig | None = None,
structured_output: Mapping[str, Any] | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -370,10 +367,8 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output,
structured_output=structured_output or {},
)
request_start_time = time.perf_counter()
@@ -925,12 +920,6 @@ class LLMNode(Node[LLMNodeData]):
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
structured_output = (
dict(invoke_result.structured_output)
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
else None
)
event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
@@ -939,7 +928,7 @@ class LLMNode(Node[LLMNodeData]):
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=structured_output,
structured_output=getattr(invoke_result, "structured_output", None),
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
@@ -973,18 +962,27 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: StructuredOutputConfig,
) -> dict[str, object]:
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, object]: The structured output schema
dict[str, Any]: The structured output schema
"""
schema = structured_output.get("schema")
if not schema:
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
return _JSON_OBJECT_ADAPTER.validate_python(schema)
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(

View File

@@ -1,10 +1,7 @@
from __future__ import annotations
from enum import StrEnum
from typing import Annotated, Any, Literal, TypeAlias, cast
from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from pydantic import AfterValidator, BaseModel, Field, field_validator
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
@@ -12,12 +9,6 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
@@ -38,36 +29,6 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
@@ -76,29 +37,7 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
raise ValueError("Variable loop inputs require a selector")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _validate_loop_value(self.value)
value: Any | list[str] | None = None
class LoopNodeData(BaseLoopNodeData):
@@ -107,14 +46,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: LoopValueMapping = Field(default_factory=dict)
outputs: dict[str, Any] = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
def validate_outputs(cls, v):
if v is None:
return {}
return _validate_loop_value_mapping(value)
return v
class LoopStartNodeData(BaseNodeData):
@@ -138,8 +77,8 @@ class LoopState(BaseLoopState):
Loop State.
"""
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
class MetaData(BaseLoopState.MetaData):
"""
@@ -148,7 +87,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> LoopValue | None:
def get_last_output(self) -> Any:
"""
Get last output.
"""
@@ -156,7 +95,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> LoopValue | None:
def get_current_output(self) -> Any:
"""
Get current output.
"""

View File

@@ -3,7 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
)
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs: dict[str, object] = {"loop_count": loop_count}
inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@@ -68,14 +68,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors: dict[str, list[str]] = {}
loop_variable_selectors = {}
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
),
}
for loop_variable in self.node_data.loop_variables:
@@ -97,7 +95,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@@ -148,7 +146,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable: dict[str, LoopValue] = {}
single_loop_variable = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
@@ -299,29 +297,20 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, object],
graph_config: Mapping[str, Any],
node_id: str,
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
variable_mapping: dict[str, Sequence[str]] = {}
variable_mapping = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items():
sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
if sub_node_config.get("data", {}).get("loop_id") != node_id:
continue
# variable selector to variable mapping
@@ -352,8 +341,9 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.require_variable_selector()
selector = loop_variable.value
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
@@ -362,7 +352,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
@@ -373,19 +363,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids: set[str] = set()
loop_node_ids = set()
# Find all nodes that belong to this loop
raw_nodes = graph_config.get("nodes")
if not isinstance(raw_nodes, list):
return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
@@ -394,7 +377,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
@@ -406,12 +389,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
# New typed payloads may already provide native lists, while legacy
# configs still serialize array constants as JSON strings.
if isinstance(original_value, str):
value = json.loads(original_value) if original_value else []
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
else:
value = original_value
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
else:
raise AssertionError("this statement should be unreachable.")
try:

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Literal
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
@@ -6,7 +6,6 @@ from pydantic import (
Field,
field_validator,
)
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -56,7 +55,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value: object) -> str:
def validate_name(cls, value) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
@@ -80,23 +79,6 @@ class ParameterConfig(BaseModel):
return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
@@ -113,19 +95,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v: object) -> str:
return str(v) if v else "function_call"
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self) -> ParameterJsonSchema:
def get_parameter_json_schema(self):
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
parameter_schema: dict[str, Any] = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
@@ -136,7 +118,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type.value
parameter_schema["type"] = parameter.type
if parameter.options:
parameter_schema["enum"] = parameter.options

View File

@@ -5,8 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from pydantic import TypeAdapter
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -65,7 +63,6 @@ from .prompts import (
)
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
@@ -73,7 +70,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
def extract_json(text: str) -> str | None:
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
@@ -395,15 +392,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
# generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
@@ -610,21 +602,19 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
transformed_result: dict[str, object] = {}
transformed_result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
if isinstance(param_value, (bool, int, float, str)):
numeric_value: bool | int | float | str = param_value
transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
@@ -671,7 +661,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
def _extract_complete_json_response(self, result: str) -> dict | None:
"""
Extract complete json response.
"""
@@ -682,11 +672,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
"""
Extract json from tool call.
"""
@@ -700,16 +690,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
return cast(dict, json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
result: dict[str, object] = {}
result: dict[str, Any] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0

View File

@@ -1,66 +1,12 @@
from __future__ import annotations
from typing import Any, Literal, Union
from typing import Literal, TypeAlias, cast
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import TypedDict
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class WorkflowToolInputValue(TypedDict):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
class ToolInputPayload(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
if isinstance(value, (str, int, float, bool)):
return cast(ToolConfigurationEntry, value)
if isinstance(value, dict):
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
class ToolEntity(BaseModel):
provider_id: str
@@ -68,29 +14,52 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: ToolConfigurations
tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise TypeError("tool_configurations must be a dictionary")
raise ValueError("tool_configurations must be a dictionary")
normalized: ToolConfigurations = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("tool_configurations keys must be strings")
normalized[key] = _validate_tool_configuration_entry(item)
return normalized
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(ToolInputPayload):
pass
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
@@ -100,7 +69,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value: object) -> object:
def filter_none_tool_inputs(cls, value):
if not isinstance(value, dict):
return value
@@ -111,10 +80,8 @@ class ToolNodeData(BaseNodeData, ToolEntity):
}
@staticmethod
def _has_valid_value(tool_input: object) -> bool:
def _has_valid_value(tool_input):
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False
return getattr(tool_input, "value", None) is not None

View File

@@ -225,11 +225,10 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {variable_selector} does not exist")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
@@ -511,9 +510,8 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
variable_selector = input.require_variable_selector()
selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
case "constant":
pass

View File

@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import Segment, SegmentType, VariableBase
from dify_graph.variables import SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode
@@ -74,29 +74,23 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@@ -66,11 +66,6 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return node execution state keyed by node id for resume support."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...
@@ -96,12 +91,6 @@ class GraphExecutionProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for per-node execution state used during resume."""
execution_id: str | None
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""

View File

@@ -11,13 +11,6 @@ class CreatorUserRole(StrEnum):
ACCOUNT = "account"
END_USER = "end_user"
@classmethod
def _missing_(cls, value):
if value == "end-user":
return cls.END_USER
else:
return super()._missing_(value)
class WorkflowRunTriggeredFrom(StrEnum):
DEBUGGING = "debugging"

View File

@@ -13,6 +13,21 @@ controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
@@ -93,6 +108,35 @@ core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py

View File

@@ -4,6 +4,7 @@ import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from hashlib import sha256
from typing import Any, cast
@@ -90,12 +91,25 @@ class TokenPair(BaseModel):
csrf_token: str
class ChangeEmailPhase(StrEnum):
OLD = "old_email"
OLD_VERIFIED = "old_email_verified"
NEW = "new_email"
NEW_VERIFIED = "new_email_verified"
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@@ -552,13 +566,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
code, token = cls.generate_change_email_token(
account_email,
account,
old_email=old_email,
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
)
send_change_mail_task.delay(
language=language,

View File

@@ -950,6 +950,16 @@ class TestWorkflowAppService:
assert result_with_new_email["total"] == 3
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
# Create another account in a different tenant using the original email.
# Querying by the old email should still fail for this app's tenant.
cross_tenant_account = AccountService.create_account(
email=original_email,
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(cross_tenant_account, name=fake.company())
# Old email unbound, is unexpected input, should raise ValueError
with pytest.raises(ValueError) as exc_info:
service.get_paginate_workflow_app_logs(

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.auth.error import InvalidTokenError
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
ChangeEmailCheckApi,
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
def test_should_infer_new_email_phase_from_token(
self,
mock_features,
mock_csrf,
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.db")
@patch("controllers.console.workspace.account.Session")
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_account_by_email,
mock_session_cls,
mock_account_db,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
existing_account = _build_account("old@example.com", "acc-old")
mock_get_account_by_email.return_value = existing_account
mock_send_email.return_value = "token-legacy"
mock_session = MagicMock()
mock_session_cm = MagicMock()
mock_session_cm.__enter__.return_value = mock_session
mock_session_cm.__exit__.return_value = None
mock_session_cls.return_value = mock_session_cm
mock_account_db.engine = MagicMock()
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
assert response == {"result": "success", "data": "token-legacy"}
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=existing_account,
email="old@example.com",
old_email="old@example.com",
language="en-US",
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_unverified_old_email_token_for_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
mock_send_email.assert_not_called()
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
"user@example.com",
code="1234",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_refresh_new_email_phase_to_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("old@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "new@example.com",
"code": "5678",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
}
mock_generate_token.return_value = (None, "new-phase-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
mock_is_rate_limit.assert_called_once_with("new@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-456")
mock_generate_token.assert_called_once_with(
"new@example.com",
code="5678",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("new@example.com")
mock_csrf.assert_called_once()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "new@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_old_phase_token_for_reset(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_mismatched_new_email_for_verified_token(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "another@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-789"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")

View File

@@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
refreshed = _refresh_model(session=None, model=source_model)
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
assert refreshed is detached_model

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
@@ -33,8 +33,8 @@ def _make_graph_state():
],
)
def test_run_uses_single_node_execution_branch(
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
single_iteration_run: Any,
single_loop_run: Any,
) -> None:
app_config = MagicMock()
app_config.app_id = "app"
@@ -130,23 +130,10 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [],
"logical_operator": "and",
},
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
}
],
"edges": [],
}
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
@@ -156,19 +143,13 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value)
return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
@@ -180,8 +161,3 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []

View File

@@ -1,9 +1,6 @@
import pytest
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
from dify_graph.nodes.loop.entities import LoopNodeData
from dify_graph.nodes.loop.loop_node import LoopNode
from dify_graph.variables.types import SegmentType
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
@@ -53,21 +50,3 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
)
assert seen_configs == [child_node_config]
@pytest.mark.parametrize(
("var_type", "original_value", "expected_value"),
[
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
],
)
def test_get_segment_for_constant_accepts_native_array_values(
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
) -> None:
segment = LoopNode._get_segment_for_constant(var_type, original_value)
assert segment.value_type == var_type
assert segment.value == expected_value

View File

@@ -1,19 +0,0 @@
import pytest
from models.enums import CreatorUserRole
def test_creator_user_role_missing_maps_hyphen_to_enum():
# given an alias with hyphen
value = "end-user"
# when converting to enum (invokes StrEnum._missing_ override)
role = CreatorUserRole(value)
# then it should map to END_USER
assert role is CreatorUserRole.END_USER
def test_creator_user_role_missing_raises_for_unknown():
with pytest.raises(ValueError):
CreatorUserRole("unknown")

View File

@@ -11,7 +11,6 @@ import type { BasicPlan } from '@/app/components/billing/type'
import { cleanup, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import { ALL_PLANS } from '@/app/components/billing/config'
import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher'
import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item'
@@ -22,6 +21,7 @@ let mockAppCtx: Record<string, unknown> = {}
const mockFetchSubscriptionUrls = vi.fn()
const mockInvoices = vi.fn()
const mockOpenAsyncWindow = vi.fn()
const mockToastNotify = vi.fn()
// ─── Context mocks ───────────────────────────────────────────────────────────
vi.mock('@/context/app-context', () => ({
@@ -49,6 +49,10 @@ vi.mock('@/hooks/use-async-window-open', () => ({
useAsyncWindowOpen: () => mockOpenAsyncWindow,
}))
vi.mock('@/app/components/base/toast', () => ({
default: { notify: (args: unknown) => mockToastNotify(args) },
}))
// ─── Navigation mocks ───────────────────────────────────────────────────────
vi.mock('@/next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
@@ -78,15 +82,12 @@ const renderCloudPlanItem = ({
canPay = true,
}: RenderCloudPlanItemOptions = {}) => {
return render(
<>
<ToastHost timeout={0} />
<CloudPlanItem
currentPlan={currentPlan}
plan={plan}
planRange={planRange}
canPay={canPay}
/>
</>,
<CloudPlanItem
currentPlan={currentPlan}
plan={plan}
planRange={planRange}
canPay={canPay}
/>,
)
}
@@ -95,7 +96,6 @@ describe('Cloud Plan Payment Flow', () => {
beforeEach(() => {
vi.clearAllMocks()
cleanup()
toast.close()
setupAppContext()
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
@@ -283,7 +283,11 @@ describe('Cloud Plan Payment Flow', () => {
await user.click(button)
await waitFor(() => {
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
)
})
// Should not proceed with payment
expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled()

View File

@@ -10,12 +10,12 @@
import { cleanup, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config'
import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item'
import { SelfHostedPlan } from '@/app/components/billing/type'
let mockAppCtx: Record<string, unknown> = {}
const mockToastNotify = vi.fn()
const originalLocation = window.location
let assignedHref = ''
@@ -40,6 +40,10 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({
AwsMarketplaceDark: () => <span data-testid="icon-aws-dark" />,
}))
vi.mock('@/app/components/base/toast', () => ({
default: { notify: (args: unknown) => mockToastNotify(args) },
}))
vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({
default: ({ plan }: { plan: string }) => (
<div data-testid={`self-hosted-list-${plan}`}>Features</div>
@@ -53,20 +57,10 @@ const setupAppContext = (overrides: Record<string, unknown> = {}) => {
}
}
const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => {
return render(
<>
<ToastHost timeout={0} />
<SelfHostedPlanItem plan={plan} />
</>,
)
}
describe('Self-Hosted Plan Flow', () => {
beforeEach(() => {
vi.clearAllMocks()
cleanup()
toast.close()
setupAppContext()
// Mock window.location with minimal getter/setter (Location props are non-enumerable)
@@ -91,14 +85,14 @@ describe('Self-Hosted Plan Flow', () => {
// ─── 1. Plan Rendering ──────────────────────────────────────────────────
describe('Plan rendering', () => {
it('should render community plan with name and description', () => {
renderSelfHostedPlanItem(SelfHostedPlan.community)
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument()
expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument()
})
it('should render premium plan with cloud provider icons', () => {
renderSelfHostedPlanItem(SelfHostedPlan.premium)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument()
expect(screen.getByTestId('icon-azure')).toBeInTheDocument()
@@ -106,39 +100,39 @@ describe('Self-Hosted Plan Flow', () => {
})
it('should render enterprise plan without cloud provider icons', () => {
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument()
expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument()
})
it('should not show price tip for community (free) plan', () => {
renderSelfHostedPlanItem(SelfHostedPlan.community)
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument()
})
it('should show price tip for premium plan', () => {
renderSelfHostedPlanItem(SelfHostedPlan.premium)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument()
})
it('should render features list for each plan', () => {
const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community)
const { unmount: unmount1 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument()
unmount1()
const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium)
const { unmount: unmount2 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument()
unmount2()
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument()
})
it('should show AWS marketplace icon for premium plan button', () => {
renderSelfHostedPlanItem(SelfHostedPlan.premium)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument()
})
@@ -148,7 +142,7 @@ describe('Self-Hosted Plan Flow', () => {
describe('Navigation flow', () => {
it('should redirect to GitHub when clicking community plan button', async () => {
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.community)
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
const button = screen.getByRole('button')
await user.click(button)
@@ -158,7 +152,7 @@ describe('Self-Hosted Plan Flow', () => {
it('should redirect to AWS Marketplace when clicking premium plan button', async () => {
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.premium)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
const button = screen.getByRole('button')
await user.click(button)
@@ -168,7 +162,7 @@ describe('Self-Hosted Plan Flow', () => {
it('should redirect to Typeform when clicking enterprise plan button', async () => {
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
const button = screen.getByRole('button')
await user.click(button)
@@ -182,13 +176,15 @@ describe('Self-Hosted Plan Flow', () => {
it('should show error toast when non-manager clicks community button', async () => {
setupAppContext({ isCurrentWorkspaceManager: false })
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.community)
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
const button = screen.getByRole('button')
await user.click(button)
await waitFor(() => {
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error' }),
)
})
// Should NOT redirect
expect(assignedHref).toBe('')
@@ -197,13 +193,15 @@ describe('Self-Hosted Plan Flow', () => {
it('should show error toast when non-manager clicks premium button', async () => {
setupAppContext({ isCurrentWorkspaceManager: false })
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.premium)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
const button = screen.getByRole('button')
await user.click(button)
await waitFor(() => {
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error' }),
)
})
expect(assignedHref).toBe('')
})
@@ -211,13 +209,15 @@ describe('Self-Hosted Plan Flow', () => {
it('should show error toast when non-manager clicks enterprise button', async () => {
setupAppContext({ isCurrentWorkspaceManager: false })
const user = userEvent.setup()
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
const button = screen.getByRole('button')
await user.click(button)
await waitFor(() => {
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(
expect.objectContaining({ type: 'error' }),
)
})
expect(assignedHref).toBe('')
})

View File

@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}, 1000)
}
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
const sendEmail = async (email: string, token?: string) => {
try {
const res = await sendVerifyCode({
email,
phase: isOrigin ? 'old_email' : 'new_email',
token,
})
startCount()
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
const sendCodeToOriginEmail = async () => {
await sendEmail(
email,
true,
)
setStep(STEP.verifyOrigin)
}
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}
await sendEmail(
mail,
false,
stepToken,
)
setStep(STEP.verifyNew)

View File

@@ -13,7 +13,7 @@ import { useTranslation } from 'react-i18next'
import { Avatar } from '@/app/components/base/avatar'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
import { useRouter, useSearchParams } from '@/next/navigation'
@@ -91,9 +91,9 @@ export default function OAuthAuthorize() {
globalThis.location.href = url.toString()
}
catch (err: any) {
toast.add({
Toast.notify({
type: 'error',
title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
})
}
}
@@ -102,10 +102,10 @@ export default function OAuthAuthorize() {
const invalidParams = !client_id || !redirect_uri
if ((invalidParams || isError) && !hasNotifiedRef.current) {
hasNotifiedRef.current = true
toast.add({
Toast.notify({
type: 'error',
title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
timeout: 0,
message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
duration: 0,
})
}
}, [client_id, redirect_uri, isError])

View File

@@ -39,8 +39,8 @@ vi.mock('../app-card', () => ({
vi.mock('@/app/components/explore/create-app-modal', () => ({
default: () => <div data-testid="create-from-template-modal" />,
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: { add: vi.fn() },
vi.mock('@/app/components/base/toast', () => ({
default: { notify: vi.fn() },
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),

View File

@@ -12,7 +12,7 @@ import { trackEvent } from '@/app/components/base/amplitude'
import Divider from '@/app/components/base/divider'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import CreateAppModal from '@/app/components/explore/create-app-modal'
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
@@ -137,9 +137,9 @@ const Apps = ({
})
setIsShowCreateModal(false)
toast.add({
Toast.notify({
type: 'success',
title: t('newApp.appCreated', { ns: 'app' }),
message: t('newApp.appCreated', { ns: 'app' }),
})
if (onSuccess)
onSuccess()
@@ -149,7 +149,7 @@ const Apps = ({
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push)
}
catch {
toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) })
Toast.notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) })
}
}

View File

@@ -1,15 +1,8 @@
'use client'
/**
* @deprecated Use `@/app/components/base/ui/toast` instead.
* This module will be removed after migration is complete.
* See: https://github.com/langgenius/dify/issues/32811
*/
import type { ReactNode } from 'react'
import { createContext, useContext } from 'use-context-selector'
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
export type IToastProps = {
type?: 'success' | 'error' | 'warning' | 'info'
size?: 'md' | 'sm'
@@ -26,8 +19,5 @@ type IToastContext = {
close: () => void
}
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
export const ToastContext = createContext<IToastContext>({} as IToastContext)
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
export const useToastContext = () => useContext(ToastContext)

View File

@@ -1,11 +1,4 @@
'use client'
/**
* @deprecated Use `@/app/components/base/ui/toast` instead.
* This component will be removed after migration is complete.
* See: https://github.com/langgenius/dify/issues/32811
*/
import type { ReactNode } from 'react'
import type { IToastProps } from './context'
import { noop } from 'es-toolkit/function'
@@ -19,7 +12,6 @@ import { ToastContext, useToastContext } from './context'
export type ToastHandle = {
clear?: VoidFunction
}
const Toast = ({
type = 'info',
size = 'md',
@@ -82,7 +74,6 @@ const Toast = ({
)
}
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
export const ToastProvider = ({
children,
}: {

View File

@@ -1,16 +1,22 @@
import type { Mock } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../../base/toast'
import { ALL_PLANS } from '../../../../config'
import { Plan } from '../../../../type'
import { PlanRange } from '../../../plan-switcher/plan-range-switcher'
import CloudPlanItem from '../index'
vi.mock('../../../../../base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
@@ -41,19 +47,11 @@ const mockUseAppContext = useAppContext as Mock
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
const mockBillingInvoices = consoleClient.billing.invoices as Mock
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
const mockToastNotify = Toast.notify as Mock
let assignedHref = ''
const originalLocation = window.location
const renderWithToastHost = (ui: React.ReactNode) => {
return render(
<>
<ToastHost timeout={0} />
{ui}
</>,
)
}
beforeAll(() => {
Object.defineProperty(window, 'location', {
configurable: true,
@@ -70,7 +68,6 @@ beforeAll(() => {
beforeEach(() => {
vi.clearAllMocks()
toast.close()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' })
@@ -166,7 +163,7 @@ describe('CloudPlanItem', () => {
it('should show toast when non-manager tries to buy a plan', () => {
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
renderWithToastHost(
render(
<CloudPlanItem
plan={Plan.professional}
currentPlan={Plan.sandbox}
@@ -176,7 +173,10 @@ describe('CloudPlanItem', () => {
)
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }))
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
expect(mockBillingInvoices).not.toHaveBeenCalled()
})

View File

@@ -4,11 +4,11 @@ import type { BasicPlan } from '../../../type'
import * as React from 'react'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { toast } from '@/app/components/base/ui/toast'
import { useAppContext } from '@/context/app-context'
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
import { fetchSubscriptionUrls } from '@/service/billing'
import { consoleClient } from '@/service/client'
import Toast from '../../../../base/toast'
import { ALL_PLANS } from '../../../config'
import { Plan } from '../../../type'
import { Professional, Sandbox, Team } from '../../assets'
@@ -66,9 +66,10 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
return
if (!isCurrentWorkspaceManager) {
toast.add({
Toast.notify({
type: 'error',
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
className: 'z-[1001]',
})
return
}
@@ -82,7 +83,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
throw new Error('Failed to open billing page')
}, {
onError: (err) => {
toast.add({ type: 'error', title: err.message || String(err) })
Toast.notify({ type: 'error', message: err.message || String(err) })
},
})
return
@@ -110,34 +111,34 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
{
isMostPopularPlan && (
<div className="flex items-center justify-center bg-saas-dify-blue-static px-1.5 py-1">
<span className="text-text-primary-on-surface system-2xs-semibold-uppercase">
<span className="system-2xs-semibold-uppercase text-text-primary-on-surface">
{t('plansCommon.mostPopular', { ns: 'billing' })}
</span>
</div>
)
}
</div>
<div className="text-text-secondary system-sm-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
<div className="system-sm-regular text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
</div>
</div>
{/* Price */}
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
{isFreePlan && (
<span className="text-text-primary title-4xl-semi-bold">{t('plansCommon.free', { ns: 'billing' })}</span>
<span className="title-4xl-semi-bold text-text-primary">{t('plansCommon.free', { ns: 'billing' })}</span>
)}
{!isFreePlan && (
<>
{isYear && (
<span className="text-text-quaternary line-through title-4xl-semi-bold">
<span className="title-4xl-semi-bold text-text-quaternary line-through">
$
{planInfo.price * 12}
</span>
)}
<span className="text-text-primary title-4xl-semi-bold">
<span className="title-4xl-semi-bold text-text-primary">
$
{isYear ? planInfo.price * 10 : planInfo.price}
</span>
<span className="pb-0.5 text-text-tertiary system-md-regular">
<span className="system-md-regular pb-0.5 text-text-tertiary">
{t('plansCommon.priceTip', { ns: 'billing' })}
{t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })}
</span>

View File

@@ -1,8 +1,8 @@
import type { Mock } from 'vitest'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import { useAppContext } from '@/context/app-context'
import Toast from '../../../../../base/toast'
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../../config'
import { SelfHostedPlan } from '../../../../type'
import SelfHostedPlanItem from '../index'
@@ -16,6 +16,12 @@ vi.mock('../list', () => ({
),
}))
vi.mock('../../../../../base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
vi.mock('@/context/app-context', () => ({
useAppContext: vi.fn(),
}))
@@ -29,19 +35,11 @@ vi.mock('../../../assets', () => ({
}))
const mockUseAppContext = useAppContext as Mock
const mockToastNotify = Toast.notify as Mock
let assignedHref = ''
const originalLocation = window.location
const renderWithToastHost = (ui: React.ReactNode) => {
return render(
<>
<ToastHost timeout={0} />
{ui}
</>,
)
}
beforeAll(() => {
Object.defineProperty(window, 'location', {
configurable: true,
@@ -58,7 +56,6 @@ beforeAll(() => {
beforeEach(() => {
vi.clearAllMocks()
toast.close()
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
assignedHref = ''
})
@@ -93,10 +90,13 @@ describe('SelfHostedPlanItem', () => {
it('should show toast when non-manager tries to proceed', () => {
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
renderWithToastHost(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
fireEvent.click(screen.getByRole('button', { name: /billing\.plans\.premium\.btnText/ }))
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'error',
message: 'billing.buyPermissionDeniedTip',
}))
})
it('should redirect to community url when community plan button clicked', () => {

View File

@@ -4,9 +4,9 @@ import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { Azure, GoogleCloud } from '@/app/components/base/icons/src/public/billing'
import { toast } from '@/app/components/base/ui/toast'
import { useAppContext } from '@/context/app-context'
import { cn } from '@/utils/classnames'
import Toast from '../../../../base/toast'
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../config'
import { SelfHostedPlan } from '../../../type'
import { Community, Enterprise, EnterpriseNoise, Premium, PremiumNoise } from '../../assets'
@@ -56,9 +56,10 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
const handleGetPayUrl = useCallback(() => {
// Only workspace manager can buy plan
if (!isCurrentWorkspaceManager) {
toast.add({
Toast.notify({
type: 'error',
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
className: 'z-[1001]',
})
return
}
@@ -81,18 +82,18 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
{/* Noise Effect */}
{STYLE_MAP[plan].noise}
<div className="flex flex-col px-5 py-4">
<div className="flex flex-col gap-y-6 px-1 pt-10">
<div className=" flex flex-col gap-y-6 px-1 pt-10">
{STYLE_MAP[plan].icon}
<div className="flex min-h-[104px] flex-col gap-y-2">
<div className="text-[30px] font-medium leading-[1.2] text-text-primary">{t(`${i18nPrefix}.name`, { ns: 'billing' })}</div>
<div className="line-clamp-2 text-text-secondary system-md-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
<div className="system-md-regular line-clamp-2 text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
</div>
</div>
{/* Price */}
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
<div className="shrink-0 text-text-primary title-4xl-semi-bold">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
<div className="title-4xl-semi-bold shrink-0 text-text-primary">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
{!isFreePlan && (
<span className="pb-0.5 text-text-tertiary system-md-regular">
<span className="system-md-regular pb-0.5 text-text-tertiary">
{t(`${i18nPrefix}.priceTip`, { ns: 'billing' })}
</span>
)}
@@ -113,7 +114,7 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
<GoogleCloud />
</div>
</div>
<span className="text-text-tertiary system-xs-regular">
<span className="system-xs-regular text-text-tertiary">
{t('plans.premium.comingSoon', { ns: 'billing' })}
</span>
</div>

View File

@@ -1,6 +1,6 @@
import type * as React from 'react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import { ChunkingMode } from '@/models/datasets'
import { IndexingType } from '../../../create/step-two'
@@ -13,7 +13,14 @@ vi.mock('@/next/navigation', () => ({
}),
}))
const toastAddSpy = vi.spyOn(toast, 'add')
const mockNotify = vi.fn()
vi.mock('use-context-selector', async (importOriginal) => {
const actual = await importOriginal() as Record<string, unknown>
return {
...actual,
useContext: () => ({ notify: mockNotify }),
}
})
// Mock dataset detail context
let mockIndexingTechnique = IndexingType.QUALIFIED
@@ -44,6 +51,11 @@ vi.mock('@/service/knowledge/use-segment', () => ({
}),
}))
// Mock app store
vi.mock('@/app/components/app/store', () => ({
useStore: () => ({ appSidebarExpand: 'expand' }),
}))
vi.mock('../completed/common/action-buttons', () => ({
default: ({ handleCancel, handleSave, loading, actionType }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string }) => (
<div data-testid="action-buttons">
@@ -127,8 +139,6 @@ vi.mock('@/app/components/datasets/common/image-uploader/image-uploader-in-chunk
describe('NewSegmentModal', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useRealTimers()
toast.close()
mockFullScreen = false
mockIndexingTechnique = IndexingType.QUALIFIED
})
@@ -248,7 +258,7 @@ describe('NewSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
@@ -262,7 +272,7 @@ describe('NewSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
@@ -277,7 +287,7 @@ describe('NewSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
@@ -327,7 +337,7 @@ describe('NewSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'success',
}),
@@ -420,9 +430,10 @@ describe('NewSegmentModal', () => {
})
})
describe('Action button in success notification', () => {
it('should call viewNewlyAddedChunk when the toast action is clicked', async () => {
describe('CustomButton in success notification', () => {
it('should call viewNewlyAddedChunk when custom button is clicked', async () => {
const mockViewNewlyAddedChunk = vi.fn()
mockNotify.mockImplementation(() => {})
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
options.onSuccess()
@@ -431,25 +442,37 @@ describe('NewSegmentModal', () => {
})
render(
<>
<ToastHost timeout={0} />
<NewSegmentModal
{...defaultProps}
docForm={ChunkingMode.text}
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
/>
</>,
<NewSegmentModal
{...defaultProps}
docForm={ChunkingMode.text}
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
/>,
)
// Enter content and save
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
fireEvent.click(screen.getByTestId('save-btn'))
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
fireEvent.click(actionButton)
await waitFor(() => {
expect(mockViewNewlyAddedChunk).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'success',
customComponent: expect.anything(),
}),
)
})
// Extract customComponent from the notify call args
const notifyCallArgs = mockNotify.mock.calls[0][0] as { customComponent?: React.ReactElement }
expect(notifyCallArgs.customComponent).toBeDefined()
const customComponent = notifyCallArgs.customComponent!
const { container: btnContainer } = render(customComponent)
const viewButton = btnContainer.querySelector('.system-xs-semibold.text-text-accent') as HTMLElement
expect(viewButton).toBeInTheDocument()
fireEvent.click(viewButton)
// Assert that viewNewlyAddedChunk was called via the onClick handler (lines 66-67)
expect(mockViewNewlyAddedChunk).toHaveBeenCalled()
})
})
@@ -576,8 +599,9 @@ describe('NewSegmentModal', () => {
})
})
describe('onSave after success', () => {
it('should call onSave immediately after save succeeds', async () => {
describe('onSave delayed call', () => {
it('should call onSave after timeout in success handler', async () => {
vi.useFakeTimers()
const mockOnSave = vi.fn()
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
options.onSuccess()
@@ -587,12 +611,15 @@ describe('NewSegmentModal', () => {
render(<NewSegmentModal {...defaultProps} onSave={mockOnSave} docForm={ChunkingMode.text} />)
// Enter content and save
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(mockOnSave).toHaveBeenCalledTimes(1)
})
// Fast-forward timer
vi.advanceTimersByTime(3000)
expect(mockOnSave).toHaveBeenCalled()
vi.useRealTimers()
})
})

View File

@@ -1,6 +1,5 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { toast, ToastHost } from '@/app/components/base/ui/toast'
import NewChildSegmentModal from '../new-child-segment'
@@ -11,7 +10,14 @@ vi.mock('@/next/navigation', () => ({
}),
}))
const toastAddSpy = vi.spyOn(toast, 'add')
const mockNotify = vi.fn()
vi.mock('use-context-selector', async (importOriginal) => {
const actual = await importOriginal() as Record<string, unknown>
return {
...actual,
useContext: () => ({ notify: mockNotify }),
}
})
// Mock document context
let mockParentMode = 'paragraph'
@@ -42,6 +48,11 @@ vi.mock('@/service/knowledge/use-segment', () => ({
}),
}))
// Mock app store
vi.mock('@/app/components/app/store', () => ({
useStore: () => ({ appSidebarExpand: 'expand' }),
}))
vi.mock('../common/action-buttons', () => ({
default: ({ handleCancel, handleSave, loading, actionType, isChildChunk }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string, isChildChunk?: boolean }) => (
<div data-testid="action-buttons">
@@ -92,8 +103,6 @@ vi.mock('../common/segment-index-tag', () => ({
describe('NewChildSegmentModal', () => {
beforeEach(() => {
vi.clearAllMocks()
vi.useRealTimers()
toast.close()
mockFullScreen = false
mockParentMode = 'paragraph'
})
@@ -189,7 +198,7 @@ describe('NewChildSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'error',
}),
@@ -244,7 +253,7 @@ describe('NewChildSegmentModal', () => {
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(toastAddSpy).toHaveBeenCalledWith(
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'success',
}),
@@ -365,62 +374,35 @@ describe('NewChildSegmentModal', () => {
// View newly added chunk
describe('View Newly Added Chunk', () => {
it('should call viewNewlyAddedChildChunk when the toast action is clicked', async () => {
it('should show custom button in full-doc mode after save', async () => {
mockParentMode = 'full-doc'
const mockViewNewlyAddedChildChunk = vi.fn()
mockAddChildSegment.mockImplementation((_params, options) => {
options.onSuccess({ data: { id: 'new-child-id' } })
options.onSettled()
return Promise.resolve()
})
render(
<>
<ToastHost timeout={0} />
<NewChildSegmentModal
{...defaultProps}
viewNewlyAddedChildChunk={mockViewNewlyAddedChildChunk}
/>
</>,
)
render(<NewChildSegmentModal {...defaultProps} />)
// Enter valid content
fireEvent.change(screen.getByTestId('content-input'), {
target: { value: 'Valid content' },
})
fireEvent.click(screen.getByTestId('save-btn'))
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
fireEvent.click(actionButton)
// Assert - success notification with custom component
await waitFor(() => {
expect(mockViewNewlyAddedChildChunk).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith(
expect.objectContaining({
type: 'success',
customComponent: expect.anything(),
}),
)
})
})
it('should call onSave immediately in full-doc mode after save succeeds', async () => {
mockParentMode = 'full-doc'
const mockOnSave = vi.fn()
mockAddChildSegment.mockImplementation((_params, options) => {
options.onSuccess({ data: { id: 'new-child-id' } })
options.onSettled()
return Promise.resolve()
})
render(<NewChildSegmentModal {...defaultProps} onSave={mockOnSave} />)
fireEvent.change(screen.getByTestId('content-input'), {
target: { value: 'Valid content' },
})
fireEvent.click(screen.getByTestId('save-btn'))
await waitFor(() => {
expect(mockOnSave).toHaveBeenCalledTimes(1)
})
})
it('should call onSave with the new child chunk in paragraph mode', async () => {
it('should not show custom button in paragraph mode after save', async () => {
mockParentMode = 'paragraph'
const mockOnSave = vi.fn()
mockAddChildSegment.mockImplementation((_params, options) => {

View File

@@ -1,10 +1,13 @@
import type { FC } from 'react'
import type { ChildChunkDetail, SegmentUpdater } from '@/models/datasets'
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
import { memo, useState } from 'react'
import { memo, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { useShallow } from 'zustand/react/shallow'
import { useStore as useAppStore } from '@/app/components/app/store'
import Divider from '@/app/components/base/divider'
import { toast } from '@/app/components/base/ui/toast'
import { ToastContext } from '@/app/components/base/toast/context'
import { ChunkingMode } from '@/models/datasets'
import { useParams } from '@/next/navigation'
import { useAddChildSegment } from '@/service/knowledge/use-segment'
@@ -32,15 +35,39 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
viewNewlyAddedChildChunk,
}) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const [content, setContent] = useState('')
const { datasetId, documentId } = useParams<{ datasetId: string, documentId: string }>()
const [loading, setLoading] = useState(false)
const [addAnother, setAddAnother] = useState(true)
const fullScreen = useSegmentListContext(s => s.fullScreen)
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const { appSidebarExpand } = useAppStore(useShallow(state => ({
appSidebarExpand: state.appSidebarExpand,
})))
const parentMode = useDocumentContext(s => s.parentMode)
const isFullDocMode = parentMode === 'full-doc'
const refreshTimer = useRef<any>(null)
const isFullDocMode = useMemo(() => {
return parentMode === 'full-doc'
}, [parentMode])
const CustomButton = (
<>
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
<button
type="button"
className="text-text-accent system-xs-semibold"
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChildChunk?.()
}}
>
{t('operation.view', { ns: 'common' })}
</button>
</>
)
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
if (actionType === 'esc' || !addAnother)
@@ -53,27 +80,26 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
const params: SegmentUpdater = { content: '' }
if (!content.trim())
return toast.add({ type: 'error', title: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
return notify({ type: 'error', message: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
params.content = content
setLoading(true)
await addChildSegment({ datasetId, documentId, segmentId: chunkId, body: params }, {
onSuccess(res) {
toast.add({
notify({
type: 'success',
title: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
actionProps: isFullDocMode
? {
children: t('operation.view', { ns: 'common' }),
onClick: viewNewlyAddedChildChunk,
}
: undefined,
message: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
!top-auto !right-auto !mb-[52px] !ml-11`,
customComponent: isFullDocMode && CustomButton,
})
handleCancel('add')
setContent('')
if (isFullDocMode) {
onSave()
refreshTimer.current = setTimeout(() => {
onSave()
}, 3000)
}
else {
onSave(res.data)
@@ -85,8 +111,10 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
})
}
const count = content.length
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
const wordCountText = useMemo(() => {
const count = content.length
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
}, [content.length])
return (
<div className="flex h-full flex-col">

View File

@@ -2,10 +2,13 @@ import type { FC } from 'react'
import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types'
import type { SegmentUpdater } from '@/models/datasets'
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
import { memo, useCallback, useState } from 'react'
import { memo, useCallback, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { useShallow } from 'zustand/react/shallow'
import { useStore as useAppStore } from '@/app/components/app/store'
import Divider from '@/app/components/base/divider'
import { toast } from '@/app/components/base/ui/toast'
import { ToastContext } from '@/app/components/base/toast/context'
import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { ChunkingMode } from '@/models/datasets'
@@ -36,6 +39,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
viewNewlyAddedChunk,
}) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const [question, setQuestion] = useState('')
const [answer, setAnswer] = useState('')
const [attachments, setAttachments] = useState<FileEntity[]>([])
@@ -46,7 +50,27 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
const fullScreen = useSegmentListContext(s => s.fullScreen)
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
const [imageUploaderKey, setImageUploaderKey] = useState(() => Date.now())
const { appSidebarExpand } = useAppStore(useShallow(state => ({
appSidebarExpand: state.appSidebarExpand,
})))
const [imageUploaderKey, setImageUploaderKey] = useState(Date.now())
const refreshTimer = useRef<any>(null)
const CustomButton = useMemo(() => (
<>
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
<button
type="button"
className="text-text-accent system-xs-semibold"
onClick={() => {
clearTimeout(refreshTimer.current)
viewNewlyAddedChunk()
}}
>
{t('operation.view', { ns: 'common' })}
</button>
</>
), [viewNewlyAddedChunk, t])
const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
if (actionType === 'esc' || !addAnother)
@@ -63,15 +87,15 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
const params: SegmentUpdater = { content: '', attachment_ids: [] }
if (docForm === ChunkingMode.qa) {
if (!question.trim()) {
return toast.add({
return notify({
type: 'error',
title: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
message: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
})
}
if (!answer.trim()) {
return toast.add({
return notify({
type: 'error',
title: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
message: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
})
}
@@ -80,9 +104,9 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
}
else {
if (!question.trim()) {
return toast.add({
return notify({
type: 'error',
title: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
message: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
})
}
@@ -98,13 +122,12 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
setLoading(true)
await addSegment({ datasetId, documentId, body: params }, {
onSuccess() {
toast.add({
notify({
type: 'success',
title: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
actionProps: {
children: t('operation.view', { ns: 'common' }),
onClick: viewNewlyAddedChunk,
},
message: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
!top-auto !right-auto !mb-[52px] !ml-11`,
customComponent: CustomButton,
})
handleCancel('add')
setQuestion('')
@@ -112,16 +135,20 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
setAttachments([])
setImageUploaderKey(Date.now())
setKeywords([])
onSave()
refreshTimer.current = setTimeout(() => {
onSave()
}, 3000)
},
onSettled() {
setLoading(false)
},
})
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, t, handleCancel, onSave, viewNewlyAddedChunk])
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
const wordCountText = useMemo(() => {
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
}, [question.length, answer.length, docForm, t])
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL

View File

@@ -21,11 +21,11 @@ vi.mock('@/context/i18n', () => ({
useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`,
}))
const mockNotify = vi.hoisted(() => vi.fn())
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
add: mockNotify,
},
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
}))
// Mock modal context
@@ -164,7 +164,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
// Verify success notification
expect(mockNotify).toHaveBeenCalledWith({
type: 'success',
title: 'External Knowledge Base Connected Successfully',
message: 'External Knowledge Base Connected Successfully',
})
// Verify navigation back
@@ -206,7 +206,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
title: 'Failed to connect External Knowledge Base',
message: 'Failed to connect External Knowledge Base',
})
})
@@ -228,7 +228,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
title: 'Failed to connect External Knowledge Base',
message: 'Failed to connect External Knowledge Base',
})
})
@@ -274,7 +274,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'success',
title: 'External Knowledge Base Connected Successfully',
message: 'External Knowledge Base Connected Successfully',
})
})
})

View File

@@ -4,12 +4,13 @@ import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-
import * as React from 'react'
import { useState } from 'react'
import { trackEvent } from '@/app/components/base/amplitude'
import { toast } from '@/app/components/base/ui/toast'
import { useToastContext } from '@/app/components/base/toast/context'
import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create'
import { useRouter } from '@/next/navigation'
import { createExternalKnowledgeBase } from '@/service/datasets'
const ExternalKnowledgeBaseConnector = () => {
const { notify } = useToastContext()
const [loading, setLoading] = useState(false)
const router = useRouter()
@@ -18,7 +19,7 @@ const ExternalKnowledgeBaseConnector = () => {
setLoading(true)
const result = await createExternalKnowledgeBase({ body: formValue })
if (result && result.id) {
toast.add({ type: 'success', title: 'External Knowledge Base Connected Successfully' })
notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' })
trackEvent('create_external_knowledge_base', {
provider: formValue.provider,
name: formValue.name,
@@ -29,7 +30,7 @@ const ExternalKnowledgeBaseConnector = () => {
}
catch (error) {
console.error('Error creating external knowledge base:', error)
toast.add({ type: 'error', title: 'Failed to connect External Knowledge Base' })
notify({ type: 'error', message: 'Failed to connect External Knowledge Base' })
}
setLoading(false)
}

View File

@@ -36,13 +36,30 @@ const TransferOwnershipModal = ({ onClose, show }: Props) => {
const [stepToken, setStepToken] = useState<string>('')
const [newOwner, setNewOwner] = useState<string>('')
const [isTransfer, setIsTransfer] = useState<boolean>(false)
const timerRef = React.useRef<ReturnType<typeof setInterval> | null>(null)
React.useEffect(() => {
return () => {
if (timerRef.current) {
clearInterval(timerRef.current)
timerRef.current = null
}
}
}, [])
const startCount = () => {
if (timerRef.current) {
clearInterval(timerRef.current)
timerRef.current = null
}
setTime(60)
const timer = setInterval(() => {
timerRef.current = setInterval(() => {
setTime((prev) => {
if (prev <= 0) {
clearInterval(timer)
if (timerRef.current) {
clearInterval(timerRef.current)
timerRef.current = null
}
return 0
}
return prev - 1

View File

@@ -43,10 +43,10 @@ vi.mock('@/context/provider-context', () => ({
}),
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
add: mockNotify,
},
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
}))
vi.mock('../../hooks', () => ({
@@ -150,7 +150,7 @@ describe('SystemModel', () => {
expect(mockUpdateDefaultModel).toHaveBeenCalledTimes(1)
expect(mockNotify).toHaveBeenCalledWith({
type: 'success',
title: 'Modified successfully',
message: 'Modified successfully',
})
expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5)
expect(mockUpdateModelList).toHaveBeenCalledTimes(5)

View File

@@ -6,13 +6,13 @@ import type {
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import { useToastContext } from '@/app/components/base/toast/context'
import {
Dialog,
DialogCloseButton,
DialogContent,
DialogTitle,
} from '@/app/components/base/ui/dialog'
import { toast } from '@/app/components/base/ui/toast'
import {
Tooltip,
TooltipContent,
@@ -64,6 +64,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
isLoading,
}) => {
const { t } = useTranslation()
const { notify } = useToastContext()
const { isCurrentWorkspaceManager } = useAppContext()
const { textGenerationModelList } = useProviderContext()
const updateModelList = useUpdateModelList()
@@ -123,7 +124,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
},
})
if (res.result === 'success') {
toast.add({ type: 'success', title: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
setOpen(false)
const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts]

View File

@@ -4,7 +4,7 @@ import { DeleteConfirm } from '../delete-confirm'
const mockRefetch = vi.fn()
const mockDelete = vi.fn()
const mockToastAdd = vi.hoisted(() => vi.fn())
const mockToast = vi.fn()
vi.mock('../use-subscription-list', () => ({
useSubscriptionList: () => ({ refetch: mockRefetch }),
@@ -14,9 +14,9 @@ vi.mock('@/service/use-triggers', () => ({
useDeleteTriggerSubscription: () => ({ mutate: mockDelete, isPending: false }),
}))
vi.mock('@/app/components/base/ui/toast', () => ({
toast: {
add: mockToastAdd,
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (args: { type: string, message: string }) => mockToast(args),
},
}))
@@ -42,7 +42,7 @@ describe('DeleteConfirm', () => {
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
expect(mockDelete).not.toHaveBeenCalled()
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
})
it('should allow deletion after matching input name', () => {
@@ -87,6 +87,6 @@ describe('DeleteConfirm', () => {
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', title: 'network error' }))
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', message: 'network error' }))
})
})

View File

@@ -1,16 +1,8 @@
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Confirm from '@/app/components/base/confirm'
import Input from '@/app/components/base/input'
import {
AlertDialog,
AlertDialogActions,
AlertDialogCancelButton,
AlertDialogConfirmButton,
AlertDialogContent,
AlertDialogDescription,
AlertDialogTitle,
} from '@/app/components/base/ui/alert-dialog'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import { useDeleteTriggerSubscription } from '@/service/use-triggers'
import { useSubscriptionList } from './use-subscription-list'
@@ -31,74 +23,58 @@ export const DeleteConfirm = (props: Props) => {
const { t } = useTranslation()
const [inputName, setInputName] = useState('')
const handleOpenChange = (open: boolean) => {
if (isDeleting)
return
if (!open)
onClose(false)
}
const onConfirm = () => {
if (workflowsInUse > 0 && inputName !== currentName) {
toast.add({
Toast.notify({
type: 'error',
title: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
message: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
// temporarily
className: 'z-[10000001]',
})
return
}
deleteSubscription(currentId, {
onSuccess: () => {
toast.add({
Toast.notify({
type: 'success',
title: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
message: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
className: 'z-[10000001]',
})
refetch?.()
onClose(true)
},
onError: (error: unknown) => {
toast.add({
onError: (error: any) => {
Toast.notify({
type: 'error',
title: error instanceof Error ? error.message : t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
message: error?.message || t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
className: 'z-[10000001]',
})
},
})
}
return (
<AlertDialog open={isShow} onOpenChange={handleOpenChange}>
<AlertDialogContent backdropProps={{ forceRender: true }}>
<div className="flex flex-col gap-2 px-6 pb-4 pt-6">
<AlertDialogTitle title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })} className="w-full truncate text-text-primary title-2xl-semi-bold">
{t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
</AlertDialogTitle>
<AlertDialogDescription className="w-full whitespace-pre-wrap break-words text-text-tertiary system-md-regular">
{workflowsInUse > 0
? t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
</AlertDialogDescription>
{workflowsInUse > 0 && (
<div className="mt-6">
<div className="mb-2 text-text-secondary system-sm-medium">
{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}
</div>
<Confirm
title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
confirmText={t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
content={workflowsInUse > 0
? (
<>
{t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })}
<div className="system-sm-medium mb-2 mt-6 text-text-secondary">{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}</div>
<Input
value={inputName}
onChange={e => setInputName(e.target.value)}
placeholder={t(`${tPrefix}.confirmInputPlaceholder`, { ns: 'pluginTrigger', name: currentName })}
/>
</div>
)}
</div>
<AlertDialogActions>
<AlertDialogCancelButton disabled={isDeleting}>
{t('operation.cancel', { ns: 'common' })}
</AlertDialogCancelButton>
<AlertDialogConfirmButton loading={isDeleting} disabled={isDeleting} onClick={onConfirm}>
{t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
</AlertDialogConfirmButton>
</AlertDialogActions>
</AlertDialogContent>
</AlertDialog>
</>
)
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
isShow={isShow}
isLoading={isDeleting}
isDisabled={isDeleting}
onConfirm={onConfirm}
onCancel={() => onClose(false)}
maskClosable={false}
/>
)
}

View File

@@ -1,40 +0,0 @@
import type { Node } from '../types'
import { screen } from '@testing-library/react'
import CandidateNode from '../candidate-node'
import { BlockEnum } from '../types'
import { renderWorkflowComponent } from './workflow-test-env'
vi.mock('../candidate-node-main', () => ({
default: ({ candidateNode }: { candidateNode: Node }) => (
<div data-testid="candidate-node-main">{candidateNode.id}</div>
),
}))
const createCandidateNode = (): Node => ({
id: 'candidate-node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
type: BlockEnum.Start,
title: 'Candidate node',
desc: 'candidate',
},
})
describe('CandidateNode', () => {
it('should not render when candidateNode is missing from the workflow store', () => {
renderWorkflowComponent(<CandidateNode />)
expect(screen.queryByTestId('candidate-node-main')).not.toBeInTheDocument()
})
it('should render CandidateNodeMain with the stored candidate node', () => {
renderWorkflowComponent(<CandidateNode />, {
initialStoreState: {
candidateNode: createCandidateNode(),
},
})
expect(screen.getByTestId('candidate-node-main')).toHaveTextContent('candidate-node-1')
})
})

View File

@@ -1,81 +0,0 @@
import type { ComponentProps } from 'react'
import { render } from '@testing-library/react'
import { getBezierPath, Position } from 'reactflow'
import CustomConnectionLine from '../custom-connection-line'
const createConnectionLineProps = (
overrides: Partial<ComponentProps<typeof CustomConnectionLine>> = {},
): ComponentProps<typeof CustomConnectionLine> => ({
fromX: 10,
fromY: 20,
toX: 70,
toY: 80,
fromPosition: Position.Right,
toPosition: Position.Left,
connectionLineType: undefined,
connectionStatus: null,
...overrides,
} as ComponentProps<typeof CustomConnectionLine>)
describe('CustomConnectionLine', () => {
it('should render the bezier path and target marker', () => {
const [expectedPath] = getBezierPath({
sourceX: 10,
sourceY: 20,
sourcePosition: Position.Right,
targetX: 70,
targetY: 80,
targetPosition: Position.Left,
curvature: 0.16,
})
const { container } = render(
<svg>
<CustomConnectionLine {...createConnectionLineProps()} />
</svg>,
)
const path = container.querySelector('path')
const marker = container.querySelector('rect')
expect(path).toHaveAttribute('fill', 'none')
expect(path).toHaveAttribute('stroke', '#D0D5DD')
expect(path).toHaveAttribute('stroke-width', '2')
expect(path).toHaveAttribute('d', expectedPath)
expect(marker).toHaveAttribute('x', '70')
expect(marker).toHaveAttribute('y', '76')
expect(marker).toHaveAttribute('width', '2')
expect(marker).toHaveAttribute('height', '8')
expect(marker).toHaveAttribute('fill', '#2970FF')
})
it('should update the path when the endpoints change', () => {
const [expectedPath] = getBezierPath({
sourceX: 30,
sourceY: 40,
sourcePosition: Position.Right,
targetX: 160,
targetY: 200,
targetPosition: Position.Left,
curvature: 0.16,
})
const { container } = render(
<svg>
<CustomConnectionLine
{...createConnectionLineProps({
fromX: 30,
fromY: 40,
toX: 160,
toY: 200,
})}
/>
</svg>,
)
expect(container.querySelector('path')).toHaveAttribute('d', expectedPath)
expect(container.querySelector('rect')).toHaveAttribute('x', '160')
expect(container.querySelector('rect')).toHaveAttribute('y', '196')
})
})

View File

@@ -1,57 +0,0 @@
import { render } from '@testing-library/react'
import CustomEdgeLinearGradientRender from '../custom-edge-linear-gradient-render'
describe('CustomEdgeLinearGradientRender', () => {
it('should render gradient definition with the provided id and positions', () => {
const { container } = render(
<svg>
<CustomEdgeLinearGradientRender
id="edge-gradient"
startColor="#123456"
stopColor="#abcdef"
position={{
x1: 10,
y1: 20,
x2: 30,
y2: 40,
}}
/>
</svg>,
)
const gradient = container.querySelector('linearGradient')
expect(gradient).toHaveAttribute('id', 'edge-gradient')
expect(gradient).toHaveAttribute('gradientUnits', 'userSpaceOnUse')
expect(gradient).toHaveAttribute('x1', '10')
expect(gradient).toHaveAttribute('y1', '20')
expect(gradient).toHaveAttribute('x2', '30')
expect(gradient).toHaveAttribute('y2', '40')
})
it('should render start and stop colors at both ends of the gradient', () => {
const { container } = render(
<svg>
<CustomEdgeLinearGradientRender
id="gradient-colors"
startColor="#111111"
stopColor="#222222"
position={{
x1: 0,
y1: 0,
x2: 100,
y2: 100,
}}
/>
</svg>,
)
const stops = container.querySelectorAll('stop')
expect(stops).toHaveLength(2)
expect(stops[0]).toHaveAttribute('offset', '0%')
expect(stops[0].getAttribute('style')).toContain('stop-color: rgb(17, 17, 17)')
expect(stops[0].getAttribute('style')).toContain('stop-opacity: 1')
expect(stops[1]).toHaveAttribute('offset', '100%')
expect(stops[1].getAttribute('style')).toContain('stop-color: rgb(34, 34, 34)')
expect(stops[1].getAttribute('style')).toContain('stop-opacity: 1')
})
})

View File

@@ -1,127 +0,0 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DSLExportConfirmModal from '../dsl-export-confirm-modal'
const envList = [
{
id: 'env-1',
name: 'SECRET_TOKEN',
value: 'masked-value',
value_type: 'secret' as const,
description: 'secret token',
},
]
const multiEnvList = [
...envList,
{
id: 'env-2',
name: 'SERVICE_KEY',
value: 'another-secret',
value_type: 'secret' as const,
description: 'service key',
},
]
describe('DSLExportConfirmModal', () => {
it('should render environment rows and close when cancel is clicked', async () => {
const user = userEvent.setup()
const onConfirm = vi.fn()
const onClose = vi.fn()
render(
<DSLExportConfirmModal
envList={envList}
onConfirm={onConfirm}
onClose={onClose}
/>,
)
expect(screen.getByText('SECRET_TOKEN')).toBeInTheDocument()
expect(screen.getByText('masked-value')).toBeInTheDocument()
await user.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(onClose).toHaveBeenCalledTimes(1)
expect(onConfirm).not.toHaveBeenCalled()
})
it('should confirm with exportSecrets=false by default', async () => {
const user = userEvent.setup()
const onConfirm = vi.fn()
const onClose = vi.fn()
render(
<DSLExportConfirmModal
envList={envList}
onConfirm={onConfirm}
onClose={onClose}
/>,
)
await user.click(screen.getByRole('button', { name: 'workflow.env.export.ignore' }))
expect(onConfirm).toHaveBeenCalledWith(false)
expect(onClose).toHaveBeenCalledTimes(1)
})
it('should confirm with exportSecrets=true after toggling the checkbox', async () => {
const user = userEvent.setup()
const onConfirm = vi.fn()
const onClose = vi.fn()
render(
<DSLExportConfirmModal
envList={envList}
onConfirm={onConfirm}
onClose={onClose}
/>,
)
await user.click(screen.getByRole('checkbox'))
await user.click(screen.getByRole('button', { name: 'workflow.env.export.export' }))
expect(onConfirm).toHaveBeenCalledWith(true)
expect(onClose).toHaveBeenCalledTimes(1)
})
it('should also toggle exportSecrets when the label text is clicked', async () => {
const user = userEvent.setup()
const onConfirm = vi.fn()
const onClose = vi.fn()
render(
<DSLExportConfirmModal
envList={envList}
onConfirm={onConfirm}
onClose={onClose}
/>,
)
await user.click(screen.getByText('workflow.env.export.checkbox'))
await user.click(screen.getByRole('button', { name: 'workflow.env.export.export' }))
expect(onConfirm).toHaveBeenCalledWith(true)
expect(onClose).toHaveBeenCalledTimes(1)
})
it('should render border separators for all rows except the last one', () => {
render(
<DSLExportConfirmModal
envList={multiEnvList}
onConfirm={vi.fn()}
onClose={vi.fn()}
/>,
)
const firstNameCell = screen.getByText('SECRET_TOKEN').closest('td')
const lastNameCell = screen.getByText('SERVICE_KEY').closest('td')
const firstValueCell = screen.getByText('masked-value').closest('td')
const lastValueCell = screen.getByText('another-secret').closest('td')
expect(firstNameCell).toHaveClass('border-b')
expect(firstValueCell).toHaveClass('border-b')
expect(lastNameCell).not.toHaveClass('border-b')
expect(lastValueCell).not.toHaveClass('border-b')
})
})

View File

@@ -1,193 +0,0 @@
import type { InputVar } from '../types'
import type { PromptVariable } from '@/models/debug'
import { screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ReactFlow, { ReactFlowProvider, useNodes } from 'reactflow'
import Features from '../features'
import { InputVarType } from '../types'
import { createStartNode } from './fixtures'
import { renderWorkflowComponent } from './workflow-test-env'
const mockHandleSyncWorkflowDraft = vi.fn()
const mockHandleAddVariable = vi.fn()
let mockIsChatMode = true
let mockNodesReadOnly = false
vi.mock('../hooks', async () => {
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
return {
...actual,
useIsChatMode: () => mockIsChatMode,
useNodesReadOnly: () => ({
nodesReadOnly: mockNodesReadOnly,
}),
useNodesSyncDraft: () => ({
handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft,
}),
}
})
vi.mock('../nodes/start/use-config', () => ({
default: () => ({
handleAddVariable: mockHandleAddVariable,
}),
}))
vi.mock('@/app/components/base/features/new-feature-panel', () => ({
default: ({
show,
isChatMode,
disabled,
onChange,
onClose,
onAutoAddPromptVariable,
workflowVariables,
}: {
show: boolean
isChatMode: boolean
disabled: boolean
onChange: () => void
onClose: () => void
onAutoAddPromptVariable: (variables: PromptVariable[]) => void
workflowVariables: InputVar[]
}) => {
if (!show)
return null
return (
<section aria-label="new feature panel">
<div>{isChatMode ? 'chat mode' : 'completion mode'}</div>
<div>{disabled ? 'panel disabled' : 'panel enabled'}</div>
<ul aria-label="workflow variables">
{workflowVariables.map(variable => (
<li key={variable.variable}>
{`${variable.label}:${variable.variable}`}
</li>
))}
</ul>
<button type="button" onClick={onChange}>open features</button>
<button type="button" onClick={onClose}>close features</button>
<button
type="button"
onClick={() => onAutoAddPromptVariable([{
key: 'opening_statement',
name: 'Opening Statement',
type: 'string',
max_length: 200,
required: true,
}])}
>
add required variable
</button>
<button
type="button"
onClick={() => onAutoAddPromptVariable([{
key: 'optional_statement',
name: 'Optional Statement',
type: 'string',
max_length: 120,
}])}
>
add optional variable
</button>
</section>
)
},
}))
const startNode = createStartNode({
id: 'start-node',
data: {
variables: [{ variable: 'existing_variable', label: 'Existing Variable' }],
},
})
const DelayedFeatures = () => {
const nodes = useNodes()
if (!nodes.length)
return null
return <Features />
}
const renderFeatures = (options?: Parameters<typeof renderWorkflowComponent>[1]) => {
return renderWorkflowComponent(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow nodes={[startNode]} edges={[]} fitView />
<DelayedFeatures />
</ReactFlowProvider>
</div>,
options,
)
}
describe('Features', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsChatMode = true
mockNodesReadOnly = false
})
describe('Rendering', () => {
it('should pass workflow context to the feature panel', () => {
renderFeatures()
expect(screen.getByText('chat mode')).toBeInTheDocument()
expect(screen.getByText('panel enabled')).toBeInTheDocument()
expect(screen.getByRole('list', { name: 'workflow variables' })).toHaveTextContent('Existing Variable:existing_variable')
})
})
describe('User Interactions', () => {
it('should sync the draft and open the workflow feature panel when users change features', async () => {
const user = userEvent.setup()
const { store } = renderFeatures()
await user.click(screen.getByRole('button', { name: 'open features' }))
expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1)
expect(store.getState().showFeaturesPanel).toBe(true)
})
it('should close the workflow feature panel and transform required prompt variables', async () => {
const user = userEvent.setup()
const { store } = renderFeatures({
initialStoreState: {
showFeaturesPanel: true,
},
})
await user.click(screen.getByRole('button', { name: 'close features' }))
expect(store.getState().showFeaturesPanel).toBe(false)
await user.click(screen.getByRole('button', { name: 'add required variable' }))
expect(mockHandleAddVariable).toHaveBeenCalledWith({
variable: 'opening_statement',
label: 'Opening Statement',
type: InputVarType.textInput,
max_length: 200,
required: true,
options: [],
})
})
it('should default prompt variables to optional when required is omitted', async () => {
const user = userEvent.setup()
renderFeatures()
await user.click(screen.getByRole('button', { name: 'add optional variable' }))
expect(mockHandleAddVariable).toHaveBeenCalledWith({
variable: 'optional_statement',
label: 'Optional Statement',
type: InputVarType.textInput,
max_length: 120,
required: false,
options: [],
})
})
})
})

View File

@@ -16,8 +16,8 @@ import * as React from 'react'
type MockNode = {
id: string
position: { x: number, y: number }
width?: number | null
height?: number | null
width?: number
height?: number
parentId?: string
data: Record<string, unknown>
}

View File

@@ -1,22 +0,0 @@
import SyncingDataModal from '../syncing-data-modal'
import { renderWorkflowComponent } from './workflow-test-env'
describe('SyncingDataModal', () => {
it('should not render when workflow draft syncing is disabled', () => {
const { container } = renderWorkflowComponent(<SyncingDataModal />)
expect(container).toBeEmptyDOMElement()
})
it('should render the fullscreen overlay when workflow draft syncing is enabled', () => {
const { container } = renderWorkflowComponent(<SyncingDataModal />, {
initialStoreState: {
isSyncingWorkflowDraft: true,
},
})
const overlay = container.firstElementChild
expect(overlay).toHaveClass('absolute', 'inset-0')
expect(overlay).toHaveClass('z-[9999]')
})
})

View File

@@ -1,108 +0,0 @@
import type * as React from 'react'
import { act, renderHook } from '@testing-library/react'
import useCheckVerticalScrollbar from '../use-check-vertical-scrollbar'
const resizeObserve = vi.fn()
const resizeDisconnect = vi.fn()
const mutationObserve = vi.fn()
const mutationDisconnect = vi.fn()
let resizeCallback: ResizeObserverCallback | null = null
let mutationCallback: MutationCallback | null = null
class MockResizeObserver implements ResizeObserver {
observe = resizeObserve
unobserve = vi.fn()
disconnect = resizeDisconnect
constructor(callback: ResizeObserverCallback) {
resizeCallback = callback
}
}
class MockMutationObserver implements MutationObserver {
observe = mutationObserve
disconnect = mutationDisconnect
takeRecords = vi.fn(() => [])
constructor(callback: MutationCallback) {
mutationCallback = callback
}
}
const setElementHeights = (element: HTMLElement, scrollHeight: number, clientHeight: number) => {
Object.defineProperty(element, 'scrollHeight', {
configurable: true,
value: scrollHeight,
})
Object.defineProperty(element, 'clientHeight', {
configurable: true,
value: clientHeight,
})
}
describe('useCheckVerticalScrollbar', () => {
beforeEach(() => {
vi.clearAllMocks()
resizeCallback = null
mutationCallback = null
vi.stubGlobal('ResizeObserver', MockResizeObserver)
vi.stubGlobal('MutationObserver', MockMutationObserver)
})
afterEach(() => {
vi.unstubAllGlobals()
})
it('should return false when the element ref is empty', () => {
const ref = { current: null } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() => useCheckVerticalScrollbar(ref))
expect(result.current).toBe(false)
expect(resizeObserve).not.toHaveBeenCalled()
expect(mutationObserve).not.toHaveBeenCalled()
})
it('should detect the initial scrollbar state and react to observer updates', () => {
const element = document.createElement('div')
setElementHeights(element, 200, 100)
const ref = { current: element } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() => useCheckVerticalScrollbar(ref))
expect(result.current).toBe(true)
expect(resizeObserve).toHaveBeenCalledWith(element)
expect(mutationObserve).toHaveBeenCalledWith(element, {
childList: true,
subtree: true,
characterData: true,
})
setElementHeights(element, 100, 100)
act(() => {
resizeCallback?.([] as ResizeObserverEntry[], new MockResizeObserver(() => {}))
})
expect(result.current).toBe(false)
setElementHeights(element, 180, 100)
act(() => {
mutationCallback?.([] as MutationRecord[], new MockMutationObserver(() => {}))
})
expect(result.current).toBe(true)
})
it('should disconnect observers on unmount', () => {
const element = document.createElement('div')
setElementHeights(element, 120, 100)
const ref = { current: element } as React.RefObject<HTMLElement | null>
const { unmount } = renderHook(() => useCheckVerticalScrollbar(ref))
unmount()
expect(resizeDisconnect).toHaveBeenCalledTimes(1)
expect(mutationDisconnect).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,103 +0,0 @@
import type * as React from 'react'
import { act, renderHook } from '@testing-library/react'
import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll'
const setRect = (element: HTMLElement, top: number, height: number) => {
element.getBoundingClientRect = vi.fn(() => new DOMRect(0, top, 100, height))
}
describe('useStickyScroll', () => {
beforeEach(() => {
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
const runScroll = (handleScroll: () => void) => {
act(() => {
handleScroll()
vi.advanceTimersByTime(120)
})
}
it('should keep the default state when refs are missing', () => {
const wrapElemRef = { current: null } as React.RefObject<HTMLElement | null>
const nextToStickyELemRef = { current: null } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() =>
useStickyScroll({
wrapElemRef,
nextToStickyELemRef,
}),
)
runScroll(result.current.handleScroll)
expect(result.current.scrollPosition).toBe(ScrollPosition.belowTheWrap)
})
it('should mark the sticky element as below the wrapper when it is outside the visible area', () => {
const wrapElement = document.createElement('div')
const nextElement = document.createElement('div')
setRect(wrapElement, 100, 200)
setRect(nextElement, 320, 20)
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() =>
useStickyScroll({
wrapElemRef,
nextToStickyELemRef,
}),
)
runScroll(result.current.handleScroll)
expect(result.current.scrollPosition).toBe(ScrollPosition.belowTheWrap)
})
it('should mark the sticky element as showing when it is within the wrapper', () => {
const wrapElement = document.createElement('div')
const nextElement = document.createElement('div')
setRect(wrapElement, 100, 200)
setRect(nextElement, 220, 20)
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() =>
useStickyScroll({
wrapElemRef,
nextToStickyELemRef,
}),
)
runScroll(result.current.handleScroll)
expect(result.current.scrollPosition).toBe(ScrollPosition.showing)
})
it('should mark the sticky element as above the wrapper when it has scrolled past the top', () => {
const wrapElement = document.createElement('div')
const nextElement = document.createElement('div')
setRect(wrapElement, 100, 200)
setRect(nextElement, 90, 20)
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
const { result } = renderHook(() =>
useStickyScroll({
wrapElemRef,
nextToStickyELemRef,
}),
)
runScroll(result.current.handleScroll)
expect(result.current.scrollPosition).toBe(ScrollPosition.aboveTheWrap)
})
})

View File

@@ -1,108 +0,0 @@
import type { DataSourceItem } from '../types'
import { transformDataSourceToTool } from '../utils'
const createLocalizedText = (text: string) => ({
en_US: text,
zh_Hans: text,
})
const createDataSourceItem = (overrides: Partial<DataSourceItem> = {}): DataSourceItem => ({
plugin_id: 'plugin-1',
plugin_unique_identifier: 'plugin-1@provider',
provider: 'provider-a',
declaration: {
credentials_schema: [{ name: 'api_key' }],
provider_type: 'hosted',
identity: {
author: 'Dify',
description: createLocalizedText('Datasource provider'),
icon: 'provider-icon',
label: createLocalizedText('Provider A'),
name: 'provider-a',
tags: ['retrieval', 'storage'],
},
datasources: [
{
description: createLocalizedText('Search in documents'),
identity: {
author: 'Dify',
label: createLocalizedText('Document Search'),
name: 'document_search',
provider: 'provider-a',
},
parameters: [{ name: 'query', type: 'string' }],
output_schema: {
type: 'object',
properties: {
result: { type: 'string' },
},
},
},
],
},
is_authorized: true,
...overrides,
})
describe('transformDataSourceToTool', () => {
it('should map datasource provider fields to tool shape', () => {
const dataSourceItem = createDataSourceItem()
const result = transformDataSourceToTool(dataSourceItem)
expect(result).toMatchObject({
id: 'plugin-1',
provider: 'provider-a',
name: 'provider-a',
author: 'Dify',
description: createLocalizedText('Datasource provider'),
icon: 'provider-icon',
label: createLocalizedText('Provider A'),
type: 'hosted',
allow_delete: true,
is_authorized: true,
is_team_authorization: true,
labels: ['retrieval', 'storage'],
plugin_id: 'plugin-1',
plugin_unique_identifier: 'plugin-1@provider',
credentialsSchema: [{ name: 'api_key' }],
meta: { version: '' },
})
expect(result.team_credentials).toEqual({})
expect(result.tools).toEqual([
{
name: 'document_search',
author: 'Dify',
label: createLocalizedText('Document Search'),
description: createLocalizedText('Search in documents'),
parameters: [{ name: 'query', type: 'string' }],
labels: [],
output_schema: {
type: 'object',
properties: {
result: { type: 'string' },
},
},
},
])
})
it('should fallback to empty arrays when tags and credentials schema are missing', () => {
const baseDataSourceItem = createDataSourceItem()
const dataSourceItem = createDataSourceItem({
declaration: {
...baseDataSourceItem.declaration,
credentials_schema: undefined as unknown as DataSourceItem['declaration']['credentials_schema'],
identity: {
...baseDataSourceItem.declaration.identity,
tags: undefined as unknown as DataSourceItem['declaration']['identity']['tags'],
},
},
})
const result = transformDataSourceToTool(dataSourceItem)
expect(result.labels).toEqual([])
expect(result.credentialsSchema).toEqual([])
})
})

View File

@@ -1,57 +0,0 @@
import { fireEvent, render } from '@testing-library/react'
import ViewTypeSelect, { ViewType } from '../view-type-select'
const getViewOptions = (container: HTMLElement) => {
const options = container.firstElementChild?.children
if (!options || options.length !== 2)
throw new Error('Expected two view options')
return [options[0] as HTMLDivElement, options[1] as HTMLDivElement]
}
describe('ViewTypeSelect', () => {
it('should highlight the active view type', () => {
const onChange = vi.fn()
const { container } = render(
<ViewTypeSelect
viewType={ViewType.flat}
onChange={onChange}
/>,
)
const [flatOption, treeOption] = getViewOptions(container)
expect(flatOption).toHaveClass('bg-components-segmented-control-item-active-bg')
expect(treeOption).toHaveClass('cursor-pointer')
})
it('should call onChange when switching to a different view type', () => {
const onChange = vi.fn()
const { container } = render(
<ViewTypeSelect
viewType={ViewType.flat}
onChange={onChange}
/>,
)
const [, treeOption] = getViewOptions(container)
fireEvent.click(treeOption)
expect(onChange).toHaveBeenCalledWith(ViewType.tree)
expect(onChange).toHaveBeenCalledTimes(1)
})
it('should ignore clicks on the current view type', () => {
const onChange = vi.fn()
const { container } = render(
<ViewTypeSelect
viewType={ViewType.tree}
onChange={onChange}
/>,
)
const [, treeOption] = getViewOptions(container)
fireEvent.click(treeOption)
expect(onChange).not.toHaveBeenCalled()
})
})

View File

@@ -1,6 +1,6 @@
import { useEffect, useState } from 'react'
const useCheckVerticalScrollbar = (ref: React.RefObject<HTMLElement | null>) => {
const useCheckVerticalScrollbar = (ref: React.RefObject<HTMLElement>) => {
const [hasVerticalScrollbar, setHasVerticalScrollbar] = useState(false)
useEffect(() => {

View File

@@ -1,59 +0,0 @@
import { fireEvent, screen } from '@testing-library/react'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import ChatVariableButton from '../chat-variable-button'
let mockTheme: 'light' | 'dark' = 'light'
vi.mock('@/hooks/use-theme', () => ({
default: () => ({
theme: mockTheme,
}),
}))
describe('ChatVariableButton', () => {
beforeEach(() => {
vi.clearAllMocks()
mockTheme = 'light'
})
it('opens the chat variable panel and closes the other workflow panels', () => {
const { store } = renderWorkflowComponent(<ChatVariableButton disabled={false} />, {
initialStoreState: {
showEnvPanel: true,
showGlobalVariablePanel: true,
showDebugAndPreviewPanel: true,
},
})
fireEvent.click(screen.getByRole('button'))
expect(store.getState().showChatVariablePanel).toBe(true)
expect(store.getState().showEnvPanel).toBe(false)
expect(store.getState().showGlobalVariablePanel).toBe(false)
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
})
it('applies the active dark theme styles when the chat variable panel is visible', () => {
mockTheme = 'dark'
renderWorkflowComponent(<ChatVariableButton disabled={false} />, {
initialStoreState: {
showChatVariablePanel: true,
},
})
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
})
it('stays disabled without mutating panel state', () => {
const { store } = renderWorkflowComponent(<ChatVariableButton disabled />, {
initialStoreState: {
showChatVariablePanel: false,
},
})
fireEvent.click(screen.getByRole('button'))
expect(screen.getByRole('button')).toBeDisabled()
expect(store.getState().showChatVariablePanel).toBe(false)
})
})

View File

@@ -1,63 +0,0 @@
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import EditingTitle from '../editing-title'
const mockFormatTime = vi.fn()
const mockFormatTimeFromNow = vi.fn()
vi.mock('@/hooks/use-timestamp', () => ({
default: () => ({
formatTime: mockFormatTime,
}),
}))
vi.mock('@/hooks/use-format-time-from-now', () => ({
useFormatTimeFromNow: () => ({
formatTimeFromNow: mockFormatTimeFromNow,
}),
}))
describe('EditingTitle', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFormatTime.mockReturnValue('08:00:00')
mockFormatTimeFromNow.mockReturnValue('2 hours ago')
})
it('should render autosave, published time, and syncing status when the draft has metadata', () => {
const { container } = renderWorkflowComponent(<EditingTitle />, {
initialStoreState: {
draftUpdatedAt: 1_710_000_000_000,
publishedAt: 1_710_003_600_000,
isSyncingWorkflowDraft: true,
maximizeCanvas: true,
},
})
expect(mockFormatTime).toHaveBeenCalledWith(1_710_000_000, 'HH:mm:ss')
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(1_710_003_600_000)
expect(container.firstChild).toHaveClass('ml-2')
expect(container).toHaveTextContent('workflow.common.autoSaved')
expect(container).toHaveTextContent('08:00:00')
expect(container).toHaveTextContent('workflow.common.published')
expect(container).toHaveTextContent('2 hours ago')
expect(container).toHaveTextContent('workflow.common.syncingData')
})
it('should render unpublished status without autosave metadata when the workflow has not been published', () => {
const { container } = renderWorkflowComponent(<EditingTitle />, {
initialStoreState: {
draftUpdatedAt: 0,
publishedAt: 0,
isSyncingWorkflowDraft: false,
maximizeCanvas: false,
},
})
expect(mockFormatTime).not.toHaveBeenCalled()
expect(mockFormatTimeFromNow).not.toHaveBeenCalled()
expect(container.firstChild).not.toHaveClass('ml-2')
expect(container).toHaveTextContent('workflow.common.unpublished')
expect(container).not.toHaveTextContent('workflow.common.autoSaved')
expect(container).not.toHaveTextContent('workflow.common.syncingData')
})
})

View File

@@ -1,68 +0,0 @@
import { fireEvent, screen } from '@testing-library/react'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import EnvButton from '../env-button'
const mockCloseAllInputFieldPanels = vi.fn()
let mockTheme: 'light' | 'dark' = 'light'
vi.mock('@/hooks/use-theme', () => ({
default: () => ({
theme: mockTheme,
}),
}))
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
useInputFieldPanel: () => ({
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
}),
}))
describe('EnvButton', () => {
beforeEach(() => {
vi.clearAllMocks()
mockTheme = 'light'
})
it('should open the environment panel and close the other panels when clicked', () => {
const { store } = renderWorkflowComponent(<EnvButton disabled={false} />, {
initialStoreState: {
showChatVariablePanel: true,
showGlobalVariablePanel: true,
showDebugAndPreviewPanel: true,
},
})
fireEvent.click(screen.getByRole('button'))
expect(store.getState().showEnvPanel).toBe(true)
expect(store.getState().showChatVariablePanel).toBe(false)
expect(store.getState().showGlobalVariablePanel).toBe(false)
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
})
it('should apply the active dark theme styles when the environment panel is visible', () => {
mockTheme = 'dark'
renderWorkflowComponent(<EnvButton disabled={false} />, {
initialStoreState: {
showEnvPanel: true,
},
})
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
})
it('should keep the button disabled when the disabled prop is true', () => {
const { store } = renderWorkflowComponent(<EnvButton disabled />, {
initialStoreState: {
showEnvPanel: false,
},
})
fireEvent.click(screen.getByRole('button'))
expect(screen.getByRole('button')).toBeDisabled()
expect(store.getState().showEnvPanel).toBe(false)
expect(mockCloseAllInputFieldPanels).not.toHaveBeenCalled()
})
})

View File

@@ -1,68 +0,0 @@
import { fireEvent, screen } from '@testing-library/react'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import GlobalVariableButton from '../global-variable-button'
const mockCloseAllInputFieldPanels = vi.fn()
let mockTheme: 'light' | 'dark' = 'light'
vi.mock('@/hooks/use-theme', () => ({
default: () => ({
theme: mockTheme,
}),
}))
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
useInputFieldPanel: () => ({
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
}),
}))
describe('GlobalVariableButton', () => {
beforeEach(() => {
vi.clearAllMocks()
mockTheme = 'light'
})
it('should open the global variable panel and close the other panels when clicked', () => {
const { store } = renderWorkflowComponent(<GlobalVariableButton disabled={false} />, {
initialStoreState: {
showEnvPanel: true,
showChatVariablePanel: true,
showDebugAndPreviewPanel: true,
},
})
fireEvent.click(screen.getByRole('button'))
expect(store.getState().showGlobalVariablePanel).toBe(true)
expect(store.getState().showEnvPanel).toBe(false)
expect(store.getState().showChatVariablePanel).toBe(false)
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
})
it('should apply the active dark theme styles when the global variable panel is visible', () => {
mockTheme = 'dark'
renderWorkflowComponent(<GlobalVariableButton disabled={false} />, {
initialStoreState: {
showGlobalVariablePanel: true,
},
})
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
})
it('should keep the button disabled when the disabled prop is true', () => {
const { store } = renderWorkflowComponent(<GlobalVariableButton disabled />, {
initialStoreState: {
showGlobalVariablePanel: false,
},
})
fireEvent.click(screen.getByRole('button'))
expect(screen.getByRole('button')).toBeDisabled()
expect(store.getState().showGlobalVariablePanel).toBe(false)
expect(mockCloseAllInputFieldPanels).not.toHaveBeenCalled()
})
})

View File

@@ -1,109 +0,0 @@
import type { VersionHistory } from '@/types/workflow'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import { WorkflowVersion } from '../../types'
import RestoringTitle from '../restoring-title'
const mockFormatTime = vi.fn()
const mockFormatTimeFromNow = vi.fn()
vi.mock('@/hooks/use-timestamp', () => ({
default: () => ({
formatTime: mockFormatTime,
}),
}))
vi.mock('@/hooks/use-format-time-from-now', () => ({
useFormatTimeFromNow: () => ({
formatTimeFromNow: mockFormatTimeFromNow,
}),
}))
const createVersion = (overrides: Partial<VersionHistory> = {}): VersionHistory => ({
id: 'version-1',
graph: {
nodes: [],
edges: [],
},
created_at: 1_700_000_000,
created_by: {
id: 'user-1',
name: 'Alice',
email: 'alice@example.com',
},
hash: 'hash-1',
updated_at: 1_700_000_100,
updated_by: {
id: 'user-2',
name: 'Bob',
email: 'bob@example.com',
},
tool_published: false,
version: 'v1',
marked_name: 'Release 1',
marked_comment: '',
...overrides,
})
describe('RestoringTitle', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFormatTime.mockReturnValue('09:30:00')
mockFormatTimeFromNow.mockReturnValue('3 hours ago')
})
it('should render draft metadata when the current version is a draft', () => {
const currentVersion = createVersion({
version: WorkflowVersion.Draft,
})
const { container } = renderWorkflowComponent(<RestoringTitle />, {
initialStoreState: {
currentVersion,
},
})
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(currentVersion.updated_at * 1000)
expect(mockFormatTime).toHaveBeenCalledWith(currentVersion.created_at, 'HH:mm:ss')
expect(container).toHaveTextContent('workflow.versionHistory.currentDraft')
expect(container).toHaveTextContent('workflow.common.viewOnly')
expect(container).toHaveTextContent('workflow.common.unpublished')
expect(container).toHaveTextContent('3 hours ago 09:30:00')
expect(container).toHaveTextContent('Alice')
})
it('should render published metadata and fallback version name when the marked name is empty', () => {
const currentVersion = createVersion({
marked_name: '',
})
const { container } = renderWorkflowComponent(<RestoringTitle />, {
initialStoreState: {
currentVersion,
},
})
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(currentVersion.created_at * 1000)
expect(container).toHaveTextContent('workflow.versionHistory.defaultName')
expect(container).toHaveTextContent('workflow.common.published')
expect(container).toHaveTextContent('Alice')
})
it('should render an empty creator name when the version creator name is missing', () => {
const currentVersion = createVersion({
created_by: {
id: 'user-1',
name: '',
email: 'alice@example.com',
},
})
const { container } = renderWorkflowComponent(<RestoringTitle />, {
initialStoreState: {
currentVersion,
},
})
expect(container).toHaveTextContent('workflow.common.published')
expect(container).not.toHaveTextContent('Alice')
})
})

View File

@@ -1,61 +0,0 @@
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import RunningTitle from '../running-title'
let mockIsChatMode = false
const mockFormatWorkflowRunIdentifier = vi.fn()
vi.mock('../../hooks', () => ({
useIsChatMode: () => mockIsChatMode,
}))
vi.mock('../../utils', () => ({
formatWorkflowRunIdentifier: (finishedAt?: number) => mockFormatWorkflowRunIdentifier(finishedAt),
}))
describe('RunningTitle', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsChatMode = false
mockFormatWorkflowRunIdentifier.mockReturnValue(' (14:30:25)')
})
it('should render the test run title in workflow mode', () => {
const { container } = renderWorkflowComponent(<RunningTitle />, {
initialStoreState: {
historyWorkflowData: {
id: 'history-1',
status: 'succeeded',
finished_at: 1_700_000_000,
},
},
})
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(1_700_000_000)
expect(container).toHaveTextContent('Test Run (14:30:25)')
expect(container).toHaveTextContent('workflow.common.viewOnly')
})
it('should render the test chat title in chat mode', () => {
mockIsChatMode = true
const { container } = renderWorkflowComponent(<RunningTitle />, {
initialStoreState: {
historyWorkflowData: {
id: 'history-2',
status: 'running',
finished_at: undefined,
},
},
})
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(undefined)
expect(container).toHaveTextContent('Test Chat (14:30:25)')
})
it('should handle missing workflow history data', () => {
const { container } = renderWorkflowComponent(<RunningTitle />)
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(undefined)
expect(container).toHaveTextContent('Test Run (14:30:25)')
})
})

View File

@@ -1,53 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { createNode } from '../../__tests__/fixtures'
import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state'
import ScrollToSelectedNodeButton from '../scroll-to-selected-node-button'
const mockScrollToWorkflowNode = vi.fn()
vi.mock('reactflow', async () =>
(await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock())
vi.mock('../../utils/node-navigation', () => ({
scrollToWorkflowNode: (nodeId: string) => mockScrollToWorkflowNode(nodeId),
}))
describe('ScrollToSelectedNodeButton', () => {
beforeEach(() => {
vi.clearAllMocks()
resetReactFlowMockState()
})
it('should render nothing when there is no selected node', () => {
rfState.nodes = [
createNode({
id: 'node-1',
data: { selected: false },
}),
]
const { container } = render(<ScrollToSelectedNodeButton />)
expect(container.firstChild).toBeNull()
})
it('should render the action and scroll to the selected node when clicked', () => {
rfState.nodes = [
createNode({
id: 'node-1',
data: { selected: false },
}),
createNode({
id: 'node-2',
data: { selected: true },
}),
]
render(<ScrollToSelectedNodeButton />)
fireEvent.click(screen.getByText('workflow.panel.scrollToSelectedNode'))
expect(mockScrollToWorkflowNode).toHaveBeenCalledWith('node-2')
expect(mockScrollToWorkflowNode).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,118 +0,0 @@
import { act, fireEvent, render, screen } from '@testing-library/react'
import UndoRedo from '../undo-redo'
type TemporalSnapshot = {
pastStates: unknown[]
futureStates: unknown[]
}
const mockUnsubscribe = vi.fn()
const mockTemporalSubscribe = vi.fn()
const mockHandleUndo = vi.fn()
const mockHandleRedo = vi.fn()
let latestTemporalListener: ((state: TemporalSnapshot) => void) | undefined
let mockNodesReadOnly = false
vi.mock('@/app/components/workflow/header/view-workflow-history', () => ({
default: () => <div data-testid="view-workflow-history" />,
}))
vi.mock('@/app/components/workflow/hooks', () => ({
useNodesReadOnly: () => ({
nodesReadOnly: mockNodesReadOnly,
}),
}))
vi.mock('@/app/components/workflow/workflow-history-store', () => ({
useWorkflowHistoryStore: () => ({
store: {
temporal: {
subscribe: mockTemporalSubscribe,
},
},
shortcutsEnabled: true,
setShortcutsEnabled: vi.fn(),
}),
}))
vi.mock('@/app/components/base/divider', () => ({
default: () => <div data-testid="divider" />,
}))
vi.mock('@/app/components/workflow/operator/tip-popup', () => ({
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
}))
describe('UndoRedo', () => {
beforeEach(() => {
vi.clearAllMocks()
mockNodesReadOnly = false
latestTemporalListener = undefined
mockTemporalSubscribe.mockImplementation((listener: (state: TemporalSnapshot) => void) => {
latestTemporalListener = listener
return mockUnsubscribe
})
})
it('enables undo and redo when history exists and triggers the callbacks', () => {
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
act(() => {
latestTemporalListener?.({
pastStates: [{}],
futureStates: [{}],
})
})
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.undo' }))
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.redo' }))
expect(mockHandleUndo).toHaveBeenCalledTimes(1)
expect(mockHandleRedo).toHaveBeenCalledTimes(1)
})
it('keeps the buttons disabled before history is available', () => {
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
const undoButton = screen.getByRole('button', { name: 'workflow.common.undo' })
const redoButton = screen.getByRole('button', { name: 'workflow.common.redo' })
fireEvent.click(undoButton)
fireEvent.click(redoButton)
expect(undoButton).toBeDisabled()
expect(redoButton).toBeDisabled()
expect(mockHandleUndo).not.toHaveBeenCalled()
expect(mockHandleRedo).not.toHaveBeenCalled()
})
it('does not trigger callbacks when the canvas is read only', () => {
mockNodesReadOnly = true
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
const undoButton = screen.getByRole('button', { name: 'workflow.common.undo' })
const redoButton = screen.getByRole('button', { name: 'workflow.common.redo' })
act(() => {
latestTemporalListener?.({
pastStates: [{}],
futureStates: [{}],
})
})
fireEvent.click(undoButton)
fireEvent.click(redoButton)
expect(undoButton).toBeDisabled()
expect(redoButton).toBeDisabled()
expect(mockHandleUndo).not.toHaveBeenCalled()
expect(mockHandleRedo).not.toHaveBeenCalled()
})
it('unsubscribes from the temporal store on unmount', () => {
const { unmount } = render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
unmount()
expect(mockUnsubscribe).toHaveBeenCalledTimes(1)
})
})

View File

@@ -1,68 +0,0 @@
import { fireEvent, render, screen } from '@testing-library/react'
import VersionHistoryButton from '../version-history-button'
let mockTheme: 'light' | 'dark' = 'light'
vi.mock('@/hooks/use-theme', () => ({
default: () => ({
theme: mockTheme,
}),
}))
vi.mock('../../utils', async (importOriginal) => {
const actual = await importOriginal<typeof import('../../utils')>()
return {
...actual,
getKeyboardKeyCodeBySystem: () => 'ctrl',
}
})
describe('VersionHistoryButton', () => {
beforeEach(() => {
vi.clearAllMocks()
mockTheme = 'light'
})
it('should call onClick when the button is clicked', () => {
const onClick = vi.fn()
render(<VersionHistoryButton onClick={onClick} />)
fireEvent.click(screen.getByRole('button'))
expect(onClick).toHaveBeenCalledTimes(1)
})
it('should trigger onClick when the version history shortcut is pressed', () => {
const onClick = vi.fn()
render(<VersionHistoryButton onClick={onClick} />)
const keyboardEvent = new KeyboardEvent('keydown', {
key: 'H',
ctrlKey: true,
shiftKey: true,
bubbles: true,
cancelable: true,
})
Object.defineProperty(keyboardEvent, 'keyCode', { value: 72 })
Object.defineProperty(keyboardEvent, 'which', { value: 72 })
window.dispatchEvent(keyboardEvent)
expect(keyboardEvent.defaultPrevented).toBe(true)
expect(onClick).toHaveBeenCalledTimes(1)
})
it('should render the tooltip popup content on hover', async () => {
render(<VersionHistoryButton onClick={vi.fn()} />)
fireEvent.mouseEnter(screen.getByRole('button'))
expect(await screen.findByText('workflow.common.versionHistory')).toBeInTheDocument()
})
it('should apply dark theme styles when the theme is dark', () => {
mockTheme = 'dark'
render(<VersionHistoryButton onClick={vi.fn()} />)
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
})
})

View File

@@ -1,276 +0,0 @@
import type { WorkflowRunHistory, WorkflowRunHistoryResponse } from '@/types/workflow'
import { fireEvent, screen } from '@testing-library/react'
import * as React from 'react'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import { ControlMode, WorkflowRunningStatus } from '../../types'
import ViewHistory from '../view-history'
const mockUseWorkflowRunHistory = vi.fn()
const mockFormatTimeFromNow = vi.fn((value: number) => `from-now:${value}`)
const mockCloseAllInputFieldPanels = vi.fn()
const mockHandleNodesCancelSelected = vi.fn()
const mockHandleCancelDebugAndPreviewPanel = vi.fn()
const mockFormatWorkflowRunIdentifier = vi.fn((finishedAt?: number, status?: string) => ` (${status || finishedAt || 'unknown'})`)
let mockIsChatMode = false
vi.mock('../../hooks', async () => {
const actual = await vi.importActual<typeof import('../../hooks')>('../../hooks')
return {
...actual,
useIsChatMode: () => mockIsChatMode,
useNodesInteractions: () => ({
handleNodesCancelSelected: mockHandleNodesCancelSelected,
}),
useWorkflowInteractions: () => ({
handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel,
}),
}
})
vi.mock('@/service/use-workflow', () => ({
useWorkflowRunHistory: (url?: string, enabled?: boolean) => mockUseWorkflowRunHistory(url, enabled),
}))
vi.mock('@/hooks/use-format-time-from-now', () => ({
useFormatTimeFromNow: () => ({
formatTimeFromNow: mockFormatTimeFromNow,
}),
}))
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
useInputFieldPanel: () => ({
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
}),
}))
vi.mock('@/app/components/base/loading', () => ({
default: () => <div data-testid="loading" />,
}))
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
}))
vi.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
return {
PortalToFollowElem: ({
children,
open,
}: {
children?: React.ReactNode
open: boolean
}) => <PortalContext.Provider value={{ open }}>{children}</PortalContext.Provider>,
PortalToFollowElemTrigger: ({
children,
onClick,
}: {
children?: React.ReactNode
onClick?: () => void
}) => <div data-testid="portal-trigger" onClick={onClick}>{children}</div>,
PortalToFollowElemContent: ({
children,
}: {
children?: React.ReactNode
}) => {
const { open } = React.useContext(PortalContext)
return open ? <div data-testid="portal-content">{children}</div> : null
},
}
})
vi.mock('../../utils', async () => {
const actual = await vi.importActual<typeof import('../../utils')>('../../utils')
return {
...actual,
formatWorkflowRunIdentifier: (finishedAt?: number, status?: string) => mockFormatWorkflowRunIdentifier(finishedAt, status),
}
})
const createHistoryItem = (overrides: Partial<WorkflowRunHistory> = {}): WorkflowRunHistory => ({
id: 'run-1',
version: 'v1',
graph: {
nodes: [],
edges: [],
},
inputs: {},
status: WorkflowRunningStatus.Succeeded,
outputs: {},
elapsed_time: 1,
total_tokens: 2,
total_steps: 3,
created_at: 100,
finished_at: 120,
created_by_account: {
id: 'user-1',
name: 'Alice',
email: 'alice@example.com',
},
...overrides,
})
describe('ViewHistory', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsChatMode = false
mockUseWorkflowRunHistory.mockReturnValue({
data: { data: [] } satisfies WorkflowRunHistoryResponse,
isLoading: false,
})
})
it('defers fetching until the history popup is opened and renders the empty state', () => {
renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
hooksStoreProps: {
handleBackupDraft: vi.fn(),
},
})
expect(mockUseWorkflowRunHistory).toHaveBeenCalledWith('/history', false)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
expect(mockUseWorkflowRunHistory).toHaveBeenLastCalledWith('/history', true)
expect(screen.getByText('workflow.common.notRunning')).toBeInTheDocument()
expect(screen.getByText('workflow.common.showRunHistory')).toBeInTheDocument()
})
it('renders the icon trigger variant and loading state, and clears log modals on trigger click', () => {
const onClearLogAndMessageModal = vi.fn()
mockUseWorkflowRunHistory.mockReturnValue({
data: { data: [] } satisfies WorkflowRunHistoryResponse,
isLoading: true,
})
renderWorkflowComponent(
<ViewHistory
historyUrl="/history"
onClearLogAndMessageModal={onClearLogAndMessageModal}
/>,
{
hooksStoreProps: {
handleBackupDraft: vi.fn(),
},
},
)
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.viewRunHistory' }))
expect(onClearLogAndMessageModal).toHaveBeenCalledTimes(1)
expect(screen.getByTestId('loading')).toBeInTheDocument()
})
it('renders workflow run history items and updates the workflow store when one is selected', () => {
const handleBackupDraft = vi.fn()
const pausedRun = createHistoryItem({
id: 'run-paused',
status: WorkflowRunningStatus.Paused,
created_at: 101,
finished_at: 0,
})
const failedRun = createHistoryItem({
id: 'run-failed',
status: WorkflowRunningStatus.Failed,
created_at: 102,
finished_at: 130,
})
const succeededRun = createHistoryItem({
id: 'run-succeeded',
status: WorkflowRunningStatus.Succeeded,
created_at: 103,
finished_at: 140,
})
mockUseWorkflowRunHistory.mockReturnValue({
data: {
data: [pausedRun, failedRun, succeededRun],
} satisfies WorkflowRunHistoryResponse,
isLoading: false,
})
const { store } = renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
initialStoreState: {
historyWorkflowData: failedRun,
showInputsPanel: true,
showEnvPanel: true,
controlMode: ControlMode.Pointer,
},
hooksStoreProps: {
handleBackupDraft,
},
})
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
expect(screen.getByText('Test Run (paused)')).toBeInTheDocument()
expect(screen.getByText('Test Run (failed)')).toBeInTheDocument()
expect(screen.getByText('Test Run (succeeded)')).toBeInTheDocument()
fireEvent.click(screen.getByText('Test Run (succeeded)'))
expect(store.getState().historyWorkflowData).toEqual(succeededRun)
expect(store.getState().showInputsPanel).toBe(false)
expect(store.getState().showEnvPanel).toBe(false)
expect(store.getState().controlMode).toBe(ControlMode.Hand)
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
expect(handleBackupDraft).toHaveBeenCalledTimes(1)
expect(mockHandleNodesCancelSelected).toHaveBeenCalledTimes(1)
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
})
it('renders chat history labels without workflow status icons in chat mode', () => {
mockIsChatMode = true
const chatRun = createHistoryItem({
id: 'chat-run',
status: WorkflowRunningStatus.Failed,
})
mockUseWorkflowRunHistory.mockReturnValue({
data: {
data: [chatRun],
} satisfies WorkflowRunHistoryResponse,
isLoading: false,
})
renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
hooksStoreProps: {
handleBackupDraft: vi.fn(),
},
})
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
expect(screen.getByText('Test Chat (failed)')).toBeInTheDocument()
})
it('closes the popup from the close button and clears log modals', () => {
const onClearLogAndMessageModal = vi.fn()
mockUseWorkflowRunHistory.mockReturnValue({
data: { data: [] } satisfies WorkflowRunHistoryResponse,
isLoading: false,
})
renderWorkflowComponent(
<ViewHistory
historyUrl="/history"
withText
onClearLogAndMessageModal={onClearLogAndMessageModal}
/>,
{
hooksStoreProps: {
handleBackupDraft: vi.fn(),
},
},
)
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' }))
expect(onClearLogAndMessageModal).toHaveBeenCalledTimes(1)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
})
})

View File

@@ -1,5 +1,6 @@
import type { FC } from 'react'
import type { CommonNodeType } from '../types'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { useNodes } from 'reactflow'
import { cn } from '@/utils/classnames'
@@ -10,15 +11,21 @@ const ScrollToSelectedNodeButton: FC = () => {
const nodes = useNodes<CommonNodeType>()
const selectedNode = nodes.find(node => node.data.selected)
const handleScrollToSelectedNode = useCallback(() => {
if (!selectedNode)
return
scrollToWorkflowNode(selectedNode.id)
}, [selectedNode])
if (!selectedNode)
return null
return (
<div
className={cn(
'flex h-6 cursor-pointer items-center justify-center whitespace-nowrap rounded-md border-[0.5px] border-effects-highlight bg-components-actionbar-bg px-3 text-text-tertiary shadow-lg backdrop-blur-sm transition-colors duration-200 system-xs-medium hover:text-text-accent',
'system-xs-medium flex h-6 cursor-pointer items-center justify-center whitespace-nowrap rounded-md border-[0.5px] border-effects-highlight bg-components-actionbar-bg px-3 text-text-tertiary shadow-lg backdrop-blur-sm transition-colors duration-200 hover:text-text-accent',
)}
onClick={() => scrollToWorkflowNode(selectedNode.id)}
onClick={handleScrollToSelectedNode}
>
{t('panel.scrollToSelectedNode', { ns: 'workflow' })}
</div>

View File

@@ -1,4 +1,8 @@
import type { FC } from 'react'
import {
RiArrowGoBackLine,
RiArrowGoForwardFill,
} from '@remixicon/react'
import { memo, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ViewWorkflowHistory from '@/app/components/workflow/header/view-workflow-history'
@@ -29,34 +33,28 @@ const UndoRedo: FC<UndoRedoProps> = ({ handleUndo, handleRedo }) => {
return (
<div className="flex items-center space-x-0.5 rounded-lg border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-lg backdrop-blur-[5px]">
<TipPopup title={t('common.undo', { ns: 'workflow' })!} shortcuts={['ctrl', 'z']}>
<button
type="button"
aria-label={t('common.undo', { ns: 'workflow' })!}
<div
data-tooltip-id="workflow.undo"
disabled={nodesReadOnly || buttonsDisabled.undo}
className={
cn('system-sm-medium flex h-8 w-8 cursor-pointer select-none items-center rounded-md px-1.5 text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary', (nodesReadOnly || buttonsDisabled.undo)
&& 'cursor-not-allowed text-text-disabled hover:bg-transparent hover:text-text-disabled')
}
onClick={handleUndo}
onClick={() => !nodesReadOnly && !buttonsDisabled.undo && handleUndo()}
>
<span className="i-ri-arrow-go-back-line h-4 w-4" />
</button>
<RiArrowGoBackLine className="h-4 w-4" />
</div>
</TipPopup>
<TipPopup title={t('common.redo', { ns: 'workflow' })!} shortcuts={['ctrl', 'y']}>
<button
type="button"
aria-label={t('common.redo', { ns: 'workflow' })!}
<div
data-tooltip-id="workflow.redo"
disabled={nodesReadOnly || buttonsDisabled.redo}
className={
cn('system-sm-medium flex h-8 w-8 cursor-pointer select-none items-center rounded-md px-1.5 text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary', (nodesReadOnly || buttonsDisabled.redo)
&& 'cursor-not-allowed text-text-disabled hover:bg-transparent hover:text-text-disabled')
}
onClick={handleRedo}
onClick={() => !nodesReadOnly && !buttonsDisabled.redo && handleRedo()}
>
<span className="i-ri-arrow-go-forward-fill h-4 w-4" />
</button>
<RiArrowGoForwardFill className="h-4 w-4" />
</div>
</TipPopup>
<Divider type="vertical" className="mx-0.5 h-3.5" />
<ViewWorkflowHistory />

View File

@@ -73,18 +73,15 @@ const ViewHistory = ({
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
{
withText && (
<button
type="button"
aria-label={t('common.showRunHistory', { ns: 'workflow' })}
className={cn(
'flex h-8 items-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3 shadow-xs',
'cursor-pointer text-[13px] font-medium text-components-button-secondary-text hover:bg-components-button-secondary-bg-hover',
open && 'bg-components-button-secondary-bg-hover',
)}
<div className={cn(
'flex h-8 items-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3 shadow-xs',
'cursor-pointer text-[13px] font-medium text-components-button-secondary-text hover:bg-components-button-secondary-bg-hover',
open && 'bg-components-button-secondary-bg-hover',
)}
>
<span className="i-custom-vender-line-time-clock-play mr-1 h-4 w-4" />
{t('common.showRunHistory', { ns: 'workflow' })}
</button>
</div>
)
}
{
@@ -92,16 +89,14 @@ const ViewHistory = ({
<Tooltip
popupContent={t('common.viewRunHistory', { ns: 'workflow' })}
>
<button
type="button"
aria-label={t('common.viewRunHistory', { ns: 'workflow' })}
<div
className={cn('group flex h-7 w-7 cursor-pointer items-center justify-center rounded-md hover:bg-state-accent-hover', open && 'bg-state-accent-hover')}
onClick={() => {
onClearLogAndMessageModal?.()
}}
>
<span className={cn('i-custom-vender-line-time-clock-play', 'h-4 w-4 group-hover:text-components-button-secondary-accent-text', open ? 'text-components-button-secondary-accent-text' : 'text-components-button-ghost-text')} />
</button>
</div>
</Tooltip>
)
}
@@ -115,9 +110,7 @@ const ViewHistory = ({
>
<div className="sticky top-0 flex items-center justify-between bg-components-panel-bg px-4 pt-3 text-base font-semibold text-text-primary">
<div className="grow">{t('common.runHistory', { ns: 'workflow' })}</div>
<button
type="button"
aria-label={t('operation.close', { ns: 'common' })}
<div
className="flex h-6 w-6 shrink-0 cursor-pointer items-center justify-center"
onClick={() => {
onClearLogAndMessageModal?.()
@@ -125,7 +118,7 @@ const ViewHistory = ({
}}
>
<span className="i-ri-close-line h-4 w-4 text-text-tertiary" />
</button>
</div>
</div>
{
isLoading && (

View File

@@ -1,36 +1,54 @@
import type { CommonNodeType } from '../../../types'
import { fireEvent, screen } from '@testing-library/react'
import { renderWorkflowComponent } from '../../../__tests__/workflow-test-env'
import { fireEvent, render, screen } from '@testing-library/react'
import { BlockEnum, NodeRunningStatus } from '../../../types'
import NodeControl from './node-control'
const {
mockHandleNodeSelect,
mockSetInitShowLastRunTab,
mockSetPendingSingleRun,
mockCanRunBySingle,
} = vi.hoisted(() => ({
mockHandleNodeSelect: vi.fn(),
mockSetInitShowLastRunTab: vi.fn(),
mockSetPendingSingleRun: vi.fn(),
mockCanRunBySingle: vi.fn(() => true),
}))
let mockPluginInstallLocked = false
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
vi.mock('../../../hooks', async () => {
const actual = await vi.importActual<typeof import('../../../hooks')>('../../../hooks')
return {
...actual,
useNodesInteractions: () => ({
handleNodeSelect: mockHandleNodeSelect,
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
<div data-testid="tooltip" data-content={popupContent}>{children}</div>
),
}))
vi.mock('@/app/components/base/icons/src/vender/line/mediaAndDevices', () => ({
Stop: ({ className }: { className?: string }) => <div data-testid="stop-icon" className={className} />,
}))
vi.mock('../../../hooks', () => ({
useNodesInteractions: () => ({
handleNodeSelect: mockHandleNodeSelect,
}),
}))
vi.mock('@/app/components/workflow/store', () => ({
useWorkflowStore: () => ({
getState: () => ({
setInitShowLastRunTab: mockSetInitShowLastRunTab,
setPendingSingleRun: mockSetPendingSingleRun,
}),
}
})
}),
}))
vi.mock('../../../utils', async () => {
const actual = await vi.importActual<typeof import('../../../utils')>('../../../utils')
return {
...actual,
canRunBySingle: mockCanRunBySingle,
}
})
vi.mock('../../../utils', () => ({
canRunBySingle: mockCanRunBySingle,
}))
vi.mock('./panel-operator', () => ({
default: ({ onOpenChange }: { onOpenChange: (open: boolean) => void }) => (
@@ -41,16 +59,6 @@ vi.mock('./panel-operator', () => ({
),
}))
function NodeControlHarness({ id, data }: { id: string, data: CommonNodeType, selected?: boolean }) {
return (
<NodeControl
id={id}
data={data}
pluginInstallLocked={mockPluginInstallLocked}
/>
)
}
const makeData = (overrides: Partial<CommonNodeType> = {}): CommonNodeType => ({
type: BlockEnum.Code,
title: 'Node',
@@ -65,71 +73,65 @@ const makeData = (overrides: Partial<CommonNodeType> = {}): CommonNodeType => ({
describe('NodeControl', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPluginInstallLocked = false
mockCanRunBySingle.mockReturnValue(true)
})
// Run/stop behavior should be driven by the workflow store, not CSS classes.
describe('Single Run Actions', () => {
it('should trigger a single run through the workflow store', () => {
const { store } = renderWorkflowComponent(
<NodeControlHarness id="node-1" data={makeData()} />,
)
it('should trigger a single run and show the hover control when plugins are not locked', () => {
const { container } = render(
<NodeControl
id="node-1"
data={makeData()}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'workflow.panel.runThisStep' }))
const wrapper = container.firstChild as HTMLElement
expect(wrapper.className).toContain('group-hover:flex')
expect(screen.getByTestId('tooltip')).toHaveAttribute('data-content', 'panel.runThisStep')
expect(store.getState().initShowLastRunTab).toBe(true)
expect(store.getState().pendingSingleRun).toEqual({ nodeId: 'node-1', action: 'run' })
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-1')
})
fireEvent.click(screen.getByTestId('tooltip').parentElement!)
it('should trigger stop when the node is already single-running', () => {
const { store } = renderWorkflowComponent(
<NodeControlHarness
id="node-2"
data={makeData({
selected: true,
_singleRunningStatus: NodeRunningStatus.Running,
})}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'workflow.debug.variableInspect.trigger.stop' }))
expect(store.getState().pendingSingleRun).toEqual({ nodeId: 'node-2', action: 'stop' })
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-2')
})
expect(mockSetInitShowLastRunTab).toHaveBeenCalledWith(true)
expect(mockSetPendingSingleRun).toHaveBeenCalledWith({ nodeId: 'node-1', action: 'run' })
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-1')
})
// Capability gating should hide the run control while leaving panel actions available.
describe('Availability', () => {
it('should keep the panel operator available when the plugin is install-locked', () => {
mockPluginInstallLocked = true
it('should render the stop action, keep locked controls hidden by default, and stay open when panel operator opens', () => {
const { container } = render(
<NodeControl
id="node-2"
pluginInstallLocked
data={makeData({
selected: true,
_singleRunningStatus: NodeRunningStatus.Running,
isInIteration: true,
})}
/>,
)
renderWorkflowComponent(
<NodeControlHarness
id="node-3"
data={makeData({
selected: true,
})}
/>,
)
const wrapper = container.firstChild as HTMLElement
expect(wrapper.className).not.toContain('group-hover:flex')
expect(wrapper.className).toContain('!flex')
expect(screen.getByTestId('stop-icon')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
})
fireEvent.click(screen.getByTestId('stop-icon').parentElement!)
it('should hide the run control when single-node execution is not supported', () => {
mockCanRunBySingle.mockReturnValue(false)
expect(mockSetPendingSingleRun).toHaveBeenCalledWith({ nodeId: 'node-2', action: 'stop' })
renderWorkflowComponent(
<NodeControlHarness
id="node-4"
data={makeData()}
/>,
)
fireEvent.click(screen.getByRole('button', { name: 'open panel' }))
expect(wrapper.className).toContain('!flex')
})
expect(screen.queryByRole('button', { name: 'workflow.panel.runThisStep' })).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
})
it('should hide the run control when single-node execution is not supported', () => {
mockCanRunBySingle.mockReturnValue(false)
render(
<NodeControl
id="node-3"
data={makeData()}
/>,
)
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
})
})

View File

@@ -1,5 +1,8 @@
import type { FC } from 'react'
import type { Node } from '../../../types'
import {
RiPlayLargeLine,
} from '@remixicon/react'
import {
memo,
useCallback,
@@ -51,9 +54,7 @@ const NodeControl: FC<NodeControlProps> = ({
>
{
canRunBySingle(data.type, isChildNode) && (
<button
type="button"
aria-label={isSingleRunning ? t('debug.variableInspect.trigger.stop', { ns: 'workflow' }) : t('panel.runThisStep', { ns: 'workflow' })}
<div
className={`flex h-5 w-5 items-center justify-center rounded-md ${isSingleRunning && 'cursor-pointer hover:bg-state-base-hover'}`}
onClick={() => {
const action = isSingleRunning ? 'stop' : 'run'
@@ -75,11 +76,11 @@ const NodeControl: FC<NodeControlProps> = ({
popupContent={t('panel.runThisStep', { ns: 'workflow' })}
asChild={false}
>
<span className="i-ri-play-large-line h-3 w-3" />
<RiPlayLargeLine className="h-3 w-3" />
</Tooltip>
)
}
</button>
</div>
)
}
<PanelOperator

View File

@@ -1,14 +1,15 @@
'use client'
import type { FC } from 'react'
import type { OutputVar } from '../../../code/types'
import type { ToastHandle } from '@/app/components/base/toast'
import type { VarType } from '@/app/components/workflow/types'
import { useDebounceFn } from 'ahooks'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback } from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
import RemoveButton from '../remove-button'
import VarTypePicker from './var-type-picker'
@@ -29,6 +30,7 @@ const OutputVarList: FC<Props> = ({
onRemove,
}) => {
const { t } = useTranslation()
const [toastHandler, setToastHandler] = useState<ToastHandle>()
const list = outputKeyOrders.map((key) => {
return {
@@ -40,17 +42,20 @@ const OutputVarList: FC<Props> = ({
const { run: validateVarInput } = useDebounceFn((existingVariables: typeof list, newKey: string) => {
const result = checkKeys([newKey], true)
if (!result.isValid) {
toast.add({
setToastHandler(Toast.notify({
type: 'error',
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
})
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
}))
return
}
if (existingVariables.some(key => key.variable?.trim() === newKey.trim())) {
toast.add({
setToastHandler(Toast.notify({
type: 'error',
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
})
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
}))
}
else {
toastHandler?.clear?.()
}
}, { wait: 500 })
@@ -61,6 +66,7 @@ const OutputVarList: FC<Props> = ({
replaceSpaceWithUnderscoreInVarNameInput(e.target)
const newKey = e.target.value
toastHandler?.clear?.()
validateVarInput(list.toSpliced(index, 1), newKey)
const newOutputs = produce(outputs, (draft) => {
@@ -69,7 +75,7 @@ const OutputVarList: FC<Props> = ({
})
onChange(newOutputs, index, newKey)
}
}, [list, onChange, outputs, validateVarInput])
}, [list, onChange, outputs, outputKeyOrders, validateVarInput])
const handleVarTypeChange = useCallback((index: number) => {
return (value: string) => {
@@ -79,7 +85,7 @@ const OutputVarList: FC<Props> = ({
})
onChange(newOutputs)
}
}, [list, onChange, outputs])
}, [list, onChange, outputs, outputKeyOrders])
const handleVarRemove = useCallback((index: number) => {
return () => {

View File

@@ -1,16 +1,17 @@
'use client'
import type { FC } from 'react'
import type { ToastHandle } from '@/app/components/base/toast'
import type { ValueSelector, Var, Variable } from '@/app/components/workflow/types'
import { RiDraggable } from '@remixicon/react'
import { useDebounceFn } from 'ahooks'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback, useMemo } from 'react'
import { useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { ReactSortable } from 'react-sortablejs'
import { v4 as uuid4 } from 'uuid'
import Input from '@/app/components/base/input'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types'
import { cn } from '@/utils/classnames'
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
@@ -41,6 +42,7 @@ const VarList: FC<Props> = ({
isSupportFileVar = true,
}) => {
const { t } = useTranslation()
const [toastHandle, setToastHandle] = useState<ToastHandle>()
const listWithIds = useMemo(() => list.map((item) => {
const id = uuid4()
@@ -53,17 +55,20 @@ const VarList: FC<Props> = ({
const { run: validateVarInput } = useDebounceFn((list: Variable[], newKey: string) => {
const result = checkKeys([newKey], true)
if (!result.isValid) {
toast.add({
setToastHandle(Toast.notify({
type: 'error',
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
})
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
}))
return
}
if (list.some(item => item.variable?.trim() === newKey.trim())) {
toast.add({
setToastHandle(Toast.notify({
type: 'error',
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
})
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
}))
}
else {
toastHandle?.clear?.()
}
}, { wait: 500 })
@@ -73,6 +78,7 @@ const VarList: FC<Props> = ({
const newKey = e.target.value
toastHandle?.clear?.()
validateVarInput(list.toSpliced(index, 1), newKey)
onVarNameChange?.(list[index].variable, newKey)

View File

@@ -1,68 +1,90 @@
import type { WebhookTriggerNodeType } from '../types'
import type { NodePanelProps } from '@/app/components/workflow/types'
import type { PanelProps } from '@/types/workflow'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import { BlockEnum } from '@/app/components/workflow/types'
import Panel from '../panel'
const {
mockHandleStatusCodeChange,
mockGenerateWebhookUrl,
mockHandleMethodChange,
mockHandleContentTypeChange,
mockHandleHeadersChange,
mockHandleParamsChange,
mockHandleBodyChange,
mockHandleResponseBodyChange,
} = vi.hoisted(() => ({
mockHandleStatusCodeChange: vi.fn(),
mockGenerateWebhookUrl: vi.fn(),
mockHandleMethodChange: vi.fn(),
mockHandleContentTypeChange: vi.fn(),
mockHandleHeadersChange: vi.fn(),
mockHandleParamsChange: vi.fn(),
mockHandleBodyChange: vi.fn(),
mockHandleResponseBodyChange: vi.fn(),
}))
const mockConfigState = {
readOnly: false,
inputs: {
method: 'POST',
webhook_url: 'https://example.com/webhook',
webhook_debug_url: '',
content_type: 'application/json',
headers: [],
params: [],
body: [],
status_code: 200,
response_body: 'ok',
variables: [],
},
}
vi.mock('../use-config', () => ({
DEFAULT_STATUS_CODE: 200,
MAX_STATUS_CODE: 399,
normalizeStatusCode: (statusCode: number) => Math.min(Math.max(statusCode, 200), 399),
useConfig: () => ({
readOnly: mockConfigState.readOnly,
inputs: mockConfigState.inputs,
handleMethodChange: mockHandleMethodChange,
handleContentTypeChange: mockHandleContentTypeChange,
handleHeadersChange: mockHandleHeadersChange,
handleParamsChange: mockHandleParamsChange,
handleBodyChange: mockHandleBodyChange,
readOnly: false,
inputs: {
method: 'POST',
webhook_url: 'https://example.com/webhook',
webhook_debug_url: '',
content_type: 'application/json',
headers: [],
params: [],
body: [],
status_code: 200,
response_body: '',
},
handleMethodChange: vi.fn(),
handleContentTypeChange: vi.fn(),
handleHeadersChange: vi.fn(),
handleParamsChange: vi.fn(),
handleBodyChange: vi.fn(),
handleStatusCodeChange: mockHandleStatusCodeChange,
handleResponseBodyChange: mockHandleResponseBodyChange,
handleResponseBodyChange: vi.fn(),
generateWebhookUrl: mockGenerateWebhookUrl,
}),
}))
const getStatusCodeInput = () => {
return screen.getAllByDisplayValue('200')
.find(element => element.getAttribute('aria-hidden') !== 'true') as HTMLInputElement
}
vi.mock('@/app/components/base/input-with-copy', () => ({
default: () => <div data-testid="input-with-copy" />,
}))
vi.mock('@/app/components/base/select', () => ({
SimpleSelect: () => <div data-testid="simple-select" />,
}))
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children }: { children: React.ReactNode }) => <>{children}</>,
}))
vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({
default: ({ title, children }: { title: React.ReactNode, children: React.ReactNode }) => (
<section>
<div>{title}</div>
{children}
</section>
),
}))
vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({
default: () => <div data-testid="output-vars" />,
}))
vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({
default: () => <div data-testid="split" />,
}))
vi.mock('../components/header-table', () => ({
default: () => <div data-testid="header-table" />,
}))
vi.mock('../components/parameter-table', () => ({
default: () => <div data-testid="parameter-table" />,
}))
vi.mock('../components/paragraph-input', () => ({
default: () => <div data-testid="paragraph-input" />,
}))
vi.mock('../utils/render-output-vars', () => ({
OutputVariablesContent: () => <div data-testid="output-variables-content" />,
}))
describe('WebhookTriggerPanel', () => {
const panelProps: NodePanelProps<WebhookTriggerNodeType> = {
@@ -78,7 +100,7 @@ describe('WebhookTriggerPanel', () => {
body: [],
async_mode: false,
status_code: 200,
response_body: 'ok',
response_body: '',
variables: [],
},
panelProps: {} as PanelProps,
@@ -86,65 +108,26 @@ describe('WebhookTriggerPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockConfigState.readOnly = false
mockConfigState.inputs = {
method: 'POST',
webhook_url: 'https://example.com/webhook',
webhook_debug_url: '',
content_type: 'application/json',
headers: [],
params: [],
body: [],
status_code: 200,
response_body: 'ok',
variables: [],
}
})
describe('Rendering', () => {
it('should render the real panel fields without generating a new webhook url when one already exists', () => {
render(<Panel {...panelProps} />)
it('should update the status code when users enter a parseable value', () => {
render(<Panel {...panelProps} />)
expect(screen.getByDisplayValue('https://example.com/webhook')).toBeInTheDocument()
expect(screen.getByText('application/json')).toBeInTheDocument()
expect(screen.getByDisplayValue('ok')).toBeInTheDocument()
expect(mockGenerateWebhookUrl).not.toHaveBeenCalled()
})
fireEvent.change(screen.getByRole('textbox'), { target: { value: '201' } })
it('should request a webhook url when the node is writable and missing one', async () => {
mockConfigState.inputs = {
...mockConfigState.inputs,
webhook_url: '',
}
render(<Panel {...panelProps} />)
await waitFor(() => {
expect(mockGenerateWebhookUrl).toHaveBeenCalledTimes(1)
})
})
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(201)
})
describe('Status Code Input', () => {
it('should update the status code when users enter a parseable value', () => {
render(<Panel {...panelProps} />)
it('should ignore clear changes until the value is committed', () => {
render(<Panel {...panelProps} />)
fireEvent.change(getStatusCodeInput(), { target: { value: '201' } })
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: '' } })
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(201)
})
expect(mockHandleStatusCodeChange).not.toHaveBeenCalled()
it('should ignore clear changes until the value is committed', () => {
render(<Panel {...panelProps} />)
fireEvent.blur(input)
const input = getStatusCodeInput()
fireEvent.change(input, { target: { value: '' } })
expect(mockHandleStatusCodeChange).not.toHaveBeenCalled()
fireEvent.blur(input)
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(200)
})
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(200)
})
})

View File

@@ -1,225 +0,0 @@
import type { ReactNode } from 'react'
import { act, render, screen, waitFor } from '@testing-library/react'
import ReactFlow, { ReactFlowProvider } from 'reactflow'
import { FlowType } from '@/types/common'
import { BlockEnum } from '../../types'
import AddBlock from '../add-block'
type BlockSelectorMockProps = {
open: boolean
onOpenChange: (open: boolean) => void
disabled: boolean
onSelect: (type: BlockEnum, pluginDefaultValue?: Record<string, unknown>) => void
placement: string
offset: {
mainAxis: number
crossAxis: number
}
trigger: (open: boolean) => ReactNode
popupClassName: string
availableBlocksTypes: BlockEnum[]
showStartTab: boolean
}
const {
mockHandlePaneContextmenuCancel,
mockWorkflowStoreSetState,
mockGenerateNewNode,
mockGetNodeCustomTypeByNodeDataType,
} = vi.hoisted(() => ({
mockHandlePaneContextmenuCancel: vi.fn(),
mockWorkflowStoreSetState: vi.fn(),
mockGenerateNewNode: vi.fn(({ type, data }: { type: string, data: Record<string, unknown> }) => ({
newNode: {
id: 'generated-node',
type,
data,
},
})),
mockGetNodeCustomTypeByNodeDataType: vi.fn((type: string) => `${type}-custom`),
}))
let latestBlockSelectorProps: BlockSelectorMockProps | null = null
let mockNodesReadOnly = false
let mockIsChatMode = false
let mockFlowType: FlowType = FlowType.appFlow
const mockAvailableNextBlocks = [BlockEnum.Answer, BlockEnum.Code]
const mockNodesMetaDataMap = {
[BlockEnum.Answer]: {
defaultValue: {
title: 'Answer',
desc: '',
type: BlockEnum.Answer,
},
},
}
vi.mock('@/app/components/workflow/block-selector', () => ({
default: (props: BlockSelectorMockProps) => {
latestBlockSelectorProps = props
return (
<div data-testid="block-selector">
{props.trigger(props.open)}
</div>
)
},
}))
vi.mock('../../hooks', () => ({
useAvailableBlocks: () => ({
availableNextBlocks: mockAvailableNextBlocks,
}),
useIsChatMode: () => mockIsChatMode,
useNodesMetaData: () => ({
nodesMap: mockNodesMetaDataMap,
}),
useNodesReadOnly: () => ({
nodesReadOnly: mockNodesReadOnly,
}),
usePanelInteractions: () => ({
handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel,
}),
}))
vi.mock('../../hooks-store', () => ({
useHooksStore: (selector: (state: { configsMap?: { flowType?: FlowType } }) => unknown) =>
selector({ configsMap: { flowType: mockFlowType } }),
}))
vi.mock('../../store', () => ({
useWorkflowStore: () => ({
setState: mockWorkflowStoreSetState,
}),
}))
vi.mock('../../utils', () => ({
generateNewNode: mockGenerateNewNode,
getNodeCustomTypeByNodeDataType: mockGetNodeCustomTypeByNodeDataType,
}))
vi.mock('../tip-popup', () => ({
default: ({ children }: { children?: ReactNode }) => <>{children}</>,
}))
const renderWithReactFlow = (nodes: Array<{ id: string, position: { x: number, y: number }, data: { type: BlockEnum } }>) => {
return render(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow nodes={nodes} edges={[]} fitView />
<AddBlock />
</ReactFlowProvider>
</div>,
)
}
describe('AddBlock', () => {
beforeEach(() => {
vi.clearAllMocks()
latestBlockSelectorProps = null
mockNodesReadOnly = false
mockIsChatMode = false
mockFlowType = FlowType.appFlow
})
// Rendering and selector configuration.
describe('Rendering', () => {
it('should pass the selector props for a writable app workflow', async () => {
renderWithReactFlow([])
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
expect(screen.getByTestId('block-selector')).toBeInTheDocument()
expect(latestBlockSelectorProps).toMatchObject({
disabled: false,
availableBlocksTypes: mockAvailableNextBlocks,
showStartTab: true,
placement: 'right-start',
popupClassName: '!min-w-[256px]',
})
expect(latestBlockSelectorProps?.offset).toEqual({
mainAxis: 4,
crossAxis: -8,
})
})
it('should hide the start tab for chat mode and rag pipeline flows', async () => {
mockIsChatMode = true
const { rerender } = renderWithReactFlow([])
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
expect(latestBlockSelectorProps?.showStartTab).toBe(false)
mockIsChatMode = false
mockFlowType = FlowType.ragPipeline
rerender(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow nodes={[]} edges={[]} fitView />
<AddBlock />
</ReactFlowProvider>
</div>,
)
expect(latestBlockSelectorProps?.showStartTab).toBe(false)
})
})
// User interactions that bridge selector state and workflow state.
describe('User Interactions', () => {
it('should cancel the pane context menu when the selector closes', async () => {
renderWithReactFlow([])
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
act(() => {
latestBlockSelectorProps?.onOpenChange(false)
})
expect(mockHandlePaneContextmenuCancel).toHaveBeenCalledTimes(1)
})
it('should create a candidate node with an incremented title when a block is selected', async () => {
renderWithReactFlow([
{ id: 'node-1', position: { x: 0, y: 0 }, data: { type: BlockEnum.Answer } },
{ id: 'node-2', position: { x: 80, y: 0 }, data: { type: BlockEnum.Answer } },
])
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
act(() => {
latestBlockSelectorProps?.onSelect(BlockEnum.Answer, { pluginId: 'plugin-1' })
})
expect(mockGetNodeCustomTypeByNodeDataType).toHaveBeenCalledWith(BlockEnum.Answer)
expect(mockGenerateNewNode).toHaveBeenCalledWith({
type: 'answer-custom',
data: {
title: 'Answer 3',
desc: '',
type: BlockEnum.Answer,
pluginId: 'plugin-1',
_isCandidate: true,
},
position: {
x: 0,
y: 0,
},
})
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({
candidateNode: {
id: 'generated-node',
type: 'answer-custom',
data: {
title: 'Answer 3',
desc: '',
type: BlockEnum.Answer,
pluginId: 'plugin-1',
_isCandidate: true,
},
},
})
})
})
})

View File

@@ -1,136 +0,0 @@
import type { ReactNode } from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import { ControlMode } from '../../types'
import Control from '../control'
type WorkflowStoreState = {
controlMode: ControlMode
maximizeCanvas: boolean
}
const {
mockHandleAddNote,
mockHandleLayout,
mockHandleModeHand,
mockHandleModePointer,
mockHandleToggleMaximizeCanvas,
} = vi.hoisted(() => ({
mockHandleAddNote: vi.fn(),
mockHandleLayout: vi.fn(),
mockHandleModeHand: vi.fn(),
mockHandleModePointer: vi.fn(),
mockHandleToggleMaximizeCanvas: vi.fn(),
}))
let mockNodesReadOnly = false
let mockStoreState: WorkflowStoreState
vi.mock('../../hooks', () => ({
useNodesReadOnly: () => ({
nodesReadOnly: mockNodesReadOnly,
getNodesReadOnly: () => mockNodesReadOnly,
}),
useWorkflowCanvasMaximize: () => ({
handleToggleMaximizeCanvas: mockHandleToggleMaximizeCanvas,
}),
useWorkflowMoveMode: () => ({
handleModePointer: mockHandleModePointer,
handleModeHand: mockHandleModeHand,
}),
useWorkflowOrganize: () => ({
handleLayout: mockHandleLayout,
}),
}))
vi.mock('../hooks', () => ({
useOperator: () => ({
handleAddNote: mockHandleAddNote,
}),
}))
vi.mock('../../store', () => ({
useStore: (selector: (state: WorkflowStoreState) => unknown) => selector(mockStoreState),
}))
vi.mock('../add-block', () => ({
default: () => <div data-testid="add-block" />,
}))
vi.mock('../more-actions', () => ({
default: () => <div data-testid="more-actions" />,
}))
vi.mock('../tip-popup', () => ({
default: ({
children,
title,
}: {
children?: ReactNode
title?: string
}) => <div data-testid={title}>{children}</div>,
}))
describe('Control', () => {
beforeEach(() => {
vi.clearAllMocks()
mockNodesReadOnly = false
mockStoreState = {
controlMode: ControlMode.Pointer,
maximizeCanvas: false,
}
})
// Rendering and visual states for control buttons.
describe('Rendering', () => {
it('should render the child action groups and highlight the active pointer mode', () => {
render(<Control />)
expect(screen.getByTestId('add-block')).toBeInTheDocument()
expect(screen.getByTestId('more-actions')).toBeInTheDocument()
expect(screen.getByTestId('workflow.common.pointerMode').firstElementChild).toHaveClass('bg-state-accent-active')
expect(screen.getByTestId('workflow.common.handMode').firstElementChild).not.toHaveClass('bg-state-accent-active')
expect(screen.getByTestId('workflow.panel.maximize')).toBeInTheDocument()
})
it('should switch the maximize tooltip and active style when the canvas is maximized', () => {
mockStoreState = {
controlMode: ControlMode.Hand,
maximizeCanvas: true,
}
render(<Control />)
expect(screen.getByTestId('workflow.common.handMode').firstElementChild).toHaveClass('bg-state-accent-active')
expect(screen.getByTestId('workflow.panel.minimize').firstElementChild).toHaveClass('bg-state-accent-active')
})
})
// User interactions exposed by the control bar.
describe('User Interactions', () => {
it('should trigger the note, mode, organize, and maximize handlers', () => {
render(<Control />)
fireEvent.click(screen.getByTestId('workflow.nodes.note.addNote').firstElementChild as HTMLElement)
fireEvent.click(screen.getByTestId('workflow.common.pointerMode').firstElementChild as HTMLElement)
fireEvent.click(screen.getByTestId('workflow.common.handMode').firstElementChild as HTMLElement)
fireEvent.click(screen.getByTestId('workflow.panel.organizeBlocks').firstElementChild as HTMLElement)
fireEvent.click(screen.getByTestId('workflow.panel.maximize').firstElementChild as HTMLElement)
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)
expect(mockHandleModePointer).toHaveBeenCalledTimes(1)
expect(mockHandleModeHand).toHaveBeenCalledTimes(1)
expect(mockHandleLayout).toHaveBeenCalledTimes(1)
expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1)
})
it('should block note creation when the workflow is read only', () => {
mockNodesReadOnly = true
render(<Control />)
fireEvent.click(screen.getByTestId('workflow.nodes.note.addNote').firstElementChild as HTMLElement)
expect(mockHandleAddNote).not.toHaveBeenCalled()
})
})
})

View File

@@ -1,323 +0,0 @@
import type { Shape as HooksStoreShape } from '../../hooks-store/store'
import type { RunFile } from '../../types'
import type { FileUpload } from '@/app/components/base/features/types'
import { screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import ReactFlow, { ReactFlowProvider } from 'reactflow'
import { TransferMethod } from '@/types/app'
import { FlowType } from '@/types/common'
import { createStartNode } from '../../__tests__/fixtures'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import { InputVarType, WorkflowRunningStatus } from '../../types'
import InputsPanel from '../inputs-panel'
const mockCheckInputsForm = vi.fn()
const mockNotify = vi.fn()
vi.mock('next/navigation', () => ({
useParams: () => ({}),
}))
vi.mock('@/app/components/base/toast/context', () => ({
useToastContext: () => ({
notify: mockNotify,
close: vi.fn(),
}),
}))
vi.mock('@/app/components/base/chat/chat/check-input-forms-hooks', () => ({
useCheckInputsForms: () => ({
checkInputsForm: mockCheckInputsForm,
}),
}))
const fileSettingsWithImage = {
enabled: true,
image: {
enabled: true,
},
allowed_file_upload_methods: [TransferMethod.remote_url],
number_limits: 3,
image_file_size_limit: 10,
} satisfies FileUpload & { image_file_size_limit: number }
const uploadedRunFile = {
transfer_method: TransferMethod.remote_url,
upload_file_id: 'file-2',
} as unknown as RunFile
const uploadingRunFile = {
transfer_method: TransferMethod.local_file,
} as unknown as RunFile
const createHooksStoreProps = (
overrides: Partial<HooksStoreShape> = {},
): Partial<HooksStoreShape> => ({
handleRun: vi.fn(),
configsMap: {
flowId: 'flow-1',
flowType: FlowType.appFlow,
fileSettings: fileSettingsWithImage,
},
...overrides,
})
const renderInputsPanel = (
startNode: ReturnType<typeof createStartNode>,
options?: Parameters<typeof renderWorkflowComponent>[1],
) => {
return renderWorkflowComponent(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow nodes={[startNode]} edges={[]} fitView />
<InputsPanel onRun={vi.fn()} />
</ReactFlowProvider>
</div>,
options,
)
}
describe('InputsPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockCheckInputsForm.mockReturnValue(true)
})
describe('Rendering', () => {
it('should render current inputs, defaults, and the image uploader from the start node', () => {
renderInputsPanel(
createStartNode({
data: {
variables: [
{
type: InputVarType.textInput,
variable: 'question',
label: 'Question',
required: true,
default: 'default question',
},
{
type: InputVarType.number,
variable: 'count',
label: 'Count',
required: false,
default: '2',
},
],
},
}),
{
initialStoreState: {
inputs: {
question: 'overridden question',
},
},
hooksStoreProps: createHooksStoreProps(),
},
)
expect(screen.getByDisplayValue('overridden question')).toHaveFocus()
expect(screen.getByRole('spinbutton')).toHaveValue(2)
expect(screen.getByText('common.imageUploader.pasteImageLink')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should update workflow inputs and image files when users edit the form', async () => {
const user = userEvent.setup()
const { store } = renderInputsPanel(
createStartNode({
data: {
variables: [
{
type: InputVarType.textInput,
variable: 'question',
label: 'Question',
required: true,
},
],
},
}),
{
hooksStoreProps: createHooksStoreProps(),
},
)
await user.type(screen.getByPlaceholderText('Question'), 'changed question')
expect(store.getState().inputs).toEqual({ question: 'changed question' })
await user.click(screen.getByText('common.imageUploader.pasteImageLink'))
await user.type(
await screen.findByPlaceholderText('common.imageUploader.pasteImageLinkInputPlaceholder'),
'https://example.com/image.png',
)
await user.click(screen.getByRole('button', { name: 'common.operation.ok' }))
await waitFor(() => {
expect(store.getState().files).toEqual([{
type: 'image',
transfer_method: TransferMethod.remote_url,
url: 'https://example.com/image.png',
upload_file_id: '',
}])
})
})
it('should not start a run when input validation fails', async () => {
const user = userEvent.setup()
mockCheckInputsForm.mockReturnValue(false)
const onRun = vi.fn()
const handleRun = vi.fn()
renderWorkflowComponent(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow
nodes={[
createStartNode({
data: {
variables: [
{
type: InputVarType.textInput,
variable: 'question',
label: 'Question',
required: true,
default: 'default question',
},
],
},
}),
]}
edges={[]}
fitView
/>
<InputsPanel onRun={onRun} />
</ReactFlowProvider>
</div>,
{
hooksStoreProps: createHooksStoreProps({ handleRun }),
},
)
await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' }))
expect(mockCheckInputsForm).toHaveBeenCalledWith(
{ question: 'default question' },
expect.arrayContaining([
expect.objectContaining({ variable: 'question' }),
expect.objectContaining({ variable: '__image' }),
]),
)
expect(onRun).not.toHaveBeenCalled()
expect(handleRun).not.toHaveBeenCalled()
})
it('should start a run with processed inputs when validation succeeds', async () => {
const user = userEvent.setup()
const onRun = vi.fn()
const handleRun = vi.fn()
renderWorkflowComponent(
<div style={{ width: 800, height: 600 }}>
<ReactFlowProvider>
<ReactFlow
nodes={[
createStartNode({
data: {
variables: [
{
type: InputVarType.textInput,
variable: 'question',
label: 'Question',
required: true,
},
{
type: InputVarType.checkbox,
variable: 'confirmed',
label: 'Confirmed',
required: false,
},
],
},
}),
]}
edges={[]}
fitView
/>
<InputsPanel onRun={onRun} />
</ReactFlowProvider>
</div>,
{
initialStoreState: {
inputs: {
question: 'run this',
confirmed: 'truthy',
},
files: [uploadedRunFile],
},
hooksStoreProps: createHooksStoreProps({
handleRun,
configsMap: {
flowId: 'flow-1',
flowType: FlowType.appFlow,
fileSettings: {
enabled: false,
},
},
}),
},
)
await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' }))
expect(onRun).toHaveBeenCalledTimes(1)
expect(handleRun).toHaveBeenCalledWith({
inputs: {
question: 'run this',
confirmed: true,
},
files: [uploadedRunFile],
})
})
})
describe('Disabled States', () => {
it('should disable the run button while a local file is still uploading', () => {
renderInputsPanel(createStartNode(), {
initialStoreState: {
files: [uploadingRunFile],
},
hooksStoreProps: createHooksStoreProps({
configsMap: {
flowId: 'flow-1',
flowType: FlowType.appFlow,
fileSettings: {
enabled: false,
},
},
}),
})
expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled()
})
it('should disable the run button while the workflow is already running', () => {
renderInputsPanel(createStartNode(), {
initialStoreState: {
workflowRunningData: {
result: {
status: WorkflowRunningStatus.Running,
inputs_truncated: false,
process_data_truncated: false,
outputs_truncated: false,
},
tracing: [],
},
},
hooksStoreProps: createHooksStoreProps(),
})
expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled()
})
})
})

View File

@@ -1,163 +0,0 @@
import type { WorkflowRunDetailResponse } from '@/models/log'
import { act, screen } from '@testing-library/react'
import { createEdge, createNode } from '../../__tests__/fixtures'
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
import Record from '../record'
const mockHandleUpdateWorkflowCanvas = vi.fn()
const mockFormatWorkflowRunIdentifier = vi.fn((finishedAt?: number) => finishedAt ? ' (Finished)' : ' (Running)')
let latestGetResultCallback: ((res: WorkflowRunDetailResponse) => void) | undefined
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowUpdate: () => ({
handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas,
}),
}))
vi.mock('@/app/components/workflow/run', () => ({
default: ({
runDetailUrl,
tracingListUrl,
getResultCallback,
}: {
runDetailUrl: string
tracingListUrl: string
getResultCallback: (res: WorkflowRunDetailResponse) => void
}) => {
latestGetResultCallback = getResultCallback
return (
<div
data-run-detail-url={runDetailUrl}
data-testid="run"
data-tracing-list-url={tracingListUrl}
/>
)
},
}))
vi.mock('@/app/components/workflow/utils', () => ({
formatWorkflowRunIdentifier: (finishedAt?: number) => mockFormatWorkflowRunIdentifier(finishedAt),
}))
const createRunDetail = (overrides: Partial<WorkflowRunDetailResponse> = {}): WorkflowRunDetailResponse => ({
id: 'run-1',
version: '1',
graph: {
nodes: [],
edges: [],
},
inputs: '{}',
inputs_truncated: false,
status: 'succeeded',
outputs: '{}',
outputs_truncated: false,
total_steps: 1,
created_by_role: 'account',
created_at: 1,
finished_at: 2,
...overrides,
})
describe('Record', () => {
beforeEach(() => {
vi.clearAllMocks()
latestGetResultCallback = undefined
})
it('renders the run title and passes run and trace URLs to the run panel', () => {
const getWorkflowRunAndTraceUrl = vi.fn((runId?: string) => ({
runUrl: `/runs/${runId}`,
traceUrl: `/traces/${runId}`,
}))
renderWorkflowComponent(<Record />, {
initialStoreState: {
historyWorkflowData: {
id: 'run-1',
status: 'succeeded',
finished_at: 1700000000000,
},
},
hooksStoreProps: {
getWorkflowRunAndTraceUrl,
},
})
expect(screen.getByText('Test Run (Finished)')).toBeInTheDocument()
expect(screen.getByTestId('run')).toHaveAttribute('data-run-detail-url', '/runs/run-1')
expect(screen.getByTestId('run')).toHaveAttribute('data-tracing-list-url', '/traces/run-1')
expect(getWorkflowRunAndTraceUrl).toHaveBeenCalledTimes(2)
expect(getWorkflowRunAndTraceUrl).toHaveBeenNthCalledWith(1, 'run-1')
expect(getWorkflowRunAndTraceUrl).toHaveBeenNthCalledWith(2, 'run-1')
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(1700000000000)
})
it('updates the workflow canvas with a fallback viewport when the response omits one', () => {
const nodes = [createNode({ id: 'node-1' })]
const edges = [createEdge({ id: 'edge-1' })]
renderWorkflowComponent(<Record />, {
initialStoreState: {
historyWorkflowData: {
id: 'run-1',
status: 'succeeded',
},
},
hooksStoreProps: {
getWorkflowRunAndTraceUrl: () => ({ runUrl: '/runs/run-1', traceUrl: '/traces/run-1' }),
},
})
expect(latestGetResultCallback).toBeDefined()
act(() => {
latestGetResultCallback?.(createRunDetail({
graph: {
nodes,
edges,
},
}))
})
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
nodes,
edges,
viewport: { x: 0, y: 0, zoom: 1 },
})
})
it('uses the response viewport when one is available', () => {
const nodes = [createNode({ id: 'node-1' })]
const edges = [createEdge({ id: 'edge-1' })]
const viewport = { x: 12, y: 24, zoom: 0.75 }
renderWorkflowComponent(<Record />, {
initialStoreState: {
historyWorkflowData: {
id: 'run-1',
status: 'succeeded',
},
},
hooksStoreProps: {
getWorkflowRunAndTraceUrl: () => ({ runUrl: '/runs/run-1', traceUrl: '/traces/run-1' }),
},
})
act(() => {
latestGetResultCallback?.(createRunDetail({
graph: {
nodes,
edges,
viewport,
},
}))
})
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
nodes,
edges,
viewport,
})
})
})

View File

@@ -7,7 +7,7 @@ import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import VersionInfoModal from '@/app/components/app/app-publisher/version-info-modal'
import Divider from '@/app/components/base/divider'
import { toast } from '@/app/components/base/ui/toast'
import Toast from '@/app/components/base/toast'
import { useSelector as useAppContextSelector } from '@/context/app-context'
import { useDeleteWorkflow, useInvalidAllLastRun, useResetWorkflowVersionHistory, useUpdateWorkflow, useWorkflowVersionHistory } from '@/service/use-workflow'
import { useDSL, useNodesSyncDraft, useWorkflowRun } from '../../hooks'
@@ -118,9 +118,9 @@ export const VersionHistoryPanel = ({
break
case VersionHistoryContextMenuOptions.copyId:
copy(item.id)
toast.add({
Toast.notify({
type: 'success',
title: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
message: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
})
break
case VersionHistoryContextMenuOptions.exportDSL:
@@ -152,17 +152,17 @@ export const VersionHistoryPanel = ({
workflowStore.setState({ backupDraft: undefined })
handleSyncWorkflowDraft(true, false, {
onSuccess: () => {
toast.add({
Toast.notify({
type: 'success',
title: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
})
deleteAllInspectVars()
invalidAllLastRun()
},
onError: () => {
toast.add({
Toast.notify({
type: 'error',
title: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
})
},
onSettled: () => {
@@ -177,18 +177,18 @@ export const VersionHistoryPanel = ({
await deleteWorkflow(deleteVersionUrl?.(id) || '', {
onSuccess: () => {
setDeleteConfirmOpen(false)
toast.add({
Toast.notify({
type: 'success',
title: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
message: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
})
resetWorkflowVersionHistory()
deleteAllInspectVars()
invalidAllLastRun()
},
onError: () => {
toast.add({
Toast.notify({
type: 'error',
title: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
message: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
})
},
onSettled: () => {
@@ -207,16 +207,16 @@ export const VersionHistoryPanel = ({
}, {
onSuccess: () => {
setEditModalOpen(false)
toast.add({
Toast.notify({
type: 'success',
title: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
message: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
})
resetWorkflowVersionHistory()
},
onError: () => {
toast.add({
Toast.notify({
type: 'error',
title: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
message: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
})
},
onSettled: () => {

Some files were not shown because too many files have changed in this diff Show More