Compare commits

...

51 Commits

Author SHA1 Message Date
Yanli 盐粒
217f402da4 Revert "fix: validate node config ids and document compatibility defaults"
This reverts commit 465f36e691.
2026-03-01 18:34:42 +08:00
Yanli 盐粒
465f36e691 fix: validate node config ids and document compatibility defaults 2026-03-01 18:16:41 +08:00
Yanli 盐粒
93e4c108f8 docs: document node data type fallback in initializer 2026-03-01 17:58:49 +08:00
Yanli 盐粒
13d45769d7 fix: address PR feedback on node config validation 2026-03-01 17:54:11 +08:00
Yanli 盐粒
94b31b101e fix: resolve basedpyright config type mismatches 2026-03-01 17:28:17 +08:00
Yanli 盐粒
262a720673 refactor: remove unused _hydrate_node_data in Node 2026-03-01 17:18:48 +08:00
Yanli 盐粒
b3d074f04e chore: normalize webhook import ordering 2026-03-01 17:11:46 +08:00
Yanli 盐粒
af7c2ca5e9 merge: resolve origin/main conflicts 2026-03-01 17:11:15 +08:00
Yanli 盐粒
834a7bedf6 Merge remote-tracking branch 'origin/main' into yanli/pydantic-model-node-config-data 2026-02-12 11:18:44 +08:00
Yanli 盐粒
03fb284031 make ruff happy 2026-02-12 11:06:26 +08:00
Yanli 盐粒
994402b510 merge main 2026-02-12 11:05:38 +08:00
autofix-ci[bot]
c5ed658c1e [autofix.ci] apply automated fixes 2026-02-09 12:29:56 +00:00
盐粒 Yanli
d15502d7ac Merge branch 'main' into yanli/pydantic-model-node-config-data 2026-02-09 20:28:04 +08:00
autofix-ci[bot]
747badbccd [autofix.ci] apply automated fixes 2026-02-06 10:23:49 +00:00
盐粒 Yanli
93cbaac910 Merge branch 'main' into yanli/pydantic-model-node-config-data 2026-02-06 18:21:41 +08:00
盐粒 Yanli
b98e71e52a Merge branch 'main' into yanli/pydantic-model-node-config-data 2026-02-05 15:39:15 +08:00
Yanli 盐粒
1575c0db02 skip the no title test in test_webhook_data_validation_errors 2026-02-03 18:25:35 +08:00
Yanli 盐粒
98fc3cda3c make BaseNodeData.title have a default value 2026-02-03 18:17:43 +08:00
Yanli 盐粒
4e18718dc8 again update the Node.__init__ interface 2026-02-03 18:16:35 +08:00
Yanli 盐粒
b6bdd9996a update MockNodeFactory 2026-02-03 18:04:38 +08:00
Yanli 盐粒
f1a1791954 fix create_node interface 2026-02-03 18:04:16 +08:00
Yanli 盐粒
3ae27377ec fix Node.__init__ interface again 2026-02-03 18:04:06 +08:00
Yanli 盐粒
4aefb98711 ruff 2026-02-03 17:28:28 +08:00
Yanli 盐粒
e8b8320a3d fix the Node.__init__ interface 2026-02-03 17:28:20 +08:00
Yanli 盐粒
4ab0f3c1d3 ruff 2026-02-03 17:22:39 +08:00
Yanli 盐粒
1c242800b0 fix test_workflow_entry 2026-02-03 17:22:29 +08:00
Yanli 盐粒
2b4e4905bf ruff 2026-02-03 16:45:16 +08:00
Yanli 盐粒
4fdd621e1b revert the title to be required 2026-02-03 16:45:01 +08:00
Yanli 盐粒
afadf54a2b resolve conflict 2026-02-03 16:34:12 +08:00
盐粒 Yanli
59c729c376 Merge branch 'main' into yanli/pydantic-model-node-config-data 2026-02-03 16:26:08 +08:00
Yanli 盐粒
36996d44a4 make type checker happy 2026-02-03 16:15:23 +08:00
Yanli 盐粒
07eb7504e3 add default value of type for every core.workflow.nodes.*.entities 2026-02-03 15:43:11 +08:00
Yanli 盐粒
4a6798790d add type ignore comment 2026-02-03 15:42:21 +08:00
Yanli 盐粒
c0e3cf6f18 update import linter rule 2026-02-03 14:58:59 +08:00
autofix-ci[bot]
3b82ef08f1 [autofix.ci] apply automated fixes 2026-02-03 06:57:32 +00:00
Yanli 盐粒
da26cd1c12 fix import lints by moving BaseNodeData from core.workflow.nodes to core.workflow.entities 2026-02-03 14:53:21 +08:00
Yanli 盐粒
b17ddd0ea7 fix the default_factory for BaseNodeData 2026-02-03 14:21:40 +08:00
Yanli 盐粒
4f6f110014 test: fix integration webhook/http node data usage 2026-02-03 14:15:43 +08:00
Yanli 盐粒
98f302c720 test: make dify config tests ignore .env 2026-02-03 14:15:43 +08:00
Yanli 盐粒
3f957f7c0f chore: minor update 2026-02-03 14:15:43 +08:00
Yanli 盐粒
1a8b9ee1aa chore: apply ruff auto-fixes 2026-02-03 14:15:43 +08:00
Yanli 盐粒
ae65f0909c Fix workflow node data type requirement (follow #31723) 2026-02-03 14:15:43 +08:00
Yanli 盐粒
f3c6ca4a6f lint 2026-02-03 14:15:43 +08:00
Yanli 盐粒
f1b6c70893 use typed model (WebhookData and SegmentType) for the webhook service 2026-02-03 14:15:43 +08:00
Yanli 盐粒
c3bf1ac541 update the schedule trigger with typed model 2026-02-03 14:15:43 +08:00
Yanli 盐粒
1b1c1424ae make lint happy 2026-02-03 14:15:18 +08:00
Yanli 盐粒
05b109841d solve conflicts 2026-02-03 14:14:28 +08:00
Yanli 盐粒
23e00c1397 make lint happy 2026-02-03 14:06:44 +08:00
Yanli 盐粒
fdff4b15bb enhance the dynamic NodeData type inference and validation process. 2026-02-03 14:06:44 +08:00
Yanli 盐粒
4809ad9bf1 more trival updates 2026-02-03 14:06:44 +08:00
Yanli 盐粒
8d85f51a3a Refactor node configuration handling to use NodeConfigDict and improve type safety 2026-02-03 14:06:44 +08:00
69 changed files with 757 additions and 564 deletions

View File

@@ -31,6 +31,7 @@ from core.app.entities.queue_entities import (
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -310,7 +311,7 @@ class WorkflowBasedAppRunner:
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=target_node_config
graph_config=workflow.graph_dict, config=NodeConfigDictAdapter.validate_python(target_node_config)
)
except NotImplementedError:
variable_mapping = {}

View File

@@ -17,7 +17,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from core.workflow.enums import NodeType
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
@@ -115,7 +115,7 @@ class DifyNodeFactory(NodeFactory):
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
@@ -124,14 +124,12 @@ class DifyNodeFactory(NodeFactory):
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config["id"]
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
node_id = typed_node_config["id"]
# Get node type from config
node_data = node_config["data"]
try:
node_type = NodeType(node_data["type"])
except ValueError:
raise ValueError(f"Unknown node type: {node_data['type']}")
node_data = typed_node_config["data"]
node_type = node_data.type
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
@@ -139,7 +137,7 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
node_version = str(node_data.get("version", "1"))
node_version = str(node_data.version)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
@@ -149,7 +147,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
@@ -159,7 +157,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
@@ -169,7 +167,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_request_config=self._http_request_config,
@@ -179,11 +177,12 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
node_data_dict = node_data.model_dump()
model_instance = self._build_model_instance_for_llm_node(node_data_dict)
memory = self._build_memory_for_llm_node(node_data=node_data_dict, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
@@ -195,7 +194,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
@@ -204,7 +203,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
@@ -213,17 +212,17 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.DOCUMENT_EXTRACTOR:
return DocumentExtractorNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
model_instance = self._build_model_instance_for_llm_node(node_data.model_dump())
return QuestionClassifierNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
@@ -232,10 +231,10 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
model_instance = self._build_model_instance_for_llm_node(node_data.model_dump())
return ParameterExtractorNode(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
@@ -245,7 +244,7 @@ class DifyNodeFactory(NodeFactory):
return node_class(
id=node_id,
config=node_config,
config=typed_node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)

View File

@@ -19,6 +19,7 @@ from core.trigger.debug.events import (
build_plugin_pool_key,
build_webhook_pool_key,
)
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
@@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
app_id: str
user_id: str
tenant_id: str
node_config: Mapping[str, Any]
node_config: NodeConfigDict
node_id: str
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
self.tenant_id = tenant_id
self.user_id = user_id
self.app_id = app_id
@@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
def poll(self) -> TriggerDebugEvent | None:
from services.trigger.trigger_service import TriggerService
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
pool_key: str = build_plugin_pool_key(
name=plugin_trigger_data.event_name,

View File

@@ -0,0 +1,195 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.workflow.enums import ErrorStrategy, NodeType
_NumberType = Union[int, float]
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]
class OutputVariableType(StrEnum):
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
SECRET = "secret"
BOOLEAN = "boolean"
OBJECT = "object"
FILE = "file"
ARRAY = "array"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_FILE = "array[file]"
ANY = "any"
ARRAY_ANY = "array[any]"
class OutputVariableEntity(BaseModel):
"""
Output Variable Entity.
"""
variable: str
value_type: OutputVariableType = OutputVariableType.ANY
value_selector: Sequence[str]
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, v: Any) -> Any:
"""
Normalize value_type to handle case-insensitive array types.
Converts 'Array[...]' to 'array[...]' for backward compatibility.
"""
if isinstance(v, str) and v.startswith("Array["):
return v.lower()
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
model_config = ConfigDict(extra="allow")
type: NodeType
title: str = ""
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = Field(default_factory=RetryConfig)
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}

View File

@@ -4,21 +4,18 @@ import sys
from pydantic import TypeAdapter, with_config
from core.workflow.entities.base_node import BaseNodeData
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
data: BaseNodeData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@@ -3,12 +3,12 @@ from __future__ import annotations
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from typing import Any, Protocol, cast, final
from pydantic import TypeAdapter
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
from core.workflow.nodes.base.node import Node
from libs.typing import is_str
@@ -28,7 +28,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
"""
Create a Node instance from node configuration data.
@@ -115,10 +115,7 @@ class Graph:
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
if node_data.type.is_start_node:
start_node_id = nid
break
@@ -197,6 +194,27 @@ class Graph:
return nodes
@staticmethod
def _is_custom_note_node(node_config: object) -> bool:
"""
Check whether a raw graph node is a UI-only custom note node.
Custom note nodes are not executable workflow nodes, and their `data.type`
field may be empty, so they must be filtered out before strict node schema
validation.
"""
if not isinstance(node_config, Mapping):
return False
if node_config.get("type") == "custom-note":
return True
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return False
return node_data.get("type") == "custom-note"
@classmethod
def new(cls) -> GraphBuilder:
"""Create a fluent builder for assembling a graph programmatically."""
@@ -302,13 +320,13 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[object], node_configs)
node_configs = [node_config for node_config in node_configs if not cls._is_custom_note_node(node_config)]
node_configs = _ListNodeConfigDict.validate_python(node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@@ -367,12 +367,11 @@ class AgentNode(Node[AgentNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:

View File

@@ -5,10 +5,12 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
class AgentNodeData(BaseNodeData):
type: NodeType = NodeType.AGENT
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str # redundancy

View File

@@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
_ = graph_config # Explicitly mark as unused
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

View File

@@ -3,6 +3,7 @@ from enum import StrEnum, auto
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
Answer Node Data.
"""
type: NodeType = NodeType.ANSWER
answer: str = Field(..., description="answer template string")

View File

@@ -1,185 +1,18 @@
from __future__ import annotations
import json
from abc import ABC
from builtins import type as type_
from collections.abc import Sequence
from enum import StrEnum
from typing import Any, Union
from pydantic import BaseModel
from pydantic import BaseModel, field_validator, model_validator
from core.workflow.enums import ErrorStrategy
from .exc import DefaultValueTypeError
_NumberType = Union[int, float]
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: Sequence[str]
class OutputVariableType(StrEnum):
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
SECRET = "secret"
BOOLEAN = "boolean"
OBJECT = "object"
FILE = "file"
ARRAY = "array"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_FILE = "array[file]"
ANY = "any"
ARRAY_ANY = "array[any]"
class OutputVariableEntity(BaseModel):
"""
Output Variable Entity.
"""
variable: str
value_type: OutputVariableType = OutputVariableType.ANY
value_selector: Sequence[str]
@field_validator("value_type", mode="before")
@classmethod
def normalize_value_type(cls, v: Any) -> Any:
"""
Normalize value_type to handle case-insensitive array types.
Converts 'Array[...]' to 'array[...]' for backward compatibility.
"""
if isinstance(v, str) and v.startswith("Array["):
return v.lower()
return v
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
class DefaultValue(BaseModel):
value: Any = None
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": _NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": _NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator: dict[str, Any] = type_validators.get(self.type, {})
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: str | None = None
version: str = "1"
error_strategy: ErrorStrategy | None = None
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
from core.workflow.entities.base_node import (
BaseNodeData,
BaseNodeError,
DefaultValue,
DefaultValueType,
DefaultValueTypeError,
OutputVariableEntity,
OutputVariableType,
RetryConfig,
VariableSelector,
)
class BaseIterationNodeData(BaseNodeData):
@@ -210,3 +43,20 @@ class BaseLoopState(BaseModel):
pass
metadata: MetaData
__all__ = [
"BaseIterationNodeData",
"BaseIterationState",
"BaseLoopNodeData",
"BaseLoopState",
"BaseNodeData",
"BaseNodeError",
"DefaultValue",
"DefaultValueType",
"DefaultValueTypeError",
"OutputVariableEntity",
"OutputVariableType",
"RetryConfig",
"VariableSelector",
]

View File

@@ -1,10 +1,6 @@
class BaseNodeError(ValueError):
"""Base class for node errors."""
from core.workflow.entities.base_node import BaseNodeError, DefaultValueTypeError
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass
__all__ = [
"BaseNodeError",
"DefaultValueTypeError",
]

View File

@@ -13,6 +13,7 @@ from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@@ -76,7 +77,7 @@ class Node(Generic[NodeDataT]):
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
_node_data_type: ClassVar[type[NodeDataT]] # type: ignore[misc] # assigned per-subclass in __init_subclass__
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
@@ -130,11 +131,11 @@ class Node(Generic[NodeDataT]):
Later, in __init__:
::
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
config["data"] ──► _node_data_type.model_validate()
CodeNodeData instance
(stored in self._node_data)
Example:
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
@@ -182,7 +183,7 @@ class Node(Generic[NodeDataT]):
bucket["latest"] = bucket[latest_key]
@classmethod
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
def _extract_node_data_type_from_generic(cls) -> type[NodeDataT] | None:
"""
Extract the node data type from the generic parameter `Node[T]`.
@@ -208,7 +209,7 @@ class Node(Generic[NodeDataT]):
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
return candidate
return candidate # type: ignore[return-value]
return None
@@ -218,10 +219,16 @@ class Node(Generic[NodeDataT]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: dict[str, Any] | NodeConfigDict, # NodeConfigDict
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
) -> None:
"""
Initialize a node and hydrate typed node data from config.
For backward compatibility, when incoming mapping-based `data` misses a
`type`, this constructor fills it with `self.node_type` before validation.
"""
self._graph_init_params = graph_init_params
self.id = id
self.tenant_id = graph_init_params.tenant_id
@@ -235,19 +242,25 @@ class Node(Generic[NodeDataT]):
self.graph_runtime_state = graph_runtime_state
self.state: NodeState = NodeState.UNKNOWN # node execution state
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
if "id" not in config:
raise ValueError("node config missing required 'id' field")
node_id = config["id"]
self._node_id = node_id
self._node_execution_id: str = ""
self._start_at = naive_utc_now()
raw_node_data = config.get("data") or {}
if not isinstance(raw_node_data, Mapping):
raise ValueError("Node config data must be a mapping.")
if "data" not in config:
raise ValueError(f"node config for node {node_id} missing required 'data' field")
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
if isinstance(config["data"], BaseNodeData):
self._node_data = self._node_data_type.model_validate(config["data"], from_attributes=True)
elif isinstance(config["data"], dict):
if "type" not in config["data"]:
config["data"]["type"] = self.node_type
self._node_data = self._node_data_type.model_validate(config["data"])
else:
raise TypeError(f"node config 'data' field must be a dict or {self._node_data_type.__name__} instance")
self.post_init()
@@ -291,9 +304,6 @@ class Node(Generic[NodeDataT]):
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@abstractmethod
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
"""
@@ -342,8 +352,6 @@ class Node(Generic[NodeDataT]):
start_event.provider_id = getattr(self.node_data, "provider_id", "")
start_event.provider_type = getattr(self.node_data, "provider_type", "")
from typing import cast
from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.agent.entities import AgentNodeData
@@ -410,7 +418,7 @@ class Node(Generic[NodeDataT]):
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
config: NodeConfigDict,
) -> Mapping[str, Sequence[str]]:
"""Extracts references variable selectors from node configuration.
@@ -448,13 +456,13 @@ class Node(Generic[NodeDataT]):
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_id = config["id"]
# Pass raw dict data instead of creating NodeData instance
node_data = cls._node_data_type.model_validate(config["data"], from_attributes=True)
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
return data
@@ -464,7 +472,7 @@ class Node(Generic[NodeDataT]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: NodeDataT,
) -> Mapping[str, Sequence[str]]:
return {}
@@ -519,23 +527,23 @@ class Node(Generic[NodeDataT]):
def _get_error_strategy(self) -> ErrorStrategy | None:
"""Get the error strategy for this node."""
return self._node_data.error_strategy
return self.node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
return self._node_data.retry_config
return self.node_data.retry_config
def _get_title(self) -> str:
"""Get the node title."""
return self._node_data.title
return self.node_data.title
def _get_description(self) -> str | None:
"""Get the node description."""
return self._node_data.desc
return self.node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._node_data.default_value_dict
return self.node_data.default_value_dict
# Public interface properties that delegate to abstract methods
@property

View File

@@ -3,6 +3,7 @@ from decimal import Decimal
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Protocol, cast
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}
@property

View File

@@ -3,6 +3,7 @@ from typing import Annotated, Literal
from pydantic import AfterValidator, BaseModel
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.variables.types import SegmentType
@@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.CODE
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
@@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
@@ -180,7 +181,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DatasourceNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -189,7 +190,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
:param node_data: node data
:return:
"""
typed_node_data = DatasourceNodeData.model_validate(node_data)
typed_node_data = node_data
result = {}
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:

