From fdff4b15bb40e4287cd7a7a9d340e7f692784666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Fri, 30 Jan 2026 22:58:24 +0800 Subject: [PATCH] enhance the dynamic NodeData type inference and validation process. --- api/core/app/apps/workflow_app_runner.py | 3 +- api/core/workflow/nodes/base/node.py | 45 ++++++++++-------------- api/core/workflow/workflow_entry.py | 6 ++-- 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 13b7865f55..ef140d1490 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -53,6 +53,7 @@ from core.workflow.graph_events import ( NodeRunSucceededEvent, ) from core.workflow.graph_events.graph import GraphRunAbortedEvent +from core.workflow.entities.graph_config import NodeConfigDictAdapter from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -299,7 +300,7 @@ class WorkflowBasedAppRunner: try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=target_node_config + graph_config=workflow.graph_dict, config=NodeConfigDictAdapter.validate_python(target_node_config) ) except NotImplementedError: variable_mapping = {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 63e0260341..8762393669 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,4 +1,5 @@ from __future__ import annotations +from sqlalchemy.sql.operators import from_ import importlib import logging @@ -13,6 +14,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus from core.workflow.graph_events import ( GraphNodeEventBase, @@ -63,7 +65,7 @@ logger = logging.getLogger(__name__) class Node(Generic[NodeDataT]): node_type: ClassVar[NodeType] execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData + _node_data_type: ClassVar[type[NodeDataT]] def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -169,7 +171,7 @@ class Node(Generic[NodeDataT]): bucket["latest"] = bucket[latest_key] @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: + def _extract_node_data_type_from_generic(cls) -> type[NodeDataT] | None: """ Extract the node data type from the generic parameter `Node[T]`. @@ -205,7 +207,7 @@ class Node(Generic[NodeDataT]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, ) -> None: @@ -222,19 +224,13 @@ class Node(Generic[NodeDataT]): self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required.") + node_id = config["id"] self._node_id = node_id self._node_execution_id: str = "" self._start_at = naive_utc_now() - raw_node_data = config.get("data") or {} - if not isinstance(raw_node_data, Mapping): - raise ValueError("Node config data must be a mapping.") - - self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data) + self._node_data = self._node_data_type.model_validate(config["data"], from_attributes=True) self.post_init() @@ -255,9 +251,6 @@ class Node(Generic[NodeDataT]): self._node_execution_id = str(uuid4()) return self._node_execution_id - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: - return cast(NodeDataT, self._node_data_type.model_validate(data)) - @abstractmethod def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ @@ -374,7 +367,7 @@ class Node(Generic[NodeDataT]): cls, *, graph_config: Mapping[str, Any], - config: Mapping[str, Any], + config: NodeConfigDict, ) -> Mapping[str, Sequence[str]]: """Extracts references variable selectors from node configuration. @@ -412,13 +405,13 @@ class Node(Generic[NodeDataT]): :param config: node config :return: """ - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + node_id = config["id"] - # Pass raw dict data instead of creating NodeData instance + node_data = cls._node_data_type.model_validate(config["data"], from_attributes=True) data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) + graph_config=graph_config, + node_id=node_id, + node_data=node_data, ) return data @@ -428,7 +421,7 @@ class Node(Generic[NodeDataT]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: NodeDataT, ) -> Mapping[str, Sequence[str]]: return {} @@ -483,23 +476,23 @@ class Node(Generic[NodeDataT]): def _get_error_strategy(self) -> ErrorStrategy | None: """Get the error strategy for this node.""" - return self._node_data.error_strategy + return self.node_data.error_strategy def _get_retry_config(self) -> RetryConfig: """Get the retry configuration for this node.""" - return self._node_data.retry_config + return self.node_data.retry_config def _get_title(self) -> str: """Get the node title.""" - return self._node_data.title + return self.node_data.title def _get_description(self) -> str | None: """Get the node description.""" - return self._node_data.desc + return self.node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict + return self.node_data.default_value_dict # Public interface properties that delegate to abstract methods @property diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4219b26226..521afd3634 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -25,6 +25,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.entities.graph_config import NodeConfigDictAdapter from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.enums import UserFrom @@ -302,10 +303,7 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config = { - "id": node_id, - "data": node_data, - } + node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node: Node = node_cls( id=str(uuid.uuid4()), config=node_config,