mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 14:27:05 +00:00
Compare commits
11 Commits
yanli/phas
...
verify-ema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0da31b1a14 | ||
|
|
10d1904e59 | ||
|
|
095b436621 | ||
|
|
baaf4e8041 | ||
|
|
bc41371975 | ||
|
|
9c5c935ed5 | ||
|
|
559f8263b7 | ||
|
|
59c5638342 | ||
|
|
897ffb6b35 | ||
|
|
d367a6b1e1 | ||
|
|
daa9d38788 |
@@ -134,7 +134,6 @@ class EducationAutocompleteQuery(BaseModel):
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
|
||||
@@ -548,13 +547,17 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.token is not None:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@@ -574,7 +577,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -609,12 +612,26 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
phase_transitions: dict[str, str] = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@@ -644,13 +661,22 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -47,6 +47,7 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -521,9 +522,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session=session, model=workflow)
|
||||
message = _refresh_model(session=session, model=message)
|
||||
assert message is not None
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
@@ -690,21 +690,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
raise e
|
||||
|
||||
|
||||
@overload
|
||||
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
|
||||
|
||||
@overload
|
||||
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
|
||||
|
||||
|
||||
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
|
||||
if session is not None:
|
||||
detached_model = session.get(type(model), model.id)
|
||||
assert detached_model is not None
|
||||
return detached_model
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
|
||||
detached_model = refresh_session.get(type(model), model.id)
|
||||
assert detached_model is not None
|
||||
return detached_model
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from typing import Any
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@@ -16,26 +16,24 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_full_response(stream_response)
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_simple_response(stream_response)
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@@ -52,14 +50,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -224,7 +224,6 @@ class BaseAppGenerator:
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
assert isinstance(account, Account)
|
||||
debug_account = account
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
@@ -241,7 +240,7 @@ class BaseAppGenerator:
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
user=debug_account,
|
||||
user=account,
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
@@ -166,19 +166,15 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
generated_conversation_id = str(conversation.id)
|
||||
generated_message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=generated_conversation_id,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=generated_message_id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -188,8 +184,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=generated_conversation_id,
|
||||
message_id=generated_message_id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -149,8 +149,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -314,19 +312,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
conversation_id = str(conversation.id)
|
||||
message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation_id,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message_id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -336,7 +330,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message_id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator, Iterator
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, TypeAlias
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@@ -40,7 +36,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -79,14 +75,6 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GraphConfigObject: TypeAlias = dict[str, object]
|
||||
GraphConfigMapping: TypeAlias = Mapping[str, object]
|
||||
|
||||
|
||||
class SingleNodeRunEntity(Protocol):
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
@@ -110,7 +98,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: GraphConfigMapping,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -166,8 +154,8 @@ class WorkflowBasedAppRunner:
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: SingleNodeRunEntity | None = None,
|
||||
single_loop_run: SingleNodeRunEntity | None = None,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@@ -220,88 +208,11 @@ class WorkflowBasedAppRunner:
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
@staticmethod
|
||||
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
|
||||
nodes = graph_config.get("nodes")
|
||||
edges = graph_config.get("edges")
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
if not isinstance(edges, list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
validated_nodes: list[GraphConfigMapping] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
raise ValueError("nodes in workflow graph must be mappings")
|
||||
validated_nodes.append(node)
|
||||
|
||||
validated_edges: list[GraphConfigMapping] = []
|
||||
for edge in edges:
|
||||
if not isinstance(edge, Mapping):
|
||||
raise ValueError("edges in workflow graph must be mappings")
|
||||
validated_edges.append(edge)
|
||||
|
||||
return validated_nodes, validated_edges
|
||||
|
||||
@staticmethod
|
||||
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
|
||||
if node_config is None:
|
||||
return None
|
||||
node_data = node_config.get("data")
|
||||
if not isinstance(node_data, Mapping):
|
||||
return None
|
||||
start_node_id = node_data.get("start_node_id")
|
||||
return start_node_id if isinstance(start_node_id, str) else None
|
||||
|
||||
@classmethod
|
||||
def _build_single_node_graph_config(
|
||||
cls,
|
||||
*,
|
||||
graph_config: GraphConfigMapping,
|
||||
node_id: str,
|
||||
node_type_filter_key: str,
|
||||
) -> tuple[GraphConfigObject, NodeConfigDict]:
|
||||
node_configs, edge_configs = cls._get_graph_items(graph_config)
|
||||
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
|
||||
start_node_id = cls._extract_start_node_id(main_node_config)
|
||||
|
||||
filtered_node_configs = [
|
||||
dict(node)
|
||||
for node in node_configs
|
||||
if node.get("id") == node_id
|
||||
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
if not filtered_node_configs:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
filtered_node_ids = {
|
||||
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
|
||||
}
|
||||
filtered_edge_configs = [
|
||||
dict(edge)
|
||||
for edge in edge_configs
|
||||
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
|
||||
]
|
||||
|
||||
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
|
||||
if target_node_config is None:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
return (
|
||||
{
|
||||
"nodes": filtered_node_configs,
|
||||
"edges": filtered_edge_configs,
|
||||
},
|
||||
NodeConfigDictAdapter.validate_python(target_node_config),
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: Mapping[str, object],
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
@@ -325,14 +236,41 @@ class WorkflowBasedAppRunner:
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
graph_config, target_node_config = self._build_single_node_graph_config(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_type_filter_key=node_type_filter_key,
|
||||
)
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -361,6 +299,18 @@ class WorkflowBasedAppRunner:
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
|
||||
@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
inputs: Mapping
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
inputs: Mapping
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -38,32 +37,6 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@@ -112,6 +85,18 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@@ -1045,10 +1045,9 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "constant":
|
||||
parameter_value = tool_input.value
|
||||
|
||||
@@ -1,24 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Literal, TypeAlias
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
AgentInputConstantValue: TypeAlias = (
|
||||
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
|
||||
)
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
|
||||
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
@@ -32,20 +21,8 @@ class AgentNodeData(BaseNodeData):
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: AgentInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> AgentInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "variable":
|
||||
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type in {"mixed", "constant"}:
|
||||
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown agent input type: {input_type}")
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypeAlias
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
@@ -29,14 +28,6 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
JsonObject: TypeAlias = dict[str, object]
|
||||
JsonObjectList: TypeAlias = list[JsonObject]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
@@ -48,12 +39,12 @@ class AgentRuntimeSupport:
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
invoke_from: Any,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, object]:
|
||||
) -> dict[str, Any]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, object] = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@@ -63,10 +54,9 @@ class AgentRuntimeSupport:
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
|
||||
variable = variable_pool.get(variable_selector)
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
@@ -89,38 +79,60 @@ class AgentRuntimeSupport:
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
value = self._normalize_tool_payloads(
|
||||
strategy=strategy,
|
||||
tools=tool_payloads,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = self._coerce_tool_provider_type(tool.get("type"))
|
||||
setting_params = self._coerce_json_object(tool.get("settings")) or {}
|
||||
parameters = self._coerce_json_object(tool.get("parameters")) or {}
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
|
||||
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
|
||||
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
entity = AgentToolEntity(
|
||||
provider_id=provider_id,
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool_name,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
credential_id=credential_id,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = self._coerce_json_object(tool.get("extra")) or {}
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
@@ -133,9 +145,8 @@ class AgentRuntimeSupport:
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
description_override = self._coerce_optional_string(extra.get("description"))
|
||||
tool_runtime.entity.description.llm = (
|
||||
description_override or tool_runtime.entity.description.llm
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
@@ -156,13 +167,13 @@ class AgentRuntimeSupport:
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": credential_id,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
@@ -188,27 +199,17 @@ class AgentRuntimeSupport:
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
tools = parameters.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
return credentials
|
||||
|
||||
for raw_tool in tools:
|
||||
tool = self._coerce_json_object(raw_tool)
|
||||
if tool is None:
|
||||
continue
|
||||
for tool in parameters.get("tools", []):
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
if credential_id is None:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = credential_id
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
@@ -231,14 +232,14 @@ class AgentRuntimeSupport:
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=str(value.get("provider", "")),
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = str(value.get("model", ""))
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
@@ -248,7 +249,7 @@ class AgentRuntimeSupport:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(str(value.get("model_type", ""))),
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
@@ -267,88 +268,9 @@ class AgentRuntimeSupport:
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: JsonObjectList,
|
||||
) -> JsonObjectList:
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _normalize_tool_payloads(
|
||||
self,
|
||||
*,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: JsonObjectList,
|
||||
variable_pool: VariablePool,
|
||||
) -> JsonObjectList:
|
||||
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
|
||||
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
|
||||
for tool in normalized_tools:
|
||||
tool.pop("schemas", None)
|
||||
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
|
||||
tool["settings"] = self._resolve_tool_settings(tool)
|
||||
return normalized_tools
|
||||
|
||||
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
|
||||
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
|
||||
if parameter_configs is None:
|
||||
raw_parameters = self._coerce_json_object(tool.get("parameters"))
|
||||
return raw_parameters or {}
|
||||
|
||||
resolved_parameters: JsonObject = {}
|
||||
for key, parameter_config in parameter_configs.items():
|
||||
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
|
||||
value_param = self._coerce_json_object(parameter_config.get("value"))
|
||||
if value_param and value_param.get("type") == "variable":
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
resolved_parameters[key] = variable.value
|
||||
else:
|
||||
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
resolved_parameters[key] = None
|
||||
|
||||
return resolved_parameters
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
|
||||
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
|
||||
if settings is None:
|
||||
return {}
|
||||
return {key: setting.get("value") for key, setting in settings.items()}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json_object(value: object) -> JsonObject | None:
|
||||
try:
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
except ValidationError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_optional_string(value: object) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
|
||||
if isinstance(value, ToolProviderType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return ToolProviderType(value)
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@classmethod
|
||||
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
coerced: dict[str, JsonObject] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
return None
|
||||
json_object = cls._coerce_json_object(item)
|
||||
if json_object is None:
|
||||
return None
|
||||
coerced[key] = json_object
|
||||
return coerced
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@@ -32,13 +32,6 @@ from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SpecialValueScalar: TypeAlias = str | int | float | bool | None
|
||||
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
|
||||
SerializedSpecialValue: TypeAlias = (
|
||||
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
|
||||
)
|
||||
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
@@ -283,10 +276,10 @@ class WorkflowEntry:
|
||||
@staticmethod
|
||||
def _create_single_node_graph(
|
||||
node_id: str,
|
||||
node_data: Mapping[str, object],
|
||||
node_data: dict[str, Any],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> SingleNodeGraphConfig:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
@@ -296,14 +289,14 @@ class WorkflowEntry:
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config: dict[str, object] = {
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": dict(node_data),
|
||||
"data": node_data,
|
||||
}
|
||||
start_node_config: dict[str, object] = {
|
||||
start_node_config = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
@@ -328,12 +321,7 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls,
|
||||
node_data: Mapping[str, object],
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_inputs: Mapping[str, object],
|
||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
@@ -351,8 +339,6 @@ class WorkflowEntry:
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = node_data.get("type", "")
|
||||
if not isinstance(node_type, str):
|
||||
raise ValueError("Node type must be a string")
|
||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
@@ -383,7 +369,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@@ -419,34 +405,30 @@ class WorkflowEntry:
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
|
||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
raise TypeError("handle_special_values expects a mapping input")
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
|
||||
def _handle_special_values(value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, Mapping):
|
||||
res: dict[str, SerializedSpecialValue] = {}
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list: list[SerializedSpecialValue] = []
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return dict(value.to_dict())
|
||||
return value.to_dict()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -112,8 +112,6 @@ def _get_encoded_string(f: File, /) -> str:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@@ -133,8 +133,6 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
else:
|
||||
return
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
|
||||
@@ -336,7 +336,12 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
node_executions = graph_execution.node_executions
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
@@ -390,7 +395,8 @@ class Node(Generic[NodeDataT]):
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield event.model_copy(update={"id": self.execution_id})
|
||||
event.id = self.execution_id
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
|
||||
@@ -443,10 +443,7 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
doc_body = getattr(doc.element, "body", None)
|
||||
if doc_body is None:
|
||||
raise TextExtractionError("DOCX body not found")
|
||||
it = iter(doc_body)
|
||||
it = iter(doc.element.body)
|
||||
part = next(it, None)
|
||||
i = 0
|
||||
while part is not None:
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, NotRequired
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -11,17 +10,11 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class StructuredOutputConfig(TypedDict):
|
||||
schema: Mapping[str, object]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, object] = Field(default_factory=dict)
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
@@ -40,7 +33,7 @@ class VisionConfig(BaseModel):
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: object):
|
||||
def convert_none_configs(cls, v: Any):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
@@ -51,7 +44,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: object):
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
@@ -74,7 +67,7 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: StructuredOutputConfig | None = None
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
@@ -97,30 +90,11 @@ class LLMNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: object):
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@field_validator("structured_output", mode="before")
|
||||
@classmethod
|
||||
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
|
||||
if not isinstance(v, Mapping):
|
||||
return v
|
||||
|
||||
schema = v.get("schema")
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
normalized: StructuredOutputConfig = {"schema": schema}
|
||||
name = v.get("name")
|
||||
description = v.get("description")
|
||||
if isinstance(name, str):
|
||||
normalized["name"] = name
|
||||
if isinstance(description, str):
|
||||
normalized["description"] = description
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def structured_output_enabled(self) -> bool:
|
||||
return self.structured_output_switch_on and self.structured_output is not None
|
||||
|
||||
@@ -9,7 +9,6 @@ import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@@ -75,7 +74,6 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
StructuredOutputConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -90,7 +88,6 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
|
||||
class LLMNode(Node[LLMNodeData]):
|
||||
@@ -357,7 +354,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: StructuredOutputConfig | None = None,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@@ -370,10 +367,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
if structured_output is None:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output,
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@@ -925,12 +920,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
structured_output = (
|
||||
dict(invoke_result.structured_output)
|
||||
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
|
||||
else None
|
||||
)
|
||||
|
||||
event = ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
@@ -939,7 +928,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if enabled
|
||||
structured_output=structured_output,
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
)
|
||||
if request_latency is not None:
|
||||
event.usage.latency = round(request_latency, 3)
|
||||
@@ -973,18 +962,27 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: StructuredOutputConfig,
|
||||
) -> dict[str, object]:
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, object]: The structured output schema
|
||||
dict[str, Any]: The structured output schema
|
||||
"""
|
||||
schema = structured_output.get("schema")
|
||||
if not schema:
|
||||
if not structured_output:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(schema)
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
if not isinstance(schema, dict):
|
||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||
return schema
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, TypeAlias, cast
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
@@ -12,12 +9,6 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
@@ -38,36 +29,6 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
return seg_type
|
||||
|
||||
|
||||
def _validate_loop_value(value: object) -> LoopValue:
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return cast(LoopValue, value)
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_validate_loop_value(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
normalized: dict[str, LoopValue] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop values only support string object keys")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
|
||||
|
||||
|
||||
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("Loop outputs must be an object")
|
||||
|
||||
normalized: LoopValueMapping = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("Loop output keys must be strings")
|
||||
normalized[key] = _validate_loop_value(item)
|
||||
return normalized
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
Loop Variable Data.
|
||||
@@ -76,29 +37,7 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: LoopValue | VariableSelector | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
||||
value_type = validation_info.data.get("value_type")
|
||||
if value_type == "variable":
|
||||
if value is None:
|
||||
raise ValueError("Variable loop inputs require a selector")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if value_type == "constant":
|
||||
return _validate_loop_value(value)
|
||||
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.value_type != "variable":
|
||||
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
def require_constant_value(self) -> LoopValue:
|
||||
if self.value_type != "constant":
|
||||
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||
return _validate_loop_value(self.value)
|
||||
value: Any | list[str] | None = None
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@@ -107,14 +46,14 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: LoopValueMapping = Field(default_factory=dict)
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||
if value is None:
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
return {}
|
||||
return _validate_loop_value_mapping(value)
|
||||
return v
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@@ -138,8 +77,8 @@ class LoopState(BaseLoopState):
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[LoopValue] = Field(default_factory=list)
|
||||
current_output: LoopValue | None = None
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@@ -148,7 +87,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> LoopValue | None:
|
||||
def get_last_output(self) -> Any:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@@ -156,7 +95,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> LoopValue | None:
|
||||
def get_current_output(self) -> Any:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
|
||||
)
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs: dict[str, object] = {"loop_count": loop_count}
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
@@ -68,14 +68,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors: dict[str, list[str]] = {}
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: (
|
||||
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
|
||||
if var.value is not None
|
||||
else None
|
||||
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
|
||||
),
|
||||
}
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
@@ -97,7 +95,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||
|
||||
@@ -148,7 +146,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable: dict[str, LoopValue] = {}
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
@@ -299,29 +297,20 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, object],
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Get node configs from graph_config
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
node_configs: dict[str, Mapping[str, object]] = {}
|
||||
if isinstance(raw_nodes, list):
|
||||
for raw_node in raw_nodes:
|
||||
if not isinstance(raw_node, dict):
|
||||
continue
|
||||
raw_node_id = raw_node.get("id")
|
||||
if isinstance(raw_node_id, str):
|
||||
node_configs[raw_node_id] = raw_node
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
sub_node_data = sub_node_config.get("data")
|
||||
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -352,8 +341,9 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.require_variable_selector()
|
||||
selector = loop_variable.value
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
@@ -362,7 +352,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
@@ -373,19 +363,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids: set[str] = set()
|
||||
loop_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
if not isinstance(raw_nodes, list):
|
||||
return loop_node_ids
|
||||
|
||||
for node in raw_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
node_data = node.get("data")
|
||||
if not isinstance(node_data, dict):
|
||||
continue
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
@@ -394,7 +377,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
@@ -406,12 +389,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
# New typed payloads may already provide native lists, while legacy
|
||||
# configs still serialize array constants as JSON strings.
|
||||
if isinstance(original_value, str):
|
||||
value = json.loads(original_value) if original_value else []
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
value = original_value
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -6,7 +6,6 @@ from pydantic import (
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -56,7 +55,7 @@ class ParameterConfig(BaseModel):
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value: object) -> str:
|
||||
def validate_name(cls, value) -> str:
|
||||
if not value:
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
@@ -80,23 +79,6 @@ class ParameterConfig(BaseModel):
|
||||
return element_type
|
||||
|
||||
|
||||
class JsonSchemaArrayItems(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterJsonSchemaProperty(TypedDict, total=False):
|
||||
description: str
|
||||
type: str
|
||||
items: JsonSchemaArrayItems
|
||||
enum: list[str]
|
||||
|
||||
|
||||
class ParameterJsonSchema(TypedDict):
|
||||
type: Literal["object"]
|
||||
properties: dict[str, ParameterJsonSchemaProperty]
|
||||
required: list[str]
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
@@ -113,19 +95,19 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v: object) -> str:
|
||||
return str(v) if v else "function_call"
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> ParameterJsonSchema:
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
@@ -136,7 +118,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type.value
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -65,7 +63,6 @@ from .prompts import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
@@ -73,7 +70,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text: str) -> str | None:
|
||||
def extract_json(text):
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
@@ -395,15 +392,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
# generate tool
|
||||
parameter_schema = node_data.get_parameter_json_schema()
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters={
|
||||
"type": parameter_schema["type"],
|
||||
"properties": dict(parameter_schema["properties"]),
|
||||
"required": list(parameter_schema["required"]),
|
||||
},
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
@@ -610,21 +602,19 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result: dict[str, object] = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
if isinstance(param_value, (bool, int, float, str)):
|
||||
numeric_value: bool | int | float | str = param_value
|
||||
transformed = self._transform_number(numeric_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
@@ -671,7 +661,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@@ -682,11 +672,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
return cast(dict, json.loads(json_str))
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@@ -700,16 +690,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
return cast(dict, json.loads(json_str))
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result: dict[str, object] = {}
|
||||
result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
|
||||
@@ -1,66 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from typing import Literal, TypeAlias, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class WorkflowToolInputValue(TypedDict):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
|
||||
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
|
||||
|
||||
|
||||
class ToolInputPayload(BaseModel):
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
|
||||
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
|
||||
if isinstance(value, (str, int, float, bool)):
|
||||
return cast(ToolConfigurationEntry, value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
|
||||
|
||||
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@@ -68,29 +14,52 @@ class ToolEntity(BaseModel):
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: ToolConfigurations
|
||||
tool_configurations: dict[str, Any]
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("tool_configurations must be a dictionary")
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
normalized: ToolConfigurations = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("tool_configurations keys must be strings")
|
||||
normalized[key] = _validate_tool_configuration_entry(item)
|
||||
return normalized
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = BuiltinNodeTypes.TOOL
|
||||
|
||||
class ToolInput(ToolInputPayload):
|
||||
pass
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
|
||||
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
|
||||
return typ
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
@@ -100,7 +69,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value: object) -> object:
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
@@ -111,10 +80,8 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input: object) -> bool:
|
||||
def _has_valid_value(tool_input):
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
if isinstance(tool_input, ToolNodeData.ToolInput):
|
||||
return tool_input.value is not None
|
||||
return False
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
|
||||
@@ -225,11 +225,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
@@ -511,9 +510,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
variable_selector = input.require_variable_selector()
|
||||
selector_key = ".".join(variable_selector)
|
||||
result[f"#{selector_key}#"] = variable_selector
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from dify_graph.variables import Segment, SegmentType, VariableBase
|
||||
from dify_graph.variables import SegmentType, VariableBase
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
@@ -74,29 +74,23 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
income_value: Segment
|
||||
updated_variable: VariableBase
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
@@ -66,11 +66,6 @@ class GraphExecutionProtocol(Protocol):
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
@property
|
||||
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
|
||||
"""Return node execution state keyed by node id for resume support."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
@@ -96,12 +91,6 @@ class GraphExecutionProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeExecutionProtocol(Protocol):
|
||||
"""Structural interface for per-node execution state used during resume."""
|
||||
|
||||
execution_id: str | None
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
|
||||
@@ -11,13 +11,6 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@@ -13,6 +13,21 @@ controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
@@ -93,6 +108,35 @@ core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -90,12 +91,25 @@ class TokenPair(BaseModel):
|
||||
csrf_token: str
|
||||
|
||||
|
||||
class ChangeEmailPhase(StrEnum):
|
||||
OLD = "old_email"
|
||||
OLD_VERIFIED = "old_email_verified"
|
||||
NEW = "new_email"
|
||||
NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
|
||||
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@@ -552,13 +566,20 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@@ -950,6 +950,16 @@ class TestWorkflowAppService:
|
||||
assert result_with_new_email["total"] == 3
|
||||
assert all(log.created_by_role == CreatorUserRole.ACCOUNT for log in result_with_new_email["data"])
|
||||
|
||||
# Create another account in a different tenant using the original email.
|
||||
# Querying by the old email should still fail for this app's tenant.
|
||||
cross_tenant_account = AccountService.create_account(
|
||||
email=original_email,
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(cross_tenant_account, name=fake.company())
|
||||
|
||||
# Old email unbound, is unexpected input, should raise ValueError
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.get_paginate_workflow_app_logs(
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
def test_should_infer_new_email_phase_from_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.db")
|
||||
@patch("controllers.console.workspace.account.Session")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account_by_email,
|
||||
mock_session_cls,
|
||||
mock_account_db,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
|
||||
existing_account = _build_account("old@example.com", "acc-old")
|
||||
mock_get_account_by_email.return_value = existing_account
|
||||
mock_send_email.return_value = "token-legacy"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__.return_value = mock_session
|
||||
mock_session_cm.__exit__.return_value = None
|
||||
mock_session_cls.return_value = mock_session_cm
|
||||
mock_account_db.engine = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-legacy"}
|
||||
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=existing_account,
|
||||
email="old@example.com",
|
||||
old_email="old@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_unverified_old_email_token_for_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_refresh_new_email_phase_to_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("old@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "5678",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-phase-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("new@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-456")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="5678",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "new@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_phase_token_for_reset(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_mismatched_new_email_for_verified_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "another@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-789"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
|
||||
|
||||
refreshed = _refresh_model(session=None, model=source_model)
|
||||
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
|
||||
|
||||
assert refreshed is detached_model
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -33,8 +33,8 @@ def _make_graph_state():
|
||||
],
|
||||
)
|
||||
def test_run_uses_single_node_execution_branch(
|
||||
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
|
||||
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
|
||||
single_iteration_run: Any,
|
||||
single_loop_run: Any,
|
||||
) -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
@@ -130,23 +130,10 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "other-node",
|
||||
"target": "loop-node",
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
original_graph_dict = deepcopy(workflow.graph_dict)
|
||||
|
||||
_, _, graph_runtime_state = _make_graph_state()
|
||||
seen_configs: list[object] = []
|
||||
@@ -156,19 +143,13 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
|
||||
):
|
||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
@@ -180,8 +161,3 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
assert workflow.graph_dict == original_graph_dict
|
||||
graph_config = graph_init.call_args.kwargs["graph_config"]
|
||||
assert graph_config is not workflow.graph_dict
|
||||
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
|
||||
assert graph_config["edges"] == []
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
|
||||
from dify_graph.nodes.loop.entities import LoopNodeData
|
||||
from dify_graph.nodes.loop.loop_node import LoopNode
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
|
||||
@@ -53,21 +50,3 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
|
||||
)
|
||||
|
||||
assert seen_configs == [child_node_config]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("var_type", "original_value", "expected_value"),
|
||||
[
|
||||
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
|
||||
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
|
||||
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
|
||||
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
|
||||
],
|
||||
)
|
||||
def test_get_segment_for_constant_accepts_native_array_values(
|
||||
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
|
||||
) -> None:
|
||||
segment = LoopNode._get_segment_for_constant(var_type, original_value)
|
||||
|
||||
assert segment.value_type == var_type
|
||||
assert segment.value == expected_value
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@@ -11,7 +11,6 @@ import type { BasicPlan } from '@/app/components/billing/type'
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ALL_PLANS } from '@/app/components/billing/config'
|
||||
import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item'
|
||||
@@ -22,6 +21,7 @@ let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockFetchSubscriptionUrls = vi.fn()
|
||||
const mockInvoices = vi.fn()
|
||||
const mockOpenAsyncWindow = vi.fn()
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
// ─── Context mocks ───────────────────────────────────────────────────────────
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
@@ -49,6 +49,10 @@ vi.mock('@/hooks/use-async-window-open', () => ({
|
||||
useAsyncWindowOpen: () => mockOpenAsyncWindow,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
// ─── Navigation mocks ───────────────────────────────────────────────────────
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
@@ -78,15 +82,12 @@ const renderCloudPlanItem = ({
|
||||
canPay = true,
|
||||
}: RenderCloudPlanItemOptions = {}) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>
|
||||
</>,
|
||||
<CloudPlanItem
|
||||
currentPlan={currentPlan}
|
||||
plan={plan}
|
||||
planRange={planRange}
|
||||
canPay={canPay}
|
||||
/>,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -95,7 +96,6 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' })
|
||||
mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' })
|
||||
@@ -283,7 +283,11 @@ describe('Cloud Plan Payment Flow', () => {
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
)
|
||||
})
|
||||
// Should not proceed with payment
|
||||
expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled()
|
||||
|
||||
@@ -10,12 +10,12 @@
|
||||
import { cleanup, render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config'
|
||||
import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item'
|
||||
import { SelfHostedPlan } from '@/app/components/billing/type'
|
||||
|
||||
let mockAppCtx: Record<string, unknown> = {}
|
||||
const mockToastNotify = vi.fn()
|
||||
|
||||
const originalLocation = window.location
|
||||
let assignedHref = ''
|
||||
@@ -40,6 +40,10 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({
|
||||
AwsMarketplaceDark: () => <span data-testid="icon-aws-dark" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: (args: unknown) => mockToastNotify(args) },
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({
|
||||
default: ({ plan }: { plan: string }) => (
|
||||
<div data-testid={`self-hosted-list-${plan}`}>Features</div>
|
||||
@@ -53,20 +57,10 @@ const setupAppContext = (overrides: Record<string, unknown> = {}) => {
|
||||
}
|
||||
}
|
||||
|
||||
const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<SelfHostedPlanItem plan={plan} />
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Self-Hosted Plan Flow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
cleanup()
|
||||
toast.close()
|
||||
setupAppContext()
|
||||
|
||||
// Mock window.location with minimal getter/setter (Location props are non-enumerable)
|
||||
@@ -91,14 +85,14 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
// ─── 1. Plan Rendering ──────────────────────────────────────────────────
|
||||
describe('Plan rendering', () => {
|
||||
it('should render community plan with name and description', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
|
||||
expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render premium plan with cloud provider icons', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument()
|
||||
expect(screen.getByTestId('icon-azure')).toBeInTheDocument()
|
||||
@@ -106,39 +100,39 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
})
|
||||
|
||||
it('should render enterprise plan without cloud provider icons', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
|
||||
expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument()
|
||||
expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not show price tip for community (free) plan', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
|
||||
expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show price tip for premium plan', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
|
||||
expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render features list for each plan', () => {
|
||||
const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
const { unmount: unmount1 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument()
|
||||
unmount1()
|
||||
|
||||
const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
const { unmount: unmount2 } = render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument()
|
||||
unmount2()
|
||||
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show AWS marketplace icon for premium plan button', () => {
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
|
||||
expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument()
|
||||
})
|
||||
@@ -148,7 +142,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
describe('Navigation flow', () => {
|
||||
it('should redirect to GitHub when clicking community plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -158,7 +152,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to AWS Marketplace when clicking premium plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -168,7 +162,7 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
|
||||
it('should redirect to Typeform when clicking enterprise plan button', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
@@ -182,13 +176,15 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks community button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.community)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.community} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
// Should NOT redirect
|
||||
expect(assignedHref).toBe('')
|
||||
@@ -197,13 +193,15 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks premium button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.premium)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
@@ -211,13 +209,15 @@ describe('Self-Hosted Plan Flow', () => {
|
||||
it('should show error toast when non-manager clicks enterprise button', async () => {
|
||||
setupAppContext({ isCurrentWorkspaceManager: false })
|
||||
const user = userEvent.setup()
|
||||
renderSelfHostedPlanItem(SelfHostedPlan.enterprise)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.enterprise} />)
|
||||
|
||||
const button = screen.getByRole('button')
|
||||
await user.click(button)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'error' }),
|
||||
)
|
||||
})
|
||||
expect(assignedHref).toBe('')
|
||||
})
|
||||
|
||||
@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
const sendEmail = async (email: string, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
|
||||
@@ -13,7 +13,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
@@ -91,9 +91,9 @@ export default function OAuthAuthorize() {
|
||||
globalThis.location.href = url.toString()
|
||||
}
|
||||
catch (err: any) {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
|
||||
message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -102,10 +102,10 @@ export default function OAuthAuthorize() {
|
||||
const invalidParams = !client_id || !redirect_uri
|
||||
if ((invalidParams || isError) && !hasNotifiedRef.current) {
|
||||
hasNotifiedRef.current = true
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
timeout: 0,
|
||||
message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }),
|
||||
duration: 0,
|
||||
})
|
||||
}
|
||||
}, [client_id, redirect_uri, isError])
|
||||
|
||||
@@ -39,8 +39,8 @@ vi.mock('../app-card', () => ({
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
default: () => <div data-testid="create-from-template-modal" />,
|
||||
}))
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: { add: vi.fn() },
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: vi.fn() },
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
|
||||
@@ -12,7 +12,7 @@ import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import CreateAppModal from '@/app/components/explore/create-app-modal'
|
||||
import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks'
|
||||
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
|
||||
@@ -137,9 +137,9 @@ const Apps = ({
|
||||
})
|
||||
|
||||
setIsShowCreateModal(false)
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t('newApp.appCreated', { ns: 'app' }),
|
||||
message: t('newApp.appCreated', { ns: 'app' }),
|
||||
})
|
||||
if (onSuccess)
|
||||
onSuccess()
|
||||
@@ -149,7 +149,7 @@ const Apps = ({
|
||||
getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push)
|
||||
}
|
||||
catch {
|
||||
toast.add({ type: 'error', title: t('newApp.appCreateFailed', { ns: 'app' }) })
|
||||
Toast.notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* @deprecated Use `@/app/components/base/ui/toast` instead.
|
||||
* This module will be removed after migration is complete.
|
||||
* See: https://github.com/langgenius/dify/issues/32811
|
||||
*/
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import { createContext, useContext } from 'use-context-selector'
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export type IToastProps = {
|
||||
type?: 'success' | 'error' | 'warning' | 'info'
|
||||
size?: 'md' | 'sm'
|
||||
@@ -26,8 +19,5 @@ type IToastContext = {
|
||||
close: () => void
|
||||
}
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const ToastContext = createContext<IToastContext>({} as IToastContext)
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const useToastContext = () => useContext(ToastContext)
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
'use client'
|
||||
|
||||
/**
|
||||
* @deprecated Use `@/app/components/base/ui/toast` instead.
|
||||
* This component will be removed after migration is complete.
|
||||
* See: https://github.com/langgenius/dify/issues/32811
|
||||
*/
|
||||
|
||||
import type { ReactNode } from 'react'
|
||||
import type { IToastProps } from './context'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
@@ -19,7 +12,6 @@ import { ToastContext, useToastContext } from './context'
|
||||
export type ToastHandle = {
|
||||
clear?: VoidFunction
|
||||
}
|
||||
|
||||
const Toast = ({
|
||||
type = 'info',
|
||||
size = 'md',
|
||||
@@ -82,7 +74,6 @@ const Toast = ({
|
||||
)
|
||||
}
|
||||
|
||||
/** @deprecated Use `@/app/components/base/ui/toast` instead. See issue #32811. */
|
||||
export const ToastProvider = ({
|
||||
children,
|
||||
}: {
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../../config'
|
||||
import { Plan } from '../../../../type'
|
||||
import { PlanRange } from '../../../plan-switcher/plan-range-switcher'
|
||||
import CloudPlanItem from '../index'
|
||||
|
||||
vi.mock('../../../../../base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
@@ -41,19 +47,11 @@ const mockUseAppContext = useAppContext as Mock
|
||||
const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock
|
||||
const mockBillingInvoices = consoleClient.billing.invoices as Mock
|
||||
const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock
|
||||
const mockToastNotify = Toast.notify as Mock
|
||||
|
||||
let assignedHref = ''
|
||||
const originalLocation = window.location
|
||||
|
||||
const renderWithToastHost = (ui: React.ReactNode) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
{ui}
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
configurable: true,
|
||||
@@ -70,7 +68,6 @@ beforeAll(() => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toast.close()
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
|
||||
mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open()))
|
||||
mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' })
|
||||
@@ -166,7 +163,7 @@ describe('CloudPlanItem', () => {
|
||||
it('should show toast when non-manager tries to buy a plan', () => {
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
|
||||
|
||||
renderWithToastHost(
|
||||
render(
|
||||
<CloudPlanItem
|
||||
plan={Plan.professional}
|
||||
currentPlan={Plan.sandbox}
|
||||
@@ -176,7 +173,10 @@ describe('CloudPlanItem', () => {
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }))
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: 'billing.buyPermissionDeniedTip',
|
||||
}))
|
||||
expect(mockBillingInvoices).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import type { BasicPlan } from '../../../type'
|
||||
import * as React from 'react'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useAsyncWindowOpen } from '@/hooks/use-async-window-open'
|
||||
import { fetchSubscriptionUrls } from '@/service/billing'
|
||||
import { consoleClient } from '@/service/client'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { ALL_PLANS } from '../../../config'
|
||||
import { Plan } from '../../../type'
|
||||
import { Professional, Sandbox, Team } from '../../assets'
|
||||
@@ -66,9 +66,10 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
return
|
||||
|
||||
if (!isCurrentWorkspaceManager) {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
className: 'z-[1001]',
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -82,7 +83,7 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
throw new Error('Failed to open billing page')
|
||||
}, {
|
||||
onError: (err) => {
|
||||
toast.add({ type: 'error', title: err.message || String(err) })
|
||||
Toast.notify({ type: 'error', message: err.message || String(err) })
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -110,34 +111,34 @@ const CloudPlanItem: FC<CloudPlanItemProps> = ({
|
||||
{
|
||||
isMostPopularPlan && (
|
||||
<div className="flex items-center justify-center bg-saas-dify-blue-static px-1.5 py-1">
|
||||
<span className="text-text-primary-on-surface system-2xs-semibold-uppercase">
|
||||
<span className="system-2xs-semibold-uppercase text-text-primary-on-surface">
|
||||
{t('plansCommon.mostPopular', { ns: 'billing' })}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className="text-text-secondary system-sm-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
<div className="system-sm-regular text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* Price */}
|
||||
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
|
||||
{isFreePlan && (
|
||||
<span className="text-text-primary title-4xl-semi-bold">{t('plansCommon.free', { ns: 'billing' })}</span>
|
||||
<span className="title-4xl-semi-bold text-text-primary">{t('plansCommon.free', { ns: 'billing' })}</span>
|
||||
)}
|
||||
{!isFreePlan && (
|
||||
<>
|
||||
{isYear && (
|
||||
<span className="text-text-quaternary line-through title-4xl-semi-bold">
|
||||
<span className="title-4xl-semi-bold text-text-quaternary line-through">
|
||||
$
|
||||
{planInfo.price * 12}
|
||||
</span>
|
||||
)}
|
||||
<span className="text-text-primary title-4xl-semi-bold">
|
||||
<span className="title-4xl-semi-bold text-text-primary">
|
||||
$
|
||||
{isYear ? planInfo.price * 10 : planInfo.price}
|
||||
</span>
|
||||
<span className="pb-0.5 text-text-tertiary system-md-regular">
|
||||
<span className="system-md-regular pb-0.5 text-text-tertiary">
|
||||
{t('plansCommon.priceTip', { ns: 'billing' })}
|
||||
{t(`plansCommon.${!isYear ? 'month' : 'year'}`, { ns: 'billing' })}
|
||||
</span>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import Toast from '../../../../../base/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../../config'
|
||||
import { SelfHostedPlan } from '../../../../type'
|
||||
import SelfHostedPlanItem from '../index'
|
||||
@@ -16,6 +16,12 @@ vi.mock('../list', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('../../../../../base/toast', () => ({
|
||||
default: {
|
||||
notify: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: vi.fn(),
|
||||
}))
|
||||
@@ -29,19 +35,11 @@ vi.mock('../../../assets', () => ({
|
||||
}))
|
||||
|
||||
const mockUseAppContext = useAppContext as Mock
|
||||
const mockToastNotify = Toast.notify as Mock
|
||||
|
||||
let assignedHref = ''
|
||||
const originalLocation = window.location
|
||||
|
||||
const renderWithToastHost = (ui: React.ReactNode) => {
|
||||
return render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
{ui}
|
||||
</>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeAll(() => {
|
||||
Object.defineProperty(window, 'location', {
|
||||
configurable: true,
|
||||
@@ -58,7 +56,6 @@ beforeAll(() => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toast.close()
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true })
|
||||
assignedHref = ''
|
||||
})
|
||||
@@ -93,10 +90,13 @@ describe('SelfHostedPlanItem', () => {
|
||||
it('should show toast when non-manager tries to proceed', () => {
|
||||
mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false })
|
||||
|
||||
renderWithToastHost(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
render(<SelfHostedPlanItem plan={SelfHostedPlan.premium} />)
|
||||
fireEvent.click(screen.getByRole('button', { name: /billing\.plans\.premium\.btnText/ }))
|
||||
|
||||
expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument()
|
||||
expect(mockToastNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'error',
|
||||
message: 'billing.buyPermissionDeniedTip',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should redirect to community url when community plan button clicked', () => {
|
||||
|
||||
@@ -4,9 +4,9 @@ import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Azure, GoogleCloud } from '@/app/components/base/icons/src/public/billing'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Toast from '../../../../base/toast'
|
||||
import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '../../../config'
|
||||
import { SelfHostedPlan } from '../../../type'
|
||||
import { Community, Enterprise, EnterpriseNoise, Premium, PremiumNoise } from '../../assets'
|
||||
@@ -56,9 +56,10 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
const handleGetPayUrl = useCallback(() => {
|
||||
// Only workspace manager can buy plan
|
||||
if (!isCurrentWorkspaceManager) {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
message: t('buyPermissionDeniedTip', { ns: 'billing' }),
|
||||
className: 'z-[1001]',
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -81,18 +82,18 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
{/* Noise Effect */}
|
||||
{STYLE_MAP[plan].noise}
|
||||
<div className="flex flex-col px-5 py-4">
|
||||
<div className="flex flex-col gap-y-6 px-1 pt-10">
|
||||
<div className=" flex flex-col gap-y-6 px-1 pt-10">
|
||||
{STYLE_MAP[plan].icon}
|
||||
<div className="flex min-h-[104px] flex-col gap-y-2">
|
||||
<div className="text-[30px] font-medium leading-[1.2] text-text-primary">{t(`${i18nPrefix}.name`, { ns: 'billing' })}</div>
|
||||
<div className="line-clamp-2 text-text-secondary system-md-regular">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
<div className="system-md-regular line-clamp-2 text-text-secondary">{t(`${i18nPrefix}.description`, { ns: 'billing' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* Price */}
|
||||
<div className="flex items-end gap-x-2 px-1 pb-8 pt-4">
|
||||
<div className="shrink-0 text-text-primary title-4xl-semi-bold">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
|
||||
<div className="title-4xl-semi-bold shrink-0 text-text-primary">{t(`${i18nPrefix}.price`, { ns: 'billing' })}</div>
|
||||
{!isFreePlan && (
|
||||
<span className="pb-0.5 text-text-tertiary system-md-regular">
|
||||
<span className="system-md-regular pb-0.5 text-text-tertiary">
|
||||
{t(`${i18nPrefix}.priceTip`, { ns: 'billing' })}
|
||||
</span>
|
||||
)}
|
||||
@@ -113,7 +114,7 @@ const SelfHostedPlanItem: FC<SelfHostedPlanItemProps> = ({
|
||||
<GoogleCloud />
|
||||
</div>
|
||||
</div>
|
||||
<span className="text-text-tertiary system-xs-regular">
|
||||
<span className="system-xs-regular text-text-tertiary">
|
||||
{t('plans.premium.comingSoon', { ns: 'billing' })}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type * as React from 'react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { IndexingType } from '../../../create/step-two'
|
||||
|
||||
@@ -13,7 +13,14 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const toastAddSpy = vi.spyOn(toast, 'add')
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('use-context-selector', async (importOriginal) => {
|
||||
const actual = await importOriginal() as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ notify: mockNotify }),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock dataset detail context
|
||||
let mockIndexingTechnique = IndexingType.QUALIFIED
|
||||
@@ -44,6 +51,11 @@ vi.mock('@/service/knowledge/use-segment', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock app store
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: () => ({ appSidebarExpand: 'expand' }),
|
||||
}))
|
||||
|
||||
vi.mock('../completed/common/action-buttons', () => ({
|
||||
default: ({ handleCancel, handleSave, loading, actionType }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string }) => (
|
||||
<div data-testid="action-buttons">
|
||||
@@ -127,8 +139,6 @@ vi.mock('@/app/components/datasets/common/image-uploader/image-uploader-in-chunk
|
||||
describe('NewSegmentModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useRealTimers()
|
||||
toast.close()
|
||||
mockFullScreen = false
|
||||
mockIndexingTechnique = IndexingType.QUALIFIED
|
||||
})
|
||||
@@ -248,7 +258,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -262,7 +272,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -277,7 +287,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -327,7 +337,7 @@ describe('NewSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
}),
|
||||
@@ -420,9 +430,10 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Action button in success notification', () => {
|
||||
it('should call viewNewlyAddedChunk when the toast action is clicked', async () => {
|
||||
describe('CustomButton in success notification', () => {
|
||||
it('should call viewNewlyAddedChunk when custom button is clicked', async () => {
|
||||
const mockViewNewlyAddedChunk = vi.fn()
|
||||
mockNotify.mockImplementation(() => {})
|
||||
|
||||
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
|
||||
options.onSuccess()
|
||||
@@ -431,25 +442,37 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
|
||||
render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<NewSegmentModal
|
||||
{...defaultProps}
|
||||
docForm={ChunkingMode.text}
|
||||
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
|
||||
/>
|
||||
</>,
|
||||
<NewSegmentModal
|
||||
{...defaultProps}
|
||||
docForm={ChunkingMode.text}
|
||||
viewNewlyAddedChunk={mockViewNewlyAddedChunk}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Enter content and save
|
||||
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockViewNewlyAddedChunk).toHaveBeenCalledTimes(1)
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
customComponent: expect.anything(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
|
||||
// Extract customComponent from the notify call args
|
||||
const notifyCallArgs = mockNotify.mock.calls[0][0] as { customComponent?: React.ReactElement }
|
||||
expect(notifyCallArgs.customComponent).toBeDefined()
|
||||
const customComponent = notifyCallArgs.customComponent!
|
||||
const { container: btnContainer } = render(customComponent)
|
||||
const viewButton = btnContainer.querySelector('.system-xs-semibold.text-text-accent') as HTMLElement
|
||||
expect(viewButton).toBeInTheDocument()
|
||||
fireEvent.click(viewButton)
|
||||
|
||||
// Assert that viewNewlyAddedChunk was called via the onClick handler (lines 66-67)
|
||||
expect(mockViewNewlyAddedChunk).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -576,8 +599,9 @@ describe('NewSegmentModal', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('onSave after success', () => {
|
||||
it('should call onSave immediately after save succeeds', async () => {
|
||||
describe('onSave delayed call', () => {
|
||||
it('should call onSave after timeout in success handler', async () => {
|
||||
vi.useFakeTimers()
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddSegment.mockImplementation((_params: unknown, options: { onSuccess: () => void, onSettled: () => void }) => {
|
||||
options.onSuccess()
|
||||
@@ -587,12 +611,15 @@ describe('NewSegmentModal', () => {
|
||||
|
||||
render(<NewSegmentModal {...defaultProps} onSave={mockOnSave} docForm={ChunkingMode.text} />)
|
||||
|
||||
// Enter content and save
|
||||
fireEvent.change(screen.getByTestId('question-input'), { target: { value: 'Test content' } })
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnSave).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
// Fast-forward timer
|
||||
vi.advanceTimersByTime(3000)
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalled()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { toast, ToastHost } from '@/app/components/base/ui/toast'
|
||||
|
||||
import NewChildSegmentModal from '../new-child-segment'
|
||||
|
||||
@@ -11,7 +10,14 @@ vi.mock('@/next/navigation', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
const toastAddSpy = vi.spyOn(toast, 'add')
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('use-context-selector', async (importOriginal) => {
|
||||
const actual = await importOriginal() as Record<string, unknown>
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ notify: mockNotify }),
|
||||
}
|
||||
})
|
||||
|
||||
// Mock document context
|
||||
let mockParentMode = 'paragraph'
|
||||
@@ -42,6 +48,11 @@ vi.mock('@/service/knowledge/use-segment', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock app store
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
useStore: () => ({ appSidebarExpand: 'expand' }),
|
||||
}))
|
||||
|
||||
vi.mock('../common/action-buttons', () => ({
|
||||
default: ({ handleCancel, handleSave, loading, actionType, isChildChunk }: { handleCancel: () => void, handleSave: () => void, loading: boolean, actionType: string, isChildChunk?: boolean }) => (
|
||||
<div data-testid="action-buttons">
|
||||
@@ -92,8 +103,6 @@ vi.mock('../common/segment-index-tag', () => ({
|
||||
describe('NewChildSegmentModal', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useRealTimers()
|
||||
toast.close()
|
||||
mockFullScreen = false
|
||||
mockParentMode = 'paragraph'
|
||||
})
|
||||
@@ -189,7 +198,7 @@ describe('NewChildSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'error',
|
||||
}),
|
||||
@@ -244,7 +253,7 @@ describe('NewChildSegmentModal', () => {
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastAddSpy).toHaveBeenCalledWith(
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
}),
|
||||
@@ -365,62 +374,35 @@ describe('NewChildSegmentModal', () => {
|
||||
|
||||
// View newly added chunk
|
||||
describe('View Newly Added Chunk', () => {
|
||||
it('should call viewNewlyAddedChildChunk when the toast action is clicked', async () => {
|
||||
it('should show custom button in full-doc mode after save', async () => {
|
||||
mockParentMode = 'full-doc'
|
||||
const mockViewNewlyAddedChildChunk = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
options.onSuccess({ data: { id: 'new-child-id' } })
|
||||
options.onSettled()
|
||||
return Promise.resolve()
|
||||
})
|
||||
|
||||
render(
|
||||
<>
|
||||
<ToastHost timeout={0} />
|
||||
<NewChildSegmentModal
|
||||
{...defaultProps}
|
||||
viewNewlyAddedChildChunk={mockViewNewlyAddedChildChunk}
|
||||
/>
|
||||
</>,
|
||||
)
|
||||
render(<NewChildSegmentModal {...defaultProps} />)
|
||||
|
||||
// Enter valid content
|
||||
fireEvent.change(screen.getByTestId('content-input'), {
|
||||
target: { value: 'Valid content' },
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
const actionButton = await screen.findByRole('button', { name: 'common.operation.view' })
|
||||
fireEvent.click(actionButton)
|
||||
|
||||
// Assert - success notification with custom component
|
||||
await waitFor(() => {
|
||||
expect(mockViewNewlyAddedChildChunk).toHaveBeenCalledTimes(1)
|
||||
expect(mockNotify).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'success',
|
||||
customComponent: expect.anything(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onSave immediately in full-doc mode after save succeeds', async () => {
|
||||
mockParentMode = 'full-doc'
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
options.onSuccess({ data: { id: 'new-child-id' } })
|
||||
options.onSettled()
|
||||
return Promise.resolve()
|
||||
})
|
||||
|
||||
render(<NewChildSegmentModal {...defaultProps} onSave={mockOnSave} />)
|
||||
|
||||
fireEvent.change(screen.getByTestId('content-input'), {
|
||||
target: { value: 'Valid content' },
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('save-btn'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnSave).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onSave with the new child chunk in paragraph mode', async () => {
|
||||
it('should not show custom button in paragraph mode after save', async () => {
|
||||
mockParentMode = 'paragraph'
|
||||
const mockOnSave = vi.fn()
|
||||
mockAddChildSegment.mockImplementation((_params, options) => {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import type { FC } from 'react'
|
||||
import type { ChildChunkDetail, SegmentUpdater } from '@/models/datasets'
|
||||
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
|
||||
import { memo, useState } from 'react'
|
||||
import { memo, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
import { useParams } from '@/next/navigation'
|
||||
import { useAddChildSegment } from '@/service/knowledge/use-segment'
|
||||
@@ -32,15 +35,39 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
viewNewlyAddedChildChunk,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [content, setContent] = useState('')
|
||||
const { datasetId, documentId } = useParams<{ datasetId: string, documentId: string }>()
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [addAnother, setAddAnother] = useState(true)
|
||||
const fullScreen = useSegmentListContext(s => s.fullScreen)
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const { appSidebarExpand } = useAppStore(useShallow(state => ({
|
||||
appSidebarExpand: state.appSidebarExpand,
|
||||
})))
|
||||
const parentMode = useDocumentContext(s => s.parentMode)
|
||||
|
||||
const isFullDocMode = parentMode === 'full-doc'
|
||||
const refreshTimer = useRef<any>(null)
|
||||
|
||||
const isFullDocMode = useMemo(() => {
|
||||
return parentMode === 'full-doc'
|
||||
}, [parentMode])
|
||||
|
||||
const CustomButton = (
|
||||
<>
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChildChunk?.()
|
||||
}}
|
||||
>
|
||||
{t('operation.view', { ns: 'common' })}
|
||||
</button>
|
||||
</>
|
||||
)
|
||||
|
||||
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
|
||||
if (actionType === 'esc' || !addAnother)
|
||||
@@ -53,27 +80,26 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
const params: SegmentUpdater = { content: '' }
|
||||
|
||||
if (!content.trim())
|
||||
return toast.add({ type: 'error', title: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
|
||||
return notify({ type: 'error', message: t('segment.contentEmpty', { ns: 'datasetDocuments' }) })
|
||||
|
||||
params.content = content
|
||||
|
||||
setLoading(true)
|
||||
await addChildSegment({ datasetId, documentId, segmentId: chunkId, body: params }, {
|
||||
onSuccess(res) {
|
||||
toast.add({
|
||||
notify({
|
||||
type: 'success',
|
||||
title: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
|
||||
actionProps: isFullDocMode
|
||||
? {
|
||||
children: t('operation.view', { ns: 'common' }),
|
||||
onClick: viewNewlyAddedChildChunk,
|
||||
}
|
||||
: undefined,
|
||||
message: t('segment.childChunkAdded', { ns: 'datasetDocuments' }),
|
||||
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
|
||||
!top-auto !right-auto !mb-[52px] !ml-11`,
|
||||
customComponent: isFullDocMode && CustomButton,
|
||||
})
|
||||
handleCancel('add')
|
||||
setContent('')
|
||||
if (isFullDocMode) {
|
||||
onSave()
|
||||
refreshTimer.current = setTimeout(() => {
|
||||
onSave()
|
||||
}, 3000)
|
||||
}
|
||||
else {
|
||||
onSave(res.data)
|
||||
@@ -85,8 +111,10 @@ const NewChildSegmentModal: FC<NewChildSegmentModalProps> = ({
|
||||
})
|
||||
}
|
||||
|
||||
const count = content.length
|
||||
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
const wordCountText = useMemo(() => {
|
||||
const count = content.length
|
||||
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
}, [content.length])
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
|
||||
@@ -2,10 +2,13 @@ import type { FC } from 'react'
|
||||
import type { FileEntity } from '@/app/components/datasets/common/image-uploader/types'
|
||||
import type { SegmentUpdater } from '@/models/datasets'
|
||||
import { RiCloseLine, RiExpandDiagonalLine } from '@remixicon/react'
|
||||
import { memo, useCallback, useState } from 'react'
|
||||
import { memo, useCallback, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import ImageUploaderInChunk from '@/app/components/datasets/common/image-uploader/image-uploader-in-chunk'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import { ChunkingMode } from '@/models/datasets'
|
||||
@@ -36,6 +39,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
viewNewlyAddedChunk,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [question, setQuestion] = useState('')
|
||||
const [answer, setAnswer] = useState('')
|
||||
const [attachments, setAttachments] = useState<FileEntity[]>([])
|
||||
@@ -46,7 +50,27 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
const fullScreen = useSegmentListContext(s => s.fullScreen)
|
||||
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
|
||||
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
|
||||
const [imageUploaderKey, setImageUploaderKey] = useState(() => Date.now())
|
||||
const { appSidebarExpand } = useAppStore(useShallow(state => ({
|
||||
appSidebarExpand: state.appSidebarExpand,
|
||||
})))
|
||||
const [imageUploaderKey, setImageUploaderKey] = useState(Date.now())
|
||||
const refreshTimer = useRef<any>(null)
|
||||
|
||||
const CustomButton = useMemo(() => (
|
||||
<>
|
||||
<Divider type="vertical" className="mx-1 h-3 bg-divider-regular" />
|
||||
<button
|
||||
type="button"
|
||||
className="text-text-accent system-xs-semibold"
|
||||
onClick={() => {
|
||||
clearTimeout(refreshTimer.current)
|
||||
viewNewlyAddedChunk()
|
||||
}}
|
||||
>
|
||||
{t('operation.view', { ns: 'common' })}
|
||||
</button>
|
||||
</>
|
||||
), [viewNewlyAddedChunk, t])
|
||||
|
||||
const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
|
||||
if (actionType === 'esc' || !addAnother)
|
||||
@@ -63,15 +87,15 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
const params: SegmentUpdater = { content: '', attachment_ids: [] }
|
||||
if (docForm === ChunkingMode.qa) {
|
||||
if (!question.trim()) {
|
||||
return toast.add({
|
||||
return notify({
|
||||
type: 'error',
|
||||
title: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
|
||||
message: t('segment.questionEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
if (!answer.trim()) {
|
||||
return toast.add({
|
||||
return notify({
|
||||
type: 'error',
|
||||
title: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
|
||||
message: t('segment.answerEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -80,9 +104,9 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
}
|
||||
else {
|
||||
if (!question.trim()) {
|
||||
return toast.add({
|
||||
return notify({
|
||||
type: 'error',
|
||||
title: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
|
||||
message: t('segment.contentEmpty', { ns: 'datasetDocuments' }),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -98,13 +122,12 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
setLoading(true)
|
||||
await addSegment({ datasetId, documentId, body: params }, {
|
||||
onSuccess() {
|
||||
toast.add({
|
||||
notify({
|
||||
type: 'success',
|
||||
title: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
|
||||
actionProps: {
|
||||
children: t('operation.view', { ns: 'common' }),
|
||||
onClick: viewNewlyAddedChunk,
|
||||
},
|
||||
message: t('segment.chunkAdded', { ns: 'datasetDocuments' }),
|
||||
className: `!w-[296px] !bottom-0 ${appSidebarExpand === 'expand' ? '!left-[216px]' : '!left-14'}
|
||||
!top-auto !right-auto !mb-[52px] !ml-11`,
|
||||
customComponent: CustomButton,
|
||||
})
|
||||
handleCancel('add')
|
||||
setQuestion('')
|
||||
@@ -112,16 +135,20 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
|
||||
setAttachments([])
|
||||
setImageUploaderKey(Date.now())
|
||||
setKeywords([])
|
||||
onSave()
|
||||
refreshTimer.current = setTimeout(() => {
|
||||
onSave()
|
||||
}, 3000)
|
||||
},
|
||||
onSettled() {
|
||||
setLoading(false)
|
||||
},
|
||||
})
|
||||
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, t, handleCancel, onSave, viewNewlyAddedChunk])
|
||||
}, [docForm, keywords, addSegment, datasetId, documentId, question, answer, attachments, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
|
||||
|
||||
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
const wordCountText = `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
const wordCountText = useMemo(() => {
|
||||
const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
|
||||
return `${formatNumber(count)} ${t('segment.characters', { ns: 'datasetDocuments', count })}`
|
||||
}, [question.length, answer.length, docForm, t])
|
||||
|
||||
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
|
||||
|
||||
|
||||
@@ -21,11 +21,11 @@ vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path?: string) => `https://docs.dify.ai/en${path || ''}`,
|
||||
}))
|
||||
|
||||
const mockNotify = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockNotify,
|
||||
},
|
||||
const mockNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}))
|
||||
|
||||
// Mock modal context
|
||||
@@ -164,7 +164,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
// Verify success notification
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
title: 'External Knowledge Base Connected Successfully',
|
||||
message: 'External Knowledge Base Connected Successfully',
|
||||
})
|
||||
|
||||
// Verify navigation back
|
||||
@@ -206,7 +206,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
title: 'Failed to connect External Knowledge Base',
|
||||
message: 'Failed to connect External Knowledge Base',
|
||||
})
|
||||
})
|
||||
|
||||
@@ -228,7 +228,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
title: 'Failed to connect External Knowledge Base',
|
||||
message: 'Failed to connect External Knowledge Base',
|
||||
})
|
||||
})
|
||||
|
||||
@@ -274,7 +274,7 @@ describe('ExternalKnowledgeBaseConnector', () => {
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
title: 'External Knowledge Base Connected Successfully',
|
||||
message: 'External Knowledge Base Connected Successfully',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,13 @@ import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { createExternalKnowledgeBase } from '@/service/datasets'
|
||||
|
||||
const ExternalKnowledgeBaseConnector = () => {
|
||||
const { notify } = useToastContext()
|
||||
const [loading, setLoading] = useState(false)
|
||||
const router = useRouter()
|
||||
|
||||
@@ -18,7 +19,7 @@ const ExternalKnowledgeBaseConnector = () => {
|
||||
setLoading(true)
|
||||
const result = await createExternalKnowledgeBase({ body: formValue })
|
||||
if (result && result.id) {
|
||||
toast.add({ type: 'success', title: 'External Knowledge Base Connected Successfully' })
|
||||
notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' })
|
||||
trackEvent('create_external_knowledge_base', {
|
||||
provider: formValue.provider,
|
||||
name: formValue.name,
|
||||
@@ -29,7 +30,7 @@ const ExternalKnowledgeBaseConnector = () => {
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Error creating external knowledge base:', error)
|
||||
toast.add({ type: 'error', title: 'Failed to connect External Knowledge Base' })
|
||||
notify({ type: 'error', message: 'Failed to connect External Knowledge Base' })
|
||||
}
|
||||
setLoading(false)
|
||||
}
|
||||
|
||||
@@ -36,13 +36,30 @@ const TransferOwnershipModal = ({ onClose, show }: Props) => {
|
||||
const [stepToken, setStepToken] = useState<string>('')
|
||||
const [newOwner, setNewOwner] = useState<string>('')
|
||||
const [isTransfer, setIsTransfer] = useState<boolean>(false)
|
||||
const timerRef = React.useRef<ReturnType<typeof setInterval> | null>(null)
|
||||
|
||||
React.useEffect(() => {
|
||||
return () => {
|
||||
if (timerRef.current) {
|
||||
clearInterval(timerRef.current)
|
||||
timerRef.current = null
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
const startCount = () => {
|
||||
if (timerRef.current) {
|
||||
clearInterval(timerRef.current)
|
||||
timerRef.current = null
|
||||
}
|
||||
setTime(60)
|
||||
const timer = setInterval(() => {
|
||||
timerRef.current = setInterval(() => {
|
||||
setTime((prev) => {
|
||||
if (prev <= 0) {
|
||||
clearInterval(timer)
|
||||
if (timerRef.current) {
|
||||
clearInterval(timerRef.current)
|
||||
timerRef.current = null
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return prev - 1
|
||||
|
||||
@@ -43,10 +43,10 @@ vi.mock('@/context/provider-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockNotify,
|
||||
},
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
@@ -150,7 +150,7 @@ describe('SystemModel', () => {
|
||||
expect(mockUpdateDefaultModel).toHaveBeenCalledTimes(1)
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
title: 'Modified successfully',
|
||||
message: 'Modified successfully',
|
||||
})
|
||||
expect(mockInvalidateDefaultModel).toHaveBeenCalledTimes(5)
|
||||
expect(mockUpdateModelList).toHaveBeenCalledTimes(5)
|
||||
|
||||
@@ -6,13 +6,13 @@ import type {
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { useToastContext } from '@/app/components/base/toast/context'
|
||||
import {
|
||||
Dialog,
|
||||
DialogCloseButton,
|
||||
DialogContent,
|
||||
DialogTitle,
|
||||
} from '@/app/components/base/ui/dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
@@ -64,6 +64,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
isLoading,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useToastContext()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const { textGenerationModelList } = useProviderContext()
|
||||
const updateModelList = useUpdateModelList()
|
||||
@@ -123,7 +124,7 @@ const SystemModel: FC<SystemModelSelectorProps> = ({
|
||||
},
|
||||
})
|
||||
if (res.result === 'success') {
|
||||
toast.add({ type: 'success', title: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
notify({ type: 'success', message: t('actionMsg.modifiedSuccessfully', { ns: 'common' }) })
|
||||
setOpen(false)
|
||||
|
||||
const allModelTypes = [ModelTypeEnum.textGeneration, ModelTypeEnum.textEmbedding, ModelTypeEnum.rerank, ModelTypeEnum.speech2text, ModelTypeEnum.tts]
|
||||
|
||||
@@ -4,7 +4,7 @@ import { DeleteConfirm } from '../delete-confirm'
|
||||
|
||||
const mockRefetch = vi.fn()
|
||||
const mockDelete = vi.fn()
|
||||
const mockToastAdd = vi.hoisted(() => vi.fn())
|
||||
const mockToast = vi.fn()
|
||||
|
||||
vi.mock('../use-subscription-list', () => ({
|
||||
useSubscriptionList: () => ({ refetch: mockRefetch }),
|
||||
@@ -14,9 +14,9 @@ vi.mock('@/service/use-triggers', () => ({
|
||||
useDeleteTriggerSubscription: () => ({ mutate: mockDelete, isPending: false }),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/ui/toast', () => ({
|
||||
toast: {
|
||||
add: mockToastAdd,
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (args: { type: string, message: string }) => mockToast(args),
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -42,7 +42,7 @@ describe('DeleteConfirm', () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
|
||||
|
||||
expect(mockDelete).not.toHaveBeenCalled()
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
|
||||
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error' }))
|
||||
})
|
||||
|
||||
it('should allow deletion after matching input name', () => {
|
||||
@@ -87,6 +87,6 @@ describe('DeleteConfirm', () => {
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: /pluginTrigger\.subscription\.list\.item\.actions\.deleteConfirm\.confirm/ }))
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', title: 'network error' }))
|
||||
expect(mockToast).toHaveBeenCalledWith(expect.objectContaining({ type: 'error', message: 'network error' }))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import Input from '@/app/components/base/input'
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogActions,
|
||||
AlertDialogCancelButton,
|
||||
AlertDialogConfirmButton,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogTitle,
|
||||
} from '@/app/components/base/ui/alert-dialog'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useDeleteTriggerSubscription } from '@/service/use-triggers'
|
||||
import { useSubscriptionList } from './use-subscription-list'
|
||||
|
||||
@@ -31,74 +23,58 @@ export const DeleteConfirm = (props: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const [inputName, setInputName] = useState('')
|
||||
|
||||
const handleOpenChange = (open: boolean) => {
|
||||
if (isDeleting)
|
||||
return
|
||||
|
||||
if (!open)
|
||||
onClose(false)
|
||||
}
|
||||
|
||||
const onConfirm = () => {
|
||||
if (workflowsInUse > 0 && inputName !== currentName) {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
|
||||
message: t(`${tPrefix}.confirmInputWarning`, { ns: 'pluginTrigger' }),
|
||||
// temporarily
|
||||
className: 'z-[10000001]',
|
||||
})
|
||||
return
|
||||
}
|
||||
deleteSubscription(currentId, {
|
||||
onSuccess: () => {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
|
||||
message: t(`${tPrefix}.success`, { ns: 'pluginTrigger', name: currentName }),
|
||||
className: 'z-[10000001]',
|
||||
})
|
||||
refetch?.()
|
||||
onClose(true)
|
||||
},
|
||||
onError: (error: unknown) => {
|
||||
toast.add({
|
||||
onError: (error: any) => {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: error instanceof Error ? error.message : t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
|
||||
message: error?.message || t(`${tPrefix}.error`, { ns: 'pluginTrigger', name: currentName }),
|
||||
className: 'z-[10000001]',
|
||||
})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<AlertDialog open={isShow} onOpenChange={handleOpenChange}>
|
||||
<AlertDialogContent backdropProps={{ forceRender: true }}>
|
||||
<div className="flex flex-col gap-2 px-6 pb-4 pt-6">
|
||||
<AlertDialogTitle title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })} className="w-full truncate text-text-primary title-2xl-semi-bold">
|
||||
{t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
|
||||
</AlertDialogTitle>
|
||||
<AlertDialogDescription className="w-full whitespace-pre-wrap break-words text-text-tertiary system-md-regular">
|
||||
{workflowsInUse > 0
|
||||
? t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })
|
||||
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
|
||||
</AlertDialogDescription>
|
||||
{workflowsInUse > 0 && (
|
||||
<div className="mt-6">
|
||||
<div className="mb-2 text-text-secondary system-sm-medium">
|
||||
{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}
|
||||
</div>
|
||||
<Confirm
|
||||
title={t(`${tPrefix}.title`, { ns: 'pluginTrigger', name: currentName })}
|
||||
confirmText={t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
|
||||
content={workflowsInUse > 0
|
||||
? (
|
||||
<>
|
||||
{t(`${tPrefix}.contentWithApps`, { ns: 'pluginTrigger', count: workflowsInUse })}
|
||||
<div className="system-sm-medium mb-2 mt-6 text-text-secondary">{t(`${tPrefix}.confirmInputTip`, { ns: 'pluginTrigger', name: currentName })}</div>
|
||||
<Input
|
||||
value={inputName}
|
||||
onChange={e => setInputName(e.target.value)}
|
||||
placeholder={t(`${tPrefix}.confirmInputPlaceholder`, { ns: 'pluginTrigger', name: currentName })}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<AlertDialogActions>
|
||||
<AlertDialogCancelButton disabled={isDeleting}>
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</AlertDialogCancelButton>
|
||||
<AlertDialogConfirmButton loading={isDeleting} disabled={isDeleting} onClick={onConfirm}>
|
||||
{t(`${tPrefix}.confirm`, { ns: 'pluginTrigger' })}
|
||||
</AlertDialogConfirmButton>
|
||||
</AlertDialogActions>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
</>
|
||||
)
|
||||
: t(`${tPrefix}.content`, { ns: 'pluginTrigger' })}
|
||||
isShow={isShow}
|
||||
isLoading={isDeleting}
|
||||
isDisabled={isDeleting}
|
||||
onConfirm={onConfirm}
|
||||
onCancel={() => onClose(false)}
|
||||
maskClosable={false}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import type { Node } from '../types'
|
||||
import { screen } from '@testing-library/react'
|
||||
import CandidateNode from '../candidate-node'
|
||||
import { BlockEnum } from '../types'
|
||||
import { renderWorkflowComponent } from './workflow-test-env'
|
||||
|
||||
vi.mock('../candidate-node-main', () => ({
|
||||
default: ({ candidateNode }: { candidateNode: Node }) => (
|
||||
<div data-testid="candidate-node-main">{candidateNode.id}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
const createCandidateNode = (): Node => ({
|
||||
id: 'candidate-node-1',
|
||||
type: 'custom',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
type: BlockEnum.Start,
|
||||
title: 'Candidate node',
|
||||
desc: 'candidate',
|
||||
},
|
||||
})
|
||||
|
||||
describe('CandidateNode', () => {
|
||||
it('should not render when candidateNode is missing from the workflow store', () => {
|
||||
renderWorkflowComponent(<CandidateNode />)
|
||||
|
||||
expect(screen.queryByTestId('candidate-node-main')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render CandidateNodeMain with the stored candidate node', () => {
|
||||
renderWorkflowComponent(<CandidateNode />, {
|
||||
initialStoreState: {
|
||||
candidateNode: createCandidateNode(),
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByTestId('candidate-node-main')).toHaveTextContent('candidate-node-1')
|
||||
})
|
||||
})
|
||||
@@ -1,81 +0,0 @@
|
||||
import type { ComponentProps } from 'react'
|
||||
import { render } from '@testing-library/react'
|
||||
import { getBezierPath, Position } from 'reactflow'
|
||||
import CustomConnectionLine from '../custom-connection-line'
|
||||
|
||||
const createConnectionLineProps = (
|
||||
overrides: Partial<ComponentProps<typeof CustomConnectionLine>> = {},
|
||||
): ComponentProps<typeof CustomConnectionLine> => ({
|
||||
fromX: 10,
|
||||
fromY: 20,
|
||||
toX: 70,
|
||||
toY: 80,
|
||||
fromPosition: Position.Right,
|
||||
toPosition: Position.Left,
|
||||
connectionLineType: undefined,
|
||||
connectionStatus: null,
|
||||
...overrides,
|
||||
} as ComponentProps<typeof CustomConnectionLine>)
|
||||
|
||||
describe('CustomConnectionLine', () => {
|
||||
it('should render the bezier path and target marker', () => {
|
||||
const [expectedPath] = getBezierPath({
|
||||
sourceX: 10,
|
||||
sourceY: 20,
|
||||
sourcePosition: Position.Right,
|
||||
targetX: 70,
|
||||
targetY: 80,
|
||||
targetPosition: Position.Left,
|
||||
curvature: 0.16,
|
||||
})
|
||||
|
||||
const { container } = render(
|
||||
<svg>
|
||||
<CustomConnectionLine {...createConnectionLineProps()} />
|
||||
</svg>,
|
||||
)
|
||||
|
||||
const path = container.querySelector('path')
|
||||
const marker = container.querySelector('rect')
|
||||
|
||||
expect(path).toHaveAttribute('fill', 'none')
|
||||
expect(path).toHaveAttribute('stroke', '#D0D5DD')
|
||||
expect(path).toHaveAttribute('stroke-width', '2')
|
||||
expect(path).toHaveAttribute('d', expectedPath)
|
||||
|
||||
expect(marker).toHaveAttribute('x', '70')
|
||||
expect(marker).toHaveAttribute('y', '76')
|
||||
expect(marker).toHaveAttribute('width', '2')
|
||||
expect(marker).toHaveAttribute('height', '8')
|
||||
expect(marker).toHaveAttribute('fill', '#2970FF')
|
||||
})
|
||||
|
||||
it('should update the path when the endpoints change', () => {
|
||||
const [expectedPath] = getBezierPath({
|
||||
sourceX: 30,
|
||||
sourceY: 40,
|
||||
sourcePosition: Position.Right,
|
||||
targetX: 160,
|
||||
targetY: 200,
|
||||
targetPosition: Position.Left,
|
||||
curvature: 0.16,
|
||||
})
|
||||
|
||||
const { container } = render(
|
||||
<svg>
|
||||
<CustomConnectionLine
|
||||
{...createConnectionLineProps({
|
||||
fromX: 30,
|
||||
fromY: 40,
|
||||
toX: 160,
|
||||
toY: 200,
|
||||
})}
|
||||
/>
|
||||
</svg>,
|
||||
)
|
||||
|
||||
expect(container.querySelector('path')).toHaveAttribute('d', expectedPath)
|
||||
expect(container.querySelector('rect')).toHaveAttribute('x', '160')
|
||||
expect(container.querySelector('rect')).toHaveAttribute('y', '196')
|
||||
})
|
||||
})
|
||||
@@ -1,57 +0,0 @@
|
||||
import { render } from '@testing-library/react'
|
||||
import CustomEdgeLinearGradientRender from '../custom-edge-linear-gradient-render'
|
||||
|
||||
describe('CustomEdgeLinearGradientRender', () => {
|
||||
it('should render gradient definition with the provided id and positions', () => {
|
||||
const { container } = render(
|
||||
<svg>
|
||||
<CustomEdgeLinearGradientRender
|
||||
id="edge-gradient"
|
||||
startColor="#123456"
|
||||
stopColor="#abcdef"
|
||||
position={{
|
||||
x1: 10,
|
||||
y1: 20,
|
||||
x2: 30,
|
||||
y2: 40,
|
||||
}}
|
||||
/>
|
||||
</svg>,
|
||||
)
|
||||
|
||||
const gradient = container.querySelector('linearGradient')
|
||||
expect(gradient).toHaveAttribute('id', 'edge-gradient')
|
||||
expect(gradient).toHaveAttribute('gradientUnits', 'userSpaceOnUse')
|
||||
expect(gradient).toHaveAttribute('x1', '10')
|
||||
expect(gradient).toHaveAttribute('y1', '20')
|
||||
expect(gradient).toHaveAttribute('x2', '30')
|
||||
expect(gradient).toHaveAttribute('y2', '40')
|
||||
})
|
||||
|
||||
it('should render start and stop colors at both ends of the gradient', () => {
|
||||
const { container } = render(
|
||||
<svg>
|
||||
<CustomEdgeLinearGradientRender
|
||||
id="gradient-colors"
|
||||
startColor="#111111"
|
||||
stopColor="#222222"
|
||||
position={{
|
||||
x1: 0,
|
||||
y1: 0,
|
||||
x2: 100,
|
||||
y2: 100,
|
||||
}}
|
||||
/>
|
||||
</svg>,
|
||||
)
|
||||
|
||||
const stops = container.querySelectorAll('stop')
|
||||
expect(stops).toHaveLength(2)
|
||||
expect(stops[0]).toHaveAttribute('offset', '0%')
|
||||
expect(stops[0].getAttribute('style')).toContain('stop-color: rgb(17, 17, 17)')
|
||||
expect(stops[0].getAttribute('style')).toContain('stop-opacity: 1')
|
||||
expect(stops[1]).toHaveAttribute('offset', '100%')
|
||||
expect(stops[1].getAttribute('style')).toContain('stop-color: rgb(34, 34, 34)')
|
||||
expect(stops[1].getAttribute('style')).toContain('stop-opacity: 1')
|
||||
})
|
||||
})
|
||||
@@ -1,127 +0,0 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DSLExportConfirmModal from '../dsl-export-confirm-modal'
|
||||
|
||||
const envList = [
|
||||
{
|
||||
id: 'env-1',
|
||||
name: 'SECRET_TOKEN',
|
||||
value: 'masked-value',
|
||||
value_type: 'secret' as const,
|
||||
description: 'secret token',
|
||||
},
|
||||
]
|
||||
|
||||
const multiEnvList = [
|
||||
...envList,
|
||||
{
|
||||
id: 'env-2',
|
||||
name: 'SERVICE_KEY',
|
||||
value: 'another-secret',
|
||||
value_type: 'secret' as const,
|
||||
description: 'service key',
|
||||
},
|
||||
]
|
||||
|
||||
describe('DSLExportConfirmModal', () => {
|
||||
it('should render environment rows and close when cancel is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfirm = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
render(
|
||||
<DSLExportConfirmModal
|
||||
envList={envList}
|
||||
onConfirm={onConfirm}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.getByText('SECRET_TOKEN')).toBeInTheDocument()
|
||||
expect(screen.getByText('masked-value')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
expect(onConfirm).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should confirm with exportSecrets=false by default', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfirm = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
render(
|
||||
<DSLExportConfirmModal
|
||||
envList={envList}
|
||||
onConfirm={onConfirm}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'workflow.env.export.ignore' }))
|
||||
|
||||
expect(onConfirm).toHaveBeenCalledWith(false)
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should confirm with exportSecrets=true after toggling the checkbox', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfirm = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
render(
|
||||
<DSLExportConfirmModal
|
||||
envList={envList}
|
||||
onConfirm={onConfirm}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('checkbox'))
|
||||
await user.click(screen.getByRole('button', { name: 'workflow.env.export.export' }))
|
||||
|
||||
expect(onConfirm).toHaveBeenCalledWith(true)
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should also toggle exportSecrets when the label text is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onConfirm = vi.fn()
|
||||
const onClose = vi.fn()
|
||||
|
||||
render(
|
||||
<DSLExportConfirmModal
|
||||
envList={envList}
|
||||
onConfirm={onConfirm}
|
||||
onClose={onClose}
|
||||
/>,
|
||||
)
|
||||
|
||||
await user.click(screen.getByText('workflow.env.export.checkbox'))
|
||||
await user.click(screen.getByRole('button', { name: 'workflow.env.export.export' }))
|
||||
|
||||
expect(onConfirm).toHaveBeenCalledWith(true)
|
||||
expect(onClose).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render border separators for all rows except the last one', () => {
|
||||
render(
|
||||
<DSLExportConfirmModal
|
||||
envList={multiEnvList}
|
||||
onConfirm={vi.fn()}
|
||||
onClose={vi.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
const firstNameCell = screen.getByText('SECRET_TOKEN').closest('td')
|
||||
const lastNameCell = screen.getByText('SERVICE_KEY').closest('td')
|
||||
const firstValueCell = screen.getByText('masked-value').closest('td')
|
||||
const lastValueCell = screen.getByText('another-secret').closest('td')
|
||||
|
||||
expect(firstNameCell).toHaveClass('border-b')
|
||||
expect(firstValueCell).toHaveClass('border-b')
|
||||
expect(lastNameCell).not.toHaveClass('border-b')
|
||||
expect(lastValueCell).not.toHaveClass('border-b')
|
||||
})
|
||||
})
|
||||
@@ -1,193 +0,0 @@
|
||||
import type { InputVar } from '../types'
|
||||
import type { PromptVariable } from '@/models/debug'
|
||||
import { screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import ReactFlow, { ReactFlowProvider, useNodes } from 'reactflow'
|
||||
import Features from '../features'
|
||||
import { InputVarType } from '../types'
|
||||
import { createStartNode } from './fixtures'
|
||||
import { renderWorkflowComponent } from './workflow-test-env'
|
||||
|
||||
const mockHandleSyncWorkflowDraft = vi.fn()
|
||||
const mockHandleAddVariable = vi.fn()
|
||||
|
||||
let mockIsChatMode = true
|
||||
let mockNodesReadOnly = false
|
||||
|
||||
vi.mock('../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../hooks')>('../hooks')
|
||||
return {
|
||||
...actual,
|
||||
useIsChatMode: () => mockIsChatMode,
|
||||
useNodesReadOnly: () => ({
|
||||
nodesReadOnly: mockNodesReadOnly,
|
||||
}),
|
||||
useNodesSyncDraft: () => ({
|
||||
handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../nodes/start/use-config', () => ({
|
||||
default: () => ({
|
||||
handleAddVariable: mockHandleAddVariable,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/features/new-feature-panel', () => ({
|
||||
default: ({
|
||||
show,
|
||||
isChatMode,
|
||||
disabled,
|
||||
onChange,
|
||||
onClose,
|
||||
onAutoAddPromptVariable,
|
||||
workflowVariables,
|
||||
}: {
|
||||
show: boolean
|
||||
isChatMode: boolean
|
||||
disabled: boolean
|
||||
onChange: () => void
|
||||
onClose: () => void
|
||||
onAutoAddPromptVariable: (variables: PromptVariable[]) => void
|
||||
workflowVariables: InputVar[]
|
||||
}) => {
|
||||
if (!show)
|
||||
return null
|
||||
|
||||
return (
|
||||
<section aria-label="new feature panel">
|
||||
<div>{isChatMode ? 'chat mode' : 'completion mode'}</div>
|
||||
<div>{disabled ? 'panel disabled' : 'panel enabled'}</div>
|
||||
<ul aria-label="workflow variables">
|
||||
{workflowVariables.map(variable => (
|
||||
<li key={variable.variable}>
|
||||
{`${variable.label}:${variable.variable}`}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
<button type="button" onClick={onChange}>open features</button>
|
||||
<button type="button" onClick={onClose}>close features</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onAutoAddPromptVariable([{
|
||||
key: 'opening_statement',
|
||||
name: 'Opening Statement',
|
||||
type: 'string',
|
||||
max_length: 200,
|
||||
required: true,
|
||||
}])}
|
||||
>
|
||||
add required variable
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onAutoAddPromptVariable([{
|
||||
key: 'optional_statement',
|
||||
name: 'Optional Statement',
|
||||
type: 'string',
|
||||
max_length: 120,
|
||||
}])}
|
||||
>
|
||||
add optional variable
|
||||
</button>
|
||||
</section>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const startNode = createStartNode({
|
||||
id: 'start-node',
|
||||
data: {
|
||||
variables: [{ variable: 'existing_variable', label: 'Existing Variable' }],
|
||||
},
|
||||
})
|
||||
|
||||
const DelayedFeatures = () => {
|
||||
const nodes = useNodes()
|
||||
|
||||
if (!nodes.length)
|
||||
return null
|
||||
|
||||
return <Features />
|
||||
}
|
||||
|
||||
const renderFeatures = (options?: Parameters<typeof renderWorkflowComponent>[1]) => {
|
||||
return renderWorkflowComponent(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow nodes={[startNode]} edges={[]} fitView />
|
||||
<DelayedFeatures />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
describe('Features', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsChatMode = true
|
||||
mockNodesReadOnly = false
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should pass workflow context to the feature panel', () => {
|
||||
renderFeatures()
|
||||
|
||||
expect(screen.getByText('chat mode')).toBeInTheDocument()
|
||||
expect(screen.getByText('panel enabled')).toBeInTheDocument()
|
||||
expect(screen.getByRole('list', { name: 'workflow variables' })).toHaveTextContent('Existing Variable:existing_variable')
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should sync the draft and open the workflow feature panel when users change features', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { store } = renderFeatures()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'open features' }))
|
||||
|
||||
expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1)
|
||||
expect(store.getState().showFeaturesPanel).toBe(true)
|
||||
})
|
||||
|
||||
it('should close the workflow feature panel and transform required prompt variables', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { store } = renderFeatures({
|
||||
initialStoreState: {
|
||||
showFeaturesPanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'close features' }))
|
||||
expect(store.getState().showFeaturesPanel).toBe(false)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'add required variable' }))
|
||||
expect(mockHandleAddVariable).toHaveBeenCalledWith({
|
||||
variable: 'opening_statement',
|
||||
label: 'Opening Statement',
|
||||
type: InputVarType.textInput,
|
||||
max_length: 200,
|
||||
required: true,
|
||||
options: [],
|
||||
})
|
||||
})
|
||||
|
||||
it('should default prompt variables to optional when required is omitted', async () => {
|
||||
const user = userEvent.setup()
|
||||
|
||||
renderFeatures()
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'add optional variable' }))
|
||||
expect(mockHandleAddVariable).toHaveBeenCalledWith({
|
||||
variable: 'optional_statement',
|
||||
label: 'Optional Statement',
|
||||
type: InputVarType.textInput,
|
||||
max_length: 120,
|
||||
required: false,
|
||||
options: [],
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -16,8 +16,8 @@ import * as React from 'react'
|
||||
type MockNode = {
|
||||
id: string
|
||||
position: { x: number, y: number }
|
||||
width?: number | null
|
||||
height?: number | null
|
||||
width?: number
|
||||
height?: number
|
||||
parentId?: string
|
||||
data: Record<string, unknown>
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import SyncingDataModal from '../syncing-data-modal'
|
||||
import { renderWorkflowComponent } from './workflow-test-env'
|
||||
|
||||
describe('SyncingDataModal', () => {
|
||||
it('should not render when workflow draft syncing is disabled', () => {
|
||||
const { container } = renderWorkflowComponent(<SyncingDataModal />)
|
||||
|
||||
expect(container).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('should render the fullscreen overlay when workflow draft syncing is enabled', () => {
|
||||
const { container } = renderWorkflowComponent(<SyncingDataModal />, {
|
||||
initialStoreState: {
|
||||
isSyncingWorkflowDraft: true,
|
||||
},
|
||||
})
|
||||
|
||||
const overlay = container.firstElementChild
|
||||
expect(overlay).toHaveClass('absolute', 'inset-0')
|
||||
expect(overlay).toHaveClass('z-[9999]')
|
||||
})
|
||||
})
|
||||
@@ -1,108 +0,0 @@
|
||||
import type * as React from 'react'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import useCheckVerticalScrollbar from '../use-check-vertical-scrollbar'
|
||||
|
||||
const resizeObserve = vi.fn()
|
||||
const resizeDisconnect = vi.fn()
|
||||
const mutationObserve = vi.fn()
|
||||
const mutationDisconnect = vi.fn()
|
||||
|
||||
let resizeCallback: ResizeObserverCallback | null = null
|
||||
let mutationCallback: MutationCallback | null = null
|
||||
|
||||
class MockResizeObserver implements ResizeObserver {
|
||||
observe = resizeObserve
|
||||
unobserve = vi.fn()
|
||||
disconnect = resizeDisconnect
|
||||
|
||||
constructor(callback: ResizeObserverCallback) {
|
||||
resizeCallback = callback
|
||||
}
|
||||
}
|
||||
|
||||
class MockMutationObserver implements MutationObserver {
|
||||
observe = mutationObserve
|
||||
disconnect = mutationDisconnect
|
||||
takeRecords = vi.fn(() => [])
|
||||
|
||||
constructor(callback: MutationCallback) {
|
||||
mutationCallback = callback
|
||||
}
|
||||
}
|
||||
|
||||
const setElementHeights = (element: HTMLElement, scrollHeight: number, clientHeight: number) => {
|
||||
Object.defineProperty(element, 'scrollHeight', {
|
||||
configurable: true,
|
||||
value: scrollHeight,
|
||||
})
|
||||
Object.defineProperty(element, 'clientHeight', {
|
||||
configurable: true,
|
||||
value: clientHeight,
|
||||
})
|
||||
}
|
||||
|
||||
describe('useCheckVerticalScrollbar', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
resizeCallback = null
|
||||
mutationCallback = null
|
||||
vi.stubGlobal('ResizeObserver', MockResizeObserver)
|
||||
vi.stubGlobal('MutationObserver', MockMutationObserver)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals()
|
||||
})
|
||||
|
||||
it('should return false when the element ref is empty', () => {
|
||||
const ref = { current: null } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() => useCheckVerticalScrollbar(ref))
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
expect(resizeObserve).not.toHaveBeenCalled()
|
||||
expect(mutationObserve).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should detect the initial scrollbar state and react to observer updates', () => {
|
||||
const element = document.createElement('div')
|
||||
setElementHeights(element, 200, 100)
|
||||
const ref = { current: element } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() => useCheckVerticalScrollbar(ref))
|
||||
|
||||
expect(result.current).toBe(true)
|
||||
expect(resizeObserve).toHaveBeenCalledWith(element)
|
||||
expect(mutationObserve).toHaveBeenCalledWith(element, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
characterData: true,
|
||||
})
|
||||
|
||||
setElementHeights(element, 100, 100)
|
||||
act(() => {
|
||||
resizeCallback?.([] as ResizeObserverEntry[], new MockResizeObserver(() => {}))
|
||||
})
|
||||
|
||||
expect(result.current).toBe(false)
|
||||
|
||||
setElementHeights(element, 180, 100)
|
||||
act(() => {
|
||||
mutationCallback?.([] as MutationRecord[], new MockMutationObserver(() => {}))
|
||||
})
|
||||
|
||||
expect(result.current).toBe(true)
|
||||
})
|
||||
|
||||
it('should disconnect observers on unmount', () => {
|
||||
const element = document.createElement('div')
|
||||
setElementHeights(element, 120, 100)
|
||||
const ref = { current: element } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { unmount } = renderHook(() => useCheckVerticalScrollbar(ref))
|
||||
unmount()
|
||||
|
||||
expect(resizeDisconnect).toHaveBeenCalledTimes(1)
|
||||
expect(mutationDisconnect).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,103 +0,0 @@
|
||||
import type * as React from 'react'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import useStickyScroll, { ScrollPosition } from '../use-sticky-scroll'
|
||||
|
||||
const setRect = (element: HTMLElement, top: number, height: number) => {
|
||||
element.getBoundingClientRect = vi.fn(() => new DOMRect(0, top, 100, height))
|
||||
}
|
||||
|
||||
describe('useStickyScroll', () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
const runScroll = (handleScroll: () => void) => {
|
||||
act(() => {
|
||||
handleScroll()
|
||||
vi.advanceTimersByTime(120)
|
||||
})
|
||||
}
|
||||
|
||||
it('should keep the default state when refs are missing', () => {
|
||||
const wrapElemRef = { current: null } as React.RefObject<HTMLElement | null>
|
||||
const nextToStickyELemRef = { current: null } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useStickyScroll({
|
||||
wrapElemRef,
|
||||
nextToStickyELemRef,
|
||||
}),
|
||||
)
|
||||
|
||||
runScroll(result.current.handleScroll)
|
||||
|
||||
expect(result.current.scrollPosition).toBe(ScrollPosition.belowTheWrap)
|
||||
})
|
||||
|
||||
it('should mark the sticky element as below the wrapper when it is outside the visible area', () => {
|
||||
const wrapElement = document.createElement('div')
|
||||
const nextElement = document.createElement('div')
|
||||
setRect(wrapElement, 100, 200)
|
||||
setRect(nextElement, 320, 20)
|
||||
|
||||
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
|
||||
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useStickyScroll({
|
||||
wrapElemRef,
|
||||
nextToStickyELemRef,
|
||||
}),
|
||||
)
|
||||
|
||||
runScroll(result.current.handleScroll)
|
||||
|
||||
expect(result.current.scrollPosition).toBe(ScrollPosition.belowTheWrap)
|
||||
})
|
||||
|
||||
it('should mark the sticky element as showing when it is within the wrapper', () => {
|
||||
const wrapElement = document.createElement('div')
|
||||
const nextElement = document.createElement('div')
|
||||
setRect(wrapElement, 100, 200)
|
||||
setRect(nextElement, 220, 20)
|
||||
|
||||
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
|
||||
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useStickyScroll({
|
||||
wrapElemRef,
|
||||
nextToStickyELemRef,
|
||||
}),
|
||||
)
|
||||
|
||||
runScroll(result.current.handleScroll)
|
||||
|
||||
expect(result.current.scrollPosition).toBe(ScrollPosition.showing)
|
||||
})
|
||||
|
||||
it('should mark the sticky element as above the wrapper when it has scrolled past the top', () => {
|
||||
const wrapElement = document.createElement('div')
|
||||
const nextElement = document.createElement('div')
|
||||
setRect(wrapElement, 100, 200)
|
||||
setRect(nextElement, 90, 20)
|
||||
|
||||
const wrapElemRef = { current: wrapElement } as React.RefObject<HTMLElement | null>
|
||||
const nextToStickyELemRef = { current: nextElement } as React.RefObject<HTMLElement | null>
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useStickyScroll({
|
||||
wrapElemRef,
|
||||
nextToStickyELemRef,
|
||||
}),
|
||||
)
|
||||
|
||||
runScroll(result.current.handleScroll)
|
||||
|
||||
expect(result.current.scrollPosition).toBe(ScrollPosition.aboveTheWrap)
|
||||
})
|
||||
})
|
||||
@@ -1,108 +0,0 @@
|
||||
import type { DataSourceItem } from '../types'
|
||||
import { transformDataSourceToTool } from '../utils'
|
||||
|
||||
const createLocalizedText = (text: string) => ({
|
||||
en_US: text,
|
||||
zh_Hans: text,
|
||||
})
|
||||
|
||||
const createDataSourceItem = (overrides: Partial<DataSourceItem> = {}): DataSourceItem => ({
|
||||
plugin_id: 'plugin-1',
|
||||
plugin_unique_identifier: 'plugin-1@provider',
|
||||
provider: 'provider-a',
|
||||
declaration: {
|
||||
credentials_schema: [{ name: 'api_key' }],
|
||||
provider_type: 'hosted',
|
||||
identity: {
|
||||
author: 'Dify',
|
||||
description: createLocalizedText('Datasource provider'),
|
||||
icon: 'provider-icon',
|
||||
label: createLocalizedText('Provider A'),
|
||||
name: 'provider-a',
|
||||
tags: ['retrieval', 'storage'],
|
||||
},
|
||||
datasources: [
|
||||
{
|
||||
description: createLocalizedText('Search in documents'),
|
||||
identity: {
|
||||
author: 'Dify',
|
||||
label: createLocalizedText('Document Search'),
|
||||
name: 'document_search',
|
||||
provider: 'provider-a',
|
||||
},
|
||||
parameters: [{ name: 'query', type: 'string' }],
|
||||
output_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
result: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
is_authorized: true,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('transformDataSourceToTool', () => {
|
||||
it('should map datasource provider fields to tool shape', () => {
|
||||
const dataSourceItem = createDataSourceItem()
|
||||
|
||||
const result = transformDataSourceToTool(dataSourceItem)
|
||||
|
||||
expect(result).toMatchObject({
|
||||
id: 'plugin-1',
|
||||
provider: 'provider-a',
|
||||
name: 'provider-a',
|
||||
author: 'Dify',
|
||||
description: createLocalizedText('Datasource provider'),
|
||||
icon: 'provider-icon',
|
||||
label: createLocalizedText('Provider A'),
|
||||
type: 'hosted',
|
||||
allow_delete: true,
|
||||
is_authorized: true,
|
||||
is_team_authorization: true,
|
||||
labels: ['retrieval', 'storage'],
|
||||
plugin_id: 'plugin-1',
|
||||
plugin_unique_identifier: 'plugin-1@provider',
|
||||
credentialsSchema: [{ name: 'api_key' }],
|
||||
meta: { version: '' },
|
||||
})
|
||||
expect(result.team_credentials).toEqual({})
|
||||
expect(result.tools).toEqual([
|
||||
{
|
||||
name: 'document_search',
|
||||
author: 'Dify',
|
||||
label: createLocalizedText('Document Search'),
|
||||
description: createLocalizedText('Search in documents'),
|
||||
parameters: [{ name: 'query', type: 'string' }],
|
||||
labels: [],
|
||||
output_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
result: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should fallback to empty arrays when tags and credentials schema are missing', () => {
|
||||
const baseDataSourceItem = createDataSourceItem()
|
||||
const dataSourceItem = createDataSourceItem({
|
||||
declaration: {
|
||||
...baseDataSourceItem.declaration,
|
||||
credentials_schema: undefined as unknown as DataSourceItem['declaration']['credentials_schema'],
|
||||
identity: {
|
||||
...baseDataSourceItem.declaration.identity,
|
||||
tags: undefined as unknown as DataSourceItem['declaration']['identity']['tags'],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
const result = transformDataSourceToTool(dataSourceItem)
|
||||
|
||||
expect(result.labels).toEqual([])
|
||||
expect(result.credentialsSchema).toEqual([])
|
||||
})
|
||||
})
|
||||
@@ -1,57 +0,0 @@
|
||||
import { fireEvent, render } from '@testing-library/react'
|
||||
import ViewTypeSelect, { ViewType } from '../view-type-select'
|
||||
|
||||
const getViewOptions = (container: HTMLElement) => {
|
||||
const options = container.firstElementChild?.children
|
||||
if (!options || options.length !== 2)
|
||||
throw new Error('Expected two view options')
|
||||
return [options[0] as HTMLDivElement, options[1] as HTMLDivElement]
|
||||
}
|
||||
|
||||
describe('ViewTypeSelect', () => {
|
||||
it('should highlight the active view type', () => {
|
||||
const onChange = vi.fn()
|
||||
const { container } = render(
|
||||
<ViewTypeSelect
|
||||
viewType={ViewType.flat}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const [flatOption, treeOption] = getViewOptions(container)
|
||||
|
||||
expect(flatOption).toHaveClass('bg-components-segmented-control-item-active-bg')
|
||||
expect(treeOption).toHaveClass('cursor-pointer')
|
||||
})
|
||||
|
||||
it('should call onChange when switching to a different view type', () => {
|
||||
const onChange = vi.fn()
|
||||
const { container } = render(
|
||||
<ViewTypeSelect
|
||||
viewType={ViewType.flat}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const [, treeOption] = getViewOptions(container)
|
||||
fireEvent.click(treeOption)
|
||||
|
||||
expect(onChange).toHaveBeenCalledWith(ViewType.tree)
|
||||
expect(onChange).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should ignore clicks on the current view type', () => {
|
||||
const onChange = vi.fn()
|
||||
const { container } = render(
|
||||
<ViewTypeSelect
|
||||
viewType={ViewType.tree}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
)
|
||||
|
||||
const [, treeOption] = getViewOptions(container)
|
||||
fireEvent.click(treeOption)
|
||||
|
||||
expect(onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
const useCheckVerticalScrollbar = (ref: React.RefObject<HTMLElement | null>) => {
|
||||
const useCheckVerticalScrollbar = (ref: React.RefObject<HTMLElement>) => {
|
||||
const [hasVerticalScrollbar, setHasVerticalScrollbar] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import ChatVariableButton from '../chat-variable-button'
|
||||
|
||||
let mockTheme: 'light' | 'dark' = 'light'
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({
|
||||
theme: mockTheme,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('ChatVariableButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
it('opens the chat variable panel and closes the other workflow panels', () => {
|
||||
const { store } = renderWorkflowComponent(<ChatVariableButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showEnvPanel: true,
|
||||
showGlobalVariablePanel: true,
|
||||
showDebugAndPreviewPanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(store.getState().showChatVariablePanel).toBe(true)
|
||||
expect(store.getState().showEnvPanel).toBe(false)
|
||||
expect(store.getState().showGlobalVariablePanel).toBe(false)
|
||||
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
|
||||
})
|
||||
|
||||
it('applies the active dark theme styles when the chat variable panel is visible', () => {
|
||||
mockTheme = 'dark'
|
||||
renderWorkflowComponent(<ChatVariableButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showChatVariablePanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
|
||||
})
|
||||
|
||||
it('stays disabled without mutating panel state', () => {
|
||||
const { store } = renderWorkflowComponent(<ChatVariableButton disabled />, {
|
||||
initialStoreState: {
|
||||
showChatVariablePanel: false,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(store.getState().showChatVariablePanel).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -1,63 +0,0 @@
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import EditingTitle from '../editing-title'
|
||||
|
||||
const mockFormatTime = vi.fn()
|
||||
const mockFormatTimeFromNow = vi.fn()
|
||||
|
||||
vi.mock('@/hooks/use-timestamp', () => ({
|
||||
default: () => ({
|
||||
formatTime: mockFormatTime,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
useFormatTimeFromNow: () => ({
|
||||
formatTimeFromNow: mockFormatTimeFromNow,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('EditingTitle', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFormatTime.mockReturnValue('08:00:00')
|
||||
mockFormatTimeFromNow.mockReturnValue('2 hours ago')
|
||||
})
|
||||
|
||||
it('should render autosave, published time, and syncing status when the draft has metadata', () => {
|
||||
const { container } = renderWorkflowComponent(<EditingTitle />, {
|
||||
initialStoreState: {
|
||||
draftUpdatedAt: 1_710_000_000_000,
|
||||
publishedAt: 1_710_003_600_000,
|
||||
isSyncingWorkflowDraft: true,
|
||||
maximizeCanvas: true,
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatTime).toHaveBeenCalledWith(1_710_000_000, 'HH:mm:ss')
|
||||
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(1_710_003_600_000)
|
||||
expect(container.firstChild).toHaveClass('ml-2')
|
||||
expect(container).toHaveTextContent('workflow.common.autoSaved')
|
||||
expect(container).toHaveTextContent('08:00:00')
|
||||
expect(container).toHaveTextContent('workflow.common.published')
|
||||
expect(container).toHaveTextContent('2 hours ago')
|
||||
expect(container).toHaveTextContent('workflow.common.syncingData')
|
||||
})
|
||||
|
||||
it('should render unpublished status without autosave metadata when the workflow has not been published', () => {
|
||||
const { container } = renderWorkflowComponent(<EditingTitle />, {
|
||||
initialStoreState: {
|
||||
draftUpdatedAt: 0,
|
||||
publishedAt: 0,
|
||||
isSyncingWorkflowDraft: false,
|
||||
maximizeCanvas: false,
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatTime).not.toHaveBeenCalled()
|
||||
expect(mockFormatTimeFromNow).not.toHaveBeenCalled()
|
||||
expect(container.firstChild).not.toHaveClass('ml-2')
|
||||
expect(container).toHaveTextContent('workflow.common.unpublished')
|
||||
expect(container).not.toHaveTextContent('workflow.common.autoSaved')
|
||||
expect(container).not.toHaveTextContent('workflow.common.syncingData')
|
||||
})
|
||||
})
|
||||
@@ -1,68 +0,0 @@
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import EnvButton from '../env-button'
|
||||
|
||||
const mockCloseAllInputFieldPanels = vi.fn()
|
||||
let mockTheme: 'light' | 'dark' = 'light'
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({
|
||||
theme: mockTheme,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
|
||||
useInputFieldPanel: () => ({
|
||||
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('EnvButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
it('should open the environment panel and close the other panels when clicked', () => {
|
||||
const { store } = renderWorkflowComponent(<EnvButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showChatVariablePanel: true,
|
||||
showGlobalVariablePanel: true,
|
||||
showDebugAndPreviewPanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(store.getState().showEnvPanel).toBe(true)
|
||||
expect(store.getState().showChatVariablePanel).toBe(false)
|
||||
expect(store.getState().showGlobalVariablePanel).toBe(false)
|
||||
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
|
||||
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should apply the active dark theme styles when the environment panel is visible', () => {
|
||||
mockTheme = 'dark'
|
||||
renderWorkflowComponent(<EnvButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showEnvPanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
|
||||
})
|
||||
|
||||
it('should keep the button disabled when the disabled prop is true', () => {
|
||||
const { store } = renderWorkflowComponent(<EnvButton disabled />, {
|
||||
initialStoreState: {
|
||||
showEnvPanel: false,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(store.getState().showEnvPanel).toBe(false)
|
||||
expect(mockCloseAllInputFieldPanels).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -1,68 +0,0 @@
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import GlobalVariableButton from '../global-variable-button'
|
||||
|
||||
const mockCloseAllInputFieldPanels = vi.fn()
|
||||
let mockTheme: 'light' | 'dark' = 'light'
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({
|
||||
theme: mockTheme,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
|
||||
useInputFieldPanel: () => ({
|
||||
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('GlobalVariableButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
it('should open the global variable panel and close the other panels when clicked', () => {
|
||||
const { store } = renderWorkflowComponent(<GlobalVariableButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showEnvPanel: true,
|
||||
showChatVariablePanel: true,
|
||||
showDebugAndPreviewPanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(store.getState().showGlobalVariablePanel).toBe(true)
|
||||
expect(store.getState().showEnvPanel).toBe(false)
|
||||
expect(store.getState().showChatVariablePanel).toBe(false)
|
||||
expect(store.getState().showDebugAndPreviewPanel).toBe(false)
|
||||
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should apply the active dark theme styles when the global variable panel is visible', () => {
|
||||
mockTheme = 'dark'
|
||||
renderWorkflowComponent(<GlobalVariableButton disabled={false} />, {
|
||||
initialStoreState: {
|
||||
showGlobalVariablePanel: true,
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
|
||||
})
|
||||
|
||||
it('should keep the button disabled when the disabled prop is true', () => {
|
||||
const { store } = renderWorkflowComponent(<GlobalVariableButton disabled />, {
|
||||
initialStoreState: {
|
||||
showGlobalVariablePanel: false,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
expect(store.getState().showGlobalVariablePanel).toBe(false)
|
||||
expect(mockCloseAllInputFieldPanels).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -1,109 +0,0 @@
|
||||
import type { VersionHistory } from '@/types/workflow'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import { WorkflowVersion } from '../../types'
|
||||
import RestoringTitle from '../restoring-title'
|
||||
|
||||
const mockFormatTime = vi.fn()
|
||||
const mockFormatTimeFromNow = vi.fn()
|
||||
|
||||
vi.mock('@/hooks/use-timestamp', () => ({
|
||||
default: () => ({
|
||||
formatTime: mockFormatTime,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
useFormatTimeFromNow: () => ({
|
||||
formatTimeFromNow: mockFormatTimeFromNow,
|
||||
}),
|
||||
}))
|
||||
|
||||
const createVersion = (overrides: Partial<VersionHistory> = {}): VersionHistory => ({
|
||||
id: 'version-1',
|
||||
graph: {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
},
|
||||
created_at: 1_700_000_000,
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com',
|
||||
},
|
||||
hash: 'hash-1',
|
||||
updated_at: 1_700_000_100,
|
||||
updated_by: {
|
||||
id: 'user-2',
|
||||
name: 'Bob',
|
||||
email: 'bob@example.com',
|
||||
},
|
||||
tool_published: false,
|
||||
version: 'v1',
|
||||
marked_name: 'Release 1',
|
||||
marked_comment: '',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('RestoringTitle', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockFormatTime.mockReturnValue('09:30:00')
|
||||
mockFormatTimeFromNow.mockReturnValue('3 hours ago')
|
||||
})
|
||||
|
||||
it('should render draft metadata when the current version is a draft', () => {
|
||||
const currentVersion = createVersion({
|
||||
version: WorkflowVersion.Draft,
|
||||
})
|
||||
|
||||
const { container } = renderWorkflowComponent(<RestoringTitle />, {
|
||||
initialStoreState: {
|
||||
currentVersion,
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(currentVersion.updated_at * 1000)
|
||||
expect(mockFormatTime).toHaveBeenCalledWith(currentVersion.created_at, 'HH:mm:ss')
|
||||
expect(container).toHaveTextContent('workflow.versionHistory.currentDraft')
|
||||
expect(container).toHaveTextContent('workflow.common.viewOnly')
|
||||
expect(container).toHaveTextContent('workflow.common.unpublished')
|
||||
expect(container).toHaveTextContent('3 hours ago 09:30:00')
|
||||
expect(container).toHaveTextContent('Alice')
|
||||
})
|
||||
|
||||
it('should render published metadata and fallback version name when the marked name is empty', () => {
|
||||
const currentVersion = createVersion({
|
||||
marked_name: '',
|
||||
})
|
||||
|
||||
const { container } = renderWorkflowComponent(<RestoringTitle />, {
|
||||
initialStoreState: {
|
||||
currentVersion,
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatTimeFromNow).toHaveBeenCalledWith(currentVersion.created_at * 1000)
|
||||
expect(container).toHaveTextContent('workflow.versionHistory.defaultName')
|
||||
expect(container).toHaveTextContent('workflow.common.published')
|
||||
expect(container).toHaveTextContent('Alice')
|
||||
})
|
||||
|
||||
it('should render an empty creator name when the version creator name is missing', () => {
|
||||
const currentVersion = createVersion({
|
||||
created_by: {
|
||||
id: 'user-1',
|
||||
name: '',
|
||||
email: 'alice@example.com',
|
||||
},
|
||||
})
|
||||
|
||||
const { container } = renderWorkflowComponent(<RestoringTitle />, {
|
||||
initialStoreState: {
|
||||
currentVersion,
|
||||
},
|
||||
})
|
||||
|
||||
expect(container).toHaveTextContent('workflow.common.published')
|
||||
expect(container).not.toHaveTextContent('Alice')
|
||||
})
|
||||
})
|
||||
@@ -1,61 +0,0 @@
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import RunningTitle from '../running-title'
|
||||
|
||||
let mockIsChatMode = false
|
||||
const mockFormatWorkflowRunIdentifier = vi.fn()
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
useIsChatMode: () => mockIsChatMode,
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
formatWorkflowRunIdentifier: (finishedAt?: number) => mockFormatWorkflowRunIdentifier(finishedAt),
|
||||
}))
|
||||
|
||||
describe('RunningTitle', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsChatMode = false
|
||||
mockFormatWorkflowRunIdentifier.mockReturnValue(' (14:30:25)')
|
||||
})
|
||||
|
||||
it('should render the test run title in workflow mode', () => {
|
||||
const { container } = renderWorkflowComponent(<RunningTitle />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: {
|
||||
id: 'history-1',
|
||||
status: 'succeeded',
|
||||
finished_at: 1_700_000_000,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(1_700_000_000)
|
||||
expect(container).toHaveTextContent('Test Run (14:30:25)')
|
||||
expect(container).toHaveTextContent('workflow.common.viewOnly')
|
||||
})
|
||||
|
||||
it('should render the test chat title in chat mode', () => {
|
||||
mockIsChatMode = true
|
||||
|
||||
const { container } = renderWorkflowComponent(<RunningTitle />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: {
|
||||
id: 'history-2',
|
||||
status: 'running',
|
||||
finished_at: undefined,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(undefined)
|
||||
expect(container).toHaveTextContent('Test Chat (14:30:25)')
|
||||
})
|
||||
|
||||
it('should handle missing workflow history data', () => {
|
||||
const { container } = renderWorkflowComponent(<RunningTitle />)
|
||||
|
||||
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(undefined)
|
||||
expect(container).toHaveTextContent('Test Run (14:30:25)')
|
||||
})
|
||||
})
|
||||
@@ -1,53 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { createNode } from '../../__tests__/fixtures'
|
||||
import { resetReactFlowMockState, rfState } from '../../__tests__/reactflow-mock-state'
|
||||
import ScrollToSelectedNodeButton from '../scroll-to-selected-node-button'
|
||||
|
||||
const mockScrollToWorkflowNode = vi.fn()
|
||||
|
||||
vi.mock('reactflow', async () =>
|
||||
(await import('../../__tests__/reactflow-mock-state')).createReactFlowModuleMock())
|
||||
|
||||
vi.mock('../../utils/node-navigation', () => ({
|
||||
scrollToWorkflowNode: (nodeId: string) => mockScrollToWorkflowNode(nodeId),
|
||||
}))
|
||||
|
||||
describe('ScrollToSelectedNodeButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
resetReactFlowMockState()
|
||||
})
|
||||
|
||||
it('should render nothing when there is no selected node', () => {
|
||||
rfState.nodes = [
|
||||
createNode({
|
||||
id: 'node-1',
|
||||
data: { selected: false },
|
||||
}),
|
||||
]
|
||||
|
||||
const { container } = render(<ScrollToSelectedNodeButton />)
|
||||
|
||||
expect(container.firstChild).toBeNull()
|
||||
})
|
||||
|
||||
it('should render the action and scroll to the selected node when clicked', () => {
|
||||
rfState.nodes = [
|
||||
createNode({
|
||||
id: 'node-1',
|
||||
data: { selected: false },
|
||||
}),
|
||||
createNode({
|
||||
id: 'node-2',
|
||||
data: { selected: true },
|
||||
}),
|
||||
]
|
||||
|
||||
render(<ScrollToSelectedNodeButton />)
|
||||
|
||||
fireEvent.click(screen.getByText('workflow.panel.scrollToSelectedNode'))
|
||||
|
||||
expect(mockScrollToWorkflowNode).toHaveBeenCalledWith('node-2')
|
||||
expect(mockScrollToWorkflowNode).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,118 +0,0 @@
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import UndoRedo from '../undo-redo'
|
||||
|
||||
type TemporalSnapshot = {
|
||||
pastStates: unknown[]
|
||||
futureStates: unknown[]
|
||||
}
|
||||
|
||||
const mockUnsubscribe = vi.fn()
|
||||
const mockTemporalSubscribe = vi.fn()
|
||||
const mockHandleUndo = vi.fn()
|
||||
const mockHandleRedo = vi.fn()
|
||||
|
||||
let latestTemporalListener: ((state: TemporalSnapshot) => void) | undefined
|
||||
let mockNodesReadOnly = false
|
||||
|
||||
vi.mock('@/app/components/workflow/header/view-workflow-history', () => ({
|
||||
default: () => <div data-testid="view-workflow-history" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useNodesReadOnly: () => ({
|
||||
nodesReadOnly: mockNodesReadOnly,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/workflow-history-store', () => ({
|
||||
useWorkflowHistoryStore: () => ({
|
||||
store: {
|
||||
temporal: {
|
||||
subscribe: mockTemporalSubscribe,
|
||||
},
|
||||
},
|
||||
shortcutsEnabled: true,
|
||||
setShortcutsEnabled: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/divider', () => ({
|
||||
default: () => <div data-testid="divider" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/operator/tip-popup', () => ({
|
||||
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
describe('UndoRedo', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockNodesReadOnly = false
|
||||
latestTemporalListener = undefined
|
||||
mockTemporalSubscribe.mockImplementation((listener: (state: TemporalSnapshot) => void) => {
|
||||
latestTemporalListener = listener
|
||||
return mockUnsubscribe
|
||||
})
|
||||
})
|
||||
|
||||
it('enables undo and redo when history exists and triggers the callbacks', () => {
|
||||
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
|
||||
|
||||
act(() => {
|
||||
latestTemporalListener?.({
|
||||
pastStates: [{}],
|
||||
futureStates: [{}],
|
||||
})
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.undo' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.redo' }))
|
||||
|
||||
expect(mockHandleUndo).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleRedo).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('keeps the buttons disabled before history is available', () => {
|
||||
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
|
||||
const undoButton = screen.getByRole('button', { name: 'workflow.common.undo' })
|
||||
const redoButton = screen.getByRole('button', { name: 'workflow.common.redo' })
|
||||
|
||||
fireEvent.click(undoButton)
|
||||
fireEvent.click(redoButton)
|
||||
|
||||
expect(undoButton).toBeDisabled()
|
||||
expect(redoButton).toBeDisabled()
|
||||
expect(mockHandleUndo).not.toHaveBeenCalled()
|
||||
expect(mockHandleRedo).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not trigger callbacks when the canvas is read only', () => {
|
||||
mockNodesReadOnly = true
|
||||
render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
|
||||
const undoButton = screen.getByRole('button', { name: 'workflow.common.undo' })
|
||||
const redoButton = screen.getByRole('button', { name: 'workflow.common.redo' })
|
||||
|
||||
act(() => {
|
||||
latestTemporalListener?.({
|
||||
pastStates: [{}],
|
||||
futureStates: [{}],
|
||||
})
|
||||
})
|
||||
|
||||
fireEvent.click(undoButton)
|
||||
fireEvent.click(redoButton)
|
||||
|
||||
expect(undoButton).toBeDisabled()
|
||||
expect(redoButton).toBeDisabled()
|
||||
expect(mockHandleUndo).not.toHaveBeenCalled()
|
||||
expect(mockHandleRedo).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('unsubscribes from the temporal store on unmount', () => {
|
||||
const { unmount } = render(<UndoRedo handleRedo={mockHandleRedo} handleUndo={mockHandleUndo} />)
|
||||
|
||||
unmount()
|
||||
|
||||
expect(mockUnsubscribe).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -1,68 +0,0 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import VersionHistoryButton from '../version-history-button'
|
||||
|
||||
let mockTheme: 'light' | 'dark' = 'light'
|
||||
|
||||
vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({
|
||||
theme: mockTheme,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../../utils')>()
|
||||
return {
|
||||
...actual,
|
||||
getKeyboardKeyCodeBySystem: () => 'ctrl',
|
||||
}
|
||||
})
|
||||
|
||||
describe('VersionHistoryButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockTheme = 'light'
|
||||
})
|
||||
|
||||
it('should call onClick when the button is clicked', () => {
|
||||
const onClick = vi.fn()
|
||||
render(<VersionHistoryButton onClick={onClick} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button'))
|
||||
|
||||
expect(onClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should trigger onClick when the version history shortcut is pressed', () => {
|
||||
const onClick = vi.fn()
|
||||
render(<VersionHistoryButton onClick={onClick} />)
|
||||
|
||||
const keyboardEvent = new KeyboardEvent('keydown', {
|
||||
key: 'H',
|
||||
ctrlKey: true,
|
||||
shiftKey: true,
|
||||
bubbles: true,
|
||||
cancelable: true,
|
||||
})
|
||||
Object.defineProperty(keyboardEvent, 'keyCode', { value: 72 })
|
||||
Object.defineProperty(keyboardEvent, 'which', { value: 72 })
|
||||
window.dispatchEvent(keyboardEvent)
|
||||
|
||||
expect(keyboardEvent.defaultPrevented).toBe(true)
|
||||
expect(onClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should render the tooltip popup content on hover', async () => {
|
||||
render(<VersionHistoryButton onClick={vi.fn()} />)
|
||||
|
||||
fireEvent.mouseEnter(screen.getByRole('button'))
|
||||
|
||||
expect(await screen.findByText('workflow.common.versionHistory')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should apply dark theme styles when the theme is dark', () => {
|
||||
mockTheme = 'dark'
|
||||
render(<VersionHistoryButton onClick={vi.fn()} />)
|
||||
|
||||
expect(screen.getByRole('button')).toHaveClass('border-black/5', 'bg-white/10', 'backdrop-blur-sm')
|
||||
})
|
||||
})
|
||||
@@ -1,276 +0,0 @@
|
||||
import type { WorkflowRunHistory, WorkflowRunHistoryResponse } from '@/types/workflow'
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import { ControlMode, WorkflowRunningStatus } from '../../types'
|
||||
import ViewHistory from '../view-history'
|
||||
|
||||
const mockUseWorkflowRunHistory = vi.fn()
|
||||
const mockFormatTimeFromNow = vi.fn((value: number) => `from-now:${value}`)
|
||||
const mockCloseAllInputFieldPanels = vi.fn()
|
||||
const mockHandleNodesCancelSelected = vi.fn()
|
||||
const mockHandleCancelDebugAndPreviewPanel = vi.fn()
|
||||
const mockFormatWorkflowRunIdentifier = vi.fn((finishedAt?: number, status?: string) => ` (${status || finishedAt || 'unknown'})`)
|
||||
|
||||
let mockIsChatMode = false
|
||||
|
||||
vi.mock('../../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../../hooks')>('../../hooks')
|
||||
return {
|
||||
...actual,
|
||||
useIsChatMode: () => mockIsChatMode,
|
||||
useNodesInteractions: () => ({
|
||||
handleNodesCancelSelected: mockHandleNodesCancelSelected,
|
||||
}),
|
||||
useWorkflowInteractions: () => ({
|
||||
handleCancelDebugAndPreviewPanel: mockHandleCancelDebugAndPreviewPanel,
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/service/use-workflow', () => ({
|
||||
useWorkflowRunHistory: (url?: string, enabled?: boolean) => mockUseWorkflowRunHistory(url, enabled),
|
||||
}))
|
||||
|
||||
vi.mock('@/hooks/use-format-time-from-now', () => ({
|
||||
useFormatTimeFromNow: () => ({
|
||||
formatTimeFromNow: mockFormatTimeFromNow,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/rag-pipeline/hooks', () => ({
|
||||
useInputFieldPanel: () => ({
|
||||
closeAllInputFieldPanels: mockCloseAllInputFieldPanels,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/loading', () => ({
|
||||
default: () => <div data-testid="loading" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/portal-to-follow-elem', () => {
|
||||
const PortalContext = React.createContext({ open: false })
|
||||
|
||||
return {
|
||||
PortalToFollowElem: ({
|
||||
children,
|
||||
open,
|
||||
}: {
|
||||
children?: React.ReactNode
|
||||
open: boolean
|
||||
}) => <PortalContext.Provider value={{ open }}>{children}</PortalContext.Provider>,
|
||||
PortalToFollowElemTrigger: ({
|
||||
children,
|
||||
onClick,
|
||||
}: {
|
||||
children?: React.ReactNode
|
||||
onClick?: () => void
|
||||
}) => <div data-testid="portal-trigger" onClick={onClick}>{children}</div>,
|
||||
PortalToFollowElemContent: ({
|
||||
children,
|
||||
}: {
|
||||
children?: React.ReactNode
|
||||
}) => {
|
||||
const { open } = React.useContext(PortalContext)
|
||||
return open ? <div data-testid="portal-content">{children}</div> : null
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../../utils', async () => {
|
||||
const actual = await vi.importActual<typeof import('../../utils')>('../../utils')
|
||||
return {
|
||||
...actual,
|
||||
formatWorkflowRunIdentifier: (finishedAt?: number, status?: string) => mockFormatWorkflowRunIdentifier(finishedAt, status),
|
||||
}
|
||||
})
|
||||
|
||||
const createHistoryItem = (overrides: Partial<WorkflowRunHistory> = {}): WorkflowRunHistory => ({
|
||||
id: 'run-1',
|
||||
version: 'v1',
|
||||
graph: {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
},
|
||||
inputs: {},
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
outputs: {},
|
||||
elapsed_time: 1,
|
||||
total_tokens: 2,
|
||||
total_steps: 3,
|
||||
created_at: 100,
|
||||
finished_at: 120,
|
||||
created_by_account: {
|
||||
id: 'user-1',
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com',
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('ViewHistory', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsChatMode = false
|
||||
mockUseWorkflowRunHistory.mockReturnValue({
|
||||
data: { data: [] } satisfies WorkflowRunHistoryResponse,
|
||||
isLoading: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('defers fetching until the history popup is opened and renders the empty state', () => {
|
||||
renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
|
||||
hooksStoreProps: {
|
||||
handleBackupDraft: vi.fn(),
|
||||
},
|
||||
})
|
||||
|
||||
expect(mockUseWorkflowRunHistory).toHaveBeenCalledWith('/history', false)
|
||||
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
|
||||
|
||||
expect(mockUseWorkflowRunHistory).toHaveBeenLastCalledWith('/history', true)
|
||||
expect(screen.getByText('workflow.common.notRunning')).toBeInTheDocument()
|
||||
expect(screen.getByText('workflow.common.showRunHistory')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders the icon trigger variant and loading state, and clears log modals on trigger click', () => {
|
||||
const onClearLogAndMessageModal = vi.fn()
|
||||
mockUseWorkflowRunHistory.mockReturnValue({
|
||||
data: { data: [] } satisfies WorkflowRunHistoryResponse,
|
||||
isLoading: true,
|
||||
})
|
||||
|
||||
renderWorkflowComponent(
|
||||
<ViewHistory
|
||||
historyUrl="/history"
|
||||
onClearLogAndMessageModal={onClearLogAndMessageModal}
|
||||
/>,
|
||||
{
|
||||
hooksStoreProps: {
|
||||
handleBackupDraft: vi.fn(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.viewRunHistory' }))
|
||||
|
||||
expect(onClearLogAndMessageModal).toHaveBeenCalledTimes(1)
|
||||
expect(screen.getByTestId('loading')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders workflow run history items and updates the workflow store when one is selected', () => {
|
||||
const handleBackupDraft = vi.fn()
|
||||
const pausedRun = createHistoryItem({
|
||||
id: 'run-paused',
|
||||
status: WorkflowRunningStatus.Paused,
|
||||
created_at: 101,
|
||||
finished_at: 0,
|
||||
})
|
||||
const failedRun = createHistoryItem({
|
||||
id: 'run-failed',
|
||||
status: WorkflowRunningStatus.Failed,
|
||||
created_at: 102,
|
||||
finished_at: 130,
|
||||
})
|
||||
const succeededRun = createHistoryItem({
|
||||
id: 'run-succeeded',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
created_at: 103,
|
||||
finished_at: 140,
|
||||
})
|
||||
|
||||
mockUseWorkflowRunHistory.mockReturnValue({
|
||||
data: {
|
||||
data: [pausedRun, failedRun, succeededRun],
|
||||
} satisfies WorkflowRunHistoryResponse,
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
const { store } = renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: failedRun,
|
||||
showInputsPanel: true,
|
||||
showEnvPanel: true,
|
||||
controlMode: ControlMode.Pointer,
|
||||
},
|
||||
hooksStoreProps: {
|
||||
handleBackupDraft,
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
|
||||
|
||||
expect(screen.getByText('Test Run (paused)')).toBeInTheDocument()
|
||||
expect(screen.getByText('Test Run (failed)')).toBeInTheDocument()
|
||||
expect(screen.getByText('Test Run (succeeded)')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByText('Test Run (succeeded)'))
|
||||
|
||||
expect(store.getState().historyWorkflowData).toEqual(succeededRun)
|
||||
expect(store.getState().showInputsPanel).toBe(false)
|
||||
expect(store.getState().showEnvPanel).toBe(false)
|
||||
expect(store.getState().controlMode).toBe(ControlMode.Hand)
|
||||
expect(mockCloseAllInputFieldPanels).toHaveBeenCalledTimes(1)
|
||||
expect(handleBackupDraft).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleNodesCancelSelected).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleCancelDebugAndPreviewPanel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('renders chat history labels without workflow status icons in chat mode', () => {
|
||||
mockIsChatMode = true
|
||||
const chatRun = createHistoryItem({
|
||||
id: 'chat-run',
|
||||
status: WorkflowRunningStatus.Failed,
|
||||
})
|
||||
|
||||
mockUseWorkflowRunHistory.mockReturnValue({
|
||||
data: {
|
||||
data: [chatRun],
|
||||
} satisfies WorkflowRunHistoryResponse,
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
renderWorkflowComponent(<ViewHistory historyUrl="/history" withText />, {
|
||||
hooksStoreProps: {
|
||||
handleBackupDraft: vi.fn(),
|
||||
},
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
|
||||
|
||||
expect(screen.getByText('Test Chat (failed)')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('closes the popup from the close button and clears log modals', () => {
|
||||
const onClearLogAndMessageModal = vi.fn()
|
||||
mockUseWorkflowRunHistory.mockReturnValue({
|
||||
data: { data: [] } satisfies WorkflowRunHistoryResponse,
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
renderWorkflowComponent(
|
||||
<ViewHistory
|
||||
historyUrl="/history"
|
||||
withText
|
||||
onClearLogAndMessageModal={onClearLogAndMessageModal}
|
||||
/>,
|
||||
{
|
||||
hooksStoreProps: {
|
||||
handleBackupDraft: vi.fn(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.common.showRunHistory' }))
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' }))
|
||||
|
||||
expect(onClearLogAndMessageModal).toHaveBeenCalledTimes(1)
|
||||
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { FC } from 'react'
|
||||
import type { CommonNodeType } from '../types'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useNodes } from 'reactflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@@ -10,15 +11,21 @@ const ScrollToSelectedNodeButton: FC = () => {
|
||||
const nodes = useNodes<CommonNodeType>()
|
||||
const selectedNode = nodes.find(node => node.data.selected)
|
||||
|
||||
const handleScrollToSelectedNode = useCallback(() => {
|
||||
if (!selectedNode)
|
||||
return
|
||||
scrollToWorkflowNode(selectedNode.id)
|
||||
}, [selectedNode])
|
||||
|
||||
if (!selectedNode)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex h-6 cursor-pointer items-center justify-center whitespace-nowrap rounded-md border-[0.5px] border-effects-highlight bg-components-actionbar-bg px-3 text-text-tertiary shadow-lg backdrop-blur-sm transition-colors duration-200 system-xs-medium hover:text-text-accent',
|
||||
'system-xs-medium flex h-6 cursor-pointer items-center justify-center whitespace-nowrap rounded-md border-[0.5px] border-effects-highlight bg-components-actionbar-bg px-3 text-text-tertiary shadow-lg backdrop-blur-sm transition-colors duration-200 hover:text-text-accent',
|
||||
)}
|
||||
onClick={() => scrollToWorkflowNode(selectedNode.id)}
|
||||
onClick={handleScrollToSelectedNode}
|
||||
>
|
||||
{t('panel.scrollToSelectedNode', { ns: 'workflow' })}
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import type { FC } from 'react'
|
||||
import {
|
||||
RiArrowGoBackLine,
|
||||
RiArrowGoForwardFill,
|
||||
} from '@remixicon/react'
|
||||
import { memo, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ViewWorkflowHistory from '@/app/components/workflow/header/view-workflow-history'
|
||||
@@ -29,34 +33,28 @@ const UndoRedo: FC<UndoRedoProps> = ({ handleUndo, handleRedo }) => {
|
||||
return (
|
||||
<div className="flex items-center space-x-0.5 rounded-lg border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-lg backdrop-blur-[5px]">
|
||||
<TipPopup title={t('common.undo', { ns: 'workflow' })!} shortcuts={['ctrl', 'z']}>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('common.undo', { ns: 'workflow' })!}
|
||||
<div
|
||||
data-tooltip-id="workflow.undo"
|
||||
disabled={nodesReadOnly || buttonsDisabled.undo}
|
||||
className={
|
||||
cn('system-sm-medium flex h-8 w-8 cursor-pointer select-none items-center rounded-md px-1.5 text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary', (nodesReadOnly || buttonsDisabled.undo)
|
||||
&& 'cursor-not-allowed text-text-disabled hover:bg-transparent hover:text-text-disabled')
|
||||
}
|
||||
onClick={handleUndo}
|
||||
onClick={() => !nodesReadOnly && !buttonsDisabled.undo && handleUndo()}
|
||||
>
|
||||
<span className="i-ri-arrow-go-back-line h-4 w-4" />
|
||||
</button>
|
||||
<RiArrowGoBackLine className="h-4 w-4" />
|
||||
</div>
|
||||
</TipPopup>
|
||||
<TipPopup title={t('common.redo', { ns: 'workflow' })!} shortcuts={['ctrl', 'y']}>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('common.redo', { ns: 'workflow' })!}
|
||||
<div
|
||||
data-tooltip-id="workflow.redo"
|
||||
disabled={nodesReadOnly || buttonsDisabled.redo}
|
||||
className={
|
||||
cn('system-sm-medium flex h-8 w-8 cursor-pointer select-none items-center rounded-md px-1.5 text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary', (nodesReadOnly || buttonsDisabled.redo)
|
||||
&& 'cursor-not-allowed text-text-disabled hover:bg-transparent hover:text-text-disabled')
|
||||
}
|
||||
onClick={handleRedo}
|
||||
onClick={() => !nodesReadOnly && !buttonsDisabled.redo && handleRedo()}
|
||||
>
|
||||
<span className="i-ri-arrow-go-forward-fill h-4 w-4" />
|
||||
</button>
|
||||
<RiArrowGoForwardFill className="h-4 w-4" />
|
||||
</div>
|
||||
</TipPopup>
|
||||
<Divider type="vertical" className="mx-0.5 h-3.5" />
|
||||
<ViewWorkflowHistory />
|
||||
|
||||
@@ -73,18 +73,15 @@ const ViewHistory = ({
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
|
||||
{
|
||||
withText && (
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('common.showRunHistory', { ns: 'workflow' })}
|
||||
className={cn(
|
||||
'flex h-8 items-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3 shadow-xs',
|
||||
'cursor-pointer text-[13px] font-medium text-components-button-secondary-text hover:bg-components-button-secondary-bg-hover',
|
||||
open && 'bg-components-button-secondary-bg-hover',
|
||||
)}
|
||||
<div className={cn(
|
||||
'flex h-8 items-center rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3 shadow-xs',
|
||||
'cursor-pointer text-[13px] font-medium text-components-button-secondary-text hover:bg-components-button-secondary-bg-hover',
|
||||
open && 'bg-components-button-secondary-bg-hover',
|
||||
)}
|
||||
>
|
||||
<span className="i-custom-vender-line-time-clock-play mr-1 h-4 w-4" />
|
||||
{t('common.showRunHistory', { ns: 'workflow' })}
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
@@ -92,16 +89,14 @@ const ViewHistory = ({
|
||||
<Tooltip
|
||||
popupContent={t('common.viewRunHistory', { ns: 'workflow' })}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('common.viewRunHistory', { ns: 'workflow' })}
|
||||
<div
|
||||
className={cn('group flex h-7 w-7 cursor-pointer items-center justify-center rounded-md hover:bg-state-accent-hover', open && 'bg-state-accent-hover')}
|
||||
onClick={() => {
|
||||
onClearLogAndMessageModal?.()
|
||||
}}
|
||||
>
|
||||
<span className={cn('i-custom-vender-line-time-clock-play', 'h-4 w-4 group-hover:text-components-button-secondary-accent-text', open ? 'text-components-button-secondary-accent-text' : 'text-components-button-ghost-text')} />
|
||||
</button>
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
@@ -115,9 +110,7 @@ const ViewHistory = ({
|
||||
>
|
||||
<div className="sticky top-0 flex items-center justify-between bg-components-panel-bg px-4 pt-3 text-base font-semibold text-text-primary">
|
||||
<div className="grow">{t('common.runHistory', { ns: 'workflow' })}</div>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={t('operation.close', { ns: 'common' })}
|
||||
<div
|
||||
className="flex h-6 w-6 shrink-0 cursor-pointer items-center justify-center"
|
||||
onClick={() => {
|
||||
onClearLogAndMessageModal?.()
|
||||
@@ -125,7 +118,7 @@ const ViewHistory = ({
|
||||
}}
|
||||
>
|
||||
<span className="i-ri-close-line h-4 w-4 text-text-tertiary" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
isLoading && (
|
||||
|
||||
@@ -1,36 +1,54 @@
|
||||
import type { CommonNodeType } from '../../../types'
|
||||
import { fireEvent, screen } from '@testing-library/react'
|
||||
import { renderWorkflowComponent } from '../../../__tests__/workflow-test-env'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { BlockEnum, NodeRunningStatus } from '../../../types'
|
||||
import NodeControl from './node-control'
|
||||
|
||||
const {
|
||||
mockHandleNodeSelect,
|
||||
mockSetInitShowLastRunTab,
|
||||
mockSetPendingSingleRun,
|
||||
mockCanRunBySingle,
|
||||
} = vi.hoisted(() => ({
|
||||
mockHandleNodeSelect: vi.fn(),
|
||||
mockSetInitShowLastRunTab: vi.fn(),
|
||||
mockSetPendingSingleRun: vi.fn(),
|
||||
mockCanRunBySingle: vi.fn(() => true),
|
||||
}))
|
||||
|
||||
let mockPluginInstallLocked = false
|
||||
vi.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../hooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('../../../hooks')>('../../../hooks')
|
||||
return {
|
||||
...actual,
|
||||
useNodesInteractions: () => ({
|
||||
handleNodeSelect: mockHandleNodeSelect,
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children, popupContent }: { children: React.ReactNode, popupContent: string }) => (
|
||||
<div data-testid="tooltip" data-content={popupContent}>{children}</div>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/icons/src/vender/line/mediaAndDevices', () => ({
|
||||
Stop: ({ className }: { className?: string }) => <div data-testid="stop-icon" className={className} />,
|
||||
}))
|
||||
|
||||
vi.mock('../../../hooks', () => ({
|
||||
useNodesInteractions: () => ({
|
||||
handleNodeSelect: mockHandleNodeSelect,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/store', () => ({
|
||||
useWorkflowStore: () => ({
|
||||
getState: () => ({
|
||||
setInitShowLastRunTab: mockSetInitShowLastRunTab,
|
||||
setPendingSingleRun: mockSetPendingSingleRun,
|
||||
}),
|
||||
}
|
||||
})
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../../utils', async () => {
|
||||
const actual = await vi.importActual<typeof import('../../../utils')>('../../../utils')
|
||||
return {
|
||||
...actual,
|
||||
canRunBySingle: mockCanRunBySingle,
|
||||
}
|
||||
})
|
||||
vi.mock('../../../utils', () => ({
|
||||
canRunBySingle: mockCanRunBySingle,
|
||||
}))
|
||||
|
||||
vi.mock('./panel-operator', () => ({
|
||||
default: ({ onOpenChange }: { onOpenChange: (open: boolean) => void }) => (
|
||||
@@ -41,16 +59,6 @@ vi.mock('./panel-operator', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
function NodeControlHarness({ id, data }: { id: string, data: CommonNodeType, selected?: boolean }) {
|
||||
return (
|
||||
<NodeControl
|
||||
id={id}
|
||||
data={data}
|
||||
pluginInstallLocked={mockPluginInstallLocked}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
const makeData = (overrides: Partial<CommonNodeType> = {}): CommonNodeType => ({
|
||||
type: BlockEnum.Code,
|
||||
title: 'Node',
|
||||
@@ -65,71 +73,65 @@ const makeData = (overrides: Partial<CommonNodeType> = {}): CommonNodeType => ({
|
||||
describe('NodeControl', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockPluginInstallLocked = false
|
||||
mockCanRunBySingle.mockReturnValue(true)
|
||||
})
|
||||
|
||||
// Run/stop behavior should be driven by the workflow store, not CSS classes.
|
||||
describe('Single Run Actions', () => {
|
||||
it('should trigger a single run through the workflow store', () => {
|
||||
const { store } = renderWorkflowComponent(
|
||||
<NodeControlHarness id="node-1" data={makeData()} />,
|
||||
)
|
||||
it('should trigger a single run and show the hover control when plugins are not locked', () => {
|
||||
const { container } = render(
|
||||
<NodeControl
|
||||
id="node-1"
|
||||
data={makeData()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.panel.runThisStep' }))
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper.className).toContain('group-hover:flex')
|
||||
expect(screen.getByTestId('tooltip')).toHaveAttribute('data-content', 'panel.runThisStep')
|
||||
|
||||
expect(store.getState().initShowLastRunTab).toBe(true)
|
||||
expect(store.getState().pendingSingleRun).toEqual({ nodeId: 'node-1', action: 'run' })
|
||||
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-1')
|
||||
})
|
||||
fireEvent.click(screen.getByTestId('tooltip').parentElement!)
|
||||
|
||||
it('should trigger stop when the node is already single-running', () => {
|
||||
const { store } = renderWorkflowComponent(
|
||||
<NodeControlHarness
|
||||
id="node-2"
|
||||
data={makeData({
|
||||
selected: true,
|
||||
_singleRunningStatus: NodeRunningStatus.Running,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'workflow.debug.variableInspect.trigger.stop' }))
|
||||
|
||||
expect(store.getState().pendingSingleRun).toEqual({ nodeId: 'node-2', action: 'stop' })
|
||||
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-2')
|
||||
})
|
||||
expect(mockSetInitShowLastRunTab).toHaveBeenCalledWith(true)
|
||||
expect(mockSetPendingSingleRun).toHaveBeenCalledWith({ nodeId: 'node-1', action: 'run' })
|
||||
expect(mockHandleNodeSelect).toHaveBeenCalledWith('node-1')
|
||||
})
|
||||
|
||||
// Capability gating should hide the run control while leaving panel actions available.
|
||||
describe('Availability', () => {
|
||||
it('should keep the panel operator available when the plugin is install-locked', () => {
|
||||
mockPluginInstallLocked = true
|
||||
it('should render the stop action, keep locked controls hidden by default, and stay open when panel operator opens', () => {
|
||||
const { container } = render(
|
||||
<NodeControl
|
||||
id="node-2"
|
||||
pluginInstallLocked
|
||||
data={makeData({
|
||||
selected: true,
|
||||
_singleRunningStatus: NodeRunningStatus.Running,
|
||||
isInIteration: true,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
|
||||
renderWorkflowComponent(
|
||||
<NodeControlHarness
|
||||
id="node-3"
|
||||
data={makeData({
|
||||
selected: true,
|
||||
})}
|
||||
/>,
|
||||
)
|
||||
const wrapper = container.firstChild as HTMLElement
|
||||
expect(wrapper.className).not.toContain('group-hover:flex')
|
||||
expect(wrapper.className).toContain('!flex')
|
||||
expect(screen.getByTestId('stop-icon')).toBeInTheDocument()
|
||||
|
||||
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
|
||||
})
|
||||
fireEvent.click(screen.getByTestId('stop-icon').parentElement!)
|
||||
|
||||
it('should hide the run control when single-node execution is not supported', () => {
|
||||
mockCanRunBySingle.mockReturnValue(false)
|
||||
expect(mockSetPendingSingleRun).toHaveBeenCalledWith({ nodeId: 'node-2', action: 'stop' })
|
||||
|
||||
renderWorkflowComponent(
|
||||
<NodeControlHarness
|
||||
id="node-4"
|
||||
data={makeData()}
|
||||
/>,
|
||||
)
|
||||
fireEvent.click(screen.getByRole('button', { name: 'open panel' }))
|
||||
expect(wrapper.className).toContain('!flex')
|
||||
})
|
||||
|
||||
expect(screen.queryByRole('button', { name: 'workflow.panel.runThisStep' })).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
|
||||
})
|
||||
it('should hide the run control when single-node execution is not supported', () => {
|
||||
mockCanRunBySingle.mockReturnValue(false)
|
||||
|
||||
render(
|
||||
<NodeControl
|
||||
id="node-3"
|
||||
data={makeData()}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'open panel' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Node } from '../../../types'
|
||||
import {
|
||||
RiPlayLargeLine,
|
||||
} from '@remixicon/react'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
@@ -51,9 +54,7 @@ const NodeControl: FC<NodeControlProps> = ({
|
||||
>
|
||||
{
|
||||
canRunBySingle(data.type, isChildNode) && (
|
||||
<button
|
||||
type="button"
|
||||
aria-label={isSingleRunning ? t('debug.variableInspect.trigger.stop', { ns: 'workflow' }) : t('panel.runThisStep', { ns: 'workflow' })}
|
||||
<div
|
||||
className={`flex h-5 w-5 items-center justify-center rounded-md ${isSingleRunning && 'cursor-pointer hover:bg-state-base-hover'}`}
|
||||
onClick={() => {
|
||||
const action = isSingleRunning ? 'stop' : 'run'
|
||||
@@ -75,11 +76,11 @@ const NodeControl: FC<NodeControlProps> = ({
|
||||
popupContent={t('panel.runThisStep', { ns: 'workflow' })}
|
||||
asChild={false}
|
||||
>
|
||||
<span className="i-ri-play-large-line h-3 w-3" />
|
||||
<RiPlayLargeLine className="h-3 w-3" />
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<PanelOperator
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { OutputVar } from '../../../code/types'
|
||||
import type { ToastHandle } from '@/app/components/base/toast'
|
||||
import type { VarType } from '@/app/components/workflow/types'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
|
||||
import RemoveButton from '../remove-button'
|
||||
import VarTypePicker from './var-type-picker'
|
||||
@@ -29,6 +30,7 @@ const OutputVarList: FC<Props> = ({
|
||||
onRemove,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [toastHandler, setToastHandler] = useState<ToastHandle>()
|
||||
|
||||
const list = outputKeyOrders.map((key) => {
|
||||
return {
|
||||
@@ -40,17 +42,20 @@ const OutputVarList: FC<Props> = ({
|
||||
const { run: validateVarInput } = useDebounceFn((existingVariables: typeof list, newKey: string) => {
|
||||
const result = checkKeys([newKey], true)
|
||||
if (!result.isValid) {
|
||||
toast.add({
|
||||
setToastHandler(Toast.notify({
|
||||
type: 'error',
|
||||
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
})
|
||||
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
}))
|
||||
return
|
||||
}
|
||||
if (existingVariables.some(key => key.variable?.trim() === newKey.trim())) {
|
||||
toast.add({
|
||||
setToastHandler(Toast.notify({
|
||||
type: 'error',
|
||||
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
})
|
||||
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
}))
|
||||
}
|
||||
else {
|
||||
toastHandler?.clear?.()
|
||||
}
|
||||
}, { wait: 500 })
|
||||
|
||||
@@ -61,6 +66,7 @@ const OutputVarList: FC<Props> = ({
|
||||
replaceSpaceWithUnderscoreInVarNameInput(e.target)
|
||||
const newKey = e.target.value
|
||||
|
||||
toastHandler?.clear?.()
|
||||
validateVarInput(list.toSpliced(index, 1), newKey)
|
||||
|
||||
const newOutputs = produce(outputs, (draft) => {
|
||||
@@ -69,7 +75,7 @@ const OutputVarList: FC<Props> = ({
|
||||
})
|
||||
onChange(newOutputs, index, newKey)
|
||||
}
|
||||
}, [list, onChange, outputs, validateVarInput])
|
||||
}, [list, onChange, outputs, outputKeyOrders, validateVarInput])
|
||||
|
||||
const handleVarTypeChange = useCallback((index: number) => {
|
||||
return (value: string) => {
|
||||
@@ -79,7 +85,7 @@ const OutputVarList: FC<Props> = ({
|
||||
})
|
||||
onChange(newOutputs)
|
||||
}
|
||||
}, [list, onChange, outputs])
|
||||
}, [list, onChange, outputs, outputKeyOrders])
|
||||
|
||||
const handleVarRemove = useCallback((index: number) => {
|
||||
return () => {
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { ToastHandle } from '@/app/components/base/toast'
|
||||
import type { ValueSelector, Var, Variable } from '@/app/components/workflow/types'
|
||||
import { RiDraggable } from '@remixicon/react'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ReactSortable } from 'react-sortablejs'
|
||||
import { v4 as uuid4 } from 'uuid'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { checkKeys, replaceSpaceWithUnderscoreInVarNameInput } from '@/utils/var'
|
||||
@@ -41,6 +42,7 @@ const VarList: FC<Props> = ({
|
||||
isSupportFileVar = true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [toastHandle, setToastHandle] = useState<ToastHandle>()
|
||||
|
||||
const listWithIds = useMemo(() => list.map((item) => {
|
||||
const id = uuid4()
|
||||
@@ -53,17 +55,20 @@ const VarList: FC<Props> = ({
|
||||
const { run: validateVarInput } = useDebounceFn((list: Variable[], newKey: string) => {
|
||||
const result = checkKeys([newKey], true)
|
||||
if (!result.isValid) {
|
||||
toast.add({
|
||||
setToastHandle(Toast.notify({
|
||||
type: 'error',
|
||||
title: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
})
|
||||
message: t(`varKeyError.${result.errorMessageKey}`, { ns: 'appDebug', key: result.errorKey }),
|
||||
}))
|
||||
return
|
||||
}
|
||||
if (list.some(item => item.variable?.trim() === newKey.trim())) {
|
||||
toast.add({
|
||||
setToastHandle(Toast.notify({
|
||||
type: 'error',
|
||||
title: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
})
|
||||
message: t('varKeyError.keyAlreadyExists', { ns: 'appDebug', key: newKey }),
|
||||
}))
|
||||
}
|
||||
else {
|
||||
toastHandle?.clear?.()
|
||||
}
|
||||
}, { wait: 500 })
|
||||
|
||||
@@ -73,6 +78,7 @@ const VarList: FC<Props> = ({
|
||||
|
||||
const newKey = e.target.value
|
||||
|
||||
toastHandle?.clear?.()
|
||||
validateVarInput(list.toSpliced(index, 1), newKey)
|
||||
|
||||
onVarNameChange?.(list[index].variable, newKey)
|
||||
|
||||
@@ -1,68 +1,90 @@
|
||||
import type { WebhookTriggerNodeType } from '../types'
|
||||
import type { NodePanelProps } from '@/app/components/workflow/types'
|
||||
import type { PanelProps } from '@/types/workflow'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import Panel from '../panel'
|
||||
|
||||
const {
|
||||
mockHandleStatusCodeChange,
|
||||
mockGenerateWebhookUrl,
|
||||
mockHandleMethodChange,
|
||||
mockHandleContentTypeChange,
|
||||
mockHandleHeadersChange,
|
||||
mockHandleParamsChange,
|
||||
mockHandleBodyChange,
|
||||
mockHandleResponseBodyChange,
|
||||
} = vi.hoisted(() => ({
|
||||
mockHandleStatusCodeChange: vi.fn(),
|
||||
mockGenerateWebhookUrl: vi.fn(),
|
||||
mockHandleMethodChange: vi.fn(),
|
||||
mockHandleContentTypeChange: vi.fn(),
|
||||
mockHandleHeadersChange: vi.fn(),
|
||||
mockHandleParamsChange: vi.fn(),
|
||||
mockHandleBodyChange: vi.fn(),
|
||||
mockHandleResponseBodyChange: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockConfigState = {
|
||||
readOnly: false,
|
||||
inputs: {
|
||||
method: 'POST',
|
||||
webhook_url: 'https://example.com/webhook',
|
||||
webhook_debug_url: '',
|
||||
content_type: 'application/json',
|
||||
headers: [],
|
||||
params: [],
|
||||
body: [],
|
||||
status_code: 200,
|
||||
response_body: 'ok',
|
||||
variables: [],
|
||||
},
|
||||
}
|
||||
|
||||
vi.mock('../use-config', () => ({
|
||||
DEFAULT_STATUS_CODE: 200,
|
||||
MAX_STATUS_CODE: 399,
|
||||
normalizeStatusCode: (statusCode: number) => Math.min(Math.max(statusCode, 200), 399),
|
||||
useConfig: () => ({
|
||||
readOnly: mockConfigState.readOnly,
|
||||
inputs: mockConfigState.inputs,
|
||||
handleMethodChange: mockHandleMethodChange,
|
||||
handleContentTypeChange: mockHandleContentTypeChange,
|
||||
handleHeadersChange: mockHandleHeadersChange,
|
||||
handleParamsChange: mockHandleParamsChange,
|
||||
handleBodyChange: mockHandleBodyChange,
|
||||
readOnly: false,
|
||||
inputs: {
|
||||
method: 'POST',
|
||||
webhook_url: 'https://example.com/webhook',
|
||||
webhook_debug_url: '',
|
||||
content_type: 'application/json',
|
||||
headers: [],
|
||||
params: [],
|
||||
body: [],
|
||||
status_code: 200,
|
||||
response_body: '',
|
||||
},
|
||||
handleMethodChange: vi.fn(),
|
||||
handleContentTypeChange: vi.fn(),
|
||||
handleHeadersChange: vi.fn(),
|
||||
handleParamsChange: vi.fn(),
|
||||
handleBodyChange: vi.fn(),
|
||||
handleStatusCodeChange: mockHandleStatusCodeChange,
|
||||
handleResponseBodyChange: mockHandleResponseBodyChange,
|
||||
handleResponseBodyChange: vi.fn(),
|
||||
generateWebhookUrl: mockGenerateWebhookUrl,
|
||||
}),
|
||||
}))
|
||||
|
||||
const getStatusCodeInput = () => {
|
||||
return screen.getAllByDisplayValue('200')
|
||||
.find(element => element.getAttribute('aria-hidden') !== 'true') as HTMLInputElement
|
||||
}
|
||||
vi.mock('@/app/components/base/input-with-copy', () => ({
|
||||
default: () => <div data-testid="input-with-copy" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/select', () => ({
|
||||
SimpleSelect: () => <div data-testid="simple-select" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/tooltip', () => ({
|
||||
default: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/field', () => ({
|
||||
default: ({ title, children }: { title: React.ReactNode, children: React.ReactNode }) => (
|
||||
<section>
|
||||
<div>{title}</div>
|
||||
{children}
|
||||
</section>
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/output-vars', () => ({
|
||||
default: () => <div data-testid="output-vars" />,
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/nodes/_base/components/split', () => ({
|
||||
default: () => <div data-testid="split" />,
|
||||
}))
|
||||
|
||||
vi.mock('../components/header-table', () => ({
|
||||
default: () => <div data-testid="header-table" />,
|
||||
}))
|
||||
|
||||
vi.mock('../components/parameter-table', () => ({
|
||||
default: () => <div data-testid="parameter-table" />,
|
||||
}))
|
||||
|
||||
vi.mock('../components/paragraph-input', () => ({
|
||||
default: () => <div data-testid="paragraph-input" />,
|
||||
}))
|
||||
|
||||
vi.mock('../utils/render-output-vars', () => ({
|
||||
OutputVariablesContent: () => <div data-testid="output-variables-content" />,
|
||||
}))
|
||||
|
||||
describe('WebhookTriggerPanel', () => {
|
||||
const panelProps: NodePanelProps<WebhookTriggerNodeType> = {
|
||||
@@ -78,7 +100,7 @@ describe('WebhookTriggerPanel', () => {
|
||||
body: [],
|
||||
async_mode: false,
|
||||
status_code: 200,
|
||||
response_body: 'ok',
|
||||
response_body: '',
|
||||
variables: [],
|
||||
},
|
||||
panelProps: {} as PanelProps,
|
||||
@@ -86,65 +108,26 @@ describe('WebhookTriggerPanel', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockConfigState.readOnly = false
|
||||
mockConfigState.inputs = {
|
||||
method: 'POST',
|
||||
webhook_url: 'https://example.com/webhook',
|
||||
webhook_debug_url: '',
|
||||
content_type: 'application/json',
|
||||
headers: [],
|
||||
params: [],
|
||||
body: [],
|
||||
status_code: 200,
|
||||
response_body: 'ok',
|
||||
variables: [],
|
||||
}
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render the real panel fields without generating a new webhook url when one already exists', () => {
|
||||
render(<Panel {...panelProps} />)
|
||||
it('should update the status code when users enter a parseable value', () => {
|
||||
render(<Panel {...panelProps} />)
|
||||
|
||||
expect(screen.getByDisplayValue('https://example.com/webhook')).toBeInTheDocument()
|
||||
expect(screen.getByText('application/json')).toBeInTheDocument()
|
||||
expect(screen.getByDisplayValue('ok')).toBeInTheDocument()
|
||||
expect(mockGenerateWebhookUrl).not.toHaveBeenCalled()
|
||||
})
|
||||
fireEvent.change(screen.getByRole('textbox'), { target: { value: '201' } })
|
||||
|
||||
it('should request a webhook url when the node is writable and missing one', async () => {
|
||||
mockConfigState.inputs = {
|
||||
...mockConfigState.inputs,
|
||||
webhook_url: '',
|
||||
}
|
||||
|
||||
render(<Panel {...panelProps} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockGenerateWebhookUrl).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(201)
|
||||
})
|
||||
|
||||
describe('Status Code Input', () => {
|
||||
it('should update the status code when users enter a parseable value', () => {
|
||||
render(<Panel {...panelProps} />)
|
||||
it('should ignore clear changes until the value is committed', () => {
|
||||
render(<Panel {...panelProps} />)
|
||||
|
||||
fireEvent.change(getStatusCodeInput(), { target: { value: '201' } })
|
||||
const input = screen.getByRole('textbox')
|
||||
fireEvent.change(input, { target: { value: '' } })
|
||||
|
||||
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(201)
|
||||
})
|
||||
expect(mockHandleStatusCodeChange).not.toHaveBeenCalled()
|
||||
|
||||
it('should ignore clear changes until the value is committed', () => {
|
||||
render(<Panel {...panelProps} />)
|
||||
fireEvent.blur(input)
|
||||
|
||||
const input = getStatusCodeInput()
|
||||
fireEvent.change(input, { target: { value: '' } })
|
||||
|
||||
expect(mockHandleStatusCodeChange).not.toHaveBeenCalled()
|
||||
|
||||
fireEvent.blur(input)
|
||||
|
||||
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(200)
|
||||
})
|
||||
expect(mockHandleStatusCodeChange).toHaveBeenCalledWith(200)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { act, render, screen, waitFor } from '@testing-library/react'
|
||||
import ReactFlow, { ReactFlowProvider } from 'reactflow'
|
||||
import { FlowType } from '@/types/common'
|
||||
import { BlockEnum } from '../../types'
|
||||
import AddBlock from '../add-block'
|
||||
|
||||
type BlockSelectorMockProps = {
|
||||
open: boolean
|
||||
onOpenChange: (open: boolean) => void
|
||||
disabled: boolean
|
||||
onSelect: (type: BlockEnum, pluginDefaultValue?: Record<string, unknown>) => void
|
||||
placement: string
|
||||
offset: {
|
||||
mainAxis: number
|
||||
crossAxis: number
|
||||
}
|
||||
trigger: (open: boolean) => ReactNode
|
||||
popupClassName: string
|
||||
availableBlocksTypes: BlockEnum[]
|
||||
showStartTab: boolean
|
||||
}
|
||||
|
||||
const {
|
||||
mockHandlePaneContextmenuCancel,
|
||||
mockWorkflowStoreSetState,
|
||||
mockGenerateNewNode,
|
||||
mockGetNodeCustomTypeByNodeDataType,
|
||||
} = vi.hoisted(() => ({
|
||||
mockHandlePaneContextmenuCancel: vi.fn(),
|
||||
mockWorkflowStoreSetState: vi.fn(),
|
||||
mockGenerateNewNode: vi.fn(({ type, data }: { type: string, data: Record<string, unknown> }) => ({
|
||||
newNode: {
|
||||
id: 'generated-node',
|
||||
type,
|
||||
data,
|
||||
},
|
||||
})),
|
||||
mockGetNodeCustomTypeByNodeDataType: vi.fn((type: string) => `${type}-custom`),
|
||||
}))
|
||||
|
||||
let latestBlockSelectorProps: BlockSelectorMockProps | null = null
|
||||
let mockNodesReadOnly = false
|
||||
let mockIsChatMode = false
|
||||
let mockFlowType: FlowType = FlowType.appFlow
|
||||
|
||||
const mockAvailableNextBlocks = [BlockEnum.Answer, BlockEnum.Code]
|
||||
const mockNodesMetaDataMap = {
|
||||
[BlockEnum.Answer]: {
|
||||
defaultValue: {
|
||||
title: 'Answer',
|
||||
desc: '',
|
||||
type: BlockEnum.Answer,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
vi.mock('@/app/components/workflow/block-selector', () => ({
|
||||
default: (props: BlockSelectorMockProps) => {
|
||||
latestBlockSelectorProps = props
|
||||
return (
|
||||
<div data-testid="block-selector">
|
||||
{props.trigger(props.open)}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
useAvailableBlocks: () => ({
|
||||
availableNextBlocks: mockAvailableNextBlocks,
|
||||
}),
|
||||
useIsChatMode: () => mockIsChatMode,
|
||||
useNodesMetaData: () => ({
|
||||
nodesMap: mockNodesMetaDataMap,
|
||||
}),
|
||||
useNodesReadOnly: () => ({
|
||||
nodesReadOnly: mockNodesReadOnly,
|
||||
}),
|
||||
usePanelInteractions: () => ({
|
||||
handlePaneContextmenuCancel: mockHandlePaneContextmenuCancel,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../hooks-store', () => ({
|
||||
useHooksStore: (selector: (state: { configsMap?: { flowType?: FlowType } }) => unknown) =>
|
||||
selector({ configsMap: { flowType: mockFlowType } }),
|
||||
}))
|
||||
|
||||
vi.mock('../../store', () => ({
|
||||
useWorkflowStore: () => ({
|
||||
setState: mockWorkflowStoreSetState,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../utils', () => ({
|
||||
generateNewNode: mockGenerateNewNode,
|
||||
getNodeCustomTypeByNodeDataType: mockGetNodeCustomTypeByNodeDataType,
|
||||
}))
|
||||
|
||||
vi.mock('../tip-popup', () => ({
|
||||
default: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
}))
|
||||
|
||||
const renderWithReactFlow = (nodes: Array<{ id: string, position: { x: number, y: number }, data: { type: BlockEnum } }>) => {
|
||||
return render(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow nodes={nodes} edges={[]} fitView />
|
||||
<AddBlock />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('AddBlock', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
latestBlockSelectorProps = null
|
||||
mockNodesReadOnly = false
|
||||
mockIsChatMode = false
|
||||
mockFlowType = FlowType.appFlow
|
||||
})
|
||||
|
||||
// Rendering and selector configuration.
|
||||
describe('Rendering', () => {
|
||||
it('should pass the selector props for a writable app workflow', async () => {
|
||||
renderWithReactFlow([])
|
||||
|
||||
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
|
||||
|
||||
expect(screen.getByTestId('block-selector')).toBeInTheDocument()
|
||||
expect(latestBlockSelectorProps).toMatchObject({
|
||||
disabled: false,
|
||||
availableBlocksTypes: mockAvailableNextBlocks,
|
||||
showStartTab: true,
|
||||
placement: 'right-start',
|
||||
popupClassName: '!min-w-[256px]',
|
||||
})
|
||||
expect(latestBlockSelectorProps?.offset).toEqual({
|
||||
mainAxis: 4,
|
||||
crossAxis: -8,
|
||||
})
|
||||
})
|
||||
|
||||
it('should hide the start tab for chat mode and rag pipeline flows', async () => {
|
||||
mockIsChatMode = true
|
||||
const { rerender } = renderWithReactFlow([])
|
||||
|
||||
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
|
||||
|
||||
expect(latestBlockSelectorProps?.showStartTab).toBe(false)
|
||||
|
||||
mockIsChatMode = false
|
||||
mockFlowType = FlowType.ragPipeline
|
||||
rerender(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow nodes={[]} edges={[]} fitView />
|
||||
<AddBlock />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
)
|
||||
|
||||
expect(latestBlockSelectorProps?.showStartTab).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// User interactions that bridge selector state and workflow state.
|
||||
describe('User Interactions', () => {
|
||||
it('should cancel the pane context menu when the selector closes', async () => {
|
||||
renderWithReactFlow([])
|
||||
|
||||
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
|
||||
|
||||
act(() => {
|
||||
latestBlockSelectorProps?.onOpenChange(false)
|
||||
})
|
||||
|
||||
expect(mockHandlePaneContextmenuCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should create a candidate node with an incremented title when a block is selected', async () => {
|
||||
renderWithReactFlow([
|
||||
{ id: 'node-1', position: { x: 0, y: 0 }, data: { type: BlockEnum.Answer } },
|
||||
{ id: 'node-2', position: { x: 80, y: 0 }, data: { type: BlockEnum.Answer } },
|
||||
])
|
||||
|
||||
await waitFor(() => expect(latestBlockSelectorProps).not.toBeNull())
|
||||
|
||||
act(() => {
|
||||
latestBlockSelectorProps?.onSelect(BlockEnum.Answer, { pluginId: 'plugin-1' })
|
||||
})
|
||||
|
||||
expect(mockGetNodeCustomTypeByNodeDataType).toHaveBeenCalledWith(BlockEnum.Answer)
|
||||
expect(mockGenerateNewNode).toHaveBeenCalledWith({
|
||||
type: 'answer-custom',
|
||||
data: {
|
||||
title: 'Answer 3',
|
||||
desc: '',
|
||||
type: BlockEnum.Answer,
|
||||
pluginId: 'plugin-1',
|
||||
_isCandidate: true,
|
||||
},
|
||||
position: {
|
||||
x: 0,
|
||||
y: 0,
|
||||
},
|
||||
})
|
||||
expect(mockWorkflowStoreSetState).toHaveBeenCalledWith({
|
||||
candidateNode: {
|
||||
id: 'generated-node',
|
||||
type: 'answer-custom',
|
||||
data: {
|
||||
title: 'Answer 3',
|
||||
desc: '',
|
||||
type: BlockEnum.Answer,
|
||||
pluginId: 'plugin-1',
|
||||
_isCandidate: true,
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,136 +0,0 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { ControlMode } from '../../types'
|
||||
import Control from '../control'
|
||||
|
||||
type WorkflowStoreState = {
|
||||
controlMode: ControlMode
|
||||
maximizeCanvas: boolean
|
||||
}
|
||||
|
||||
const {
|
||||
mockHandleAddNote,
|
||||
mockHandleLayout,
|
||||
mockHandleModeHand,
|
||||
mockHandleModePointer,
|
||||
mockHandleToggleMaximizeCanvas,
|
||||
} = vi.hoisted(() => ({
|
||||
mockHandleAddNote: vi.fn(),
|
||||
mockHandleLayout: vi.fn(),
|
||||
mockHandleModeHand: vi.fn(),
|
||||
mockHandleModePointer: vi.fn(),
|
||||
mockHandleToggleMaximizeCanvas: vi.fn(),
|
||||
}))
|
||||
|
||||
let mockNodesReadOnly = false
|
||||
let mockStoreState: WorkflowStoreState
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
useNodesReadOnly: () => ({
|
||||
nodesReadOnly: mockNodesReadOnly,
|
||||
getNodesReadOnly: () => mockNodesReadOnly,
|
||||
}),
|
||||
useWorkflowCanvasMaximize: () => ({
|
||||
handleToggleMaximizeCanvas: mockHandleToggleMaximizeCanvas,
|
||||
}),
|
||||
useWorkflowMoveMode: () => ({
|
||||
handleModePointer: mockHandleModePointer,
|
||||
handleModeHand: mockHandleModeHand,
|
||||
}),
|
||||
useWorkflowOrganize: () => ({
|
||||
handleLayout: mockHandleLayout,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../hooks', () => ({
|
||||
useOperator: () => ({
|
||||
handleAddNote: mockHandleAddNote,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('../../store', () => ({
|
||||
useStore: (selector: (state: WorkflowStoreState) => unknown) => selector(mockStoreState),
|
||||
}))
|
||||
|
||||
vi.mock('../add-block', () => ({
|
||||
default: () => <div data-testid="add-block" />,
|
||||
}))
|
||||
|
||||
vi.mock('../more-actions', () => ({
|
||||
default: () => <div data-testid="more-actions" />,
|
||||
}))
|
||||
|
||||
vi.mock('../tip-popup', () => ({
|
||||
default: ({
|
||||
children,
|
||||
title,
|
||||
}: {
|
||||
children?: ReactNode
|
||||
title?: string
|
||||
}) => <div data-testid={title}>{children}</div>,
|
||||
}))
|
||||
|
||||
describe('Control', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockNodesReadOnly = false
|
||||
mockStoreState = {
|
||||
controlMode: ControlMode.Pointer,
|
||||
maximizeCanvas: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Rendering and visual states for control buttons.
|
||||
describe('Rendering', () => {
|
||||
it('should render the child action groups and highlight the active pointer mode', () => {
|
||||
render(<Control />)
|
||||
|
||||
expect(screen.getByTestId('add-block')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('more-actions')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('workflow.common.pointerMode').firstElementChild).toHaveClass('bg-state-accent-active')
|
||||
expect(screen.getByTestId('workflow.common.handMode').firstElementChild).not.toHaveClass('bg-state-accent-active')
|
||||
expect(screen.getByTestId('workflow.panel.maximize')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should switch the maximize tooltip and active style when the canvas is maximized', () => {
|
||||
mockStoreState = {
|
||||
controlMode: ControlMode.Hand,
|
||||
maximizeCanvas: true,
|
||||
}
|
||||
|
||||
render(<Control />)
|
||||
|
||||
expect(screen.getByTestId('workflow.common.handMode').firstElementChild).toHaveClass('bg-state-accent-active')
|
||||
expect(screen.getByTestId('workflow.panel.minimize').firstElementChild).toHaveClass('bg-state-accent-active')
|
||||
})
|
||||
})
|
||||
|
||||
// User interactions exposed by the control bar.
|
||||
describe('User Interactions', () => {
|
||||
it('should trigger the note, mode, organize, and maximize handlers', () => {
|
||||
render(<Control />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('workflow.nodes.note.addNote').firstElementChild as HTMLElement)
|
||||
fireEvent.click(screen.getByTestId('workflow.common.pointerMode').firstElementChild as HTMLElement)
|
||||
fireEvent.click(screen.getByTestId('workflow.common.handMode').firstElementChild as HTMLElement)
|
||||
fireEvent.click(screen.getByTestId('workflow.panel.organizeBlocks').firstElementChild as HTMLElement)
|
||||
fireEvent.click(screen.getByTestId('workflow.panel.maximize').firstElementChild as HTMLElement)
|
||||
|
||||
expect(mockHandleAddNote).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleModePointer).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleModeHand).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleLayout).toHaveBeenCalledTimes(1)
|
||||
expect(mockHandleToggleMaximizeCanvas).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should block note creation when the workflow is read only', () => {
|
||||
mockNodesReadOnly = true
|
||||
|
||||
render(<Control />)
|
||||
|
||||
fireEvent.click(screen.getByTestId('workflow.nodes.note.addNote').firstElementChild as HTMLElement)
|
||||
|
||||
expect(mockHandleAddNote).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,323 +0,0 @@
|
||||
import type { Shape as HooksStoreShape } from '../../hooks-store/store'
|
||||
import type { RunFile } from '../../types'
|
||||
import type { FileUpload } from '@/app/components/base/features/types'
|
||||
import { screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import ReactFlow, { ReactFlowProvider } from 'reactflow'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { FlowType } from '@/types/common'
|
||||
import { createStartNode } from '../../__tests__/fixtures'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import { InputVarType, WorkflowRunningStatus } from '../../types'
|
||||
import InputsPanel from '../inputs-panel'
|
||||
|
||||
const mockCheckInputsForm = vi.fn()
|
||||
const mockNotify = vi.fn()
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
useParams: () => ({}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast/context', () => ({
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
close: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/chat/chat/check-input-forms-hooks', () => ({
|
||||
useCheckInputsForms: () => ({
|
||||
checkInputsForm: mockCheckInputsForm,
|
||||
}),
|
||||
}))
|
||||
|
||||
const fileSettingsWithImage = {
|
||||
enabled: true,
|
||||
image: {
|
||||
enabled: true,
|
||||
},
|
||||
allowed_file_upload_methods: [TransferMethod.remote_url],
|
||||
number_limits: 3,
|
||||
image_file_size_limit: 10,
|
||||
} satisfies FileUpload & { image_file_size_limit: number }
|
||||
|
||||
const uploadedRunFile = {
|
||||
transfer_method: TransferMethod.remote_url,
|
||||
upload_file_id: 'file-2',
|
||||
} as unknown as RunFile
|
||||
|
||||
const uploadingRunFile = {
|
||||
transfer_method: TransferMethod.local_file,
|
||||
} as unknown as RunFile
|
||||
|
||||
const createHooksStoreProps = (
|
||||
overrides: Partial<HooksStoreShape> = {},
|
||||
): Partial<HooksStoreShape> => ({
|
||||
handleRun: vi.fn(),
|
||||
configsMap: {
|
||||
flowId: 'flow-1',
|
||||
flowType: FlowType.appFlow,
|
||||
fileSettings: fileSettingsWithImage,
|
||||
},
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const renderInputsPanel = (
|
||||
startNode: ReturnType<typeof createStartNode>,
|
||||
options?: Parameters<typeof renderWorkflowComponent>[1],
|
||||
) => {
|
||||
return renderWorkflowComponent(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow nodes={[startNode]} edges={[]} fitView />
|
||||
<InputsPanel onRun={vi.fn()} />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
options,
|
||||
)
|
||||
}
|
||||
|
||||
describe('InputsPanel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCheckInputsForm.mockReturnValue(true)
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render current inputs, defaults, and the image uploader from the start node', () => {
|
||||
renderInputsPanel(
|
||||
createStartNode({
|
||||
data: {
|
||||
variables: [
|
||||
{
|
||||
type: InputVarType.textInput,
|
||||
variable: 'question',
|
||||
label: 'Question',
|
||||
required: true,
|
||||
default: 'default question',
|
||||
},
|
||||
{
|
||||
type: InputVarType.number,
|
||||
variable: 'count',
|
||||
label: 'Count',
|
||||
required: false,
|
||||
default: '2',
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
{
|
||||
initialStoreState: {
|
||||
inputs: {
|
||||
question: 'overridden question',
|
||||
},
|
||||
},
|
||||
hooksStoreProps: createHooksStoreProps(),
|
||||
},
|
||||
)
|
||||
|
||||
expect(screen.getByDisplayValue('overridden question')).toHaveFocus()
|
||||
expect(screen.getByRole('spinbutton')).toHaveValue(2)
|
||||
expect(screen.getByText('common.imageUploader.pasteImageLink')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('User Interactions', () => {
|
||||
it('should update workflow inputs and image files when users edit the form', async () => {
|
||||
const user = userEvent.setup()
|
||||
const { store } = renderInputsPanel(
|
||||
createStartNode({
|
||||
data: {
|
||||
variables: [
|
||||
{
|
||||
type: InputVarType.textInput,
|
||||
variable: 'question',
|
||||
label: 'Question',
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
{
|
||||
hooksStoreProps: createHooksStoreProps(),
|
||||
},
|
||||
)
|
||||
|
||||
await user.type(screen.getByPlaceholderText('Question'), 'changed question')
|
||||
expect(store.getState().inputs).toEqual({ question: 'changed question' })
|
||||
|
||||
await user.click(screen.getByText('common.imageUploader.pasteImageLink'))
|
||||
await user.type(
|
||||
await screen.findByPlaceholderText('common.imageUploader.pasteImageLinkInputPlaceholder'),
|
||||
'https://example.com/image.png',
|
||||
)
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.ok' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(store.getState().files).toEqual([{
|
||||
type: 'image',
|
||||
transfer_method: TransferMethod.remote_url,
|
||||
url: 'https://example.com/image.png',
|
||||
upload_file_id: '',
|
||||
}])
|
||||
})
|
||||
})
|
||||
|
||||
it('should not start a run when input validation fails', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockCheckInputsForm.mockReturnValue(false)
|
||||
const onRun = vi.fn()
|
||||
const handleRun = vi.fn()
|
||||
|
||||
renderWorkflowComponent(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow
|
||||
nodes={[
|
||||
createStartNode({
|
||||
data: {
|
||||
variables: [
|
||||
{
|
||||
type: InputVarType.textInput,
|
||||
variable: 'question',
|
||||
label: 'Question',
|
||||
required: true,
|
||||
default: 'default question',
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
]}
|
||||
edges={[]}
|
||||
fitView
|
||||
/>
|
||||
<InputsPanel onRun={onRun} />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
{
|
||||
hooksStoreProps: createHooksStoreProps({ handleRun }),
|
||||
},
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' }))
|
||||
|
||||
expect(mockCheckInputsForm).toHaveBeenCalledWith(
|
||||
{ question: 'default question' },
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({ variable: 'question' }),
|
||||
expect.objectContaining({ variable: '__image' }),
|
||||
]),
|
||||
)
|
||||
expect(onRun).not.toHaveBeenCalled()
|
||||
expect(handleRun).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should start a run with processed inputs when validation succeeds', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onRun = vi.fn()
|
||||
const handleRun = vi.fn()
|
||||
|
||||
renderWorkflowComponent(
|
||||
<div style={{ width: 800, height: 600 }}>
|
||||
<ReactFlowProvider>
|
||||
<ReactFlow
|
||||
nodes={[
|
||||
createStartNode({
|
||||
data: {
|
||||
variables: [
|
||||
{
|
||||
type: InputVarType.textInput,
|
||||
variable: 'question',
|
||||
label: 'Question',
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
type: InputVarType.checkbox,
|
||||
variable: 'confirmed',
|
||||
label: 'Confirmed',
|
||||
required: false,
|
||||
},
|
||||
],
|
||||
},
|
||||
}),
|
||||
]}
|
||||
edges={[]}
|
||||
fitView
|
||||
/>
|
||||
<InputsPanel onRun={onRun} />
|
||||
</ReactFlowProvider>
|
||||
</div>,
|
||||
{
|
||||
initialStoreState: {
|
||||
inputs: {
|
||||
question: 'run this',
|
||||
confirmed: 'truthy',
|
||||
},
|
||||
files: [uploadedRunFile],
|
||||
},
|
||||
hooksStoreProps: createHooksStoreProps({
|
||||
handleRun,
|
||||
configsMap: {
|
||||
flowId: 'flow-1',
|
||||
flowType: FlowType.appFlow,
|
||||
fileSettings: {
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
)
|
||||
|
||||
await user.click(screen.getByRole('button', { name: 'workflow.singleRun.startRun' }))
|
||||
|
||||
expect(onRun).toHaveBeenCalledTimes(1)
|
||||
expect(handleRun).toHaveBeenCalledWith({
|
||||
inputs: {
|
||||
question: 'run this',
|
||||
confirmed: true,
|
||||
},
|
||||
files: [uploadedRunFile],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Disabled States', () => {
|
||||
it('should disable the run button while a local file is still uploading', () => {
|
||||
renderInputsPanel(createStartNode(), {
|
||||
initialStoreState: {
|
||||
files: [uploadingRunFile],
|
||||
},
|
||||
hooksStoreProps: createHooksStoreProps({
|
||||
configsMap: {
|
||||
flowId: 'flow-1',
|
||||
flowType: FlowType.appFlow,
|
||||
fileSettings: {
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
}),
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should disable the run button while the workflow is already running', () => {
|
||||
renderInputsPanel(createStartNode(), {
|
||||
initialStoreState: {
|
||||
workflowRunningData: {
|
||||
result: {
|
||||
status: WorkflowRunningStatus.Running,
|
||||
inputs_truncated: false,
|
||||
process_data_truncated: false,
|
||||
outputs_truncated: false,
|
||||
},
|
||||
tracing: [],
|
||||
},
|
||||
},
|
||||
hooksStoreProps: createHooksStoreProps(),
|
||||
})
|
||||
|
||||
expect(screen.getByRole('button', { name: 'workflow.singleRun.startRun' })).toBeDisabled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,163 +0,0 @@
|
||||
import type { WorkflowRunDetailResponse } from '@/models/log'
|
||||
import { act, screen } from '@testing-library/react'
|
||||
import { createEdge, createNode } from '../../__tests__/fixtures'
|
||||
import { renderWorkflowComponent } from '../../__tests__/workflow-test-env'
|
||||
import Record from '../record'
|
||||
|
||||
const mockHandleUpdateWorkflowCanvas = vi.fn()
|
||||
const mockFormatWorkflowRunIdentifier = vi.fn((finishedAt?: number) => finishedAt ? ' (Finished)' : ' (Running)')
|
||||
|
||||
let latestGetResultCallback: ((res: WorkflowRunDetailResponse) => void) | undefined
|
||||
|
||||
vi.mock('@/app/components/workflow/hooks', () => ({
|
||||
useWorkflowUpdate: () => ({
|
||||
handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/run', () => ({
|
||||
default: ({
|
||||
runDetailUrl,
|
||||
tracingListUrl,
|
||||
getResultCallback,
|
||||
}: {
|
||||
runDetailUrl: string
|
||||
tracingListUrl: string
|
||||
getResultCallback: (res: WorkflowRunDetailResponse) => void
|
||||
}) => {
|
||||
latestGetResultCallback = getResultCallback
|
||||
return (
|
||||
<div
|
||||
data-run-detail-url={runDetailUrl}
|
||||
data-testid="run"
|
||||
data-tracing-list-url={tracingListUrl}
|
||||
/>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/workflow/utils', () => ({
|
||||
formatWorkflowRunIdentifier: (finishedAt?: number) => mockFormatWorkflowRunIdentifier(finishedAt),
|
||||
}))
|
||||
|
||||
const createRunDetail = (overrides: Partial<WorkflowRunDetailResponse> = {}): WorkflowRunDetailResponse => ({
|
||||
id: 'run-1',
|
||||
version: '1',
|
||||
graph: {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
},
|
||||
inputs: '{}',
|
||||
inputs_truncated: false,
|
||||
status: 'succeeded',
|
||||
outputs: '{}',
|
||||
outputs_truncated: false,
|
||||
total_steps: 1,
|
||||
created_by_role: 'account',
|
||||
created_at: 1,
|
||||
finished_at: 2,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
describe('Record', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
latestGetResultCallback = undefined
|
||||
})
|
||||
|
||||
it('renders the run title and passes run and trace URLs to the run panel', () => {
|
||||
const getWorkflowRunAndTraceUrl = vi.fn((runId?: string) => ({
|
||||
runUrl: `/runs/${runId}`,
|
||||
traceUrl: `/traces/${runId}`,
|
||||
}))
|
||||
|
||||
renderWorkflowComponent(<Record />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: {
|
||||
id: 'run-1',
|
||||
status: 'succeeded',
|
||||
finished_at: 1700000000000,
|
||||
},
|
||||
},
|
||||
hooksStoreProps: {
|
||||
getWorkflowRunAndTraceUrl,
|
||||
},
|
||||
})
|
||||
|
||||
expect(screen.getByText('Test Run (Finished)')).toBeInTheDocument()
|
||||
expect(screen.getByTestId('run')).toHaveAttribute('data-run-detail-url', '/runs/run-1')
|
||||
expect(screen.getByTestId('run')).toHaveAttribute('data-tracing-list-url', '/traces/run-1')
|
||||
expect(getWorkflowRunAndTraceUrl).toHaveBeenCalledTimes(2)
|
||||
expect(getWorkflowRunAndTraceUrl).toHaveBeenNthCalledWith(1, 'run-1')
|
||||
expect(getWorkflowRunAndTraceUrl).toHaveBeenNthCalledWith(2, 'run-1')
|
||||
expect(mockFormatWorkflowRunIdentifier).toHaveBeenCalledWith(1700000000000)
|
||||
})
|
||||
|
||||
it('updates the workflow canvas with a fallback viewport when the response omits one', () => {
|
||||
const nodes = [createNode({ id: 'node-1' })]
|
||||
const edges = [createEdge({ id: 'edge-1' })]
|
||||
|
||||
renderWorkflowComponent(<Record />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: {
|
||||
id: 'run-1',
|
||||
status: 'succeeded',
|
||||
},
|
||||
},
|
||||
hooksStoreProps: {
|
||||
getWorkflowRunAndTraceUrl: () => ({ runUrl: '/runs/run-1', traceUrl: '/traces/run-1' }),
|
||||
},
|
||||
})
|
||||
|
||||
expect(latestGetResultCallback).toBeDefined()
|
||||
|
||||
act(() => {
|
||||
latestGetResultCallback?.(createRunDetail({
|
||||
graph: {
|
||||
nodes,
|
||||
edges,
|
||||
},
|
||||
}))
|
||||
})
|
||||
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
|
||||
nodes,
|
||||
edges,
|
||||
viewport: { x: 0, y: 0, zoom: 1 },
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the response viewport when one is available', () => {
|
||||
const nodes = [createNode({ id: 'node-1' })]
|
||||
const edges = [createEdge({ id: 'edge-1' })]
|
||||
const viewport = { x: 12, y: 24, zoom: 0.75 }
|
||||
|
||||
renderWorkflowComponent(<Record />, {
|
||||
initialStoreState: {
|
||||
historyWorkflowData: {
|
||||
id: 'run-1',
|
||||
status: 'succeeded',
|
||||
},
|
||||
},
|
||||
hooksStoreProps: {
|
||||
getWorkflowRunAndTraceUrl: () => ({ runUrl: '/runs/run-1', traceUrl: '/traces/run-1' }),
|
||||
},
|
||||
})
|
||||
|
||||
act(() => {
|
||||
latestGetResultCallback?.(createRunDetail({
|
||||
graph: {
|
||||
nodes,
|
||||
edges,
|
||||
viewport,
|
||||
},
|
||||
}))
|
||||
})
|
||||
|
||||
expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({
|
||||
nodes,
|
||||
edges,
|
||||
viewport,
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,7 @@ import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import VersionInfoModal from '@/app/components/app/app-publisher/version-info-modal'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { toast } from '@/app/components/base/ui/toast'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useSelector as useAppContextSelector } from '@/context/app-context'
|
||||
import { useDeleteWorkflow, useInvalidAllLastRun, useResetWorkflowVersionHistory, useUpdateWorkflow, useWorkflowVersionHistory } from '@/service/use-workflow'
|
||||
import { useDSL, useNodesSyncDraft, useWorkflowRun } from '../../hooks'
|
||||
@@ -118,9 +118,9 @@ export const VersionHistoryPanel = ({
|
||||
break
|
||||
case VersionHistoryContextMenuOptions.copyId:
|
||||
copy(item.id)
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.copyIdSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
break
|
||||
case VersionHistoryContextMenuOptions.exportDSL:
|
||||
@@ -152,17 +152,17 @@ export const VersionHistoryPanel = ({
|
||||
workflowStore.setState({ backupDraft: undefined })
|
||||
handleSyncWorkflowDraft(true, false, {
|
||||
onSuccess: () => {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.restoreSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
deleteAllInspectVars()
|
||||
invalidAllLastRun()
|
||||
},
|
||||
onError: () => {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.restoreFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
@@ -177,18 +177,18 @@ export const VersionHistoryPanel = ({
|
||||
await deleteWorkflow(deleteVersionUrl?.(id) || '', {
|
||||
onSuccess: () => {
|
||||
setDeleteConfirmOpen(false)
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.deleteSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
resetWorkflowVersionHistory()
|
||||
deleteAllInspectVars()
|
||||
invalidAllLastRun()
|
||||
},
|
||||
onError: () => {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.deleteFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
@@ -207,16 +207,16 @@ export const VersionHistoryPanel = ({
|
||||
}, {
|
||||
onSuccess: () => {
|
||||
setEditModalOpen(false)
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
title: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.updateSuccess', { ns: 'workflow' }),
|
||||
})
|
||||
resetWorkflowVersionHistory()
|
||||
},
|
||||
onError: () => {
|
||||
toast.add({
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
title: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
|
||||
message: t('versionHistory.action.updateFailure', { ns: 'workflow' }),
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user