View File

@@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
@@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
type: NodeType = NodeType.DATASOURCE
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@@ -1,10 +1,12 @@
from collections.abc import Sequence
from dataclasses import dataclass
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
class DocumentExtractorNodeData(BaseNodeData):
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
variable_selector: Sequence[str]

View File

@@ -21,6 +21,7 @@ from docx.table import Table
from docx.text.paragraph import Paragraph
from core.helper import ssrf_proxy
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File, FileTransferMethod, file_manager
from core.workflow.node_events import NodeRunResult
@@ -53,7 +54,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -118,12 +119,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: DocumentExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
_ = graph_config # Explicitly mark as unused
return {node_id + ".files": node_data.variable_selector}
def _extract_text_by_mime_type(

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity
@@ -8,6 +9,7 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
type: NodeType = NodeType.END
outputs: list[OutputVariableEntity]

View File

@@ -8,6 +8,7 @@ import charset_normalizer
import httpx
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
@@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
Code Node Data.
"""
type: NodeType = NodeType.HTTP_REQUEST
method: Literal[
"get",
"post",

View File

@@ -3,6 +3,7 @@ import mimetypes
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.file import File, FileTransferMethod
from core.workflow.node_events import NodeRunResult
@@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -163,18 +164,16 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HttpRequestNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
_ = graph_config # Explicitly mark as unused
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
match body_type:
case "none":
pass

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import (
@@ -64,7 +65,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: dict[str, Any] | NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
@@ -342,7 +343,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: HumanInputNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
@@ -351,5 +352,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
return node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@@ -2,6 +2,7 @@ from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.utils.condition.entities import Condition
@@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
If Else Node Data.
"""
type: NodeType = NodeType.IF_ELSE
class Case(BaseModel):
"""
Case entity representing a single logical condition group

View File

@@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in typed_node_data.cases or []:
_ = graph_config # Explicitly mark as unused
for case in node_data.cases or []:
for condition in case.conditions:
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
var_mapping[key] = condition.variable_selector

View File

@@ -3,6 +3,7 @@ from typing import Any
from pydantic import Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
@@ -17,6 +18,7 @@ class IterationNodeData(BaseIterationNodeData):
Iteration Node Data.
"""
type: NodeType = NodeType.ITERATION
parent_loop_id: str | None = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
@@ -31,7 +33,7 @@ class IterationStartNodeData(BaseNodeData):
Iteration Start Node Data.
"""
pass
type: NodeType = NodeType.ITERATION_START
class IterationState(BaseIterationState):

View File

@@ -8,6 +8,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.enums import (
NodeExecutionType,
NodeType,
@@ -460,21 +461,24 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector,
f"{node_id}.input_selector": node_data.iterator_selector,
}
iteration_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
nodes_value = graph_config.get("nodes", [])
for node in nodes_value:
if not isinstance(node, Mapping):
continue
node_data_value = node.get("data", {})
if not isinstance(node_data_value, Mapping):
continue
if node_data_value.get("iteration_id") == node_id:
in_iteration_node_id = node.get("id")
if in_iteration_node_id:
iteration_node_ids.add(in_iteration_node_id)
@@ -497,7 +501,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=NodeConfigDictAdapter.validate_python(sub_node_config)
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:

View File

@@ -3,6 +3,7 @@ from typing import Literal, Union
from pydantic import BaseModel
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
Knowledge index Node Data.
"""
type: str = "knowledge-index"
type: NodeType = NodeType.KNOWLEDGE_INDEX
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None

View File

@@ -3,6 +3,7 @@ from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
@@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
Knowledge retrieval Node Data.
"""
type: str = "knowledge-retrieval"
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
query_variable_selector: list[str] | None | str = None
query_attachment_selector: list[str] | None | str = None
dataset_ids: list[str]

View File

@@ -6,6 +6,7 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
@@ -48,7 +49,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
@@ -260,15 +261,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: KnowledgeRetrievalNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
if typed_node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
_ = graph_config # Explicitly mark as unused
if node_data.query_variable_selector:
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
if node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
return variable_mapping

View File

@@ -3,6 +3,7 @@ from enum import StrEnum
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
class ListOperatorNodeData(BaseNodeData):
type: NodeType = NodeType.LIST_OPERATOR
variable: Sequence[str] = Field(default_factory=list)
filter_by: FilterBy
order_by: OrderByConfig

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData):
type: NodeType = NodeType.LLM
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)

View File

@@ -44,6 +44,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -119,7 +120,7 @@ class LLMNode(Node[LLMNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
@@ -951,14 +952,11 @@ class LLMNode(Node[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LLMNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
_ = graph_config # Explicitly mark as unused
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
prompt_template = node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list):
for prompt in prompt_template:
@@ -976,7 +974,7 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = typed_node_data.memory
memory = node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@@ -984,16 +982,16 @@ class LLMNode(Node[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.memory:
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
if typed_node_data.prompt_config:
if node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
@@ -1006,7 +1004,7 @@ class LLMNode(Node[LLMNodeData]):
break
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}

View File

@@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal
from pydantic import AfterValidator, BaseModel, Field, field_validator
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.variables.types import SegmentType
@@ -39,6 +40,7 @@ class LoopVariableData(BaseModel):
class LoopNodeData(BaseLoopNodeData):
type: NodeType = NodeType.LOOP
loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
@@ -58,7 +60,7 @@ class LoopStartNodeData(BaseNodeData):
Loop Start Node Data.
"""
pass
type: NodeType = NodeType.LOOP_START
class LoopEndNodeData(BaseNodeData):
@@ -66,7 +68,7 @@ class LoopEndNodeData(BaseNodeData):
Loop End Node Data.
"""
pass
type: NodeType = NodeType.LOOP_END
class LoopState(BaseLoopState):

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.enums import (
NodeExecutionType,
NodeType,
@@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# Extract loop node IDs statically from graph_config
@@ -327,7 +325,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
graph_config=graph_config, config=NodeConfigDictAdapter.validate_python(sub_node_config)
)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
@@ -342,7 +340,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in typed_node_data.loop_variables or []:
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@@ -8,6 +8,7 @@ from pydantic import (
)
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
from core.workflow.variables.types import SegmentType
@@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
Parameter Extractor Node Data.
"""
type: NodeType = NodeType.PARAMETER_EXTRACTOR
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]

