mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 21:15:10 +00:00
enhance the dynamic NodeData type inference and validation process.
This commit is contained in:
@@ -53,6 +53,7 @@ from core.workflow.graph_events import (
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunAbortedEvent
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -299,7 +300,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
graph_config=workflow.graph_dict, config=NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from sqlalchemy.sql.operators import from_
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
@@ -13,6 +14,7 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -63,7 +65,7 @@ logger = logging.getLogger(__name__)
|
||||
class Node(Generic[NodeDataT]):
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
_node_data_type: ClassVar[type[NodeDataT]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""
|
||||
@@ -169,7 +171,7 @@ class Node(Generic[NodeDataT]):
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
def _extract_node_data_type_from_generic(cls) -> type[NodeDataT] | None:
|
||||
"""
|
||||
Extract the node data type from the generic parameter `Node[T]`.
|
||||
|
||||
@@ -205,7 +207,7 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
@@ -222,19 +224,13 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
self._node_data = self._node_data_type.model_validate(config["data"], from_attributes=True)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@@ -255,9 +251,6 @@ class Node(Generic[NodeDataT]):
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@@ -374,7 +367,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@@ -412,13 +405,13 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
node_id = config["id"]
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_data = cls._node_data_type.model_validate(config["data"], from_attributes=True)
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -428,7 +421,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@@ -483,23 +476,23 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._node_data.error_strategy
|
||||
return self.node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._node_data.retry_config
|
||||
return self.node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._node_data.title
|
||||
return self.node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._node_data.desc
|
||||
return self.node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._node_data.default_value_dict
|
||||
return self.node_data.default_value_dict
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
|
||||
@@ -25,6 +25,7 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from extensions.otel.runtime import is_instrument_flag_enabled
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
@@ -302,10 +303,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": node_data,
|
||||
}
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node: Node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
|
||||
Reference in New Issue
Block a user