mirror of
https://github.com/langgenius/dify.git
synced 2026-03-13 11:17:07 +00:00
Compare commits
25 Commits
dependabot
...
refactor/w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42f930d00e | ||
|
|
4717168fe2 | ||
|
|
7fd3bd81ab | ||
|
|
0dcfac5b84 | ||
|
|
b66097b5f3 | ||
|
|
ceaa399351 | ||
|
|
dc50e4c4f2 | ||
|
|
157208ab1e | ||
|
|
3dabdc8282 | ||
|
|
ed5511ce28 | ||
|
|
68982f910e | ||
|
|
c43307dae1 | ||
|
|
b44b37518a | ||
|
|
b170eabaf3 | ||
|
|
e99628b76f | ||
|
|
60fe5e7f00 | ||
|
|
245f6b824d | ||
|
|
7d2054d4f4 | ||
|
|
07e19c0748 | ||
|
|
135b3a15a6 | ||
|
|
0045e387f5 | ||
|
|
44713a5c0f | ||
|
|
d5724aebde | ||
|
|
c59685748c | ||
|
|
36c1f4d506 |
@@ -43,7 +43,6 @@ forbidden_modules =
|
||||
extensions.ext_redis
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
|
||||
@@ -90,9 +89,6 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> core.model_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.provider_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.llm.llm_utils -> core.model_manager
|
||||
dify_graph.nodes.llm.protocols -> core.model_manager
|
||||
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
|
||||
@@ -100,8 +96,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.entities
|
||||
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
@@ -110,12 +104,10 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
|
||||
@@ -126,15 +118,11 @@ ignore_imports =
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
|
||||
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
|
||||
dify_graph.nodes.llm.node -> models.dataset
|
||||
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.signature
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
dify_graph.nodes.llm.node -> models.model
|
||||
dify_graph.nodes.agent.agent_node -> services
|
||||
dify_graph.nodes.tool.tool_node -> services
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
|
||||
@@ -114,6 +114,7 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
@@ -22,7 +23,6 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
9
api/core/agent/errors.py
Normal file
9
api/core/agent/errors.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class AgentMaxIterationError(Exception):
|
||||
"""Raised when an agent runner exceeds the configured max iteration count."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
@@ -25,7 +26,6 @@ from dify_graph.model_runtime.entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from dify_graph.nodes.agent.exc import AgentMaxIterationError
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
|
||||
@@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,10 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@@ -30,6 +33,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -63,7 +67,6 @@ from dify_graph.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from dify_graph.graph_events.graph import GraphRunAbortedEvent
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -308,7 +311,7 @@ class WorkflowBasedAppRunner:
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
# Use the variable pool from graph_runtime_state instead of creating a new one
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
@@ -336,6 +339,18 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
@staticmethod
|
||||
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
|
||||
raw_agent_strategy = event.extras.get("agent_strategy")
|
||||
if raw_agent_strategy is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return AgentStrategyInfo.model_validate(raw_agent_strategy)
|
||||
except ValidationError:
|
||||
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
|
||||
return None
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle event
|
||||
@@ -421,7 +436,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
agent_strategy=self._build_agent_strategy_info(event),
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agent_strategy import AgentStrategyInfo
|
||||
|
||||
__all__ = ["AgentStrategyInfo"]
|
||||
|
||||
8
api/core/app/entities/agent_strategy.py
Normal file
8
api/core/app/entities/agent_strategy.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class AgentStrategyInfo(BaseModel):
|
||||
name: str
|
||||
icon: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
|
||||
@@ -314,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
|
||||
@@ -4,8 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
agent_strategy: AgentStrategyInfo | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
||||
@@ -193,7 +193,8 @@ class LLMGenerator:
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
@@ -279,7 +280,8 @@ class LLMGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
error_step = "handle unexpected exception"
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /):
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any):
|
||||
|
||||
@@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.get_credentials_schema_by_type(CredentialType.API_KEY)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
:param credential_type: the type of the credential, as CredentialType or str; str values
|
||||
are normalized via CredentialType.of and may raise ValueError for invalid values.
|
||||
:return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an
|
||||
empty list for CredentialType.UNAUTHORIZED or missing schemas.
|
||||
|
||||
Reads from self.entity.oauth_schema and self.entity.credentials_schema.
|
||||
Raises ValueError for invalid credential types.
|
||||
"""
|
||||
if credential_type == CredentialType.OAUTH2.value:
|
||||
if isinstance(credential_type, str):
|
||||
credential_type = CredentialType.of(credential_type)
|
||||
if credential_type == CredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
if credential_type == CredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
if credential_type == CredentialType.UNAUTHORIZED:
|
||||
return []
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
|
||||
@@ -137,6 +137,7 @@ class ToolFileManager:
|
||||
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from .node_factory import DifyNodeFactory
|
||||
from .node_resolution import ensure_workflow_nodes_registered
|
||||
from .workflow_entry import WorkflowEntry
|
||||
|
||||
__all__ = ["DifyNodeFactory", "WorkflowEntry", "ensure_workflow_nodes_registered"]
|
||||
|
||||
__all__ = ["DifyNodeFactory", "WorkflowEntry"]
|
||||
|
||||
@@ -22,6 +22,13 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.summary_index.summary_index import SummaryIndex
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
|
||||
from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyPresentationProvider,
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
@@ -39,7 +46,6 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@@ -97,10 +103,7 @@ class DefaultWorkflowCodeExecutor:
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
Default implementation of NodeFactory that uses the traditional node mapping.
|
||||
|
||||
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
|
||||
and instantiating the appropriate node class.
|
||||
Default implementation of NodeFactory that resolves node classes from the live registry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -143,6 +146,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
self._agent_strategy_resolver = PluginAgentStrategyResolver()
|
||||
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
|
||||
self._agent_runtime_support = AgentRuntimeSupport()
|
||||
self._agent_message_transformer = AgentMessageTransformer()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
|
||||
@@ -219,6 +226,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
NodeType.TOOL: lambda: {
|
||||
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
|
||||
},
|
||||
NodeType.AGENT: lambda: {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
"presentation_provider": self._agent_strategy_presentation_provider,
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
@@ -238,16 +251,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
|
||||
42
api/core/workflow/node_resolution.py
Normal file
42
api/core/workflow/node_resolution.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, get_node_type_classes_mapping
|
||||
|
||||
_WORKFLOW_NODE_MODULES = ("core.workflow.nodes.agent",)
|
||||
_workflow_nodes_registered = False
|
||||
|
||||
|
||||
def ensure_workflow_nodes_registered() -> None:
|
||||
"""Import workflow-local node modules so they can register with `Node.__init_subclass__`."""
|
||||
global _workflow_nodes_registered
|
||||
|
||||
if _workflow_nodes_registered:
|
||||
return
|
||||
|
||||
for module_name in _WORKFLOW_NODE_MODULES:
|
||||
import_module(module_name)
|
||||
|
||||
_workflow_nodes_registered = True
|
||||
|
||||
|
||||
def get_workflow_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
ensure_workflow_nodes_registered()
|
||||
return get_node_type_classes_mapping()
|
||||
|
||||
|
||||
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = get_workflow_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
4
api/core/workflow/nodes/agent/__init__.py
Normal file
4
api/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .agent_node import AgentNode
|
||||
from .entities import AgentNodeData
|
||||
|
||||
__all__ = ["AgentNode", "AgentNodeData"]
|
||||
188
api/core/workflow/nodes/agent/agent_node.py
Normal file
188
api/core/workflow/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.enums import NodeType, SystemVariableKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from .entities import AgentNodeData
|
||||
from .exceptions import (
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
)
|
||||
from .message_transformer import AgentMessageTransformer
|
||||
from .runtime_support import AgentRuntimeSupport
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
node_type = NodeType.AGENT
|
||||
|
||||
_strategy_resolver: AgentStrategyResolver
|
||||
_presentation_provider: AgentStrategyPresentationProvider
|
||||
_runtime_support: AgentRuntimeSupport
|
||||
_message_transformer: AgentMessageTransformer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
strategy_resolver: AgentStrategyResolver,
|
||||
presentation_provider: AgentStrategyPresentationProvider,
|
||||
runtime_support: AgentRuntimeSupport,
|
||||
message_transformer: AgentMessageTransformer,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._strategy_resolver = strategy_resolver
|
||||
self._presentation_provider = presentation_provider
|
||||
self._runtime_support = runtime_support
|
||||
self._message_transformer = message_transformer
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
event.extras["agent_strategy"] = {
|
||||
"name": self.node_data.agent_strategy_name,
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = self._strategy_resolver.resolve(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
parameters = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
)
|
||||
parameters_for_log = self._runtime_support.build_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
invoke_from=dify_ctx.invoke_from,
|
||||
for_log=True,
|
||||
)
|
||||
credentials = self._runtime_support.build_credentials(parameters=parameters)
|
||||
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._message_transformer.transform(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self._presentation_provider.get_icon(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
),
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
@@ -11,9 +11,9 @@ from dify_graph.enums import NodeType
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_provider_name: str
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
agent_strategy_label: str
|
||||
memory: MemoryConfig | None = None
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
@@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError):
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMaxIterationError(AgentNodeError):
|
||||
"""Exception raised when the agent exceeds the maximum iteration limit."""
|
||||
|
||||
def __init__(self, max_iteration: int):
|
||||
self.max_iteration = max_iteration
|
||||
super().__init__(
|
||||
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
|
||||
f"The agent was unable to complete the task within the allowed number of iterations."
|
||||
)
|
||||
292
api/core/workflow/nodes/agent/message_transformer.py
Normal file
292
api/core/workflow/nodes/agent/message_transformer.py
Normal file
@@ -0,0 +1,292 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayFileSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
|
||||
|
||||
|
||||
class AgentMessageTransformer:
|
||||
def transform(
|
||||
self,
|
||||
*,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
40
api/core/workflow/nodes/agent/plugin_strategy_adapter.py
Normal file
40
api/core/workflow/nodes/agent/plugin_strategy_adapter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
|
||||
|
||||
|
||||
class PluginAgentStrategyResolver(AgentStrategyResolver):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy:
|
||||
return get_plugin_agent_strategy(
|
||||
tenant_id=tenant_id,
|
||||
agent_strategy_provider_name=agent_strategy_provider_name,
|
||||
agent_strategy_name=agent_strategy_name,
|
||||
)
|
||||
|
||||
|
||||
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
try:
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
|
||||
)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
return current_plugin.declaration.icon
|
||||
276
api/core/workflow/nodes/agent/runtime_support.py
Normal file
276
api/core/workflow/nodes/agent/runtime_support.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from dify_graph.enums import SystemVariableKey
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id,
|
||||
app_id,
|
||||
entity,
|
||||
invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
self,
|
||||
*,
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
@staticmethod
|
||||
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]:
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value)
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
39
api/core/workflow/nodes/agent/strategy_protocols.py
Normal file
39
api/core/workflow/nodes/agent/strategy_protocols.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
|
||||
class ResolvedAgentStrategy(Protocol):
|
||||
meta_version: str | None
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
credentials: InvokeCredentials | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]: ...
|
||||
|
||||
|
||||
class AgentStrategyResolver(Protocol):
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
agent_strategy_provider_name: str,
|
||||
agent_strategy_name: str,
|
||||
) -> ResolvedAgentStrategy: ...
|
||||
|
||||
|
||||
class AgentStrategyPresentationProvider(Protocol):
|
||||
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...
|
||||
@@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
|
||||
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
|
||||
from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.node_resolution import resolve_workflow_node_class
|
||||
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
@@ -23,7 +24,6 @@ from dify_graph.graph_engine.protocols.command_channel import CommandChannel
|
||||
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from dify_graph.nodes import NodeType
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
@@ -343,7 +343,7 @@ class WorkflowEntry:
|
||||
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
|
||||
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentNodeStrategyInit(BaseModel):
|
||||
"""Agent node strategy initialization data."""
|
||||
|
||||
name: str
|
||||
icon: str | None = None
|
||||
@@ -4,7 +4,6 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from dify_graph.entities import AgentNodeStrategyInit
|
||||
from dify_graph.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@@ -13,8 +12,8 @@ from .base import GraphNodeEventBase
|
||||
class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
node_title: str
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
|
||||
@@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -4,7 +4,8 @@ class InvokeError(ValueError):
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
if description is not None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
return self.description or self.__class__.__name__
|
||||
|
||||
@@ -282,7 +282,8 @@ class ModelProviderFactory:
|
||||
all_model_type_models.append(model_schema)
|
||||
|
||||
simple_provider_schema = provider_schema.to_simple_provider()
|
||||
simple_provider_schema.models.extend(all_model_type_models)
|
||||
if model_type:
|
||||
simple_provider_schema.models = all_model_type_models
|
||||
|
||||
providers.append(simple_provider_schema)
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .agent_node import AgentNode
|
||||
|
||||
__all__ = ["AgentNode"]
|
||||
@@ -1,761 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from dify_graph.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
@@ -8,10 +8,10 @@ from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from functools import singledispatchmethod
|
||||
from types import MappingProxyType
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeVar, get_args, get_origin
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
@@ -349,6 +349,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def populate_start_event(self, event: NodeRunStartedEvent) -> None:
|
||||
"""Allow subclasses to enrich the started event without cross-node imports in the base class."""
|
||||
_ = event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@@ -362,39 +366,10 @@ class Node(Generic[NodeDataT]):
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
from dify_graph.nodes.tool.tool_node import ToolNode
|
||||
|
||||
if isinstance(self, ToolNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
if isinstance(self, DatasourceNode):
|
||||
plugin_id = getattr(self.node_data, "plugin_id", "")
|
||||
provider_name = getattr(self.node_data, "provider_name", "")
|
||||
|
||||
start_event.provider_id = f"{plugin_id}/{provider_name}"
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
|
||||
if isinstance(self, TriggerEventNode):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
if isinstance(self, AgentNode):
|
||||
start_event.agent_strategy = AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
icon=self.agent_strategy_icon,
|
||||
)
|
||||
|
||||
# ===
|
||||
try:
|
||||
self.populate_start_event(start_event)
|
||||
except Exception:
|
||||
logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True)
|
||||
yield start_event
|
||||
|
||||
try:
|
||||
@@ -513,10 +488,8 @@ class Node(Generic[NodeDataT]):
|
||||
@abstractmethod
|
||||
def version(cls) -> str:
|
||||
"""`node_version` returns the version of current node type."""
|
||||
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
|
||||
#
|
||||
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
|
||||
# in `api/dify_graph/nodes/__init__.py`.
|
||||
# NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so
|
||||
# `Node.get_node_type_classes_mapping()` can resolve numeric versions and `latest`.
|
||||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@classmethod
|
||||
@@ -524,7 +497,9 @@ class Node(Generic[NodeDataT]):
|
||||
"""Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry.
|
||||
|
||||
Import all modules under dify_graph.nodes so subclasses register themselves on import.
|
||||
Then we return a readonly view of the registry to avoid accidental mutation.
|
||||
Callers that rely on workflow-local nodes defined outside `dify_graph.nodes` must import
|
||||
those modules before invoking this method so they can register through `__init_subclass__`.
|
||||
We then return a readonly view of the registry to avoid accidental mutation.
|
||||
"""
|
||||
# Import all node modules to ensure they are loaded (thus registered)
|
||||
import dify_graph.nodes as _nodes_pkg
|
||||
|
||||
@@ -48,6 +48,10 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}"
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
|
||||
@@ -486,14 +486,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
|
||||
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
if node_type not in node_mapping:
|
||||
continue
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = node_mapping[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
|
||||
@@ -316,14 +316,15 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.node_mapping import get_node_type_classes_mapping
|
||||
|
||||
typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
node_type = typed_sub_node_config["data"].type
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
node_mapping = get_node_type_classes_mapping()
|
||||
if node_type not in node_mapping:
|
||||
continue
|
||||
node_version = str(typed_sub_node_config["data"].version)
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
node_cls = node_mapping[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=typed_sub_node_config
|
||||
|
||||
@@ -5,5 +5,24 @@ from dify_graph.nodes.base.node import Node
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks dify_graph.nodes
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping()
|
||||
|
||||
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
|
||||
"""Return the live node registry after importing all `dify_graph.nodes` modules."""
|
||||
return Node.get_node_type_classes_mapping()
|
||||
|
||||
|
||||
def resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
node_mapping = get_node_type_classes_mapping().get(node_type)
|
||||
if not node_mapping:
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
raise ValueError(f"No latest version class found for node type: {node_type}")
|
||||
return node_class
|
||||
|
||||
|
||||
# Snapshot kept for compatibility with older tests; production paths should use the live helpers.
|
||||
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping()
|
||||
|
||||
@@ -65,6 +65,10 @@ class ToolNode(Node[ToolNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = self.node_data.provider_id
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Run the tool node
|
||||
|
||||
@@ -32,6 +32,10 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.provider_id = self.node_data.provider_id
|
||||
event.provider_type = self.node_data.provider_type
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the plugin trigger node.
|
||||
|
||||
@@ -87,7 +87,7 @@ dependencies = [
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"weaviate-client==4.17.0",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
@@ -202,28 +202,28 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==1.5.5",
|
||||
"clickhouse-connect~=0.14.0",
|
||||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.10.0",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.5.0",
|
||||
"elasticsearch==9.3.0",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
"oracledb==3.4.2",
|
||||
"oracledb==3.3.0",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.4.2",
|
||||
"pymilvus~=2.6.9",
|
||||
"pymochow==2.3.6",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==2.2.9",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.17.0",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.1",
|
||||
"tcvectordb~=2.0.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
"tablestore==6.3.7",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"xinference-client~=2.2.0",
|
||||
"weaviate-client==4.17.0",
|
||||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
]
|
||||
|
||||
@@ -36,6 +36,7 @@ from core.rag.entities.event import (
|
||||
)
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
@@ -48,7 +49,6 @@ from dify_graph.graph_events.base import GraphNodeEventBase
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -381,7 +381,7 @@ class RagPipelineService:
|
||||
"""
|
||||
# return default block config
|
||||
default_block_configs: list[dict[str, Any]] = []
|
||||
for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items():
|
||||
for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
filters = None
|
||||
if node_type is NodeType.HTTP_REQUEST:
|
||||
@@ -410,12 +410,13 @@ class RagPipelineService:
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_mapping = get_workflow_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
if node_type_enum not in node_mapping:
|
||||
return None
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
node_class = node_mapping[node_type_enum][LATEST_VERSION]
|
||||
final_filters = dict(filters) if filters else {}
|
||||
if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters:
|
||||
final_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config(
|
||||
|
||||
@@ -14,6 +14,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.node_resolution import LATEST_VERSION, get_workflow_node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
@@ -34,7 +35,6 @@ from dify_graph.nodes.human_input.entities import (
|
||||
)
|
||||
from dify_graph.nodes.human_input.enums import HumanInputFormKind
|
||||
from dify_graph.nodes.human_input.human_input_node import HumanInputNode
|
||||
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from dify_graph.nodes.start.entities import StartNodeData
|
||||
from dify_graph.repositories.human_input_form_repository import FormCreateParams
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -619,7 +619,7 @@ class WorkflowService:
|
||||
"""
|
||||
# return default block config
|
||||
default_block_configs: list[Mapping[str, object]] = []
|
||||
for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items():
|
||||
for node_type, node_class_mapping in get_workflow_node_type_classes_mapping().items():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
filters = None
|
||||
if node_type is NodeType.HTTP_REQUEST:
|
||||
@@ -650,12 +650,13 @@ class WorkflowService:
|
||||
:return:
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
node_mapping = get_workflow_node_type_classes_mapping()
|
||||
|
||||
# return default block config
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
if node_type_enum not in node_mapping:
|
||||
return {}
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
node_class = node_mapping[node_type_enum][LATEST_VERSION]
|
||||
resolved_filters = dict(filters) if filters else {}
|
||||
if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters:
|
||||
resolved_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config(
|
||||
|
||||
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
313
api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin endpoints
|
||||
|
||||
Tests endpoint structure (method existence) for all plugin APIs, plus
|
||||
handler-level logic tests for representative non-streaming endpoints.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.plugin.plugin import (
|
||||
PluginFetchAppInfoApi,
|
||||
PluginInvokeAppApi,
|
||||
PluginInvokeEncryptApi,
|
||||
PluginInvokeLLMApi,
|
||||
PluginInvokeLLMWithStructuredOutputApi,
|
||||
PluginInvokeModerationApi,
|
||||
PluginInvokeParameterExtractorNodeApi,
|
||||
PluginInvokeQuestionClassifierNodeApi,
|
||||
PluginInvokeRerankApi,
|
||||
PluginInvokeSpeech2TextApi,
|
||||
PluginInvokeSummaryApi,
|
||||
PluginInvokeTextEmbeddingApi,
|
||||
PluginInvokeToolApi,
|
||||
PluginInvokeTTSApi,
|
||||
PluginUploadFileRequestApi,
|
||||
)
|
||||
|
||||
|
||||
def _extract_raw_post(cls):
|
||||
"""Extract the raw post() method from a plugin endpoint class.
|
||||
|
||||
Plugin endpoint methods are wrapped by several decorators (get_user_tenant,
|
||||
setup_required, plugin_inner_api_only, plugin_data). These decorators
|
||||
use @wraps where possible. This helper ensures we retrieve the original
|
||||
post(self, user_model, tenant_model, payload) function by unwrapping
|
||||
and, if necessary, walking the closure of the innermost wrapper.
|
||||
"""
|
||||
bottom = inspect.unwrap(cls.post)
|
||||
|
||||
# If unwrap() didn't get us to the raw function (e.g. if a decorator
|
||||
# missed @wraps), try to extract it from the closure if it looks like
|
||||
# a plugin_data or similar wrapper that closes over 'view_func'.
|
||||
if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars:
|
||||
try:
|
||||
idx = bottom.__code__.co_freevars.index("view_func")
|
||||
return bottom.__closure__[idx].cell_contents
|
||||
except (AttributeError, TypeError, IndexError):
|
||||
pass
|
||||
|
||||
return bottom
|
||||
|
||||
|
||||
class TestPluginInvokeLLMApi:
|
||||
"""Test PluginInvokeLLMApi endpoint structure"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeLLMWithStructuredOutputApi:
|
||||
"""Test PluginInvokeLLMWithStructuredOutputApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeLLMWithStructuredOutputApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTextEmbeddingApi:
|
||||
"""Test PluginInvokeTextEmbeddingApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTextEmbeddingApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeRerankApi:
|
||||
"""Test PluginInvokeRerankApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeRerankApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeTTSApi:
|
||||
"""Test PluginInvokeTTSApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeTTSApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeSpeech2TextApi:
|
||||
"""Test PluginInvokeSpeech2TextApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSpeech2TextApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeModerationApi:
|
||||
"""Test PluginInvokeModerationApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeModerationApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeToolApi:
|
||||
"""Test PluginInvokeToolApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeToolApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeParameterExtractorNodeApi:
|
||||
"""Test PluginInvokeParameterExtractorNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeParameterExtractorNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeQuestionClassifierNodeApi:
|
||||
"""Test PluginInvokeQuestionClassifierNodeApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeQuestionClassifierNodeApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeAppApi:
|
||||
"""Test PluginInvokeAppApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeAppApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginInvokeEncryptApi:
|
||||
"""Test PluginInvokeEncryptApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeEncryptApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() delegates to PluginEncrypter and returns model_dump output"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act — extract raw post() bypassing all decorators including plugin_data
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload)
|
||||
assert result["data"] == {"encrypted": "data"}
|
||||
assert result.get("error") == ""
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginEncrypter")
|
||||
def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask):
|
||||
"""Test that post() catches exceptions and returns error response"""
|
||||
# Arrange
|
||||
mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed")
|
||||
mock_tenant = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginInvokeEncryptApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
assert "encrypt failed" in result["error"]
|
||||
|
||||
|
||||
class TestPluginInvokeSummaryApi:
|
||||
"""Test PluginInvokeSummaryApi endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginInvokeSummaryApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
|
||||
class TestPluginUploadFileRequestApi:
|
||||
"""Test PluginUploadFileRequestApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginUploadFileRequestApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin")
|
||||
def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask):
|
||||
"""Test that post() generates a signed URL and returns it"""
|
||||
# Arrange
|
||||
mock_get_url.return_value = "https://storage.example.com/signed-upload-url"
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-id"
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.filename = "test.pdf"
|
||||
mock_payload.mimetype = "application/pdf"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginUploadFileRequestApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_get_url.assert_called_once_with(
|
||||
filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id"
|
||||
)
|
||||
assert result["data"]["url"] == "https://storage.example.com/signed-upload-url"
|
||||
|
||||
|
||||
class TestPluginFetchAppInfoApi:
|
||||
"""Test PluginFetchAppInfoApi endpoint structure and handler logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return PluginFetchAppInfoApi()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation")
|
||||
def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask):
|
||||
"""Test that post() fetches app info and returns it"""
|
||||
# Arrange
|
||||
mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"}
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_user = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.app_id = "app-123"
|
||||
|
||||
# Act
|
||||
raw_post = _extract_raw_post(PluginFetchAppInfoApi)
|
||||
result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload)
|
||||
|
||||
# Assert
|
||||
mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id")
|
||||
assert result["data"] == {"app_name": "My App", "mode": "chat"}
|
||||
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Unit tests for inner_api plugin decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.plugin.wraps import (
|
||||
TenantUserPayload,
|
||||
get_user,
|
||||
get_user_tenant,
|
||||
plugin_data,
|
||||
)
|
||||
|
||||
|
||||
class TestTenantUserPayload:
|
||||
"""Test TenantUserPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload passes validation"""
|
||||
data = {"tenant_id": "tenant123", "user_id": "user456"}
|
||||
payload = TenantUserPayload.model_validate(data)
|
||||
assert payload.tenant_id == "tenant123"
|
||||
assert payload.user_id == "user456"
|
||||
|
||||
def test_missing_tenant_id(self):
|
||||
"""Test missing tenant_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"user_id": "user456"})
|
||||
|
||||
def test_missing_user_id(self):
|
||||
"""Test missing user_id raises ValidationError"""
|
||||
with pytest.raises(ValidationError):
|
||||
TenantUserPayload.model_validate({"tenant_id": "tenant123"})
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test get_user function"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test returning existing user when found by ID"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user123"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
mock_session.query.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_return_existing_anonymous_user_by_session_id(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test returning existing anonymous user by session_id"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.session_id = "anonymous_session"
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "anonymous_session")
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask):
|
||||
"""Test creating new user when not found in database"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_new_user = MagicMock()
|
||||
mock_enduser_class.return_value = mock_new_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", "user123")
|
||||
|
||||
# Assert
|
||||
assert result == mock_new_user
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.refresh.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_use_default_session_id_when_user_id_none(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test using default session ID when user_id is None"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.return_value.where.return_value.first.return_value = mock_user
|
||||
|
||||
# Act
|
||||
with app.app_context():
|
||||
result = get_user("tenant123", None)
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.EndUser")
|
||||
@patch("controllers.inner_api.plugin.wraps.Session")
|
||||
@patch("controllers.inner_api.plugin.wraps.db")
|
||||
def test_should_raise_error_on_database_exception(
|
||||
self, mock_db, mock_session_class, mock_enduser_class, app: Flask
|
||||
):
|
||||
"""Test raising ValueError when database operation fails"""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.query.side_effect = Exception("Database error")
|
||||
|
||||
# Act & Assert
|
||||
with app.app_context():
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
get_user("tenant123", "user123")
|
||||
|
||||
|
||||
class TestGetUserTenant:
|
||||
"""Test get_user_tenant decorator"""
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that decorator injects tenant_model and user_model into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user456"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
|
||||
def test_should_raise_error_when_tenant_id_missing(self, app: Flask):
|
||||
"""Test that Pydantic ValidationError is raised when tenant_id is missing from payload"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert - Pydantic validates payload before manual check
|
||||
with app.test_request_context(json={"user_id": "user456"}):
|
||||
with pytest.raises(ValidationError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_when_tenant_not_found(self, app: Flask):
|
||||
"""Test that ValueError is raised when tenant is not found"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}):
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ValueError, match="tenant not found"):
|
||||
protected_view()
|
||||
|
||||
@patch("controllers.inner_api.plugin.wraps.Tenant")
|
||||
def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch):
|
||||
"""Test that default session ID is used when user_id is empty string"""
|
||||
|
||||
# Arrange
|
||||
@get_user_tenant
|
||||
def protected_view(tenant_model, user_model, **kwargs):
|
||||
return {"tenant": tenant_model, "user": user_model}
|
||||
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant123"
|
||||
mock_user = MagicMock()
|
||||
|
||||
# Act - use empty string for user_id to trigger default logic
|
||||
with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}):
|
||||
monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False)
|
||||
with patch("controllers.inner_api.plugin.wraps.db.session.query") as mock_query:
|
||||
with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_get_user.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result["tenant"] == mock_tenant
|
||||
assert result["user"] == mock_user
|
||||
from models.model import DefaultEndUserSessionID
|
||||
|
||||
mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID)
|
||||
|
||||
|
||||
class PluginTestPayload:
|
||||
"""Simple test payload class"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.value = data.get("value")
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
return cls(data)
|
||||
|
||||
|
||||
class TestPluginData:
|
||||
"""Test plugin_data decorator"""
|
||||
|
||||
def test_should_inject_valid_payload(self, app: Flask):
|
||||
"""Test that valid payload is injected into kwargs"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "test_data"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "test_data"
|
||||
|
||||
def test_should_raise_error_on_invalid_json(self, app: Flask):
|
||||
"""Test that ValueError is raised when JSON parsing fails"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert - Malformed JSON triggers ValueError
|
||||
with app.test_request_context(data="not valid json", content_type="application/json"):
|
||||
with pytest.raises(ValueError):
|
||||
protected_view()
|
||||
|
||||
def test_should_raise_error_on_invalid_payload(self, app: Flask):
|
||||
"""Test that ValueError is raised when payload validation fails"""
|
||||
|
||||
# Arrange
|
||||
class InvalidPayload:
|
||||
@classmethod
|
||||
def model_validate(cls, data: dict):
|
||||
raise Exception("Validation failed")
|
||||
|
||||
@plugin_data(payload_type=InvalidPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(json={"data": "test"}):
|
||||
with pytest.raises(ValueError, match="invalid payload"):
|
||||
protected_view()
|
||||
|
||||
def test_should_work_as_parameterized_decorator(self, app: Flask):
|
||||
"""Test that decorator works when used with parentheses"""
|
||||
|
||||
# Arrange
|
||||
@plugin_data(payload_type=PluginTestPayload)
|
||||
def protected_view(payload, **kwargs):
|
||||
return payload
|
||||
|
||||
# Act
|
||||
with app.test_request_context(json={"value": "parameterized"}):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result.value == "parameterized"
|
||||
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
309
api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Unit tests for inner_api auth decorators
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.inner_api.wraps import (
|
||||
billing_inner_api_only,
|
||||
enterprise_inner_api_only,
|
||||
enterprise_inner_api_user_auth,
|
||||
plugin_inner_api_only,
|
||||
)
|
||||
|
||||
|
||||
class TestBillingInnerApiOnly:
|
||||
"""Test billing_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@billing_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiOnly:
|
||||
"""Test enterprise_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when INNER_API is enabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that 404 is returned when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 401 is returned when X-Inner-Api-Key header is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiUserAuth:
|
||||
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
|
||||
|
||||
def test_should_pass_through_when_inner_api_disabled(self, app: Flask):
|
||||
"""Test that request passes through when INNER_API is disabled"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_header_missing(self, app: Flask):
|
||||
"""Test that request passes through when Authorization header is missing"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_authorization_format_invalid(self, app: Flask):
|
||||
"""Test that request passes through when Authorization format is invalid (no colon)"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"Authorization": "invalid_format"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask):
|
||||
"""Test that request passes through when HMAC signature is invalid"""
|
||||
|
||||
# Arrange
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user", "no_user")
|
||||
|
||||
# Act - use wrong signature
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "no_user"
|
||||
|
||||
def test_should_inject_user_when_hmac_signature_valid(self, app: Flask):
|
||||
"""Test that user is injected when HMAC signature is valid"""
|
||||
# Arrange
|
||||
from base64 import b64encode
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
@enterprise_inner_api_user_auth
|
||||
def protected_view(**kwargs):
|
||||
return kwargs.get("user")
|
||||
|
||||
# Calculate valid HMAC signature
|
||||
user_id = "user123"
|
||||
inner_api_key = "valid_key"
|
||||
data_to_sign = f"DIFY {user_id}"
|
||||
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
|
||||
valid_signature = b64encode(signature.digest()).decode("utf-8")
|
||||
|
||||
# Create mock user
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
|
||||
):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch("controllers.inner_api.wraps.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_user
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
class TestPluginInnerApiOnly:
|
||||
"""Test plugin_inner_api_only decorator"""
|
||||
|
||||
def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask):
|
||||
"""Test that valid API key allows access when PLUGIN_DAEMON_KEY is set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask):
|
||||
"""Test that 404 is returned when PLUGIN_DAEMON_KEY is not set"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_404_when_api_key_invalid(self, app: Flask):
|
||||
"""Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)"""
|
||||
|
||||
# Arrange
|
||||
@plugin_inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"):
|
||||
with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
206
api/tests/unit_tests/controllers/inner_api/test_mail.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Unit tests for inner_api mail module
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.mail import (
|
||||
BaseMail,
|
||||
BillingMail,
|
||||
EnterpriseMail,
|
||||
InnerMailPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestInnerMailPayload:
|
||||
"""Test InnerMailPayload Pydantic model"""
|
||||
|
||||
def test_valid_payload_with_all_fields(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
"substitutions": {"key": "value"},
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions == {"key": "value"}
|
||||
|
||||
def test_valid_payload_without_substitutions(self):
|
||||
"""Test valid payload without optional substitutions"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert payload.to == ["test@example.com"]
|
||||
assert payload.subject == "Test Subject"
|
||||
assert payload.body == "Test Body"
|
||||
assert payload.substitutions is None
|
||||
|
||||
def test_empty_to_list_fails_validation(self):
|
||||
"""Test that empty 'to' list fails validation due to min_length=1"""
|
||||
data = {
|
||||
"to": [],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_multiple_recipients_allowed(self):
|
||||
"""Test that multiple recipients are allowed"""
|
||||
data = {
|
||||
"to": ["user1@example.com", "user2@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
payload = InnerMailPayload.model_validate(data)
|
||||
assert len(payload.to) == 2
|
||||
assert "user1@example.com" in payload.to
|
||||
assert "user2@example.com" in payload.to
|
||||
|
||||
def test_missing_to_field_fails_validation(self):
|
||||
"""Test that missing 'to' field fails validation"""
|
||||
data = {
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_subject_fails_validation(self):
|
||||
"""Test that missing 'subject' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"body": "Test Body",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
def test_missing_body_fails_validation(self):
|
||||
"""Test that missing 'body' field fails validation"""
|
||||
data = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
}
|
||||
with pytest.raises(ValidationError):
|
||||
InnerMailPayload.model_validate(data)
|
||||
|
||||
|
||||
class TestBaseMail:
|
||||
"""Test BaseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BaseMail API instance"""
|
||||
return BaseMail()
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_sends_email_task(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends inner email task"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
json={
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
):
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Test Subject",
|
||||
"body": "Test Body",
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Test Subject",
|
||||
body="Test Body",
|
||||
substitutions=None,
|
||||
)
|
||||
|
||||
@patch("controllers.inner_api.mail.send_inner_email_task")
|
||||
def test_post_with_substitutions(self, mock_task, api_instance, app: Flask):
|
||||
"""Test that POST sends email with substitutions"""
|
||||
# Arrange
|
||||
mock_task.delay.return_value = None
|
||||
|
||||
# Act
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"to": ["test@example.com"],
|
||||
"subject": "Hello {{name}}",
|
||||
"body": "Welcome {{name}}!",
|
||||
"substitutions": {"name": "John"},
|
||||
}
|
||||
result = api_instance.post()
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "success"}, 200)
|
||||
mock_task.delay.assert_called_once_with(
|
||||
to=["test@example.com"],
|
||||
subject="Hello {{name}}",
|
||||
body="Welcome {{name}}!",
|
||||
substitutions={"name": "John"},
|
||||
)
|
||||
|
||||
|
||||
class TestEnterpriseMail:
|
||||
"""Test EnterpriseMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create EnterpriseMail API instance"""
|
||||
return EnterpriseMail()
|
||||
|
||||
def test_has_enterprise_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has enterprise_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
|
||||
assert enterprise_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that EnterpriseMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
|
||||
|
||||
class TestBillingMail:
|
||||
"""Test BillingMail API endpoint"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
"""Create BillingMail API instance"""
|
||||
return BillingMail()
|
||||
|
||||
def test_has_billing_inner_api_only_decorator(self, api_instance):
|
||||
"""Test that BillingMail has billing_inner_api_only decorator"""
|
||||
# Check method_decorators
|
||||
from controllers.inner_api.wraps import billing_inner_api_only
|
||||
|
||||
assert billing_inner_api_only in api_instance.method_decorators
|
||||
|
||||
def test_has_setup_required_decorator(self, api_instance):
|
||||
"""Test that BillingMail has setup_required decorator"""
|
||||
# Check by decorator name instead of object reference
|
||||
decorator_names = [d.__name__ for d in api_instance.method_decorators]
|
||||
assert "setup_required" in decorator_names
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Unit tests for inner_api workspace module
|
||||
|
||||
Tests Pydantic model validation and endpoint handler logic.
|
||||
Auth/setup decorators are tested separately in test_auth_wraps.py;
|
||||
handler tests use inspect.unwrap() to bypass them and focus on business logic.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.inner_api.workspace.workspace import (
|
||||
EnterpriseWorkspace,
|
||||
EnterpriseWorkspaceNoOwnerEmail,
|
||||
WorkspaceCreatePayload,
|
||||
WorkspaceOwnerlessPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspaceCreatePayload:
|
||||
"""Test WorkspaceCreatePayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with all fields passes validation"""
|
||||
data = {
|
||||
"name": "My Workspace",
|
||||
"owner_email": "owner@example.com",
|
||||
}
|
||||
payload = WorkspaceCreatePayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
assert payload.owner_email == "owner@example.com"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {"owner_email": "owner@example.com"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
def test_missing_owner_email_fails_validation(self):
|
||||
"""Test that missing owner_email fails validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceCreatePayload.model_validate(data)
|
||||
assert "owner_email" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestWorkspaceOwnerlessPayload:
|
||||
"""Test WorkspaceOwnerlessPayload Pydantic model validation"""
|
||||
|
||||
def test_valid_payload(self):
|
||||
"""Test valid payload with name passes validation"""
|
||||
data = {"name": "My Workspace"}
|
||||
payload = WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert payload.name == "My Workspace"
|
||||
|
||||
def test_missing_name_fails_validation(self):
|
||||
"""Test that missing name fails validation"""
|
||||
data = {}
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
WorkspaceOwnerlessPayload.model_validate(data)
|
||||
assert "name" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspace:
|
||||
"""Test EnterpriseWorkspace API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspace()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that EnterpriseWorkspace has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace and assigns the owner account"""
|
||||
# Arrange
|
||||
mock_account = MagicMock()
|
||||
mock_account.email = "owner@example.com"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["name"] == "My Workspace"
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.db")
|
||||
def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask):
|
||||
"""Test that post() returns 404 when the owner account does not exist"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result == ({"message": "owner account not found."}, 404)
|
||||
|
||||
|
||||
class TestEnterpriseWorkspaceNoOwnerEmail:
|
||||
"""Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic.
|
||||
|
||||
Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
and exercise the core business logic directly.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def api_instance(self):
|
||||
return EnterpriseWorkspaceNoOwnerEmail()
|
||||
|
||||
def test_has_post_method(self, api_instance):
|
||||
"""Test that endpoint has post method"""
|
||||
assert hasattr(api_instance, "post")
|
||||
assert callable(api_instance.post)
|
||||
|
||||
@patch("controllers.inner_api.workspace.workspace.tenant_was_created")
|
||||
@patch("controllers.inner_api.workspace.workspace.TenantService")
|
||||
def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask):
|
||||
"""Test that post() creates a workspace without an owner and returns expected fields"""
|
||||
# Arrange
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.id = "tenant-id"
|
||||
mock_tenant.name = "My Workspace"
|
||||
mock_tenant.encrypt_public_key = "pub-key"
|
||||
mock_tenant.plan = "sandbox"
|
||||
mock_tenant.status = "normal"
|
||||
mock_tenant.custom_config = None
|
||||
mock_tenant.created_at = now
|
||||
mock_tenant.updated_at = now
|
||||
mock_tenant_svc.create_tenant.return_value = mock_tenant
|
||||
|
||||
# Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py)
|
||||
unwrapped_post = inspect.unwrap(api_instance.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {"name": "My Workspace"}
|
||||
result = unwrapped_post(api_instance)
|
||||
|
||||
# Assert
|
||||
assert result["message"] == "enterprise workspace created."
|
||||
assert result["tenant"]["id"] == "tenant-id"
|
||||
assert result["tenant"]["encrypt_public_key"] == "pub-key"
|
||||
assert result["tenant"]["custom_config"] == {}
|
||||
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
80
api/tests/unit_tests/core/agent/conftest.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class DummyPromptEntity:
|
||||
def __init__(self, first_prompt):
|
||||
self.first_prompt = first_prompt
|
||||
|
||||
|
||||
class DummyAgentConfig:
|
||||
def __init__(self, prompt_entity=None):
|
||||
self.prompt = prompt_entity
|
||||
|
||||
|
||||
class DummyAppConfig:
|
||||
def __init__(self, agent=None):
|
||||
self.agent = agent
|
||||
|
||||
|
||||
class DummyScratchpadUnit:
|
||||
def __init__(
|
||||
self,
|
||||
final=False,
|
||||
thought=None,
|
||||
action_str=None,
|
||||
observation=None,
|
||||
agent_response=None,
|
||||
):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tool_factory():
|
||||
def _factory(name):
|
||||
return DummyTool(name)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_prompt_entity_factory():
|
||||
def _factory(first_prompt):
|
||||
return DummyPromptEntity(first_prompt)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent_config_factory():
|
||||
def _factory(prompt_entity=None):
|
||||
return DummyAgentConfig(prompt_entity)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_app_config_factory():
|
||||
def _factory(agent=None):
|
||||
return DummyAppConfig(agent)
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_scratchpad_unit_factory():
|
||||
def _factory(**kwargs):
|
||||
return DummyScratchpadUnit(**kwargs)
|
||||
|
||||
return _factory
|
||||
@@ -1,70 +1,255 @@
|
||||
"""Unit tests for CotAgentOutputParser.
|
||||
|
||||
Verifies expected parsing behavior for streaming content and JSON payloads,
|
||||
including edge cases such as empty/non-string content and malformed JSON.
|
||||
Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real
|
||||
model output structures. Implementation under test:
|
||||
core.agent.output_parser.cot_output_parser.CotAgentOutputParser.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from dify_graph.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta
|
||||
|
||||
|
||||
def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]:
|
||||
for i in range(len(text)):
|
||||
yield LLMResultChunk(
|
||||
model="model",
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])),
|
||||
@pytest.fixture
|
||||
def mock_action_class(mocker):
|
||||
mock_action = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action",
|
||||
mock_action,
|
||||
)
|
||||
return mock_action
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usage_dict():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_chunk():
|
||||
def _make_chunk(content=None, usage=None):
|
||||
delta = SimpleNamespace(
|
||||
message=SimpleNamespace(content=content),
|
||||
usage=usage,
|
||||
)
|
||||
return SimpleNamespace(delta=delta)
|
||||
|
||||
return _make_chunk
|
||||
|
||||
|
||||
def test_cot_output_parser():
|
||||
test_cases = [
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with json
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# code block with JSON
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {'
|
||||
'}\n```"}```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# list
|
||||
{
|
||||
"input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block
|
||||
{
|
||||
"input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}',
|
||||
"action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"},
|
||||
"output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}',
|
||||
},
|
||||
# no code block and json
|
||||
{"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"},
|
||||
]
|
||||
# ============================================================
|
||||
# Test Suite
|
||||
# ============================================================
|
||||
|
||||
parser = CotAgentOutputParser()
|
||||
usage_dict = {}
|
||||
for test_case in test_cases:
|
||||
# mock llm_response as a generator by text
|
||||
llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"])
|
||||
results = parser.handle_react_stream_output(llm_response, usage_dict)
|
||||
output = ""
|
||||
for result in results:
|
||||
if isinstance(result, str):
|
||||
output += result
|
||||
elif isinstance(result, AgentScratchpadUnit.Action):
|
||||
if test_case["action"]:
|
||||
assert result.to_dict() == test_case["action"]
|
||||
output += json.dumps(result.to_dict())
|
||||
if test_case["output"]:
|
||||
assert output == test_case["output"]
|
||||
|
||||
class TestCotAgentOutputParser:
|
||||
"""Validate CotAgentOutputParser streaming + JSON parsing behavior.
|
||||
|
||||
Lifecycle: no explicit setup/teardown; relies on pytest fixtures for
|
||||
lightweight chunk/action doubles. Invariants: non-string/empty content
|
||||
yields no output, usage gets recorded when provided, and valid action JSON
|
||||
results in Action instantiation. Usage: invoke via pytest (e.g.,
|
||||
`pytest -k TestCotAgentOutputParser`).
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Basic streaming & usage
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_stream_plain_text(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("hello world")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "".join(result) == "hello world"
|
||||
|
||||
def test_stream_empty_string(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_stream_none_content(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk(None)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.parametrize("content", [123, 12.5, [], {}, object()])
|
||||
def test_non_string_content(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result == []
|
||||
|
||||
def test_usage_update(self, make_chunk, usage_dict) -> None:
|
||||
usage_data = {"tokens": 99}
|
||||
chunks = [make_chunk("abc", usage=usage_data)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert usage_dict["usage"] == usage_data
|
||||
|
||||
# --------------------------------------------------------
|
||||
# JSON parsing (direct + streaming)
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '{"action": "search", "input": "query"}'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="search", action_input="query")
|
||||
|
||||
def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '[{"action": "lookup", "input": "abc"}]'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None:
|
||||
content = '{"foo": "bar"}'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# Expect the serialized JSON to be yielded as a single element.
|
||||
assert result == [json.dumps({"foo": "bar"})]
|
||||
|
||||
def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None:
|
||||
content = "{invalid json}"
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert any("invalid json" in str(r) for r in result)
|
||||
|
||||
def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
chunks = [
|
||||
make_chunk('{"action": '),
|
||||
make_chunk('"multi", '),
|
||||
make_chunk('"input": "step"}'),
|
||||
]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="multi", action_input="step")
|
||||
|
||||
def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('{"foo": "bar"')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('{"foo": "bar"' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Code block JSON extraction
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = """```json
|
||||
{"action": "lookup", "input": "abc"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc")
|
||||
|
||||
def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
# Multiple JSON objects inside single code fence (invalid combined JSON)
|
||||
# Parser should safely ignore invalid combined block
|
||||
content = """```json
|
||||
{"action": "a1", "input": "x"}
|
||||
{"action": "a2", "input": "y"}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# No valid parsed action expected due to invalid combined JSON
|
||||
assert mock_action_class.call_count == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None:
|
||||
content = """```json
|
||||
{invalid}
|
||||
```"""
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert result
|
||||
|
||||
def test_unclosed_code_block(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk('```json {"a":1}')]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
assert any('```json {"a":1}' in item for item in result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Action / Thought prefix handling
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
" action: something",
|
||||
" ACTION: something",
|
||||
" thought: reasoning",
|
||||
" THOUGHT: reasoning",
|
||||
],
|
||||
)
|
||||
def test_prefix_handling(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
joined = "".join(str(item) for item in result)
|
||||
expected_word = "something" if "action:" in content.lower() else "reasoning"
|
||||
assert expected_word in joined
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("xaction: test")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "x" in "".join(map(str, result))
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Mixed streaming scenarios
|
||||
# --------------------------------------------------------
|
||||
|
||||
def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = 'start {"action": "mix", "input": "1"} end'
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
# JSON action should be parsed
|
||||
mock_action_class.assert_called_once()
|
||||
# Ensure surrounding text is streamed (character-level)
|
||||
joined = "".join(str(r) for r in result if not isinstance(r, MagicMock))
|
||||
assert "start" in joined
|
||||
assert "end" in joined
|
||||
|
||||
def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None:
|
||||
content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```'
|
||||
chunks = [make_chunk(content)]
|
||||
list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert mock_action_class.call_count == 2
|
||||
|
||||
def test_backtick_noise(self, make_chunk, usage_dict) -> None:
|
||||
chunks = [make_chunk("text with ` random ` backticks")]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert "text with" in "".join(result)
|
||||
|
||||
# --------------------------------------------------------
|
||||
# Boundary & edge inputs
|
||||
# --------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content",
|
||||
[
|
||||
"```",
|
||||
"{",
|
||||
"}",
|
||||
"```json",
|
||||
"action:",
|
||||
"thought:",
|
||||
" ",
|
||||
],
|
||||
)
|
||||
def test_edge_inputs(self, make_chunk, usage_dict, content) -> None:
|
||||
chunks = [make_chunk(content)]
|
||||
result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict))
|
||||
assert all(isinstance(item, str) for item in result)
|
||||
joined = "".join(result)
|
||||
if content == " ":
|
||||
assert result == [] or joined == content
|
||||
if content in {"```", "{", "}", "```json"}:
|
||||
assert content in joined
|
||||
if content.lower() in {"action:", "thought:"}:
|
||||
assert "action:" not in joined.lower()
|
||||
assert "thought:" not in joined.lower()
|
||||
|
||||
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
174
api/tests/unit_tests/core/agent/strategy/test_base.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
|
||||
|
||||
class DummyStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Concrete implementation for testing BaseAgentStrategy
|
||||
"""
|
||||
|
||||
def __init__(self, return_values=None, raise_exception=None):
|
||||
self.return_values = return_values or []
|
||||
self.raise_exception = raise_exception
|
||||
self.received_args = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id=None,
|
||||
app_id=None,
|
||||
message_id=None,
|
||||
credentials=None,
|
||||
) -> Generator:
|
||||
self.received_args = (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
credentials,
|
||||
)
|
||||
|
||||
if self.raise_exception:
|
||||
raise self.raise_exception
|
||||
|
||||
yield from self.return_values
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInstantiation:
|
||||
def test_cannot_instantiate_abstract_class(self) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
BaseAgentStrategy()
|
||||
|
||||
|
||||
class TestBaseAgentStrategyInvoke:
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
return MagicMock(name="AgentInvokeMessage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credentials(self):
|
||||
return MagicMock(name="InvokeCredentials")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("params", "user_id", "conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
({"key": "value"}, "user1", "conv1", "app1", "msg1"),
|
||||
({}, "user2", None, None, None),
|
||||
({"a": 1}, "", "", "", ""),
|
||||
({"nested": {"x": 1}}, "user3", None, "app3", None),
|
||||
],
|
||||
)
|
||||
def test_invoke_success(
|
||||
self,
|
||||
mock_message,
|
||||
mock_credentials,
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[mock_message])
|
||||
|
||||
# Act
|
||||
result = list(
|
||||
strategy.invoke(
|
||||
params=params,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
credentials=mock_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [mock_message]
|
||||
assert strategy.received_args == (
|
||||
params,
|
||||
user_id,
|
||||
conversation_id,
|
||||
app_id,
|
||||
message_id,
|
||||
mock_credentials,
|
||||
)
|
||||
|
||||
def test_invoke_multiple_yields(self, mock_message) -> None:
|
||||
# Arrange
|
||||
messages = [mock_message, MagicMock(), MagicMock()]
|
||||
strategy = DummyStrategy(return_values=messages)
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == messages
|
||||
|
||||
def test_invoke_empty_generator(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
# Act
|
||||
result = list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_invoke_propagates_exception(self) -> None:
|
||||
# Arrange
|
||||
strategy = DummyStrategy(raise_exception=ValueError("failure"))
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="failure"):
|
||||
list(strategy.invoke(params={}, user_id="user"))
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_params",
|
||||
[
|
||||
None,
|
||||
"",
|
||||
123,
|
||||
[],
|
||||
],
|
||||
)
|
||||
def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None:
|
||||
"""
|
||||
Base class does not validate types — ensure pass-through behavior
|
||||
"""
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params=invalid_params, user_id="user"))
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_invoke_none_user_id(self) -> None:
|
||||
strategy = DummyStrategy(return_values=[])
|
||||
|
||||
result = list(strategy.invoke(params={}, user_id=None))
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestBaseAgentStrategyGetParameters:
|
||||
def test_get_parameters_default_empty_list(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
result = strategy.get_parameters()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == []
|
||||
|
||||
def test_get_parameters_returns_new_list_each_time(self) -> None:
|
||||
strategy = DummyStrategy()
|
||||
|
||||
first = strategy.get_parameters()
|
||||
second = strategy.get_parameters()
|
||||
|
||||
assert first == second == []
|
||||
assert first is not second
|
||||
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
272
api/tests/unit_tests/core/agent/strategy/test_plugin.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# File: tests/unit_tests/core/agent/strategy/test_plugin.py
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
|
||||
# ============================================================
|
||||
# Fixtures
|
||||
# ============================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parameter():
|
||||
def _factory(name="param", return_value="initialized"):
|
||||
param = MagicMock()
|
||||
param.name = name
|
||||
param.init_frontend_parameter = MagicMock(return_value=return_value)
|
||||
return param
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_declaration(mock_parameter):
|
||||
param1 = mock_parameter("param1", "init1")
|
||||
param2 = mock_parameter("param2", "init2")
|
||||
|
||||
identity = MagicMock()
|
||||
identity.provider = "provider_x"
|
||||
identity.name = "strategy_x"
|
||||
|
||||
declaration = MagicMock()
|
||||
declaration.parameters = [param1, param2]
|
||||
declaration.identity = identity
|
||||
|
||||
return declaration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(mock_declaration):
|
||||
return PluginAgentStrategy(
|
||||
tenant_id="tenant_123",
|
||||
declaration=mock_declaration,
|
||||
meta_version="v1",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Initialization Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestPluginAgentStrategyInitialization:
|
||||
def test_init_sets_attributes(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version="meta_v",
|
||||
)
|
||||
|
||||
assert strategy.tenant_id == "tenant_test"
|
||||
assert strategy.declaration == mock_declaration
|
||||
assert strategy.meta_version == "meta_v"
|
||||
|
||||
def test_init_meta_version_none(self, mock_declaration) -> None:
|
||||
strategy = PluginAgentStrategy(
|
||||
tenant_id="tenant_test",
|
||||
declaration=mock_declaration,
|
||||
meta_version=None,
|
||||
)
|
||||
|
||||
assert strategy.meta_version is None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# get_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestGetParameters:
|
||||
def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None:
|
||||
result = strategy.get_parameters()
|
||||
assert result == mock_declaration.parameters
|
||||
|
||||
|
||||
# ============================================================
|
||||
# initialize_parameters Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInitializeParameters:
|
||||
def test_initialize_parameters_success(self, strategy, mock_declaration) -> None:
|
||||
params = {"param1": "value1"}
|
||||
|
||||
result = strategy.initialize_parameters(params.copy())
|
||||
|
||||
assert result["param1"] == "init1"
|
||||
assert result["param2"] == "init2"
|
||||
|
||||
mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1")
|
||||
mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_params",
|
||||
[
|
||||
{},
|
||||
{"param1": None},
|
||||
{"param1": ""},
|
||||
{"param1": 0},
|
||||
{"param1": []},
|
||||
{"param1": {}, "param2": "value"},
|
||||
],
|
||||
)
|
||||
def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None:
|
||||
result = strategy.initialize_parameters(input_params.copy())
|
||||
|
||||
for param in strategy.declaration.parameters:
|
||||
assert param.name in result
|
||||
|
||||
def test_initialize_parameters_invalid_input_type(self, strategy) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.initialize_parameters(None)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# _invoke Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestInvoke:
|
||||
def test_invoke_success_all_arguments(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mock_convert = mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={"converted": True},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={"param1": "value"},
|
||||
user_id="user_1",
|
||||
conversation_id="conv_1",
|
||||
app_id="app_1",
|
||||
message_id="msg_1",
|
||||
credentials=None,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == ["msg1", "msg2"]
|
||||
mock_convert.assert_called_once()
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
call_kwargs = mock_manager.invoke.call_args.kwargs
|
||||
assert call_kwargs["tenant_id"] == "tenant_123"
|
||||
assert call_kwargs["user_id"] == "user_1"
|
||||
assert call_kwargs["agent_provider"] == "provider_x"
|
||||
assert call_kwargs["agent_strategy"] == "strategy_x"
|
||||
assert call_kwargs["agent_params"] == {"converted": True}
|
||||
assert call_kwargs["conversation_id"] == "conv_1"
|
||||
assert call_kwargs["app_id"] == "app_1"
|
||||
assert call_kwargs["message_id"] == "msg_1"
|
||||
assert call_kwargs["context"] is not None
|
||||
|
||||
def test_invoke_with_credentials(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
# Patch PluginInvokeContext to bypass pydantic validation
|
||||
mock_context = MagicMock()
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginInvokeContext",
|
||||
return_value=mock_context,
|
||||
)
|
||||
|
||||
credentials = MagicMock()
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
credentials=credentials,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("conversation_id", "app_id", "message_id"),
|
||||
[
|
||||
(None, None, None),
|
||||
("conv", None, None),
|
||||
(None, "app", None),
|
||||
(None, None, "msg"),
|
||||
],
|
||||
)
|
||||
def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke = MagicMock(return_value=iter([]))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
result = list(
|
||||
strategy._invoke(
|
||||
params={},
|
||||
user_id="user_1",
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_manager.invoke.assert_called_once()
|
||||
|
||||
def test_invoke_convert_raises_exception(self, strategy, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
side_effect=ValueError("conversion failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
|
||||
def test_invoke_manager_raises_exception(self, strategy, mocker) -> None:
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.invoke.side_effect = RuntimeError("invoke failed")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.PluginAgentClient",
|
||||
return_value=mock_manager,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.strategy.plugin.convert_parameters_to_plugin_format",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
list(strategy._invoke(params={}, user_id="user_1"))
|
||||
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
802
api/tests/unit_tests/core/agent/test_base_agent_runner.py
Normal file
@@ -0,0 +1,802 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.agent.base_agent_runner as module
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
|
||||
# ==========================================================
|
||||
# Fixtures
|
||||
# ==========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mocker):
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, mock_db_session):
|
||||
r = BaseAgentRunner.__new__(BaseAgentRunner)
|
||||
r.tenant_id = "tenant"
|
||||
r.user_id = "user"
|
||||
r.agent_thought_count = 0
|
||||
r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1")
|
||||
r.app_config = mocker.MagicMock()
|
||||
r.app_config.app_id = "app1"
|
||||
r.app_config.agent = None
|
||||
r.dataset_tools = []
|
||||
r.application_generate_entity = mocker.MagicMock(invoke_from="test")
|
||||
r._current_thoughts = []
|
||||
return r
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _repack_app_generate_entity
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestRepack:
|
||||
def test_sets_empty_if_none(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = None
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == ""
|
||||
|
||||
def test_keeps_existing(self, runner, mocker):
|
||||
entity = mocker.MagicMock()
|
||||
entity.app_config.prompt_template.simple_prompt_template = "abc"
|
||||
result = runner._repack_app_generate_entity(entity)
|
||||
assert result.app_config.prompt_template.simple_prompt_template == "abc"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# update_prompt_message_tool
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestUpdatePromptTool:
|
||||
def build_param(self, mocker, **kwargs):
|
||||
p = mocker.MagicMock()
|
||||
p.form = kwargs.get("form")
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
p.type = mock_type
|
||||
|
||||
p.name = kwargs.get("name", "p1")
|
||||
p.llm_description = "desc"
|
||||
p.input_schema = kwargs.get("input_schema")
|
||||
p.options = kwargs.get("options")
|
||||
p.required = kwargs.get("required", False)
|
||||
return p
|
||||
|
||||
def test_skip_non_llm(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form="NOT_LLM")
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_enum_and_required(self, runner, mocker):
|
||||
option = mocker.MagicMock(value="opt1")
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
options=[option],
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool = mocker.MagicMock()
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert "p1" in result.parameters["required"]
|
||||
|
||||
def test_skip_file_type_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM)
|
||||
param.type = module.ToolParameter.ToolParameterType.FILE
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"] == {}
|
||||
|
||||
def test_duplicate_required_not_duplicated(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = self.build_param(
|
||||
mocker,
|
||||
form=module.ToolParameter.ToolParameterForm.LLM,
|
||||
required=True,
|
||||
)
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": ["p1"]}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["required"].count("p1") == 1
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# create_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestCreateAgentThought:
|
||||
def test_with_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=10)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"])
|
||||
assert result == "10"
|
||||
assert runner.agent_thought_count == 1
|
||||
|
||||
def test_without_files(self, runner, mock_db_session, mocker):
|
||||
mock_thought = mocker.MagicMock(id=11)
|
||||
mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought)
|
||||
|
||||
result = runner.create_agent_thought("m", "msg", "tool", "input", [])
|
||||
assert result == "11"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# save_agent_thought
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestSaveAgentThought:
|
||||
def setup_agent(self, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;tool2"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
return agent
|
||||
|
||||
def test_not_found(self, runner, mock_db_session):
|
||||
mock_db_session.scalar.return_value = None
|
||||
with pytest.raises(ValueError):
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
def test_full_update(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mock_label = mocker.MagicMock()
|
||||
mock_label.to_dict.return_value = {"en_US": "label"}
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label)
|
||||
|
||||
usage = mocker.MagicMock(
|
||||
prompt_tokens=1,
|
||||
prompt_price_unit=Decimal("0.1"),
|
||||
prompt_unit_price=Decimal("0.1"),
|
||||
completion_tokens=2,
|
||||
completion_price_unit=Decimal("0.2"),
|
||||
completion_unit_price=Decimal("0.2"),
|
||||
total_tokens=3,
|
||||
total_price=Decimal("0.3"),
|
||||
)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1;tool2",
|
||||
{"a": 1},
|
||||
"thought",
|
||||
{"b": 2},
|
||||
{"meta": 1},
|
||||
"answer",
|
||||
["f1"],
|
||||
usage,
|
||||
)
|
||||
|
||||
assert agent.answer == "answer"
|
||||
assert agent.tokens == 3
|
||||
assert "tool1" in json.loads(agent.tool_labels_str)
|
||||
|
||||
def test_label_fallback_when_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
agent.tool = "unknown_tool"
|
||||
mock_db_session.scalar.return_value = agent
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "unknown_tool" in labels
|
||||
|
||||
def test_json_failure_paths(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
bad_obj = MagicMock()
|
||||
bad_obj.__str__.return_value = "bad"
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
bad_obj,
|
||||
None,
|
||||
bad_obj,
|
||||
bad_obj,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_messages_ids_none(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, None, None)
|
||||
assert mock_db_session.commit.called
|
||||
|
||||
def test_success_dict_serialization(self, runner, mock_db_session, mocker):
|
||||
agent = self.setup_agent(mocker)
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
None,
|
||||
{"a": 1},
|
||||
None,
|
||||
{"b": 2},
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_user_prompt
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeUserPrompt:
|
||||
def test_no_files(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_with_files_no_config(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
msg = mocker.MagicMock(id="1", query="hello", app_model_config=None)
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
def test_image_detail_low_fallback(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[])
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result.content == "hello"
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# organize_agent_history
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestOrganizeHistory:
|
||||
def test_empty(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_answer_only(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert any(isinstance(x, module.AssistantPromptMessage) for x in result)
|
||||
|
||||
def test_skip_current_message(self, runner, mock_db_session, mocker):
|
||||
msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None)
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert result == []
|
||||
|
||||
def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input="invalid",
|
||||
observation="invalid",
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_empty_tool_name_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=";", thought="thinking")
|
||||
msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_valid_json_tool_flow(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=json.dumps({"tool1": {"x": 1}}),
|
||||
observation=json.dumps({"tool1": "obs"}),
|
||||
thought="thinking",
|
||||
)
|
||||
|
||||
msg = mocker.MagicMock(
|
||||
id="m100",
|
||||
agent_thoughts=[thought],
|
||||
answer=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _convert_tool_to_prompt_message_tool (new coverage)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestConvertToolToPromptMessageTool:
|
||||
def test_basic_conversion(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
runtime_param = mocker.MagicMock()
|
||||
runtime_param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param.name = "param1"
|
||||
runtime_param.llm_description = "desc"
|
||||
runtime_param.required = True
|
||||
runtime_param.input_schema = None
|
||||
runtime_param.options = None
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
runtime_param.type = mock_type
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [runtime_param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert entity == tool_entity
|
||||
|
||||
def test_full_conversion_multiple_params(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
# LLM param with input_schema override
|
||||
param1 = mocker.MagicMock()
|
||||
param1.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param1.name = "p1"
|
||||
param1.llm_description = "desc"
|
||||
param1.required = True
|
||||
param1.input_schema = {"type": "integer"}
|
||||
param1.options = None
|
||||
param1.type = mocker.MagicMock()
|
||||
|
||||
# SYSTEM_FILES param should be skipped
|
||||
param2 = mocker.MagicMock()
|
||||
param2.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param2.name = "file_param"
|
||||
param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param1, param2]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert entity == tool_entity
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# _init_prompt_tools additional branches
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestInitPromptToolsExtended:
|
||||
def test_agent_tool_branch(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="agent_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity"))
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert "agent_tool" in tools
|
||||
|
||||
def test_exception_in_conversion(self, runner, mocker):
|
||||
agent_tool = mocker.MagicMock(tool_name="bad_tool")
|
||||
runner.app_config.agent = mocker.MagicMock(tools=[agent_tool])
|
||||
mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception)
|
||||
|
||||
tools, prompts = runner._init_prompt_tools()
|
||||
assert tools == {}
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS)
|
||||
# ==========================================================
|
||||
|
||||
|
||||
class TestAdditionalCoverage:
|
||||
def test_update_prompt_with_input_schema(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "p1"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.options = None
|
||||
param.input_schema = {"type": "number"}
|
||||
|
||||
mock_type = mocker.MagicMock()
|
||||
mock_type.as_normal_type.return_value = "string"
|
||||
param.type = mock_type
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
assert result.parameters["properties"]["p1"]["type"] == "number"
|
||||
|
||||
def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {"tool1": {"en_US": "existing"}}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert labels["tool1"]["en_US"] == "existing"
|
||||
|
||||
def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None)
|
||||
assert agent.tool_meta_str == "meta_string"
|
||||
|
||||
def test_convert_dataset_retriever_tool(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
assert prompt is not None
|
||||
|
||||
def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()]
|
||||
|
||||
file_config = mocker.MagicMock()
|
||||
file_config.image_config = mocker.MagicMock(detail=None)
|
||||
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"])
|
||||
mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock())
|
||||
|
||||
mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw))
|
||||
mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
msg = mocker.MagicMock(id="1", query="hello")
|
||||
msg.app_model_config.to_dict.return_value = {}
|
||||
|
||||
result = runner.organize_agent_user_prompt(msg)
|
||||
assert result is not None
|
||||
|
||||
def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(tool=None, thought="thinking")
|
||||
msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1;tool2",
|
||||
tool_input=json.dumps({"tool1": {}, "tool2": {}}),
|
||||
observation=json.dumps({"tool1": "o1", "tool2": "o2"}),
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
# ================= Additional Surgical Coverage =================
|
||||
|
||||
def test_convert_tool_select_enum_branch(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = True
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
assert prompt_tool is not None
|
||||
|
||||
|
||||
class TestConvertDatasetRetrieverTool:
|
||||
def test_required_param_added(self, runner, mocker):
|
||||
ds_tool = mocker.MagicMock()
|
||||
ds_tool.entity.identity.name = "ds"
|
||||
ds_tool.entity.description.llm = "desc"
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.name = "query"
|
||||
param.llm_description = "desc"
|
||||
param.required = True
|
||||
|
||||
ds_tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool)
|
||||
|
||||
assert prompt is not None
|
||||
|
||||
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"])
|
||||
|
||||
llm = mocker.MagicMock()
|
||||
llm.get_model_schema.return_value = mocker.MagicMock(
|
||||
features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION]
|
||||
)
|
||||
model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c")
|
||||
|
||||
app_config = mocker.MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.agent = None
|
||||
app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"})
|
||||
app_config.additional_features = mocker.MagicMock(show_retrieve_source=True)
|
||||
|
||||
app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"])
|
||||
message = mocker.MagicMock(id="msg1", conversation_id="conv1")
|
||||
|
||||
runner = BaseAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=app_generate,
|
||||
conversation=mocker.MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=mocker.MagicMock(),
|
||||
config=mocker.MagicMock(),
|
||||
queue_manager=mocker.MagicMock(),
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
assert runner.stream_tool_call is True
|
||||
assert runner.files == ["file1"]
|
||||
assert runner.dataset_tools == ["ds_tool"]
|
||||
assert runner.agent_thought_count == 2
|
||||
|
||||
|
||||
class TestBaseAgentRunnerCoverage:
|
||||
def test_convert_tool_skips_non_llm_param(self, runner, mocker):
|
||||
tool = mocker.MagicMock(tool_name="tool1")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = "NOT_LLM"
|
||||
param.type = mocker.MagicMock()
|
||||
|
||||
tool_entity = mocker.MagicMock()
|
||||
tool_entity.entity.description.llm = "desc"
|
||||
tool_entity.get_merged_runtime_parameters.return_value = [param]
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity)
|
||||
mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw))
|
||||
|
||||
prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool)
|
||||
|
||||
assert prompt_tool.parameters["properties"] == {}
|
||||
|
||||
def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker):
|
||||
dataset_tool = mocker.MagicMock()
|
||||
dataset_tool.entity.identity.name = "ds"
|
||||
runner.dataset_tools = [dataset_tool]
|
||||
|
||||
mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock())
|
||||
|
||||
tools, prompt_tools = runner._init_prompt_tools()
|
||||
|
||||
assert tools["ds"] == dataset_tool
|
||||
assert len(prompt_tools) == 1
|
||||
|
||||
def test_update_prompt_message_tool_select_enum(self, runner, mocker):
|
||||
tool = mocker.MagicMock()
|
||||
|
||||
option1 = mocker.MagicMock(value="A")
|
||||
option2 = mocker.MagicMock(value="B")
|
||||
|
||||
param = mocker.MagicMock()
|
||||
param.form = module.ToolParameter.ToolParameterForm.LLM
|
||||
param.name = "select_param"
|
||||
param.required = False
|
||||
param.llm_description = "desc"
|
||||
param.input_schema = None
|
||||
param.options = [option1, option2]
|
||||
param.type = module.ToolParameter.ToolParameterType.SELECT
|
||||
|
||||
tool.get_runtime_parameters.return_value = [param]
|
||||
|
||||
prompt_tool = mocker.MagicMock()
|
||||
prompt_tool.parameters = {"properties": {}, "required": []}
|
||||
|
||||
result = runner.update_prompt_message_tool(tool, prompt_tool)
|
||||
|
||||
assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"]
|
||||
|
||||
def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
tool_input = {"a": 1}
|
||||
observation = {"b": 2}
|
||||
tool_meta = {"c": 3}
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def dumps_side_effect(value, *args, **kwargs):
|
||||
if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False:
|
||||
raise TypeError("fail")
|
||||
return real_dumps(value, *args, **kwargs)
|
||||
|
||||
mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect)
|
||||
|
||||
runner.save_agent_thought(
|
||||
"id",
|
||||
"tool1",
|
||||
tool_input,
|
||||
None,
|
||||
observation,
|
||||
tool_meta,
|
||||
None,
|
||||
[],
|
||||
None,
|
||||
)
|
||||
|
||||
assert isinstance(agent.tool_input, str)
|
||||
assert isinstance(agent.observation, str)
|
||||
assert isinstance(agent.tool_meta_str, str)
|
||||
|
||||
def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker):
|
||||
agent = mocker.MagicMock()
|
||||
agent.tool = "tool1;;"
|
||||
agent.tool_labels = {}
|
||||
agent.thought = ""
|
||||
mock_db_session.scalar.return_value = agent
|
||||
|
||||
mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None)
|
||||
|
||||
runner.save_agent_thought("id", None, None, None, None, None, None, [], None)
|
||||
|
||||
labels = json.loads(agent.tool_labels_str)
|
||||
assert "" not in labels
|
||||
|
||||
def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker):
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[])
|
||||
|
||||
system_message = module.SystemPromptMessage(content="sys")
|
||||
|
||||
result = runner.organize_agent_history([system_message])
|
||||
|
||||
assert system_message in result
|
||||
|
||||
def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker):
|
||||
thought = mocker.MagicMock(
|
||||
tool="tool1",
|
||||
tool_input=None,
|
||||
observation=None,
|
||||
thought="thinking",
|
||||
)
|
||||
msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None)
|
||||
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg]
|
||||
mocker.patch.object(module, "extract_thread_messages", return_value=[msg])
|
||||
mocker.patch("uuid.uuid4", return_value="uuid")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"organize_agent_user_prompt",
|
||||
return_value=module.UserPromptMessage(content="user"),
|
||||
)
|
||||
|
||||
result = runner.organize_agent_history([])
|
||||
|
||||
assert any(isinstance(item, module.ToolPromptMessage) for item in result)
|
||||
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
551
api/tests/unit_tests/core/agent/test_cot_agent_runner.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
215
api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@@ -0,0 +1,234 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
452
api/tests/unit_tests/core/agent/test_fc_agent_runner.py
Normal file
@@ -0,0 +1,452 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
324
api/tests/unit_tests/core/agent/test_plugin_entities.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Unit tests for core.agent.plugin_entities.
|
||||
|
||||
Covers entities such as AgentFeature, AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on
|
||||
Pydantic ValidationError behavior and pytest fixtures for validation and
|
||||
mocking; ensure entity invariants and validation rules remain stable.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.agent.plugin_entities import (
|
||||
AgentFeature,
|
||||
AgentProviderEntityWithPlugin,
|
||||
AgentStrategyEntity,
|
||||
AgentStrategyIdentity,
|
||||
AgentStrategyParameter,
|
||||
AgentStrategyProviderEntity,
|
||||
AgentStrategyProviderIdentity,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity
|
||||
|
||||
# =========================================================
|
||||
# Fixtures
|
||||
# =========================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyIdentity)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_identity(mocker):
|
||||
return mocker.MagicMock(spec=AgentStrategyProviderIdentity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameterType Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameterType:
|
||||
@pytest.mark.parametrize(
|
||||
"enum_member",
|
||||
list(AgentStrategyParameter.AgentStrategyParameterType),
|
||||
)
|
||||
def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
return_value="normalized",
|
||||
)
|
||||
|
||||
result = enum_member.as_normal_type()
|
||||
|
||||
mock_func.assert_called_once_with(enum_member)
|
||||
assert result == "normalized"
|
||||
|
||||
def test_as_normal_type_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.as_normal_type",
|
||||
side_effect=RuntimeError("boom"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
enum_member.as_normal_type()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enum_member", "value"),
|
||||
[
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.STRING, None),
|
||||
(AgentStrategyParameter.AgentStrategyParameterType.FILES, []),
|
||||
],
|
||||
)
|
||||
def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
return_value="casted",
|
||||
)
|
||||
|
||||
result = enum_member.cast_value(value)
|
||||
|
||||
mock_func.assert_called_once_with(enum_member, value)
|
||||
assert result == "casted"
|
||||
|
||||
def test_cast_value_propagates_exception(self, mocker) -> None:
|
||||
enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.cast_parameter_value",
|
||||
side_effect=ValueError("invalid"),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
enum_member.cast_value("bad")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyParameter Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyParameter:
|
||||
def test_valid_creation_minimal(self) -> None:
|
||||
# bypass base PluginParameter required fields using model_construct
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=None,
|
||||
)
|
||||
assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
assert param.help is None
|
||||
|
||||
def test_valid_creation_with_help(self) -> None:
|
||||
help_obj = I18nObject(en_US="test")
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
help=help_obj,
|
||||
)
|
||||
assert param.help == help_obj
|
||||
|
||||
@pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}])
|
||||
def test_invalid_type_raises_validation_error(self, invalid_type) -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y"))
|
||||
|
||||
assert any(error["loc"] == ("type",) for error in exc_info.value.errors())
|
||||
|
||||
def test_init_frontend_parameter_calls_external(self, mocker) -> None:
|
||||
mock_func = mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
return_value="frontend",
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
result = param.init_frontend_parameter("value")
|
||||
|
||||
mock_func.assert_called_once_with(param, param.type, "value")
|
||||
assert result == "frontend"
|
||||
|
||||
def test_init_frontend_parameter_propagates_exception(self, mocker) -> None:
|
||||
mocker.patch(
|
||||
"core.agent.plugin_entities.init_frontend_parameter",
|
||||
side_effect=RuntimeError("error"),
|
||||
)
|
||||
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
param.init_frontend_parameter("value")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyProviderEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyProviderEntity:
|
||||
def test_creation_with_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="plugin-123",
|
||||
)
|
||||
assert entity.plugin_id == "plugin-123"
|
||||
|
||||
def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(
|
||||
identity=mock_provider_identity,
|
||||
plugin_id="",
|
||||
)
|
||||
assert entity.plugin_id == ""
|
||||
|
||||
def test_creation_without_plugin_id(self, mock_provider_identity) -> None:
|
||||
entity = AgentStrategyProviderEntity(identity=mock_provider_identity)
|
||||
assert entity.plugin_id is None
|
||||
|
||||
def test_invalid_identity_raises(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyProviderEntity(identity="invalid")
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentStrategyEntity Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentStrategyEntity:
|
||||
def test_parameters_default_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_none_converted_to_empty(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=None,
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
def test_parameters_preserved(self, mock_identity) -> None:
|
||||
param = AgentStrategyParameter.model_construct(
|
||||
type=AgentStrategyParameter.AgentStrategyParameterType.STRING,
|
||||
name="test",
|
||||
label="label",
|
||||
)
|
||||
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[param],
|
||||
)
|
||||
assert entity.parameters == [param]
|
||||
|
||||
def test_invalid_parameters_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters="invalid",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"features",
|
||||
[
|
||||
None,
|
||||
[],
|
||||
[AgentFeature.HISTORY_MESSAGES],
|
||||
],
|
||||
)
|
||||
def test_features_valid(self, mock_identity, features) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features=features,
|
||||
)
|
||||
assert entity.features == features
|
||||
|
||||
def test_invalid_features_type_raises(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
features="invalid",
|
||||
)
|
||||
|
||||
def test_output_schema_and_meta_version(self, mock_identity) -> None:
|
||||
entity = AgentStrategyEntity(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
output_schema={"type": "object"},
|
||||
meta_version="v1",
|
||||
)
|
||||
assert entity.output_schema == {"type": "object"}
|
||||
assert entity.meta_version == "v1"
|
||||
|
||||
def test_missing_required_fields_raise(self, mock_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentStrategyEntity(identity=mock_identity)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# AgentProviderEntityWithPlugin Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestAgentProviderEntityWithPlugin:
|
||||
def test_default_strategies_empty(self, mock_provider_identity) -> None:
|
||||
entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity)
|
||||
assert entity.strategies == []
|
||||
|
||||
def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None:
|
||||
strategy = AgentStrategyEntity.model_construct(
|
||||
identity=mock_identity,
|
||||
description=I18nObject(en_US="test"),
|
||||
parameters=[],
|
||||
)
|
||||
|
||||
entity = AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies=[strategy],
|
||||
)
|
||||
assert entity.strategies == [strategy]
|
||||
|
||||
def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentProviderEntityWithPlugin(
|
||||
identity=mock_provider_identity,
|
||||
strategies="invalid",
|
||||
)
|
||||
|
||||
|
||||
# =========================================================
|
||||
# Inheritance Smoke Tests
|
||||
# =========================================================
|
||||
|
||||
|
||||
class TestInheritanceBehavior:
|
||||
def test_agent_strategy_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyIdentity, ToolIdentity)
|
||||
|
||||
def test_agent_strategy_provider_identity_inherits(self) -> None:
|
||||
assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity)
|
||||
0
api/tests/unit_tests/core/app/apps/__init__.py
Normal file
0
api/tests/unit_tests/core/app/apps/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAdvancedChatAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow)
|
||||
|
||||
assert app_config.workflow_id == "wf-1"
|
||||
assert app_config.app_mode == AppMode.ADVANCED_CHAT
|
||||
|
||||
def test_config_validate_filters_keys(self):
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = kwargs.get("config") if kwargs else args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 7),
|
||||
),
|
||||
):
|
||||
filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={})
|
||||
|
||||
assert filtered["file_upload"] == 1
|
||||
assert filtered["opening_statement"] == 2
|
||||
assert filtered["suggested_questions_after_answer"] == 3
|
||||
assert filtered["speech_to_text"] == 4
|
||||
assert filtered["text_to_speech"] == 5
|
||||
assert filtered["retriever_resource"] == 6
|
||||
assert filtered["sensitive_word_avoidance"] == 7
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,96 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateResponseConverter:
|
||||
def test_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_stream_simple_response_includes_node_events(self):
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t1",
|
||||
workflow_run_id="r1",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="e1",
|
||||
node_id="n1",
|
||||
node_type="answer",
|
||||
title="Answer",
|
||||
index=1,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
elapsed_time=0.1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_start,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=node_finish,
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert converted[0] == "ping"
|
||||
assert converted[1]["event"] == "node_started"
|
||||
assert converted[2]["event"] == "node_finished"
|
||||
assert converted[3]["event"] == "error"
|
||||
@@ -0,0 +1,600 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
from core.base.tts.app_generator_tts_publisher import AudioTrunk
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from models.enums import MessageStatus
|
||||
from models.model import AppMode, EndUser
|
||||
|
||||
|
||||
def _make_pipeline():
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity.model_construct(
|
||||
task_id="task",
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
user_id="user",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
trace_manager=None,
|
||||
workflow_run_id="run-id",
|
||||
)
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
query="hello",
|
||||
created_at=datetime.utcnow(),
|
||||
status=MessageStatus.NORMAL,
|
||||
answer="",
|
||||
)
|
||||
conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT)
|
||||
workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={})
|
||||
user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session")
|
||||
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None),
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=False,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=lambda **kwargs: None,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
class TestAdvancedChatGenerateTaskPipeline:
|
||||
def test_ensure_workflow_initialized_raises(self):
|
||||
pipeline = _make_pipeline()
|
||||
|
||||
with pytest.raises(ValueError, match="workflow run not initialized"):
|
||||
pipeline._ensure_workflow_initialized()
|
||||
|
||||
def test_to_blocking_response_returns_message_end(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.answer = "done"
|
||||
|
||||
def _gen():
|
||||
yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"})
|
||||
|
||||
response = pipeline._to_blocking_response(_gen())
|
||||
|
||||
assert response.data.answer == "done"
|
||||
assert response.data.metadata == {"k": "v"}
|
||||
|
||||
def test_handle_text_chunk_event_updates_state(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager = SimpleNamespace(
|
||||
message_to_stream_response=lambda **kwargs: MessageEndStreamResponse(
|
||||
task_id="task", id="message-id", metadata={}
|
||||
)
|
||||
)
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=None)
|
||||
|
||||
responses = list(pipeline._handle_text_chunk_event(event))
|
||||
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert responses
|
||||
|
||||
def test_listen_audio_msg_returns_audio_stream(self):
|
||||
pipeline = _make_pipeline()
|
||||
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
|
||||
|
||||
response = pipeline._listen_audio_msg(publisher=publisher, task_id="task")
|
||||
|
||||
assert isinstance(response, MessageAudioStreamResponse)
|
||||
|
||||
def test_handle_ping_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task")
|
||||
|
||||
responses = list(pipeline._handle_ping_event(QueuePingEvent()))
|
||||
|
||||
assert isinstance(responses[0], PingStreamResponse)
|
||||
|
||||
def test_handle_error_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
pipeline._database_session = _fake_session
|
||||
|
||||
responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom"))))
|
||||
|
||||
assert isinstance(responses[0], ValueError)
|
||||
|
||||
def test_handle_workflow_started_event_sets_run_id(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started"
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace())
|
||||
|
||||
responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent()))
|
||||
|
||||
assert pipeline._workflow_run_id == "run-id"
|
||||
assert responses == ["started"]
|
||||
|
||||
def test_message_end_to_stream_response_strips_annotation_reply(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id="ann",
|
||||
account=AnnotationReplyAccount(id="acc", name="acc"),
|
||||
)
|
||||
|
||||
response = pipeline._message_end_to_stream_response()
|
||||
|
||||
assert "annotation_reply" not in response.metadata
|
||||
|
||||
def test_handle_output_moderation_chunk_publishes_stop(self):
|
||||
pipeline = _make_pipeline()
|
||||
events: list[object] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return True
|
||||
|
||||
def get_final_output(self):
|
||||
return "final"
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
pipeline._base_task_pipeline.queue_manager = SimpleNamespace(
|
||||
publish=lambda event, pub_from: events.append(event)
|
||||
)
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("ignored")
|
||||
|
||||
assert result is True
|
||||
assert pipeline._task_state.answer == "final"
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in events)
|
||||
assert any(isinstance(event, QueueStopEvent) for event in events)
|
||||
|
||||
def test_handle_node_succeeded_event_records_files(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [
|
||||
{"type": "file", "transfer_method": "local"}
|
||||
]
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
event = SimpleNamespace(
|
||||
node_type=NodeType.ANSWER,
|
||||
outputs={"k": "v"},
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
assert responses == ["done"]
|
||||
assert pipeline._recorded_files
|
||||
|
||||
def test_iteration_and_loop_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: (
|
||||
"iter_start"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next"
|
||||
pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: (
|
||||
"iter_done"
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start"
|
||||
pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next"
|
||||
pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done"
|
||||
|
||||
iter_start = QueueIterationStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_next = QueueIterationNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
iter_done = QueueIterationCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_start = QueueLoopStartEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_next = QueueLoopNextEvent(
|
||||
index=1,
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
node_run_index=1,
|
||||
)
|
||||
loop_done = QueueLoopCompletedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="LLM",
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_index=1,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"]
|
||||
assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"]
|
||||
assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"]
|
||||
assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"]
|
||||
assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"]
|
||||
assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"]
|
||||
|
||||
def test_workflow_finish_handlers(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_run_id = "run-id"
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish"
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"]
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: None
|
||||
pipeline._save_message = lambda **kwargs: None
|
||||
pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None
|
||||
pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom")
|
||||
pipeline._base_task_pipeline.error_to_stream_response = lambda err: err
|
||||
pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id")
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace(scalar=lambda *args, **kwargs: None)
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={})))
|
||||
assert len(succeeded_responses) == 2
|
||||
assert isinstance(succeeded_responses[0], MessageEndStreamResponse)
|
||||
assert succeeded_responses[1] == "finish"
|
||||
|
||||
partial_success_responses = list(
|
||||
pipeline._handle_workflow_partial_success_event(
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
)
|
||||
)
|
||||
assert len(partial_success_responses) == 2
|
||||
assert isinstance(partial_success_responses[0], MessageEndStreamResponse)
|
||||
assert partial_success_responses[1] == "finish"
|
||||
assert (
|
||||
list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0]
|
||||
== "finish"
|
||||
)
|
||||
assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [
|
||||
"pause"
|
||||
]
|
||||
|
||||
def test_node_failure_handlers(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish"
|
||||
pipeline._save_output_for_event = lambda event, node_execution_id: None
|
||||
|
||||
failed_event = QueueNodeFailedEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
exc_event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"]
|
||||
assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"]
|
||||
|
||||
def test_handle_text_chunk_event_tracks_streaming_metrics(self):
|
||||
pipeline = _make_pipeline()
|
||||
published: list[object] = []
|
||||
|
||||
class _Publisher:
|
||||
def publish(self, message):
|
||||
published.append(message)
|
||||
|
||||
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message)
|
||||
)
|
||||
|
||||
assert responses == ["chunk"]
|
||||
assert pipeline._task_state.is_streaming_response is True
|
||||
assert pipeline._task_state.first_token_time is not None
|
||||
assert pipeline._task_state.last_token_time is not None
|
||||
assert pipeline._task_state.answer == "hi"
|
||||
assert published == [queue_message]
|
||||
|
||||
def test_handle_output_moderation_chunk_appends_token(self):
|
||||
pipeline = _make_pipeline()
|
||||
seen: list[str] = []
|
||||
|
||||
class _Moderation:
|
||||
def should_direct_output(self):
|
||||
return False
|
||||
|
||||
def append_new_token(self, text):
|
||||
seen.append(text)
|
||||
|
||||
pipeline._base_task_pipeline.output_moderation_handler = _Moderation()
|
||||
|
||||
result = pipeline._handle_output_moderation_chunk("token")
|
||||
|
||||
assert result is False
|
||||
assert seen == ["token"]
|
||||
|
||||
def test_handle_retriever_and_annotation_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
calls = {"retriever": 0, "annotation": 0}
|
||||
|
||||
def _hit_retriever(event):
|
||||
calls["retriever"] += 1
|
||||
|
||||
def _hit_annotation(event):
|
||||
calls["annotation"] += 1
|
||||
|
||||
pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever
|
||||
pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation
|
||||
|
||||
retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[])
|
||||
annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann")
|
||||
|
||||
assert list(pipeline._handle_retriever_resources_event(retriever_event)) == []
|
||||
assert list(pipeline._handle_annotation_reply_event(annotation_event)) == []
|
||||
assert calls == {"retriever": 1, "annotation": 1}
|
||||
|
||||
def test_handle_message_replace_event(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
|
||||
event = QueueMessageReplaceEvent(
|
||||
text="new",
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_message_replace_event(event)) == ["replace"]
|
||||
|
||||
def test_handle_human_input_events(self):
|
||||
pipeline = _make_pipeline()
|
||||
persisted: list[str] = []
|
||||
pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved")
|
||||
pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled"
|
||||
pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout"
|
||||
|
||||
filled_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
rendered_content="content",
|
||||
action_id="action",
|
||||
action_text="action",
|
||||
)
|
||||
timeout_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="title",
|
||||
expiration_time=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"]
|
||||
assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"]
|
||||
assert persisted == ["saved"]
|
||||
|
||||
def test_save_message_strips_markdown_and_sets_usage(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._recorded_files = [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "remote",
|
||||
"remote_url": "http://example.com/file.png",
|
||||
"related_id": "file-id",
|
||||
}
|
||||
]
|
||||
pipeline._task_state.answer = " hello"
|
||||
pipeline._task_state.is_streaming_response = True
|
||||
pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1
|
||||
pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2
|
||||
|
||||
message = SimpleNamespace(
|
||||
id="message-id",
|
||||
status=MessageStatus.PAUSED,
|
||||
answer="",
|
||||
updated_at=None,
|
||||
provider_response_latency=None,
|
||||
message_tokens=None,
|
||||
message_unit_price=None,
|
||||
message_price_unit=None,
|
||||
answer_tokens=None,
|
||||
answer_unit_price=None,
|
||||
answer_price_unit=None,
|
||||
total_price=None,
|
||||
currency=None,
|
||||
message_metadata=None,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
from_account_id=None,
|
||||
from_end_user_id="end-user",
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def scalar(self, *args, **kwargs):
|
||||
return message
|
||||
|
||||
def add_all(self, items):
|
||||
self.items = items
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state)
|
||||
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
assert message.answer == "hello"
|
||||
assert message.message_metadata
|
||||
|
||||
def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)))
|
||||
|
||||
assert responses == ["end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_handle_message_end_event_applies_output_moderation(self, monkeypatch):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")),
|
||||
start_at=0.0,
|
||||
)
|
||||
pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe"
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace"
|
||||
pipeline._message_end_to_stream_response = lambda: "end"
|
||||
|
||||
saved: list[str] = []
|
||||
|
||||
def _save_message(**kwargs):
|
||||
saved.append("saved")
|
||||
|
||||
pipeline._save_message = _save_message
|
||||
|
||||
@contextmanager
|
||||
def _fake_session():
|
||||
yield SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(pipeline, "_database_session", _fake_session)
|
||||
|
||||
responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent()))
|
||||
|
||||
assert responses == ["replace", "end"]
|
||||
assert saved == ["saved"]
|
||||
|
||||
def test_dispatch_event_handles_node_exception(self):
|
||||
pipeline = _make_pipeline()
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed"
|
||||
pipeline._save_output_for_event = lambda *args, **kwargs: None
|
||||
|
||||
event = QueueNodeExceptionEvent(
|
||||
node_execution_id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
process_data={},
|
||||
error="err",
|
||||
)
|
||||
|
||||
assert list(pipeline._dispatch_event(event)) == ["failed"]
|
||||
@@ -0,0 +1,302 @@
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import (
|
||||
AgentChatAppConfigManager,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerGetAppConfig:
|
||||
def test_get_app_config_override_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"ignored": True}
|
||||
|
||||
override_config = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.variables == "variables"
|
||||
assert result.external_data_variables == "external"
|
||||
|
||||
def test_get_app_config_conversation_specific(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
conversation = mocker.MagicMock()
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_dict == app_model_config.to_dict.return_value
|
||||
assert result.app_model_config_from.value == "conversation-specific-config"
|
||||
|
||||
def test_get_app_config_latest_config(self, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert")
|
||||
mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert")
|
||||
mocker.patch.object(AgentChatAppConfigManager, "convert_features")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert",
|
||||
return_value=("variables", "external"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig",
|
||||
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
result = AgentChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=None,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from.value == "app-latest-config"
|
||||
|
||||
|
||||
class TestAgentChatAppConfigManagerConfigValidate:
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {},
|
||||
"user_input_form": {},
|
||||
"file_upload": {},
|
||||
"prompt_template": {},
|
||||
"agent_mode": {},
|
||||
"opening_statement": {},
|
||||
"suggested_questions_after_answer": {},
|
||||
"speech_to_text": {},
|
||||
"text_to_speech": {},
|
||||
"retriever_resource": {},
|
||||
"dataset": {},
|
||||
"moderation": {},
|
||||
"extra": "value",
|
||||
}
|
||||
|
||||
def return_with_key(key):
|
||||
return config, [key]
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("model"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("file_upload"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda app_mode, cfg: return_with_key("prompt_template"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
AgentChatAppConfigManager,
|
||||
"validate_agent_mode_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("opening_statement"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("speech_to_text"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("text_to_speech"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda cfg: return_with_key("retriever_resource"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=lambda tenant_id, cfg: return_with_key("moderation"),
|
||||
)
|
||||
|
||||
filtered = AgentChatAppConfigManager.config_validate("tenant", config)
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"prompt_template",
|
||||
"agent_mode",
|
||||
"opening_statement",
|
||||
"suggested_questions_after_answer",
|
||||
"speech_to_text",
|
||||
"text_to_speech",
|
||||
"retriever_resource",
|
||||
"dataset",
|
||||
"moderation",
|
||||
}
|
||||
assert "extra" not in filtered
|
||||
|
||||
|
||||
class TestValidateAgentModeAndSetDefaults:
|
||||
def test_defaults_when_missing(self):
|
||||
config = {}
|
||||
updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert "agent_mode" in updated
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
assert keys == ["agent_mode"]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_mode",
|
||||
["invalid", 123],
|
||||
)
|
||||
def test_agent_mode_type_validation(self, agent_mode):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode})
|
||||
|
||||
def test_agent_mode_empty_list_defaults(self):
|
||||
config = {"agent_mode": []}
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"] == []
|
||||
|
||||
def test_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}})
|
||||
|
||||
def test_strategy_must_be_valid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}}
|
||||
)
|
||||
|
||||
def test_tools_must_be_list(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_requires_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}}
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_must_be_uuid(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_dataset_id_not_exists(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=False,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}},
|
||||
)
|
||||
|
||||
def test_old_tool_enabled_must_be_bool(self):
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant",
|
||||
{"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"])
|
||||
def test_new_style_tool_requires_fields(self, missing_key):
|
||||
tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"}
|
||||
tool.pop(missing_key, None)
|
||||
with pytest.raises(ValueError):
|
||||
AgentChatAppConfigManager.validate_agent_mode_and_set_defaults(
|
||||
"tenant", {"agent_mode": {"enabled": True, "tools": [tool]}}
|
||||
)
|
||||
|
||||
def test_valid_old_and_new_style_tools(self, mocker):
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists",
|
||||
return_value=True,
|
||||
)
|
||||
dataset_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"agent_mode": {
|
||||
"enabled": True,
|
||||
"strategy": PlanningStrategy.ROUTER.value,
|
||||
"tools": [
|
||||
{"dataset": {"id": dataset_id}},
|
||||
{
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "p1",
|
||||
"tool_name": "tool",
|
||||
"tool_parameters": {},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config)
|
||||
assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False
|
||||
assert updated["agent_mode"]["tools"][1]["enabled"] is False
|
||||
@@ -0,0 +1,296 @@
|
||||
import contextlib
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id):
|
||||
self.id = user_id
|
||||
self.session_id = f"session-{user_id}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = AgentChatAppGenerator()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.current_app",
|
||||
new=mocker.MagicMock(_get_current_object=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx")
|
||||
return gen
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorGenerate:
|
||||
def test_generate_rejects_blocking_mode(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False)
|
||||
|
||||
def test_generate_requires_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock())
|
||||
|
||||
def test_generate_rejects_non_string_query(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": 123, "inputs": {}},
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
def test_generate_override_requires_debugger(self, generator, mocker):
|
||||
app_model = mocker.MagicMock()
|
||||
user = DummyAccount("user")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_success_with_debugger_override(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate",
|
||||
return_value={"validated": True},
|
||||
)
|
||||
app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[])
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ConversationService.get_conversation",
|
||||
return_value=mocker.MagicMock(id="conv"),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=queue_manager,
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {
|
||||
"query": "hello",
|
||||
"inputs": {"name": "world"},
|
||||
"conversation_id": "conv",
|
||||
"model_config": {"model": {"provider": "p"}},
|
||||
"files": [{"id": "f1"}],
|
||||
}
|
||||
|
||||
result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
thread_obj.start.assert_called_once()
|
||||
|
||||
def test_generate_without_file_config(self, generator, mocker):
|
||||
app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat")
|
||||
app_model_config = mocker.MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "p"}}
|
||||
|
||||
user = DummyAccount("user")
|
||||
|
||||
generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config)
|
||||
generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1})
|
||||
generator._init_generate_records = mocker.MagicMock(
|
||||
return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg"))
|
||||
)
|
||||
generator._handle_response = mocker.MagicMock(return_value="response")
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config",
|
||||
return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings",
|
||||
return_value=["file-obj"],
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.TraceQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager",
|
||||
return_value=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
thread_obj = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.threading.Thread",
|
||||
return_value=thread_obj,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP)
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity",
|
||||
return_value=app_entity,
|
||||
)
|
||||
|
||||
args = {"query": "hello", "inputs": {"name": "world"}}
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
class TestAgentChatAppGeneratorWorker:
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_context(self, mocker):
|
||||
@contextlib.contextmanager
|
||||
def ctx_manager(*args, **kwargs):
|
||||
yield
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager)
|
||||
|
||||
def test_generate_worker_handles_generate_task_stopped(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = GenerateTaskStoppedError()
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error",
|
||||
[
|
||||
InvokeAuthorizationError("bad"),
|
||||
ValidationError.from_exception_data("TestModel", []),
|
||||
ValueError("bad"),
|
||||
Exception("bad"),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_publishes_errors(self, generator, mocker, error):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = error
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called
|
||||
|
||||
def test_generate_worker_logs_value_error_when_debug(self, generator, mocker):
|
||||
queue_manager = mocker.MagicMock()
|
||||
generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock())
|
||||
|
||||
runner = mocker.MagicMock()
|
||||
runner.run.side_effect = ValueError("bad")
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner)
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close")
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True))
|
||||
logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger")
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=mocker.MagicMock(),
|
||||
context=mocker.MagicMock(),
|
||||
application_generate_entity=mocker.MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
logger.exception.assert_called_once()
|
||||
@@ -0,0 +1,413 @@
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_moderation_error_direct_output(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
user_id="user",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation)
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
|
||||
queue_manager = mocker.MagicMock()
|
||||
runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
queue_manager.publish.assert_called_once()
|
||||
runner.direct_output.assert_called_once()
|
||||
|
||||
def test_run_hosting_moderation_short_circuits(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock()
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(),
|
||||
conversation_id=None,
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=True)
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
def test_run_model_schema_missing(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = None
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: mode}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = [ModelFeature.TOOL_CALL]
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_conversation_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_message_not_found(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING)
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
query="q",
|
||||
files=[],
|
||||
stream=True,
|
||||
model_conf=mocker.MagicMock(
|
||||
provider_model_bundle=mocker.MagicMock(),
|
||||
model="m",
|
||||
provider="p",
|
||||
credentials={"k": "v"},
|
||||
),
|
||||
conversation_id="conv",
|
||||
invoke_from=mocker.MagicMock(),
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
mocker.patch.object(runner, "check_hosting_moderation", return_value=False)
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance)
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
@@ -0,0 +1,162 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterBlocking:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"a": 1},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"a": 1}
|
||||
|
||||
def test_convert_blocking_simple_response_with_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
|
||||
def test_convert_blocking_simple_response_with_non_dict_metadata(self):
|
||||
blocking = ChatbotAppBlockingResponse.model_construct(
|
||||
task_id="task",
|
||||
data=ChatbotAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="agent-chat",
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
|
||||
class TestAgentChatAppGenerateResponseConverterStream:
|
||||
def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
def _gen():
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=2,
|
||||
stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=3,
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="m1",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
"extra": "ignored",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"id": "a"},
|
||||
"usage": {"prompt_tokens": 1},
|
||||
},
|
||||
),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="conv",
|
||||
message_id="msg",
|
||||
created_at=4,
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")),
|
||||
)
|
||||
|
||||
return _gen()
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
assert items[1]["event"] == "message"
|
||||
assert "answer" in items[1]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert items[3]["event"] == "error"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream()))
|
||||
assert items[0] == "ping"
|
||||
# Assert the message event structure and content at items[1]
|
||||
assert items[1]["event"] == "message"
|
||||
assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"]
|
||||
assert items[2]["event"] == "message_end"
|
||||
assert "metadata" in items[2]
|
||||
metadata = items[2]["metadata"]
|
||||
assert "annotation_reply" not in metadata
|
||||
assert "usage" not in metadata
|
||||
assert metadata["retriever_resources"] == [
|
||||
{
|
||||
"segment_id": "s1",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "content",
|
||||
"summary": "summary",
|
||||
}
|
||||
]
|
||||
assert items[3]["event"] == "error"
|
||||
0
api/tests/unit_tests/core/app/apps/chat/__init__.py
Normal file
0
api/tests/unit_tests/core/app/apps/chat/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestChatAppConfigManager:
|
||||
def test_get_app_config_uses_override_dict(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value)
|
||||
app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"})
|
||||
override = {"model": "override"}
|
||||
|
||||
model_entity = ModelConfigEntity(provider="p", model="m")
|
||||
prompt_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hi",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert",
|
||||
return_value=None,
|
||||
),
|
||||
patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])),
|
||||
):
|
||||
app_config = ChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=None,
|
||||
override_config_dict=override,
|
||||
)
|
||||
|
||||
assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert app_config.app_model_config_dict == override
|
||||
assert app_config.app_mode == AppMode.CHAT
|
||||
|
||||
def test_config_validate_filters_related_keys(self):
|
||||
config = {"extra": 1}
|
||||
|
||||
def _add_key(key, value):
|
||||
def _inner(*args, **kwargs):
|
||||
config = args[-1]
|
||||
config = {**config, key: value}
|
||||
return config, [key]
|
||||
|
||||
return _inner
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("model", 1),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("inputs", 2),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("file_upload", 3),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("prompt", 4),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("dataset", 5),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("opening_statement", 6),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("suggested_questions_after_answer", 7),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("speech_to_text", 8),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("text_to_speech", 9),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("retriever_resource", 10),
|
||||
),
|
||||
patch(
|
||||
"core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults",
|
||||
side_effect=_add_key("sensitive_word_avoidance", 11),
|
||||
),
|
||||
):
|
||||
filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config)
|
||||
|
||||
assert filtered["model"] == 1
|
||||
assert filtered["inputs"] == 2
|
||||
assert filtered["file_upload"] == 3
|
||||
assert filtered["prompt"] == 4
|
||||
assert filtered["dataset"] == 5
|
||||
assert filtered["opening_statement"] == 6
|
||||
assert filtered["suggested_questions_after_answer"] == 7
|
||||
assert filtered["speech_to_text"] == 8
|
||||
assert filtered["text_to_speech"] == 9
|
||||
assert filtered["retriever_resource"] == 10
|
||||
assert filtered["sensitive_word_avoidance"] == 11
|
||||
@@ -0,0 +1,280 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.published = []
|
||||
|
||||
def publish_error(self, error, pub_from):
|
||||
self.published.append((error, pub_from))
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestChatAppGenerator:
|
||||
def test_generate_requires_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_rejects_non_string_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
args={"query": 1, "inputs": {}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_debugger_overrides_model_config(self):
|
||||
generator = ChatAppGenerator()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
user = SimpleNamespace(id="user-1", session_id="session-1")
|
||||
args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}
|
||||
|
||||
with (
|
||||
patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config",
|
||||
return_value=SimpleNamespace(
|
||||
variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT
|
||||
),
|
||||
),
|
||||
patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None),
|
||||
patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity),
|
||||
patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager),
|
||||
patch(
|
||||
"core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True}
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})),
|
||||
patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}),
|
||||
patch.object(
|
||||
ChatAppGenerator,
|
||||
"_init_generate_records",
|
||||
return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")),
|
||||
),
|
||||
patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}),
|
||||
patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f),
|
||||
patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread,
|
||||
):
|
||||
mock_thread.return_value.start.return_value = None
|
||||
result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_generate_rejects_model_config_override_for_non_debugger(self):
|
||||
generator = ChatAppGenerator()
|
||||
with pytest.raises(ValueError):
|
||||
with (
|
||||
patch.object(
|
||||
ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})
|
||||
),
|
||||
):
|
||||
generator.generate(
|
||||
app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value),
|
||||
user=SimpleNamespace(id="u1", session_id="s1"),
|
||||
args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_worker_handles_exceptions(self):
|
||||
generator = ChatAppGenerator()
|
||||
queue_manager = DummyQueueManager()
|
||||
entity = DummyGenerateEntity(task_id="t1", user_id="u1")
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
assert queue_manager.published
|
||||
|
||||
with (
|
||||
patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()),
|
||||
patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()),
|
||||
patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()),
|
||||
patch("core.app.apps.chat.app_generator.db.session.close"),
|
||||
):
|
||||
generator._generate_worker(
|
||||
flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))),
|
||||
application_generate_entity=entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppRunner:
|
||||
def test_run_raises_when_app_missing(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[]
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
def test_run_moderation_error_direct_output(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_annotation_reply_short_circuits(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
annotation = SimpleNamespace(id="ann-1", content="answer")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
):
|
||||
queue_manager = DummyQueueManager()
|
||||
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published)
|
||||
mock_direct.assert_called_once()
|
||||
|
||||
def test_run_returns_when_hosting_moderation_blocks(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=False,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestChatAppGenerateResponseConverter:
|
||||
def test_convert_blocking_simple_response_metadata(self):
|
||||
data = ChatbotAppBlockingResponse.Data(
|
||||
id="msg-1",
|
||||
mode="chat",
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
answer="hi",
|
||||
metadata={"usage": {"total_tokens": 1}},
|
||||
created_at=1,
|
||||
)
|
||||
blocking = ChatbotAppBlockingResponse(task_id="t1", data=data)
|
||||
|
||||
response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "usage" not in response["metadata"]
|
||||
|
||||
def test_convert_stream_responses(self):
|
||||
def stream() -> Generator[ChatbotAppStreamResponse, None, None]:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=PingStreamResponse(task_id="t1"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")),
|
||||
)
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id="c1",
|
||||
message_id="m1",
|
||||
created_at=1,
|
||||
stream_response=MessageEndStreamResponse(task_id="t1", id="m1"),
|
||||
)
|
||||
|
||||
full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
assert full[0] == "ping"
|
||||
assert full[1]["event"] == "message"
|
||||
assert full[2]["event"] == "error"
|
||||
|
||||
simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
assert simple[0] == "ping"
|
||||
assert simple[-1]["event"] == "message_end"
|
||||
162
api/tests/unit_tests/core/app/apps/completion/test_app_runner.py
Normal file
162
api/tests/unit_tests/core/app/apps/completion/test_app_runner.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.completion.app_runner as module
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.moderation.base import ModerationError
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CompletionAppRunner()
|
||||
|
||||
|
||||
def _build_app_config(dataset=None, external_tools=None, additional_features=None):
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app1"
|
||||
app_config.tenant_id = "tenant"
|
||||
app_config.prompt_template = MagicMock()
|
||||
app_config.dataset = dataset
|
||||
app_config.external_data_variables = external_tools or []
|
||||
app_config.additional_features = additional_features
|
||||
app_config.app_model_config_dict = {"file_upload": {"enabled": True}}
|
||||
return app_config
|
||||
|
||||
|
||||
def _build_generate_entity(app_config, file_upload_config=None):
|
||||
model_conf = MagicMock(
|
||||
provider_model_bundle="bundle",
|
||||
model="model",
|
||||
parameters={"max_tokens": 10},
|
||||
stop=["stop"],
|
||||
)
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
model_conf=model_conf,
|
||||
inputs={"qvar": "query_from_input"},
|
||||
query="original_query",
|
||||
files=[],
|
||||
file_upload_config=file_upload_config,
|
||||
stream=True,
|
||||
user_id="user",
|
||||
invoke_from=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppRunner:
|
||||
def test_run_app_not_found(self, runner, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
|
||||
def test_run_moderation_error_outputs_direct(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked"))
|
||||
runner.direct_output = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_hosting_moderation_stops(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_dataset_and_external_tools_flow(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
session.close = MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
retrieve_config = MagicMock(query_variable="qvar")
|
||||
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
|
||||
additional_features = MagicMock(show_retrieve_source=True)
|
||||
app_config = _build_app_config(
|
||||
dataset=dataset_config,
|
||||
external_tools=["tool"],
|
||||
additional_features=additional_features,
|
||||
)
|
||||
|
||||
file_upload_config = MagicMock()
|
||||
file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])])
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs)
|
||||
runner.check_hosting_moderation = MagicMock(return_value=False)
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
dataset_retrieval = MagicMock()
|
||||
dataset_retrieval.retrieve.return_value = ("ctx", ["file1"])
|
||||
mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval)
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.invoke_llm.return_value = "invoke_result"
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
|
||||
dataset_retrieval.retrieve.assert_called_once()
|
||||
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_uses_low_image_detail_default(self, runner, mocker):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
|
||||
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
assert (
|
||||
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
|
||||
== ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.completion.app_config_manager as module
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestCompletionAppConfigManager:
|
||||
def test_get_app_config_with_override(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"]))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_config,
|
||||
)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS
|
||||
assert result.app_model_config_dict == override_config
|
||||
assert result.variables == ["v1"]
|
||||
assert result.external_data_variables == ["ext1"]
|
||||
assert result.app_mode == AppMode.COMPLETION
|
||||
|
||||
def test_get_app_config_without_override_uses_model_config(self, mocker):
|
||||
app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value)
|
||||
app_model_config = MagicMock(id="cfg1")
|
||||
app_model_config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
|
||||
mocker.patch.object(module.ModelConfigManager, "convert", return_value="model")
|
||||
mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt")
|
||||
mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation")
|
||||
mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset")
|
||||
mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features")
|
||||
mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], []))
|
||||
mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config)
|
||||
|
||||
assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
|
||||
assert result.app_model_config_dict == {"model": {"provider": "x"}}
|
||||
|
||||
def test_config_validate_filters_related_keys(self, mocker):
|
||||
config = {
|
||||
"model": {"provider": "x"},
|
||||
"variables": ["v"],
|
||||
"file_upload": {"enabled": True},
|
||||
"prompt": {"template": "t"},
|
||||
"dataset": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"more_like_this": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.ModelConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["model"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.BasicVariablesConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["variables"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PromptTemplateConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["prompt"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DatasetConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["dataset"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.MoreLikeThisConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["more_like_this"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = CompletionAppConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert "extra" not in filtered
|
||||
assert set(filtered.keys()) == {
|
||||
"model",
|
||||
"variables",
|
||||
"file_upload",
|
||||
"prompt",
|
||||
"dataset",
|
||||
"tts",
|
||||
"more_like_this",
|
||||
"moderation",
|
||||
}
|
||||
@@ -0,0 +1,321 @@
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import core.app.apps.completion.app_generator as module
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = CompletionAppGenerator()
|
||||
|
||||
mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn)
|
||||
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app)))
|
||||
|
||||
thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=thread)
|
||||
|
||||
mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock())
|
||||
mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_app_model():
|
||||
return MagicMock(tenant_id="tenant", id="app1", mode="completion")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", session_id="session")
|
||||
|
||||
|
||||
def _build_app_model_config():
|
||||
config = MagicMock(id="cfg")
|
||||
config.to_dict.return_value = {"model": {"provider": "x"}}
|
||||
return config
|
||||
|
||||
|
||||
class TestCompletionAppGenerator:
|
||||
def test_generate_invalid_query_type(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": 123, "inputs": {}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
def test_generate_override_not_debugger(self, generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": {}},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
def test_generate_success_no_file_config(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings")
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": []},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_not_called()
|
||||
|
||||
def test_generate_success_with_files(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
conversation = MagicMock(id="conv", mode="completion")
|
||||
message = MagicMock(id="msg")
|
||||
mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message))
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]},
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
module.file_factory.build_from_mappings.assert_called_once()
|
||||
|
||||
def test_generate_override_model_config_debugger(self, generator, mocker):
|
||||
app_model_config = _build_app_model_config()
|
||||
mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config)
|
||||
|
||||
override_config = {"model": {"provider": "override"}}
|
||||
mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config)
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
generator.generate(
|
||||
app_model=_build_app_model(),
|
||||
user=_build_user(),
|
||||
args={"query": "q", "inputs": {}, "files": [], "model_config": override_config},
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert get_app_config.call_args.kwargs["override_config_dict"] == override_config
|
||||
|
||||
def test_generate_more_like_this_message_not_found(self, generator, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=_build_app_model(),
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_disabled(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False})
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_app_model_config_missing(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = None
|
||||
|
||||
message = MagicMock()
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(MoreLikeThisDisabledError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_message_config_none(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock(app_model_config=None)
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
def test_generate_more_like_this_success(self, generator, mocker):
|
||||
app_model = _build_app_model()
|
||||
app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True})
|
||||
|
||||
message = MagicMock()
|
||||
message.message_files = [{"id": "f"}]
|
||||
message.inputs = {"a": 1}
|
||||
message.query = "q"
|
||||
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.to_dict.return_value = {
|
||||
"model": {"completion_params": {"temperature": 0.1}},
|
||||
"file_upload": {"enabled": True},
|
||||
}
|
||||
message.app_model_config = app_model_config
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = message
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
file_extra_config = MagicMock()
|
||||
mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config)
|
||||
mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"])
|
||||
|
||||
app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={}))
|
||||
get_app_config = mocker.patch.object(
|
||||
module.CompletionAppConfigManager,
|
||||
"get_app_config",
|
||||
return_value=app_config,
|
||||
)
|
||||
mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_init_generate_records",
|
||||
return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")),
|
||||
)
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id="msg",
|
||||
user=_build_user(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
override_dict = get_app_config.call_args.kwargs["override_config_dict"]
|
||||
assert override_dict["model"]["completion_params"]["temperature"] == 0.9
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("error", "should_publish"),
|
||||
[
|
||||
(GenerateTaskStoppedError(), False),
|
||||
(InvokeAuthorizationError("bad"), True),
|
||||
(
|
||||
ValidationError.from_exception_data(
|
||||
"Model",
|
||||
[{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}],
|
||||
),
|
||||
True,
|
||||
),
|
||||
(ValueError("bad"), True),
|
||||
(RuntimeError("boom"), True),
|
||||
],
|
||||
)
|
||||
def test_generate_worker_error_handling(self, generator, mocker, error, should_publish):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
|
||||
session = mocker.MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(generator, "_get_message", return_value=MagicMock())
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = error
|
||||
mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert queue_manager.publish_error.called is should_publish
|
||||
@@ -0,0 +1,153 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestCompletionAppGenerateResponseConverter:
|
||||
def test_convert_blocking_full_response(self):
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata={"k": "v"},
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
|
||||
assert result["event"] == "message"
|
||||
assert result["task_id"] == "task"
|
||||
assert result["message_id"] == "msg"
|
||||
assert result["answer"] == "answer"
|
||||
assert result["metadata"] == {"k": "v"}
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_simplified(self):
|
||||
metadata = {
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
"extra": "x",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
}
|
||||
blocking = CompletionAppBlockingResponse(
|
||||
task_id="task",
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata=metadata,
|
||||
created_at=123,
|
||||
),
|
||||
)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert "annotation_reply" not in result["metadata"]
|
||||
assert "usage" not in result["metadata"]
|
||||
assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s"
|
||||
assert "extra" not in result["metadata"]["retriever_resources"][0]
|
||||
|
||||
def test_convert_blocking_simple_response_metadata_not_dict(self):
|
||||
data = CompletionAppBlockingResponse.Data.model_construct(
|
||||
id="id",
|
||||
mode="completion",
|
||||
message_id="msg",
|
||||
answer="answer",
|
||||
metadata="bad",
|
||||
created_at=123,
|
||||
)
|
||||
blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data)
|
||||
|
||||
result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert result["metadata"] == {}
|
||||
|
||||
def test_convert_stream_full_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
assert result[2]["event"] == "message"
|
||||
|
||||
def test_convert_stream_simple_response(self):
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
message_id="m",
|
||||
created_at=1,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=MessageEndStreamResponse(
|
||||
task_id="t",
|
||||
id="end",
|
||||
metadata={
|
||||
"retriever_resources": [
|
||||
{
|
||||
"segment_id": "s",
|
||||
"position": 1,
|
||||
"document_name": "doc",
|
||||
"score": 0.9,
|
||||
"content": "c",
|
||||
"summary": "sum",
|
||||
}
|
||||
],
|
||||
"annotation_reply": {"a": 1},
|
||||
"usage": {"t": 2},
|
||||
},
|
||||
),
|
||||
message_id="m",
|
||||
created_at=2,
|
||||
)
|
||||
yield CompletionAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
message_id="m",
|
||||
created_at=3,
|
||||
)
|
||||
|
||||
result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "message_end"
|
||||
assert "annotation_reply" not in result[1]["metadata"]
|
||||
assert "usage" not in result[1]["metadata"]
|
||||
assert result[2]["event"] == "error"
|
||||
@@ -0,0 +1,55 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import core.app.apps.pipeline.pipeline_config_manager as module
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_get_pipeline_config(mocker):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe1")
|
||||
workflow = MagicMock(id="wf1")
|
||||
|
||||
mocker.patch.object(
|
||||
module.WorkflowVariablesConfigManager,
|
||||
"convert_rag_pipeline_variable",
|
||||
return_value=["var1"],
|
||||
)
|
||||
mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start")
|
||||
|
||||
assert result.tenant_id == "tenant"
|
||||
assert result.app_id == "pipe1"
|
||||
assert result.workflow_id == "wf1"
|
||||
assert result.app_mode == AppMode.RAG_PIPELINE
|
||||
assert result.rag_pipeline_variables == ["var1"]
|
||||
|
||||
|
||||
def test_config_validate_filters_related_keys(mocker):
|
||||
config = {
|
||||
"file_upload": {"enabled": True},
|
||||
"tts": {"enabled": True},
|
||||
"moderation": {"enabled": True},
|
||||
"extra": "drop",
|
||||
}
|
||||
|
||||
mocker.patch.object(
|
||||
module.FileUploadConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["file_upload"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.TextToSpeechConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["tts"]),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.SensitiveWordAvoidanceConfigManager,
|
||||
"validate_and_set_defaults",
|
||||
return_value=(config, ["moderation"]),
|
||||
)
|
||||
|
||||
filtered = PipelineConfigManager.config_validate("tenant", config)
|
||||
|
||||
assert set(filtered.keys()) == {"file_upload", "tts", "moderation"}
|
||||
@@ -0,0 +1,111 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
def test_convert_blocking_full_and_simple_response():
|
||||
blocking = WorkflowAppBlockingResponse(
|
||||
task_id="task",
|
||||
workflow_run_id="run",
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id="id",
|
||||
workflow_id="wf",
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs={"k": "v"},
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
total_tokens=10,
|
||||
total_steps=1,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
),
|
||||
)
|
||||
|
||||
full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking)
|
||||
simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking)
|
||||
|
||||
assert full == simple
|
||||
assert full["workflow_run_id"] == "run"
|
||||
assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_convert_stream_full_response():
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=PingStreamResponse(task_id="t"),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
yield WorkflowAppStreamResponse(
|
||||
stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")),
|
||||
workflow_run_id="run",
|
||||
)
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream()))
|
||||
|
||||
assert result[0] == "ping"
|
||||
assert result[1]["event"] == "error"
|
||||
assert result[1]["code"] == "invalid_param"
|
||||
|
||||
|
||||
def test_convert_stream_simple_response_node_ignore_details():
|
||||
node_start = NodeStartStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
created_at=1,
|
||||
),
|
||||
)
|
||||
node_finish = NodeFinishStreamResponse(
|
||||
task_id="t",
|
||||
workflow_run_id="run",
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id="nid",
|
||||
node_id="node",
|
||||
node_type="type",
|
||||
title="Title",
|
||||
index=1,
|
||||
predecessor_node_id=None,
|
||||
inputs={"a": 1},
|
||||
inputs_truncated=False,
|
||||
process_data=None,
|
||||
process_data_truncated=False,
|
||||
outputs={"b": 2},
|
||||
outputs_truncated=False,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=0.1,
|
||||
execution_metadata=None,
|
||||
created_at=1,
|
||||
finished_at=2,
|
||||
files=[],
|
||||
),
|
||||
)
|
||||
|
||||
def stream() -> Generator[AppStreamResponse, None, None]:
|
||||
yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run")
|
||||
yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run")
|
||||
|
||||
result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream()))
|
||||
|
||||
assert result[0]["event"] == "node_started"
|
||||
assert result[0]["data"]["inputs"] is None
|
||||
assert result[1]["event"] == "node_finished"
|
||||
assert result[1]["data"]["inputs"] is None
|
||||
@@ -0,0 +1,699 @@
|
||||
import contextlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_generator as module
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
|
||||
|
||||
class FakeRagPipelineGenerateEntity(SimpleNamespace):
|
||||
class SingleIterationRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
class SingleLoopRunEntity(SimpleNamespace):
|
||||
pass
|
||||
|
||||
def model_dump(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mocker):
|
||||
gen = module.PipelineGenerator()
|
||||
|
||||
mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity)
|
||||
mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs)
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock()))
|
||||
mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock()))
|
||||
|
||||
return gen
|
||||
|
||||
|
||||
def _build_pipeline_dataset():
|
||||
return SimpleNamespace(
|
||||
id="ds",
|
||||
name="dataset",
|
||||
description="desc",
|
||||
chunk_structure="chunk",
|
||||
built_in_field_enabled=True,
|
||||
tenant_id="tenant",
|
||||
)
|
||||
|
||||
|
||||
def _build_pipeline():
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
pipeline.retrieve_dataset.return_value = _build_pipeline_dataset()
|
||||
return pipeline
|
||||
|
||||
|
||||
def _build_workflow():
|
||||
return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant")
|
||||
|
||||
|
||||
def _build_user():
|
||||
return MagicMock(id="user", name="User", session_id="session")
|
||||
|
||||
|
||||
def _build_args():
|
||||
return {
|
||||
"inputs": {"k": "v"},
|
||||
"start_node_id": "start",
|
||||
"datasource_type": DatasourceProviderType.LOCAL_FILE.value,
|
||||
"datasource_info_list": [{"name": "file"}],
|
||||
}
|
||||
|
||||
|
||||
def _patch_session(mocker, session):
|
||||
mocker.patch.object(module, "Session", return_value=session)
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
|
||||
def _dummy_preserve(*args, **kwargs):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self):
|
||||
self.scalar = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def test_generate_dataset_missing(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=_build_workflow(),
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_debugger_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
datasource_info_list = [{"name": "file1"}, {"name": "file2"}]
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=datasource_info_list,
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1)
|
||||
|
||||
document1 = SimpleNamespace(
|
||||
id="doc1",
|
||||
position=1,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file1",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
document2 = SimpleNamespace(
|
||||
id="doc2",
|
||||
position=2,
|
||||
data_source_type=DatasourceProviderType.LOCAL_FILE,
|
||||
data_source_info="{}",
|
||||
name="file2",
|
||||
indexing_status="",
|
||||
error=None,
|
||||
enabled=True,
|
||||
)
|
||||
mocker.patch.object(generator, "_build_document", side_effect=[document1, document2])
|
||||
|
||||
mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock())
|
||||
|
||||
db_session = MagicMock()
|
||||
mocker.patch.object(module.db, "session", db_session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
task_proxy = MagicMock()
|
||||
mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy)
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result["batch"]
|
||||
assert len(result["documents"]) == 2
|
||||
task_proxy.delay.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_is_retry_calls_generate(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_format_datasource_info_list",
|
||||
return_value=[{"name": "file"}],
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]),
|
||||
)
|
||||
mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"result": "ok"})
|
||||
|
||||
result = generator.generate(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=_build_user(),
|
||||
args=_build_args(),
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
streaming=True,
|
||||
is_retry=True,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
def test_generate_worker_handles_errors(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
runner_instance.run.side_effect = ValueError("bad")
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
queue_manager.publish_error.assert_called_once()
|
||||
|
||||
|
||||
def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
flask_app.app_context.return_value = contextlib.nullcontext()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
mocker.patch.object(module.db, "session", MagicMock(close=MagicMock()))
|
||||
mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock())
|
||||
|
||||
application_generate_entity = FakeRagPipelineGenerateEntity(
|
||||
app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
session = DummySession()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")]
|
||||
_patch_session(mocker, session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
mocker.patch.object(module, "PipelineRunner", return_value=runner_instance)
|
||||
|
||||
generator._generate_worker(
|
||||
flask_app=flask_app,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
context=contextlib.nullcontext(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session"
|
||||
|
||||
|
||||
def test_generate_raises_when_workflow_not_found(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_success_returns_converted(generator, mocker):
|
||||
flask_app = MagicMock()
|
||||
mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve)
|
||||
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={})
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager)
|
||||
|
||||
worker_thread = MagicMock()
|
||||
mocker.patch.object(module.threading, "Thread", return_value=worker_thread)
|
||||
|
||||
mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
mocker.patch.object(generator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted")
|
||||
|
||||
result = generator._generate(
|
||||
flask_app=flask_app,
|
||||
context=contextlib.nullcontext(),
|
||||
pipeline=_build_pipeline(),
|
||||
workflow_id="wf",
|
||||
user=_build_user(),
|
||||
application_generate_entity=FakeRagPipelineGenerateEntity(
|
||||
task_id="t",
|
||||
app_config=SimpleNamespace(app_id="pipe"),
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
|
||||
|
||||
def test_single_iteration_generate_validates_inputs(generator, mocker):
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
_build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None}
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_dataset_required(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
pipeline.retrieve_dataset.return_value = None
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
)
|
||||
|
||||
|
||||
def test_single_iteration_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_iteration_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_single_loop_generate_success(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
|
||||
session = DummySession()
|
||||
_patch_session(mocker, session)
|
||||
|
||||
mocker.patch.object(
|
||||
module.PipelineConfigManager,
|
||||
"get_pipeline_config",
|
||||
return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(
|
||||
module.DifyCoreRepositoryFactory,
|
||||
"create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock()))
|
||||
|
||||
mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock())
|
||||
mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock())
|
||||
|
||||
mocker.patch.object(generator, "_generate", return_value={"ok": True})
|
||||
|
||||
result = generator.single_loop_generate(
|
||||
pipeline,
|
||||
_build_workflow(),
|
||||
"node",
|
||||
_build_user(),
|
||||
{"inputs": {"a": 1}},
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
|
||||
|
||||
def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker):
|
||||
pipeline = _build_pipeline()
|
||||
workflow = _build_workflow()
|
||||
app_entity = FakeRagPipelineGenerateEntity(task_id="t")
|
||||
|
||||
task_pipeline = MagicMock()
|
||||
task_pipeline.process.side_effect = ValueError("I/O operation on closed file.")
|
||||
mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=app_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=MagicMock(),
|
||||
user=_build_user(),
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
|
||||
def test_build_document_sets_metadata_for_builtin_fields(generator, mocker):
|
||||
class DummyDocument(SimpleNamespace):
|
||||
pass
|
||||
|
||||
mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs))
|
||||
|
||||
document = generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=True,
|
||||
datasource_type=DatasourceProviderType.LOCAL_FILE,
|
||||
datasource_info={"name": "file"},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
assert document.name == "file"
|
||||
assert document.doc_metadata
|
||||
|
||||
|
||||
def test_build_document_invalid_datasource_type(generator):
|
||||
with pytest.raises(ValueError):
|
||||
generator._build_document(
|
||||
tenant_id="tenant",
|
||||
dataset_id="ds",
|
||||
built_in_field_enabled=False,
|
||||
datasource_type="invalid",
|
||||
datasource_info={},
|
||||
created_from="rag-pipeline",
|
||||
position=1,
|
||||
account=_build_user(),
|
||||
batch="batch",
|
||||
document_form="text",
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_non_online_drive(generator):
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
[{"name": "file"}],
|
||||
_build_pipeline(),
|
||||
_build_workflow(),
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"name": "file"}]
|
||||
|
||||
|
||||
def test_format_datasource_info_list_missing_node_data(generator):
|
||||
workflow = MagicMock(graph_dict={"nodes": []})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
|
||||
def test_format_datasource_info_list_online_drive_folder(generator, mocker):
|
||||
workflow = MagicMock(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start",
|
||||
"data": {
|
||||
"plugin_id": "p",
|
||||
"provider_name": "provider",
|
||||
"datasource_name": "drive",
|
||||
"credential_id": "cred",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.runtime = SimpleNamespace(credentials=None)
|
||||
runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=runtime,
|
||||
)
|
||||
mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"})
|
||||
|
||||
mocker.patch.object(
|
||||
generator,
|
||||
"_get_files_in_folder",
|
||||
side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}),
|
||||
)
|
||||
|
||||
result = generator._format_datasource_info_list(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
[{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}],
|
||||
_build_pipeline(),
|
||||
workflow,
|
||||
"start",
|
||||
_build_user(),
|
||||
)
|
||||
|
||||
assert result == [{"id": "f"}]
|
||||
|
||||
|
||||
def test_get_files_in_folder_recurses_and_collects(generator):
|
||||
class File:
|
||||
def __init__(self, id, name, type):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.type = type
|
||||
|
||||
class FilesPage:
|
||||
def __init__(self, files, is_truncated=False, next_page_parameters=None):
|
||||
self.files = files
|
||||
self.is_truncated = is_truncated
|
||||
self.next_page_parameters = next_page_parameters
|
||||
|
||||
class Result:
|
||||
def __init__(self, result):
|
||||
self.result = result
|
||||
|
||||
class Runtime:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def datasource_provider_type(self):
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def online_drive_browse_files(self, user_id, request, provider_type):
|
||||
self.calls.append(request.next_page_parameters)
|
||||
if request.prefix == "fd":
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
if request.next_page_parameters is None:
|
||||
return iter(
|
||||
[
|
||||
Result(
|
||||
[FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})]
|
||||
)
|
||||
]
|
||||
)
|
||||
return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])])
|
||||
|
||||
runtime = Runtime()
|
||||
all_files = []
|
||||
|
||||
generator._get_files_in_folder(
|
||||
datasource_runtime=runtime,
|
||||
prefix="root",
|
||||
bucket="b",
|
||||
user_id="user",
|
||||
all_files=all_files,
|
||||
datasource_info={},
|
||||
)
|
||||
|
||||
assert {f["id"] for f in all_files} == {"f1", "f2"}
|
||||
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_queue_manager as module
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult
|
||||
|
||||
|
||||
def test_publish_sets_stop_listen_and_raises_on_stopped(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_stop_events_trigger_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
for event in [
|
||||
QueueErrorEvent(error=ValueError("bad")),
|
||||
QueueMessageEndEvent(llm_result=LLMResult.model_construct()),
|
||||
QueueWorkflowSucceededEvent(),
|
||||
QueueWorkflowFailedEvent(error="failed", exceptions_count=1),
|
||||
QueueWorkflowPartialSuccessEvent(exceptions_count=1),
|
||||
]:
|
||||
manager.stop_listen.reset_mock()
|
||||
manager._publish(event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
|
||||
def test_publish_non_stop_event_no_stop_listen(mocker):
|
||||
manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag")
|
||||
manager._q = mocker.MagicMock()
|
||||
manager.stop_listen = mocker.MagicMock()
|
||||
manager._is_stopped = mocker.MagicMock(return_value=False)
|
||||
|
||||
non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent)
|
||||
manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE)
|
||||
manager.stop_listen.assert_not_called()
|
||||
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Unit tests for PipelineRunner behavior.
|
||||
Asserts correct event handling, error propagation, and user invocation logic.
|
||||
Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies.
|
||||
Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities.
|
||||
"""
|
||||
|
||||
"""Unit tests for PipelineRunner behavior.
|
||||
|
||||
This module validates core control-flow outcomes for
|
||||
``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph
|
||||
initialization guards, invoke-source to user-source resolution, and failed-run
|
||||
event handling. Invariants asserted here include strict graph-config
|
||||
validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing
|
||||
error paths driven by ``GraphRunFailedEvent`` through mocked collaborators.
|
||||
Primary collaborators include ``PipelineRunner``,
|
||||
``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``,
|
||||
``UserFrom``, and patched DB/runtime dependencies used by the runner.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.pipeline.pipeline_runner as module
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from dify_graph.graph_events import GraphRunFailedEvent
|
||||
|
||||
|
||||
def _build_app_generate_entity() -> SimpleNamespace:
|
||||
app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant")
|
||||
return SimpleNamespace(
|
||||
app_config=app_config,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_id="user",
|
||||
trace_manager=MagicMock(),
|
||||
inputs={"input1": "v1"},
|
||||
files=[],
|
||||
workflow_execution_id="run",
|
||||
document_id="doc",
|
||||
original_document_id=None,
|
||||
batch="batch",
|
||||
dataset_id="ds",
|
||||
datasource_type="local_file",
|
||||
datasource_info={"name": "file"},
|
||||
start_node_id="start",
|
||||
call_depth=0,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
queue_manager = MagicMock()
|
||||
variable_loader = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow_execution_repository = MagicMock()
|
||||
workflow_node_execution_repository = MagicMock()
|
||||
|
||||
return PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
|
||||
|
||||
def test_get_app_id(runner):
|
||||
assert runner._get_app_id() == "pipe"
|
||||
|
||||
|
||||
def test_get_workflow_returns_workflow(mocker, runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = workflow
|
||||
mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query)))
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_invalid_config(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": "bad", "edges": []}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
workflow.graph_dict = {"nodes": [], "edges": "bad"}
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_init_rag_pipeline_graph_not_found(mocker, runner):
|
||||
workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []})
|
||||
mocker.patch.object(module.Graph, "init", return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock())
|
||||
|
||||
|
||||
def test_update_document_status_on_failure(mocker, runner):
|
||||
document = MagicMock()
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = document
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
|
||||
runner._update_document_status(event, document_id="doc", dataset_id="ds")
|
||||
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pipeline_not_found(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.invoke_from = InvokeFrom.WEB_APP
|
||||
app_generate_entity.single_iteration_run = None
|
||||
app_generate_entity.single_loop_run = None
|
||||
|
||||
query = MagicMock()
|
||||
query.where.return_value.first.return_value = None
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_workflow_not_initialized(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value = query_pipeline
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
runner.get_workflow = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run()
|
||||
|
||||
|
||||
def test_run_single_iteration_path(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
app_generate_entity.single_iteration_run = MagicMock()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(
|
||||
return_value=MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={},
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
)
|
||||
runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state"))
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = [MagicMock()]
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._prepare_single_node_execution.assert_called_once()
|
||||
runner._handle_event.assert_called()
|
||||
|
||||
|
||||
def test_run_normal_path_builds_graph(mocker):
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
query_pipeline = MagicMock()
|
||||
query_pipeline.where.return_value.first.return_value = pipeline
|
||||
|
||||
query_end_user = MagicMock()
|
||||
query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess")
|
||||
|
||||
session = MagicMock()
|
||||
session.query.side_effect = [query_end_user, query_pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
workflow = MagicMock(
|
||||
id="wf",
|
||||
tenant_id="tenant",
|
||||
app_id="pipe",
|
||||
graph_dict={"nodes": [], "edges": []},
|
||||
environment_variables=[],
|
||||
rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}],
|
||||
type="rag-pipeline",
|
||||
version="v1",
|
||||
)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=MagicMock(),
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT)
|
||||
runner.get_workflow = MagicMock(return_value=workflow)
|
||||
runner._init_rag_pipeline_graph = MagicMock(return_value="graph")
|
||||
runner._update_document_status = MagicMock()
|
||||
runner._handle_event = MagicMock()
|
||||
|
||||
mocker.patch.object(
|
||||
module.RAGPipelineVariable,
|
||||
"model_validate",
|
||||
return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"),
|
||||
)
|
||||
mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs))
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = []
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
runner._init_rag_pipeline_graph.assert_called_once()
|
||||
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestBaseAppGeneratorExtras:
|
||||
def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="file",
|
||||
label="file",
|
||||
type=VariableEntityType.FILE,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="file_list",
|
||||
label="file_list",
|
||||
type=VariableEntityType.FILE_LIST,
|
||||
required=False,
|
||||
allowed_file_types=[],
|
||||
allowed_file_extensions=[],
|
||||
allowed_file_upload_methods=[],
|
||||
),
|
||||
VariableEntity(
|
||||
variable="json",
|
||||
label="json",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mapping",
|
||||
lambda mapping, tenant_id, config, strict_type_validation=False: "file-object",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_generator.file_factory.build_from_mappings",
|
||||
lambda mappings, tenant_id, config: ["file-1", "file-2"],
|
||||
)
|
||||
|
||||
user_inputs = {
|
||||
"file": {"id": "file-id"},
|
||||
"file_list": [{"id": "file-1"}, {"id": "file-2"}],
|
||||
"json": {"key": "value"},
|
||||
}
|
||||
|
||||
prepared = base_app_generator._prepare_user_inputs(
|
||||
user_inputs=user_inputs,
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert prepared["file"] == "file-object"
|
||||
assert prepared["file_list"] == ["file-1", "file-2"]
|
||||
assert prepared["json"] == {"key": "value"}
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_dict_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": {"unexpected": "dict"}},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_prepare_user_inputs_rejects_invalid_list_inputs(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="text",
|
||||
label="text",
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
base_app_generator._prepare_user_inputs(
|
||||
user_inputs={"text": [{"unexpected": "dict"}]},
|
||||
variables=variables,
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
def test_convert_to_event_stream(self):
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True}
|
||||
|
||||
def _gen():
|
||||
yield {"delta": "hi"}
|
||||
yield "ping"
|
||||
|
||||
converted = list(base_app_generator.convert_to_event_stream(_gen()))
|
||||
|
||||
assert converted[0].startswith("data: ")
|
||||
assert "\n\n" in converted[0]
|
||||
assert converted[1] == "event: ping\n\n"
|
||||
|
||||
def test_get_draft_var_saver_factory_debugger(self):
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from dify_graph.enums import NodeType
|
||||
from models import Account
|
||||
|
||||
base_app_generator = BaseAppGenerator()
|
||||
account = Account(name="Tester", email="tester@example.com")
|
||||
account.id = "account-id"
|
||||
account.tenant_id = "tenant-id"
|
||||
|
||||
factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account)
|
||||
saver = factory(
|
||||
session=MagicMock(),
|
||||
app_id="app-id",
|
||||
node_id="node-id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="node-exec-id",
|
||||
)
|
||||
|
||||
assert saver is not None
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent
|
||||
|
||||
|
||||
class DummyQueueManager(AppQueueManager):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.published = []
|
||||
|
||||
def _publish(self, event, pub_from):
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
class TestBaseAppQueueManager:
|
||||
def test_init_requires_user_id(self):
|
||||
with pytest.raises(ValueError):
|
||||
DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
def test_publish_error_records_event(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert isinstance(manager.published[0][0], QueueErrorEvent)
|
||||
|
||||
def test_set_stop_flag_checks_user(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.get.return_value = b"end-user-u1"
|
||||
AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_set_stop_flag_no_user_check(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id="t1")
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
def test_is_stopped_reads_cache(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
mock_redis.get.return_value = b"1"
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
assert manager._is_stopped() is True
|
||||
|
||||
def test_check_for_sqlalchemy_models_raises(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API)
|
||||
|
||||
bad = SimpleNamespace(_sa_instance_state=True)
|
||||
with pytest.raises(TypeError):
|
||||
manager._check_for_sqlalchemy_models(bad)
|
||||
442
api/tests/unit_tests/core/app/apps/test_base_app_runner.py
Normal file
442
api/tests/unit_tests/core/app/apps/test_base_app_runner.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
AdvancedCompletionPromptTemplateEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class _DummyParameterRule:
|
||||
def __init__(self, name: str, use_template: str | None = None) -> None:
|
||||
self.name = name
|
||||
self.use_template = use_template
|
||||
|
||||
|
||||
class _QueueRecorder:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[object] = []
|
||||
|
||||
def publish(self, event, pub_from):
|
||||
_ = pub_from
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
class TestAppRunner:
|
||||
def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: 100},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80),
|
||||
)
|
||||
|
||||
runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")])
|
||||
|
||||
assert model_config.parameters["max_tokens"] == 20
|
||||
|
||||
def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
model_schema = SimpleNamespace(
|
||||
model_properties={},
|
||||
parameter_rules=[_DummyParameterRule("max_tokens")],
|
||||
)
|
||||
model_config = SimpleNamespace(
|
||||
provider_model_bundle=object(),
|
||||
model="mock",
|
||||
model_schema=model_schema,
|
||||
parameters={"max_tokens": 30},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ModelInstance",
|
||||
lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10),
|
||||
)
|
||||
|
||||
assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1
|
||||
|
||||
def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None)
|
||||
|
||||
runner.direct_output(
|
||||
queue_manager=queue,
|
||||
app_generate_entity=app_generate_entity,
|
||||
prompt_messages=[],
|
||||
text="hi",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_direct_publishes_end_event(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
llm_result = LLMResult(
|
||||
model="mock",
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content="done"),
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=llm_result,
|
||||
queue_manager=queue,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
|
||||
def test_handle_invoke_result_invalid_type_raises(self):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=["unexpected"],
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def test_organize_prompt_messages_simple_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["STOP"])
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="hello",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.SimplePromptTransform.get_prompt",
|
||||
lambda self, **kwargs: (["simple-message"], ["simple-stop"]),
|
||||
)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["simple-message"]
|
||||
assert stop == ["simple-stop"]
|
||||
|
||||
def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="completion", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="answer",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"),
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-completion-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-completion-message"]
|
||||
assert stop == ["<END>"]
|
||||
memory_config = captured["memory_config"]
|
||||
assert memory_config.role_prefix.user == "U"
|
||||
assert memory_config.role_prefix.assistant == "A"
|
||||
|
||||
def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
model_config = SimpleNamespace(mode="chat", stop=["<END>"])
|
||||
captured: dict[str, object] = {}
|
||||
prompt_template_entity = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_advanced_prompt(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return ["advanced-chat-message"]
|
||||
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt)
|
||||
|
||||
prompt_messages, stop = runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs={},
|
||||
files=[],
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert prompt_messages == ["advanced-chat-message"]
|
||||
assert stop == ["<END>"]
|
||||
assert len(captured["prompt_template"]) == 2
|
||||
|
||||
def test_organize_prompt_messages_advanced_missing_templates_raise(self):
|
||||
runner = AppRunner()
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="completion", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"):
|
||||
runner.organize_prompt_messages(
|
||||
app_record=SimpleNamespace(mode=AppMode.CHAT.value),
|
||||
model_config=SimpleNamespace(mode="chat", stop=[]),
|
||||
prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED),
|
||||
inputs={},
|
||||
files=[],
|
||||
)
|
||||
|
||||
def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
warning_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger)
|
||||
|
||||
image_content = ImagePromptMessageContent(
|
||||
url="https://example.com/image.png", format="png", mime_type="image/png"
|
||||
)
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="stream-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage.model_construct(
|
||||
content=[
|
||||
"a",
|
||||
TextPromptMessageContent(data="b"),
|
||||
SimpleNamespace(data="c"),
|
||||
image_content,
|
||||
]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
stream=True,
|
||||
agent=False,
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueLLMChunkEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.message.content == "abc"
|
||||
warning_logger.assert_called_once()
|
||||
|
||||
def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
exception_logger = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger)
|
||||
|
||||
monkeypatch.setattr(
|
||||
runner,
|
||||
"_handle_multimodal_image_content",
|
||||
MagicMock(side_effect=RuntimeError("failed to save image")),
|
||||
)
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def _stream():
|
||||
yield LLMResultChunk(
|
||||
model="agent-model",
|
||||
prompt_messages=[AssistantPromptMessage(content="prompt")],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(
|
||||
url="https://example.com/image.png",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
TextPromptMessageContent(data="done"),
|
||||
]
|
||||
),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
runner._handle_invoke_result_stream(
|
||||
invoke_result=_stream(),
|
||||
queue_manager=queue,
|
||||
agent=True,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
)
|
||||
|
||||
assert isinstance(queue.events[0], QueueAgentMessageEvent)
|
||||
assert isinstance(queue.events[-1], QueueMessageEndEvent)
|
||||
assert queue.events[-1].llm_result.usage == usage
|
||||
exception_logger.assert_called_once()
|
||||
|
||||
def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
|
||||
class _ToggleBool:
|
||||
def __init__(self, values: list[bool]):
|
||||
self._values = values
|
||||
self._index = 0
|
||||
|
||||
def __bool__(self):
|
||||
value = self._values[min(self._index, len(self._values) - 1)]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
content = SimpleNamespace(
|
||||
url=_ToggleBool([False, False]),
|
||||
base64_data=_ToggleBool([True, False]),
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock())
|
||||
monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session))
|
||||
|
||||
queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock())
|
||||
|
||||
runner._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id="message-id",
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
queue_manager.publish.assert_not_called()
|
||||
|
||||
def test_check_hosting_moderation_direct_output_called(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
queue = _QueueRecorder()
|
||||
app_generate_entity = SimpleNamespace(stream=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.HostingModerationFeature.check",
|
||||
lambda self, application_generate_entity, prompt_messages: True,
|
||||
)
|
||||
direct_output = MagicMock()
|
||||
monkeypatch.setattr(runner, "direct_output", direct_output)
|
||||
|
||||
result = runner.check_hosting_moderation(
|
||||
application_generate_entity=app_generate_entity,
|
||||
queue_manager=queue,
|
||||
prompt_messages=[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert direct_output.called
|
||||
|
||||
def test_fill_in_inputs_from_external_data_tools(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.ExternalDataFetch.fetch",
|
||||
lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"},
|
||||
)
|
||||
|
||||
result = runner.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
external_data_tools=[],
|
||||
inputs={},
|
||||
query="q",
|
||||
)
|
||||
|
||||
assert result == {"foo": "bar"}
|
||||
|
||||
def test_moderation_for_inputs_returns_result(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.InputModeration.check",
|
||||
lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""),
|
||||
)
|
||||
app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None)
|
||||
|
||||
result = runner.moderation_for_inputs(
|
||||
app_id="app",
|
||||
tenant_id="tenant",
|
||||
app_generate_entity=app_generate_entity,
|
||||
inputs={},
|
||||
query="q",
|
||||
message_id="msg",
|
||||
)
|
||||
|
||||
assert result == (True, {}, "")
|
||||
|
||||
def test_query_app_annotations_to_reply(self, monkeypatch):
|
||||
runner = AppRunner()
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.base_app_runner.AnnotationReplyFeature.query",
|
||||
lambda self, app_record, message, query, user_id, invoke_from: "reply",
|
||||
)
|
||||
|
||||
response = runner.query_app_annotations_to_reply(
|
||||
app_record=SimpleNamespace(),
|
||||
message=SimpleNamespace(),
|
||||
query="hello",
|
||||
user_id="user",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
)
|
||||
|
||||
assert response == "reply"
|
||||
7
api/tests/unit_tests/core/app/apps/test_exc.py
Normal file
7
api/tests/unit_tests/core/app/apps/test_exc.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
|
||||
|
||||
class TestAppsExceptions:
|
||||
def test_generate_task_stopped_error(self):
|
||||
err = GenerateTaskStoppedError("stopped")
|
||||
assert str(err) == "stopped"
|
||||
@@ -13,9 +13,11 @@ from core.app.app_config.entities import (
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
@@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
|
||||
|
||||
class TestMessageBasedAppGeneratorExtras:
|
||||
def test_handle_response_closed_file_raises_stopped(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
class _Pipeline:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
_ = kwargs
|
||||
|
||||
def process(self):
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline",
|
||||
_Pipeline,
|
||||
)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
generator._handle_response(
|
||||
application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)),
|
||||
queue_manager=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
user=SimpleNamespace(),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def test_get_app_model_config_requires_valid_config(self, monkeypatch):
|
||||
generator = MessageBasedAppGenerator()
|
||||
app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model, conversation=None)
|
||||
|
||||
conversation = SimpleNamespace(app_model_config_id="missing-id")
|
||||
monkeypatch.setattr(
|
||||
message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None))
|
||||
)
|
||||
|
||||
with pytest.raises(AppModelConfigBrokenError):
|
||||
generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation)
|
||||
|
||||
def test_get_conversation_introduction_handles_missing_inputs(self):
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
app_config.additional_features.opening_statement = "Hello {{name}}"
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
entity.inputs = {}
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
assert generator._get_conversation_introduction(entity) == "Hello {name}"
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent
|
||||
|
||||
|
||||
class TestMessageBasedAppQueueManager:
|
||||
def test_publish_stops_on_terminal_events(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager.stop_listen = Mock()
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
|
||||
manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock())
|
||||
manager.stop_listen.assert_called_once()
|
||||
|
||||
def test_publish_raises_when_stopped(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=True)
|
||||
|
||||
with pytest.raises(GenerateTaskStoppedError):
|
||||
manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def test_publish_enqueues_message_end(self):
|
||||
with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis:
|
||||
mock_redis.setex.return_value = True
|
||||
manager = MessageBasedAppQueueManager(
|
||||
task_id="t1",
|
||||
user_id="u1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
conversation_id="c1",
|
||||
app_mode="chat",
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
manager._is_stopped = Mock(return_value=False)
|
||||
manager.stop_listen = Mock()
|
||||
|
||||
manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
assert manager._q.qsize() == 1
|
||||
29
api/tests/unit_tests/core/app/apps/test_message_generator.py
Normal file
29
api/tests/unit_tests/core/app/apps/test_message_generator.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestMessageGenerator:
|
||||
def test_get_response_topic(self):
|
||||
channel = Mock()
|
||||
channel.topic.return_value = "topic"
|
||||
|
||||
with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel):
|
||||
topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1")
|
||||
|
||||
assert topic == "topic"
|
||||
expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1")
|
||||
channel.topic.assert_called_once_with(expected_key)
|
||||
|
||||
def test_retrieve_events_passes_arguments(self):
|
||||
with (
|
||||
patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"),
|
||||
patch(
|
||||
"core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}])
|
||||
) as mock_stream,
|
||||
):
|
||||
events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2))
|
||||
|
||||
assert events == [{"event": "ping"}]
|
||||
mock_stream.assert_called_once()
|
||||
@@ -6,6 +6,7 @@ import queue
|
||||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
@@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
||||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
|
||||
|
||||
def test_normalize_terminal_events_defaults():
|
||||
assert _normalize_terminal_events(None) == {
|
||||
StreamEvent.WORKFLOW_FINISHED.value,
|
||||
StreamEvent.WORKFLOW_PAUSED.value,
|
||||
}
|
||||
|
||||
|
||||
def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0]
|
||||
|
||||
def fake_time():
|
||||
return times.pop(0)
|
||||
|
||||
monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time)
|
||||
|
||||
generator = stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=10.0,
|
||||
ping_interval=1.0,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
# next receive yields None -> ping interval triggers
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunIterationSucceededEvent,
|
||||
NodeRunLoopFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
|
||||
class TestWorkflowBasedAppRunner:
|
||||
def test_resolve_user_from(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT
|
||||
assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER
|
||||
|
||||
def test_init_graph_validates_graph_structure(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes or edges not found"):
|
||||
runner._init_graph(
|
||||
graph_config={},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="nodes in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": {}, "edges": []},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="edges in workflow graph must be a list"):
|
||||
runner._init_graph(
|
||||
graph_config={"nodes": [], "edges": {}},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
def test_prepare_single_node_execution_requires_run(self):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
|
||||
workflow = SimpleNamespace(environment_variables=[], graph_dict={})
|
||||
|
||||
with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"):
|
||||
runner._prepare_single_node_execution(workflow, None, None)
|
||||
|
||||
def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
|
||||
graph_config = {
|
||||
"nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}],
|
||||
"edges": [],
|
||||
}
|
||||
workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.Graph.init",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
class _NodeCls:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(graph_config, config):
|
||||
return {}
|
||||
|
||||
from core.app.apps import workflow_app_runner
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_app_runner,
|
||||
"resolve_workflow_node_class",
|
||||
lambda **_kwargs: _NodeCls,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.load_into_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
node_id="node-1",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
assert variable_pool is graph_runtime_state.variable_pool
|
||||
|
||||
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append((event, publish_from))
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_runtime_state.register_paused_node("node-1")
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
emails: list[dict] = []
|
||||
|
||||
class _Dispatch:
|
||||
def apply_async(self, *, kwargs, queue):
|
||||
emails.append({"kwargs": kwargs, "queue": queue})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.dispatch_human_input_email_task",
|
||||
_Dispatch(),
|
||||
)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Node",
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, GraphRunStartedEvent())
|
||||
runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True}))
|
||||
runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={}))
|
||||
|
||||
assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published)
|
||||
assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published)
|
||||
paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent))
|
||||
assert paused_event.paused_nodes == ["node-1"]
|
||||
assert emails
|
||||
|
||||
def test_handle_node_events_publishes_queue_events(self):
|
||||
published: list[object] = []
|
||||
|
||||
class _QueueManager:
|
||||
def publish(self, event, publish_from):
|
||||
published.append(event)
|
||||
|
||||
runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app")
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state))
|
||||
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStartedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
node_title="Start",
|
||||
start_at=datetime.utcnow(),
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunAgentLogEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.START,
|
||||
message_id="msg",
|
||||
label="label",
|
||||
node_execution_id="exec",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="done",
|
||||
data={},
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunIterationSucceededEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Iter",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={"ok": True},
|
||||
metadata={},
|
||||
steps=1,
|
||||
),
|
||||
)
|
||||
runner._handle_event(
|
||||
workflow_entry,
|
||||
NodeRunLoopFailedEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_title="Loop",
|
||||
start_at=datetime.utcnow(),
|
||||
inputs={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
steps=1,
|
||||
error="boom",
|
||||
),
|
||||
)
|
||||
|
||||
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
|
||||
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
|
||||
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
|
||||
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user