View File

@@ -24,6 +24,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
@@ -101,7 +102,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -834,15 +835,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ParameterExtractorNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
@@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
class QuestionClassifierNodeData(BaseNodeData):
type: NodeType = NodeType.QUESTION_CLASSIFIER
query_variable_selector: list[str]
model: ModelConfig
classes: list[ClassConfig]

View File

@@ -10,6 +10,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import (
NodeExecutionType,
NodeType,
@@ -60,7 +61,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -246,16 +247,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: QuestionClassifierNodeData,
) -> Mapping[str, Sequence[str]]:
# graph_config is not used in this node type
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
_ = graph_config # Explicitly mark as unused
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)

View File

@@ -3,6 +3,7 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
Start Node Data
"""
type: NodeType = NodeType.START
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@@ -1,3 +1,4 @@
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
Template Transform Node Data.
"""
type: NodeType = NodeType.TEMPLATE_TRANSFORM
variables: list[VariableSelector]
template: str

View File

@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -26,7 +27,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
@@ -87,12 +88,10 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in typed_node_data.variables
for variable_selector in node_data.variables
}

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
@@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = NodeType.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]

View File

@@ -467,7 +467,7 @@ class ToolNode(Node[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: ToolNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -476,9 +476,8 @@ class ToolNode(Node[ToolNodeData]):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
typed_node_data = node_data
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]

View File

@@ -4,6 +4,7 @@ from typing import Any, Literal, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.trigger.entities.entities import EventParameter
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
@@ -11,6 +12,8 @@ from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
class TriggerEventNodeData(BaseNodeData):
"""Plugin trigger node data"""
type: NodeType = NodeType.TRIGGER_PLUGIN
class TriggerEventInput(BaseModel):
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@@ -38,7 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
raise ValueError("value must be a string, int, float, bool or dict")
return type
title: str
desc: str | None = None
plugin_id: str = Field(..., description="Plugin ID")
provider_id: str = Field(..., description="Provider ID")

View File

@@ -2,6 +2,7 @@ from typing import Literal, Union
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
Trigger Schedule Node Data
"""
type: NodeType = NodeType.TRIGGER_SCHEDULE
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")

View File

@@ -1,10 +1,25 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, Field, field_validator
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.types import SegmentType
_WEBHOOK_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_BOOLEAN,
SegmentType.ARRAY_OBJECT,
SegmentType.FILE,
}
)
class Method(StrEnum):
@@ -28,26 +43,31 @@ class WebhookParameter(BaseModel):
"""Parameter definition for headers, query params, or body."""
name: str
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook parameter type: {v}")
return v
class WebhookBodyParameter(BaseModel):
"""Body parameter with type information."""
name: str
type: Literal[
"string",
"number",
"boolean",
"object",
"array[string]",
"array[number]",
"array[boolean]",
"array[object]",
"file",
] = "string"
type: SegmentType = SegmentType.STRING
required: bool = False
@field_validator("type", mode="after")
@classmethod
def validate_type(cls, v: SegmentType) -> SegmentType:
if v not in _WEBHOOK_ALLOWED_TYPES:
raise ValueError(f"Unsupported webhook body parameter type: {v}")
return v
class WebhookData(BaseNodeData):
"""
@@ -57,6 +77,7 @@ class WebhookData(BaseNodeData):
class SyncMode(StrEnum):
SYNC = "async" # only support
type: NodeType = NodeType.TRIGGER_WEBHOOK
method: Method = Method.GET
content_type: ContentType = Field(default=ContentType.JSON)
headers: Sequence[WebhookParameter] = Field(default_factory=list)

View File

@@ -151,7 +151,7 @@ class TriggerWebhookNode(Node[WebhookData]):
outputs[param_name] = raw_data
continue
if param_type == "file":
if param_type == SegmentType.FILE:
# Get File object (already processed by webhook controller)
files = webhook_data.get("files", {})
if files and isinstance(files, dict):

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.types import SegmentType
@@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
Variable Aggregator Node Data.
"""
type: NodeType = NodeType.VARIABLE_AGGREGATOR
output_type: str
variables: list[list[str]]
advanced_settings: AdvancedSettings | None = None

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@@ -52,21 +53,19 @@ class VariableAssignerNode(Node[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(typed_node_data.assigned_variable_selector)
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.assigned_variable_selector
mapping[key] = node_data.assigned_variable_selector
selector_key = ".".join(typed_node_data.input_variable_selector)
selector_key = ".".join(node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.input_variable_selector
mapping[key] = node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:

View File

@@ -1,6 +1,7 @@
from collections.abc import Sequence
from enum import StrEnum
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
@@ -11,6 +12,7 @@ class WriteMode(StrEnum):
class VariableAssignerData(BaseNodeData):
type: NodeType = NodeType.VARIABLE_ASSIGNER
assigned_variable_selector: Sequence[str]
write_mode: WriteMode
input_variable_selector: Sequence[str]

View File

@@ -3,6 +3,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.workflow.enums import NodeType
from core.workflow.nodes.base import BaseNodeData
from .enums import InputType, Operation
@@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel):
class VariableAssignerNodeData(BaseNodeData):
type: NodeType = NodeType.VARIABLE_ASSIGNER
version: str = "2"
items: Sequence[VariableOperationItem] = Field(default_factory=list)

View File

@@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def __init__(
self,
id: str,
config: Mapping[str, Any],
config: NodeConfigDict,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
):
@@ -94,13 +95,11 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
node_data: VariableAssignerNodeData,
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
_ = graph_config # Explicitly mark as unused
var_mapping: dict[str, Sequence[str]] = {}
for item in typed_node_data.items:
for item in node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping

View File

@@ -10,7 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.file.models import File
from core.workflow.graph import Graph
@@ -148,7 +148,7 @@ class WorkflowEntry:
node_config_data = node_config["data"]
# Get node type
node_type = NodeType(node_config_data["type"])
node_type = node_config_data.type
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -303,10 +303,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config: NodeConfigDict = {
"id": node_id,
"data": cast(NodeConfigData, node_data),
}
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,

View File

@@ -234,7 +234,7 @@ class Workflow(Base): # bug
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
"""Extract a node configuration from the workflow graph by node ID.
A node configuration is a dictionary containing the node's properties, including
the node's id, title, and its data as a dict.
the node's id and its validated `data` payload as a `BaseNodeData` model.
"""
workflow_graph = self.graph_dict
@@ -252,12 +252,9 @@ class Workflow(Base): # bug
return NodeConfigDictAdapter.validate_python(node_config)
@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config_data.get("type"))
return node_type
return node_config["data"].type
@staticmethod
def get_enclosing_node_type_and_id(

View File

@@ -1,14 +1,18 @@
import json
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.nodes import NodeType
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
from core.workflow.nodes.trigger_schedule.entities import (
ScheduleConfig,
SchedulePlanUpdate,
TriggerScheduleNodeData,
VisualConfig,
)
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
from models.account import Account, TenantAccountJoin
@@ -176,26 +180,26 @@ class ScheduleService:
return next_run_at
@staticmethod
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig:
"""
Converts user-friendly visual schedule settings to cron expression.
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
"""
node_data = node_config.get("data", {})
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node_config.get("id", "start")
node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True)
mode = node_data.mode
timezone = node_data.timezone
node_id = node_config["id"]
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
cron_expression = node_data.cron_expression
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = str(node_data.get("frequency"))
frequency = str(node_data.frequency or "")
if not frequency:
raise ScheduleConfigError("Frequency is required for visual mode")
visual_config = VisualConfig(**node_data.get("visual_config", {}))
visual_config = VisualConfig.model_validate(node_data.visual_config or {})
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for visual mode")
@@ -239,19 +243,21 @@ class ScheduleService:
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
continue
mode = node_data.get("mode", "visual")
timezone = node_data.get("timezone", "UTC")
node_id = node.get("id", "start")
trigger_data = TriggerScheduleNodeData.model_validate(node_data)
mode = trigger_data.mode
timezone = trigger_data.timezone
cron_expression = None
if mode == "cron":
cron_expression = node_data.get("cron_expression")
cron_expression = trigger_data.cron_expression
if not cron_expression:
raise ScheduleConfigError("Cron expression is required for cron mode")
elif mode == "visual":
frequency = node_data.get("frequency")
visual_config_dict = node_data.get("visual_config", {})
visual_config = VisualConfig(**visual_config_dict)
frequency = trigger_data.frequency
if not frequency:
raise ScheduleConfigError("Frequency is required for visual mode")
visual_config = VisualConfig.model_validate(trigger_data.visual_config or {})
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
else:
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")

View File

@@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent
from core.trigger.provider import PluginTriggerProviderController
from core.trigger.trigger_manager import TriggerManager
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from extensions.ext_database import db
@@ -41,7 +42,7 @@ class TriggerService:
@classmethod
def invoke_trigger_event(
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent
) -> TriggerInvokeEventResponse:
"""Invoke a trigger event."""
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
@@ -50,7 +51,7 @@ class TriggerService:
)
if not subscription:
raise ValueError("Subscription not found")
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True)
request = TriggerHttpRequestCachingService.get_request(event.request_id)
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
# invoke triger

View File

@@ -2,7 +2,7 @@ import json
import logging
import mimetypes
import secrets
from collections.abc import Mapping
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import orjson
@@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from core.workflow.enums import NodeType
from core.workflow.file.models import FileTransferMethod
from core.workflow.variables.types import SegmentType
from core.workflow.nodes.trigger_webhook.entities import (
ContentType,
WebhookBodyParameter,
WebhookData,
WebhookParameter,
)
from core.workflow.variables.types import ArrayValidation, SegmentType
from enums.quota_type import QuotaType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -57,7 +64,7 @@ class WebhookService:
@classmethod
def get_webhook_trigger_and_workflow(
cls, webhook_id: str, is_debug: bool = False
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]:
"""Get webhook trigger, workflow, and node configuration.
Args:
@@ -129,13 +136,13 @@ class WebhookService:
if not workflow:
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
node_config = workflow.get_node_config_by_id(webhook_trigger.node_id)
node_config = NodeConfigDictAdapter.validate_python(workflow.get_node_config_by_id(webhook_trigger.node_id))
return webhook_trigger, workflow, node_config
@classmethod
def extract_and_validate_webhook_data(
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
) -> dict[str, Any]:
"""Extract and validate webhook data in a single unified process.
@@ -153,7 +160,7 @@ class WebhookService:
raw_data = cls.extract_webhook_data(webhook_trigger)
# Validate HTTP metadata (method, content-type)
node_data = node_config.get("data", {})
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
validation_result = cls._validate_http_metadata(raw_data, node_data)
if not validation_result["valid"]:
raise ValueError(validation_result["error"])
@@ -192,7 +199,7 @@ class WebhookService:
content_type = cls._extract_content_type(dict(request.headers))
# Route to appropriate extractor based on content type
extractors = {
extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = {
"application/json": cls._extract_json_body,
"application/x-www-form-urlencoded": cls._extract_form_body,
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
@@ -214,7 +221,7 @@ class WebhookService:
return data
@classmethod
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
"""Process and validate webhook data according to node configuration.
Args:
@@ -230,18 +237,13 @@ class WebhookService:
result = raw_data.copy()
# Validate and process headers
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
cls._validate_required_headers(raw_data["headers"], node_data.headers)
# Process query parameters with type conversion and validation
result["query_params"] = cls._process_parameters(
raw_data["query_params"], node_data.get("params", []), is_form_data=True
)
result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True)
# Process body parameters based on content type
configured_content_type = node_data.get("content_type", "application/json").lower()
result["body"] = cls._process_body_parameters(
raw_data["body"], node_data.get("body", []), configured_content_type
)
result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type)
return result
@@ -424,7 +426,11 @@ class WebhookService:
@classmethod
def _process_parameters(
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
cls,
raw_params: dict[str, str],
param_configs: Sequence[WebhookParameter],
*,
is_form_data: bool = False,
) -> dict[str, Any]:
"""Process parameters with unified validation and type conversion.
@@ -440,13 +446,13 @@ class WebhookService:
ValueError: If required parameters are missing or validation fails
"""
processed = {}
configured_params = {config.get("name", ""): config for config in param_configs}
configured_params = {config.name: config for config in param_configs}
# Process configured parameters
for param_config in param_configs:
name = param_config.get("name", "")
param_type = param_config.get("type", SegmentType.STRING)
required = param_config.get("required", False)
name = param_config.name
param_type = param_config.type
required = param_config.required
# Check required parameters
if required and name not in raw_params:
@@ -465,7 +471,10 @@ class WebhookService:
@classmethod
def _process_body_parameters(
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
cls,
raw_body: dict[str, Any],
body_configs: Sequence[WebhookBodyParameter],
content_type: ContentType,
) -> dict[str, Any]:
"""Process body parameters based on content type and configuration.
@@ -480,25 +489,28 @@ class WebhookService:
Raises:
ValueError: If required body parameters are missing or validation fails
"""
if content_type in ["text/plain", "application/octet-stream"]:
# For text/plain and octet-stream, validate required content exists
if body_configs and any(config.get("required", False) for config in body_configs):
raw_content = raw_body.get("raw")
if not raw_content:
raise ValueError(f"Required body content missing for {content_type} request")
return raw_body
match content_type:
case ContentType.TEXT | ContentType.BINARY:
# For text/plain and octet-stream, validate required content exists
if body_configs and any(config.required for config in body_configs):
raw_content = raw_body.get("raw")
if not raw_content:
raise ValueError(f"Required body content missing for {content_type} request")
return raw_body
case _:
pass
# For structured data (JSON, form-data, etc.)
processed = {}
configured_params = {config.get("name", ""): config for config in body_configs}
configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs}
for body_config in body_configs:
name = body_config.get("name", "")
param_type = body_config.get("type", SegmentType.STRING)
required = body_config.get("required", False)
name = body_config.name
param_type = body_config.type
required = body_config.required
# Handle file parameters for multipart data
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA:
# File validation is handled separately in extract phase
continue
@@ -508,7 +520,7 @@ class WebhookService:
if name in raw_body:
raw_value = raw_body[name]
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
is_form_data = content_type in (ContentType.FORM_URLENCODED, ContentType.FORM_DATA)
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
# Include unconfigured parameters
@@ -519,7 +531,9 @@ class WebhookService:
return processed
@classmethod
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
def _validate_and_convert_value(
cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool
) -> Any:
"""Unified validation and type conversion for parameter values.
Args:
@@ -545,7 +559,7 @@ class WebhookService:
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
@classmethod
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any:
"""Convert form data string values to specified types.
Args:
@@ -559,24 +573,28 @@ class WebhookService:
Raises:
ValueError: If the value cannot be converted to the specified type
"""
if param_type == SegmentType.STRING:
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
if param_type_enum is None:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
if param_type_enum == SegmentType.STRING:
return value
elif param_type == SegmentType.NUMBER:
elif param_type_enum == SegmentType.NUMBER:
if not cls._can_convert_to_number(value):
raise ValueError(f"Cannot convert '{value}' to number")
numeric_value = float(value)
return int(numeric_value) if numeric_value.is_integer() else numeric_value
elif param_type == SegmentType.BOOLEAN:
elif param_type_enum == SegmentType.BOOLEAN:
lower_value = value.lower()
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
if lower_value not in bool_map:
raise ValueError(f"Cannot convert '{value}' to boolean")
return bool_map[lower_value]
else:
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
raise ValueError(f"Unsupported type '{param_type_enum.value}' for form data parameter '{param_name}'")
@classmethod
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
"""Validate JSON values against expected types.
Args:
@@ -590,43 +608,42 @@ class WebhookService:
Raises:
ValueError: If the value type doesn't match the expected type
"""
type_validators = {
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
SegmentType.ARRAY_STRING: (
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
"array of strings",
),
SegmentType.ARRAY_NUMBER: (
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
"array of numbers",
),
SegmentType.ARRAY_BOOLEAN: (
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
"array of booleans",
),
SegmentType.ARRAY_OBJECT: (
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
"array of objects",
),
}
validator_info = type_validators.get(SegmentType(param_type))
if not validator_info:
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
if param_type_enum is None:
return value
validator, expected_type = validator_info
if not validator(value):
if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL):
actual_type = type(value).__name__
expected_type = cls._expected_type_label(param_type_enum)
raise ValueError(f"Expected {expected_type}, got {actual_type}")
return value
@classmethod
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None:
if isinstance(param_type, SegmentType):
return param_type
try:
return SegmentType(param_type)
except Exception:
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
return None
@staticmethod
def _expected_type_label(param_type: SegmentType) -> str:
match param_type:
case SegmentType.ARRAY_STRING:
return "array of strings"
case SegmentType.ARRAY_NUMBER:
return "array of numbers"
case SegmentType.ARRAY_BOOLEAN:
return "array of booleans"
case SegmentType.ARRAY_OBJECT:
return "array of objects"
case _:
return param_type.value
@classmethod
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None:
"""Validate required headers are present.
Args:
@@ -639,14 +656,14 @@ class WebhookService:
headers_lower = {k.lower(): v for k, v in headers.items()}
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
for header_config in header_configs:
if header_config.get("required", False):
header_name = header_config.get("name", "")
if header_config.required:
header_name = header_config.name
sanitized_name = cls._sanitize_key(header_name).lower()
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
raise ValueError(f"Required header missing: {header_name}")
@classmethod
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
"""Validate HTTP method and content-type.
Args:
@@ -657,13 +674,13 @@ class WebhookService:
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
"""
# Validate HTTP method
configured_method = node_data.get("method", "get").upper()
configured_method = node_data.method.value.upper()
request_method = webhook_data["method"].upper()
if configured_method != request_method:
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
# Validate Content-type
configured_content_type = node_data.get("content_type", "application/json").lower()
configured_content_type = node_data.content_type.value.lower()
request_content_type = cls._extract_content_type(webhook_data["headers"])
if configured_content_type != request_content_type:
@@ -788,7 +805,7 @@ class WebhookService:
raise
@classmethod
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]:
"""Generate HTTP response based on node configuration.
Args:
@@ -797,11 +814,11 @@ class WebhookService:
Returns:
tuple[dict[str, Any], int]: Response data and HTTP status code
"""
node_data = node_config.get("data", {})
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
# Get configured status code and response body
status_code = node_data.get("status_code", 200)
response_body = node_data.get("response_body", "")
status_code = node_data.status_code
response_body = node_data.response_body
# Parse response body as JSON if it's valid JSON, otherwise return as text
try:

