enhance the dynamic NodeData type inference and validation process.

This commit is contained in:
Yanli 盐粒
2026-01-30 22:58:24 +08:00
parent 4809ad9bf1
commit fdff4b15bb
3 changed files with 23 additions and 31 deletions

View File

@@ -53,6 +53,7 @@ from core.workflow.graph_events import (
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -299,7 +300,7 @@ class WorkflowBasedAppRunner:
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=target_node_config
graph_config=workflow.graph_dict, config=NodeConfigDictAdapter.validate_python(target_node_config)
)
except NotImplementedError:
variable_mapping = {}

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from sqlalchemy.sql.operators import from_
import importlib
import logging
@@ -13,6 +14,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -63,7 +65,7 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
_node_data_type: ClassVar[type[NodeDataT]]
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
@@ -169,7 +171,7 @@ class Node(Generic[NodeDataT]):
bucket["latest"] = bucket[latest_key]
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
def _extract_node_data_type_from_generic(cls) -> type[NodeDataT] | None:
"""
Extract the node data type from the generic parameter `Node[T]`.
@@ -205,7 +207,7 @@ class Node(Generic[NodeDataT]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
@@ -222,19 +224,13 @@ class Node(Generic[NodeDataT]):
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
node_id = config["id"]
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
self._node_data = self._node_data_type.model_validate(config["data"], from_attributes=True)
self.post_init()
@@ -255,9 +251,6 @@ class Node(Generic[NodeDataT]):
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
@@ -374,7 +367,7 @@ class Node(Generic[NodeDataT]):
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
config: NodeConfigDict,
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
@@ -412,13 +405,13 @@ class Node(Generic[NodeDataT]):
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_id = config["id"]
# Pass raw dict data instead of creating NodeData instance
node_data = cls._node_data_type.model_validate(config["data"], from_attributes=True)
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
return data
@@ -428,7 +421,7 @@ class Node(Generic[NodeDataT]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: NodeDataT,
) -> Mapping[str, Sequence[str]]:
return {}
@@ -483,23 +476,23 @@ class Node(Generic[NodeDataT]):
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._node_data.error_strategy
return self.node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
return self._node_data.retry_config
return self.node_data.retry_config
def _get_title(self) -> str:
"""Get the node title."""
return self._node_data.title
return self.node_data.title
def _get_description(self) -> str | None:
"""Get the node description."""
return self._node_data.desc
return self.node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._node_data.default_value_dict
return self.node_data.default_value_dict
# Public interface properties that delegate to abstract methods
@property

View File

@@ -25,6 +25,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from extensions.otel.runtime import is_instrument_flag_enabled
from factories import file_factory
from models.enums import UserFrom
@@ -302,10 +303,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = {
"id": node_id,
"data": node_data,
}
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node: Node = node_cls(
id=str(uuid.uuid4()),
config=node_config,