View File

@@ -16,6 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -694,7 +695,7 @@ class WorkflowService:
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
node_data = node_config["data"]
if node_type.is_start_node:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
@@ -704,7 +705,7 @@ class WorkflowService:
workflow=draft_workflow,
)
if node_type is NodeType.START:
start_data = StartNodeData.model_validate(node_data)
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
)
@@ -1063,6 +1064,7 @@ class WorkflowService:
node_config: Mapping[str, Any],
variable_pool: VariablePool,
) -> HumanInputNode:
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
@@ -1078,8 +1080,8 @@ class WorkflowService:
start_at=time.perf_counter(),
)
node = HumanInputNode(
id=node_config.get("id", str(uuid.uuid4())),
config=node_config,
id=typed_node_config["id"],
config=typed_node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
@@ -1093,6 +1095,7 @@ class WorkflowService:
node_config: Mapping[str, Any],
manual_inputs: Mapping[str, Any],
) -> VariablePool:
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
@@ -1111,7 +1114,7 @@ class WorkflowService:
)
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=node_config,
config=typed_node_config,
)
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)

View File

@@ -190,6 +190,7 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
from core.workflow.enums import NodeType
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -200,7 +201,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
# Create variable pool
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="test", files=[]),
user_inputs={},
@@ -208,8 +208,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
conversation_variables=[],
)
# Create node data with custom auth and empty api_key
node_data = HttpRequestNodeData(
type=NodeType.HTTP_REQUEST,
title="http",
desc="",
url="http://example.com",
@@ -228,7 +228,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
ssl_verify=True,
)
# Create executor should raise AuthorizationConfigError
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,

View File

@@ -172,7 +172,7 @@ class TestWebhookService:
assert workflow.app_id == test_data["app"].id
assert node_config is not None
assert node_config["id"] == "webhook_node"
assert node_config["data"]["title"] == "Test Webhook"
assert node_config["data"].title == "Test Webhook"
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
"""Test webhook trigger not found scenario."""

View File

@@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
# load dotenv file with pydantic-settings
config = DifyConfig()
# Disable `.env` loading to ensure test stability across environments
config = DifyConfig(_env_file=None)
# constant values
assert config.COMMIT_SHA == ""
@@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
config = DifyConfig()
# Disable `.env` loading to ensure test stability across environments
config = DifyConfig(_env_file=None)
# Verify default timeout values
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
@@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*")
monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/")
flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore
# Disable `.env` loading to ensure test stability across environments
flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore
config = flask_app.config
# configs read from pydantic-settings

View File

@@ -126,6 +126,33 @@ def test_graph_initialization_runs_default_validators(
assert "answer" in graph.nodes
def test_graph_initialization_filters_custom_note_nodes_before_validation(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
):
node_factory, graph_config = graph_init_dependencies
graph_config["nodes"] = [
{
"id": "note",
"type": "custom-note",
"data": {
"type": "",
"title": "",
"text": "UI-only note node",
},
},
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
]
graph_config["edges"] = [
{"source": "start", "target": "answer", "sourceHandle": "success"},
]
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
assert "note" not in graph.nodes
assert graph.root_node.id == "start"
def test_graph_validation_fails_for_unknown_edge_targets(
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
) -> None:

View File

@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
@@ -39,7 +40,7 @@ def test_abort_command():
# Create mock nodes with required attributes - using shared runtime state
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -149,7 +150,7 @@ def test_pause_command():
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -205,7 +206,7 @@ def test_update_variables_command_updates_pool():
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",

View File

@@ -9,6 +9,8 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities.base_node import BaseNodeData
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.nodes.base.node import Node
@@ -75,34 +77,36 @@ class MockNodeFactory(DifyNodeFactory):
NodeType.CODE: MockCodeNode,
}
def create_node(self, node_config: Mapping[str, Any]) -> Node:
def create_node(self, node_config: Mapping[str, Any] | NodeConfigDict) -> Node:
"""
Create a node instance, using mock implementations for third-party service nodes.
:param node_config: Node configuration dictionary
:return: Node instance (real or mocked)
"""
# Get node type from config
node_data = node_config.get("data", {})
node_type_str = node_data.get("type")
node_data = node_config.get("data")
if not node_type_str:
# Fall back to parent implementation for nodes without type
# Support both dict-based and BaseNodeData-based configurations.
node_type: NodeType | None = None
if isinstance(node_data, BaseNodeData):
node_type = node_data.type
elif isinstance(node_data, Mapping):
node_type_str = node_data.get("type")
if node_type_str:
try:
node_type = NodeType(node_type_str)
except ValueError:
return super().create_node(node_config)
if node_type is None:
return super().create_node(node_config)
try:
node_type = NodeType(node_type_str)
except ValueError:
# Unknown node type, use parent implementation
return super().create_node(node_config)
# Check if this node type should be mocked
# Check if this node type should be mocked.
if node_type in self._mock_node_types:
node_id = node_config.get("id")
if not node_id:
raise ValueError("Node config missing id")
# Create mock node instance
mock_class = self._mock_node_types[node_type]
if node_type == NodeType.CODE:
mock_instance = mock_class(
@@ -147,7 +151,7 @@ class MockNodeFactory(DifyNodeFactory):
return mock_instance
# For non-mocked node types, use parent implementation
# For non-mocked node types, use parent implementation.
return super().create_node(node_config)
def should_mock_node(self, node_type: NodeType) -> bool:

View File

@@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
@@ -63,7 +64,7 @@ class TestStopEventPropagation:
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -111,7 +112,7 @@ class TestStopEventPropagation:
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -195,7 +196,10 @@ class TestNodeStopCheck:
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
config={
"id": "answer",
"data": {"type": NodeType.ANSWER, "title": "answer", "answer": "{{#start.result#}}"},
},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -225,7 +229,7 @@ class TestNodeStopCheck:
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
config={"id": "answer", "data": {"type": NodeType.ANSWER, "title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -276,7 +280,7 @@ class TestStopEventIntegration:
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -292,7 +296,7 @@ class TestStopEventIntegration:
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
config={"id": "answer", "data": {"type": NodeType.ANSWER, "title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -343,7 +347,10 @@ class TestStopEventIntegration:
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
config={
"id": f"answer_{i}",
"data": {"type": NodeType.ANSWER, "title": f"answer_{i}", "answer": f"test{i}"},
},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@@ -447,7 +454,7 @@ class TestStopEventResumeBehavior:
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",

View File

@@ -49,7 +49,7 @@ def test_node_hydrates_data_during_initialization():
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
config={"id": "node-1", "data": {"type": NodeType.ANSWER, "title": "Sample", "foo": "bar"}},
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
@@ -58,6 +58,20 @@ def test_node_hydrates_data_during_initialization():
assert node.title == "Sample"
def test_node_initialization_falls_back_to_node_type_when_data_type_is_missing():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
assert node.node_data.type == NodeType.ANSWER
def test_missing_generic_argument_raises_type_error():
graph_config: dict[str, object] = {}

View File

@@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias():
def test_webhook_data_validation_errors():
"""Test WebhookData validation errors."""
# Title is required (inherited from BaseNodeData)
with pytest.raises(ValidationError):
WebhookData()
# Invalid method
with pytest.raises(ValidationError):

View File

@@ -8,6 +8,7 @@ from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
)
from core.workflow.entities.graph_config import NodeConfigDictAdapter
from core.workflow.file.enums import FileType
from core.workflow.file.models import File, FileTransferMethod
from core.workflow.nodes.code.code_node import CodeNode
@@ -124,7 +125,7 @@ class TestWorkflowEntry:
def get_node_config_by_id(self, target_id: str):
assert target_id == node_id
return node_config
return NodeConfigDictAdapter.validate_python(node_config)
workflow = StubWorkflow()
variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})