mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 21:15:10 +00:00
Compare commits
51 Commits
move-core-
...
yanli/pyda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
217f402da4 | ||
|
|
465f36e691 | ||
|
|
93e4c108f8 | ||
|
|
13d45769d7 | ||
|
|
94b31b101e | ||
|
|
262a720673 | ||
|
|
b3d074f04e | ||
|
|
af7c2ca5e9 | ||
|
|
834a7bedf6 | ||
|
|
03fb284031 | ||
|
|
994402b510 | ||
|
|
c5ed658c1e | ||
|
|
d15502d7ac | ||
|
|
747badbccd | ||
|
|
93cbaac910 | ||
|
|
b98e71e52a | ||
|
|
1575c0db02 | ||
|
|
98fc3cda3c | ||
|
|
4e18718dc8 | ||
|
|
b6bdd9996a | ||
|
|
f1a1791954 | ||
|
|
3ae27377ec | ||
|
|
4aefb98711 | ||
|
|
e8b8320a3d | ||
|
|
4ab0f3c1d3 | ||
|
|
1c242800b0 | ||
|
|
2b4e4905bf | ||
|
|
4fdd621e1b | ||
|
|
afadf54a2b | ||
|
|
59c729c376 | ||
|
|
36996d44a4 | ||
|
|
07eb7504e3 | ||
|
|
4a6798790d | ||
|
|
c0e3cf6f18 | ||
|
|
3b82ef08f1 | ||
|
|
da26cd1c12 | ||
|
|
b17ddd0ea7 | ||
|
|
4f6f110014 | ||
|
|
98f302c720 | ||
|
|
3f957f7c0f | ||
|
|
1a8b9ee1aa | ||
|
|
ae65f0909c | ||
|
|
f3c6ca4a6f | ||
|
|
f1b6c70893 | ||
|
|
c3bf1ac541 | ||
|
|
1b1c1424ae | ||
|
|
05b109841d | ||
|
|
23e00c1397 | ||
|
|
fdff4b15bb | ||
|
|
4809ad9bf1 | ||
|
|
8d85f51a3a |
@@ -31,6 +31,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -310,7 +311,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=target_node_config
|
||||
graph_config=workflow.graph_dict, config=NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -17,7 +17,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
@@ -115,7 +115,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data using the traditional mapping.
|
||||
|
||||
@@ -124,14 +124,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
:raises ValueError: if node type is unknown or configuration is invalid
|
||||
"""
|
||||
# Get node_id from config
|
||||
node_id = node_config["id"]
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
|
||||
# Get node type from config
|
||||
node_data = node_config["data"]
|
||||
try:
|
||||
node_type = NodeType(node_data["type"])
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown node type: {node_data['type']}")
|
||||
node_data = typed_node_config["data"]
|
||||
node_type = node_data.type
|
||||
|
||||
# Get node class
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
@@ -139,7 +137,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
raise ValueError(f"No class mapping found for node type: {node_type}")
|
||||
|
||||
latest_node_class = node_mapping.get(LATEST_VERSION)
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_version = str(node_data.version)
|
||||
matched_node_class = node_mapping.get(node_version)
|
||||
node_class = matched_node_class or latest_node_class
|
||||
if not node_class:
|
||||
@@ -149,7 +147,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.CODE:
|
||||
return CodeNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
code_executor=self._code_executor,
|
||||
@@ -159,7 +157,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.TEMPLATE_TRANSFORM:
|
||||
return TemplateTransformNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
@@ -169,7 +167,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
return HttpRequestNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
http_request_config=self._http_request_config,
|
||||
@@ -179,11 +177,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
node_data_dict = node_data.model_dump()
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data_dict)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data_dict, model_instance=model_instance)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
@@ -195,7 +194,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
@@ -204,7 +203,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
@@ -213,17 +212,17 @@ class DifyNodeFactory(NodeFactory):
|
||||
if node_type == NodeType.DOCUMENT_EXTRACTOR:
|
||||
return DocumentExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data.model_dump())
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
@@ -232,10 +231,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data.model_dump())
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
@@ -245,7 +244,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.trigger.debug.events import (
|
||||
build_plugin_pool_key,
|
||||
build_webhook_pool_key,
|
||||
)
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
@@ -41,10 +42,10 @@ class TriggerDebugEventPoller(ABC):
|
||||
app_id: str
|
||||
user_id: str
|
||||
tenant_id: str
|
||||
node_config: Mapping[str, Any]
|
||||
node_config: NodeConfigDict
|
||||
node_id: str
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str):
|
||||
def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.app_id = app_id
|
||||
@@ -60,7 +61,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller):
|
||||
def poll(self) -> TriggerDebugEvent | None:
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {}))
|
||||
plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True)
|
||||
provider_id = TriggerProviderID(plugin_trigger_data.provider_id)
|
||||
pool_key: str = build_plugin_pool_key(
|
||||
name=plugin_trigger_data.event_name,
|
||||
|
||||
195
api/core/workflow/entities/base_node.py
Normal file
195
api/core/workflow/entities/base_node.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeType
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class OutputVariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
SECRET = "secret"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY = "array"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ANY = "any"
|
||||
ARRAY_ANY = "array[any]"
|
||||
|
||||
|
||||
class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
Output Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType = OutputVariableType.ANY
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize value_type to handle case-insensitive array types.
|
||||
Converts 'Array[...]' to 'array[...]' for backward compatibility.
|
||||
"""
|
||||
if isinstance(v, str) and v.startswith("Array["):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
type: NodeType
|
||||
title: str = ""
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = Field(default_factory=RetryConfig)
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
@@ -4,21 +4,18 @@ import sys
|
||||
|
||||
from pydantic import TypeAdapter, with_config
|
||||
|
||||
from core.workflow.entities.base_node import BaseNodeData
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigData(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
@with_config(extra="allow")
|
||||
class NodeConfigDict(TypedDict):
|
||||
id: str
|
||||
data: NodeConfigData
|
||||
data: BaseNodeData
|
||||
|
||||
|
||||
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)
|
||||
|
||||
@@ -3,12 +3,12 @@ from __future__ import annotations
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Protocol, cast, final
|
||||
from typing import Any, Protocol, cast, final
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.typing import is_str
|
||||
|
||||
@@ -28,7 +28,7 @@ class NodeFactory(Protocol):
|
||||
allowing for different node creation strategies while maintaining type safety.
|
||||
"""
|
||||
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a Node instance from node configuration data.
|
||||
|
||||
@@ -115,10 +115,7 @@ class Graph:
|
||||
start_node_id = None
|
||||
for nid in root_candidates:
|
||||
node_data = node_configs_map[nid]["data"]
|
||||
node_type = node_data["type"]
|
||||
if not isinstance(node_type, str):
|
||||
continue
|
||||
if NodeType(node_type).is_start_node:
|
||||
if node_data.type.is_start_node:
|
||||
start_node_id = nid
|
||||
break
|
||||
|
||||
@@ -197,6 +194,27 @@ class Graph:
|
||||
|
||||
return nodes
|
||||
|
||||
@staticmethod
|
||||
def _is_custom_note_node(node_config: object) -> bool:
|
||||
"""
|
||||
Check whether a raw graph node is a UI-only custom note node.
|
||||
|
||||
Custom note nodes are not executable workflow nodes, and their `data.type`
|
||||
field may be empty, so they must be filtered out before strict node schema
|
||||
validation.
|
||||
"""
|
||||
if not isinstance(node_config, Mapping):
|
||||
return False
|
||||
|
||||
if node_config.get("type") == "custom-note":
|
||||
return True
|
||||
|
||||
node_data = node_config.get("data")
|
||||
if not isinstance(node_data, Mapping):
|
||||
return False
|
||||
|
||||
return node_data.get("type") == "custom-note"
|
||||
|
||||
@classmethod
|
||||
def new(cls) -> GraphBuilder:
|
||||
"""Create a fluent builder for assembling a graph programmatically."""
|
||||
@@ -302,13 +320,13 @@ class Graph:
|
||||
node_configs = graph_config.get("nodes", [])
|
||||
|
||||
edge_configs = cast(list[dict[str, object]], edge_configs)
|
||||
node_configs = cast(list[object], node_configs)
|
||||
node_configs = [node_config for node_config in node_configs if not cls._is_custom_note_node(node_config)]
|
||||
node_configs = _ListNodeConfigDict.validate_python(node_configs)
|
||||
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
|
||||
@@ -367,12 +367,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
result: dict[str, Any] = {}
|
||||
typed_node_data = node_data
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
|
||||
@@ -5,10 +5,12 @@ from pydantic import BaseModel
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.AGENT
|
||||
agent_strategy_provider_name: str # redundancy
|
||||
agent_strategy_name: str
|
||||
agent_strategy_label: str # redundancy
|
||||
|
||||
@@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: AnswerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import StrEnum, auto
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData):
|
||||
Answer Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ANSWER
|
||||
answer: str = Field(..., description="answer template string")
|
||||
|
||||
|
||||
|
||||
@@ -1,185 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from builtins import type as type_
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from core.workflow.enums import ErrorStrategy
|
||||
|
||||
from .exc import DefaultValueTypeError
|
||||
|
||||
_NumberType = Union[int, float]
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_selector: Sequence[str]
|
||||
|
||||
|
||||
class OutputVariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
SECRET = "secret"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
ARRAY = "array"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ANY = "any"
|
||||
ARRAY_ANY = "array[any]"
|
||||
|
||||
|
||||
class OutputVariableEntity(BaseModel):
|
||||
"""
|
||||
Output Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
value_type: OutputVariableType = OutputVariableType.ANY
|
||||
value_selector: Sequence[str]
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_value_type(cls, v: Any) -> Any:
|
||||
"""
|
||||
Normalize value_type to handle case-insensitive array types.
|
||||
Converts 'Array[...]' to 'array[...]' for backward compatibility.
|
||||
"""
|
||||
if isinstance(v, str) and v.startswith("Array["):
|
||||
return v.lower()
|
||||
return v
|
||||
|
||||
|
||||
class DefaultValueType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
OBJECT = "object"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILES = "array[file]"
|
||||
|
||||
|
||||
class DefaultValue(BaseModel):
|
||||
value: Any = None
|
||||
type: DefaultValueType
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||
|
||||
@staticmethod
|
||||
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||
"""Unified array type validation"""
|
||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(value: str) -> float:
|
||||
"""Unified number conversion handler"""
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_value_type(self) -> DefaultValue:
|
||||
# Type validation configuration
|
||||
type_validators: dict[DefaultValueType, dict[str, Any]] = {
|
||||
DefaultValueType.STRING: {
|
||||
"type": str,
|
||||
"converter": lambda x: x,
|
||||
},
|
||||
DefaultValueType.NUMBER: {
|
||||
"type": _NumberType,
|
||||
"converter": self._convert_number,
|
||||
},
|
||||
DefaultValueType.OBJECT: {
|
||||
"type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_NUMBER: {
|
||||
"type": list,
|
||||
"element_type": _NumberType,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_STRING: {
|
||||
"type": list,
|
||||
"element_type": str,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
DefaultValueType.ARRAY_OBJECT: {
|
||||
"type": list,
|
||||
"element_type": dict,
|
||||
"converter": self._parse_json,
|
||||
},
|
||||
}
|
||||
|
||||
validator: dict[str, Any] = type_validators.get(self.type, {})
|
||||
if not validator:
|
||||
if self.type == DefaultValueType.ARRAY_FILES:
|
||||
# Handle files type
|
||||
return self
|
||||
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||
|
||||
# Handle string input cases
|
||||
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||
self.value = validator["converter"](self.value)
|
||||
|
||||
# Validate base type
|
||||
if not isinstance(self.value, validator["type"]):
|
||||
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||
|
||||
# Validate array element types
|
||||
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: str | None = None
|
||||
version: str = "1"
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
from core.workflow.entities.base_node import (
|
||||
BaseNodeData,
|
||||
BaseNodeError,
|
||||
DefaultValue,
|
||||
DefaultValueType,
|
||||
DefaultValueTypeError,
|
||||
OutputVariableEntity,
|
||||
OutputVariableType,
|
||||
RetryConfig,
|
||||
VariableSelector,
|
||||
)
|
||||
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
@@ -210,3 +43,20 @@ class BaseLoopState(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNodeData",
|
||||
"BaseNodeError",
|
||||
"DefaultValue",
|
||||
"DefaultValueType",
|
||||
"DefaultValueTypeError",
|
||||
"OutputVariableEntity",
|
||||
"OutputVariableType",
|
||||
"RetryConfig",
|
||||
"VariableSelector",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
from core.workflow.entities.base_node import BaseNodeError, DefaultValueTypeError
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DefaultValueTypeError(BaseNodeError):
|
||||
"""Raised when the default value type is invalid."""
|
||||
|
||||
pass
|
||||
__all__ = [
|
||||
"BaseNodeError",
|
||||
"DefaultValueTypeError",
|
||||
]
|
||||
|
||||
@@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
@@ -76,7 +77,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
_node_data_type: ClassVar[type[NodeDataT]] # type: ignore[misc] # assigned per-subclass in __init_subclass__
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
"""
|
||||
@@ -130,11 +131,11 @@ class Node(Generic[NodeDataT]):
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
config["data"] ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
@@ -182,7 +183,7 @@ class Node(Generic[NodeDataT]):
|
||||
bucket["latest"] = bucket[latest_key]
|
||||
|
||||
@classmethod
|
||||
def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None:
|
||||
def _extract_node_data_type_from_generic(cls) -> type[NodeDataT] | None:
|
||||
"""
|
||||
Extract the node data type from the generic parameter `Node[T]`.
|
||||
|
||||
@@ -208,7 +209,7 @@ class Node(Generic[NodeDataT]):
|
||||
if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData):
|
||||
raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype")
|
||||
|
||||
return candidate
|
||||
return candidate # type: ignore[return-value]
|
||||
|
||||
return None
|
||||
|
||||
@@ -218,10 +219,16 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: dict[str, Any] | NodeConfigDict, # NodeConfigDict
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a node and hydrate typed node data from config.
|
||||
|
||||
For backward compatibility, when incoming mapping-based `data` misses a
|
||||
`type`, this constructor fills it with `self.node_type` before validation.
|
||||
"""
|
||||
self._graph_init_params = graph_init_params
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
@@ -235,19 +242,25 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
if "id" not in config:
|
||||
raise ValueError("node config missing required 'id' field")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
if "data" not in config:
|
||||
raise ValueError(f"node config for node {node_id} missing required 'data' field")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
if isinstance(config["data"], BaseNodeData):
|
||||
self._node_data = self._node_data_type.model_validate(config["data"], from_attributes=True)
|
||||
elif isinstance(config["data"], dict):
|
||||
if "type" not in config["data"]:
|
||||
config["data"]["type"] = self.node_type
|
||||
self._node_data = self._node_data_type.model_validate(config["data"])
|
||||
else:
|
||||
raise TypeError(f"node config 'data' field must be a dict or {self._node_data_type.__name__} instance")
|
||||
|
||||
self.post_init()
|
||||
|
||||
@@ -291,9 +304,6 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@@ -342,8 +352,6 @@ class Node(Generic[NodeDataT]):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.nodes.agent.agent_node import AgentNode
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
|
||||
@@ -410,7 +418,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@@ -448,13 +456,13 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
node_id = config["id"]
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_data = cls._node_data_type.model_validate(config["data"], from_attributes=True)
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -464,7 +472,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
@@ -519,23 +527,23 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._node_data.error_strategy
|
||||
return self.node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._node_data.retry_config
|
||||
return self.node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._node_data.title
|
||||
return self.node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
"""Get the node description."""
|
||||
return self._node_data.desc
|
||||
return self.node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._node_data.default_value_dict
|
||||
return self.node_data.default_value_dict
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
|
||||
@@ -3,6 +3,7 @@ from decimal import Decimal
|
||||
from textwrap import dedent
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -77,7 +78,7 @@ class CodeNode(Node[CodeNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: CodeNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Annotated, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.variables.types import SegmentType
|
||||
@@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.CODE
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: dict[str, "CodeNodeData.Output"] | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
@@ -34,7 +35,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
@@ -180,7 +181,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DatasourceNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -189,7 +190,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
typed_node_data = DatasourceNodeData.model_validate(node_data)
|
||||
typed_node_data = node_data
|
||||
result = {}
|
||||
if typed_node_data.datasource_parameters:
|
||||
for parameter_name in typed_node_data.datasource_parameters:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
@@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel):
|
||||
|
||||
|
||||
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||
type: NodeType = NodeType.DATASOURCE
|
||||
|
||||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class DocumentExtractorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.DOCUMENT_EXTRACTOR
|
||||
variable_selector: Sequence[str]
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File, FileTransferMethod, file_manager
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@@ -53,7 +54,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -118,12 +119,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: DocumentExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity
|
||||
|
||||
|
||||
@@ -8,6 +9,7 @@ class EndNodeData(BaseNodeData):
|
||||
END Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.END
|
||||
outputs: list[OutputVariableEntity]
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import charset_normalizer
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config"
|
||||
@@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
|
||||
Code Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.HTTP_REQUEST
|
||||
method: Literal[
|
||||
"get",
|
||||
"post",
|
||||
|
||||
@@ -3,6 +3,7 @@ import mimetypes
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File, FileTransferMethod
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@@ -37,7 +38,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -163,18 +164,16 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HttpRequestNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
match body_type:
|
||||
case "none":
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import (
|
||||
@@ -64,7 +65,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: dict[str, Any] | NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
@@ -342,7 +343,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: HumanInputNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
@@ -351,5 +352,4 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
return node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
@@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData):
|
||||
If Else Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.IF_ELSE
|
||||
|
||||
class Case(BaseModel):
|
||||
"""
|
||||
Case entity representing a single logical condition group
|
||||
|
||||
@@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IfElseNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in typed_node_data.cases or []:
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
for case in node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = f"{node_id}.#{'.'.join(condition.variable_selector)}#"
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
@@ -17,6 +18,7 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
Iteration Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.ITERATION
|
||||
parent_loop_id: str | None = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
@@ -31,7 +33,7 @@ class IterationStartNodeData(BaseNodeData):
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.ITERATION_START
|
||||
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -460,21 +461,24 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: IterationNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
}
|
||||
iteration_node_ids = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
if node_data.get("iteration_id") == node_id:
|
||||
nodes_value = graph_config.get("nodes", [])
|
||||
for node in nodes_value:
|
||||
if not isinstance(node, Mapping):
|
||||
continue
|
||||
|
||||
node_data_value = node.get("data", {})
|
||||
if not isinstance(node_data_value, Mapping):
|
||||
continue
|
||||
|
||||
if node_data_value.get("iteration_id") == node_id:
|
||||
in_iteration_node_id = node.get("id")
|
||||
if in_iteration_node_id:
|
||||
iteration_node_ids.add(in_iteration_node_id)
|
||||
@@ -497,7 +501,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Literal, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -155,7 +156,7 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
Knowledge index Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-index"
|
||||
type: NodeType = NodeType.KNOWLEDGE_INDEX
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
@@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
|
||||
type: str = "knowledge-retrieval"
|
||||
type: NodeType = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
query_variable_selector: list[str] | None | str = None
|
||||
query_attachment_selector: list[str] | None | str = None
|
||||
dataset_ids: list[str]
|
||||
|
||||
@@ -6,6 +6,7 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
@@ -48,7 +49,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
@@ -260,15 +261,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: KnowledgeRetrievalNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
if typed_node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
if node_data.query_variable_selector:
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
if node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -62,6 +63,7 @@ class ExtractConfig(BaseModel):
|
||||
|
||||
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LIST_OPERATOR
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderByConfig
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.LLM
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
|
||||
@@ -44,6 +44,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -119,7 +120,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
@@ -951,14 +952,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LLMNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
prompt_template = node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list):
|
||||
for prompt in prompt_template:
|
||||
@@ -976,7 +974,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = typed_node_data.memory
|
||||
memory = node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
@@ -984,16 +982,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
|
||||
if typed_node_data.memory:
|
||||
if node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY]
|
||||
|
||||
if typed_node_data.prompt_config:
|
||||
if node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
@@ -1006,7 +1004,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
break
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.variables.types import SegmentType
|
||||
@@ -39,6 +40,7 @@ class LoopVariableData(BaseModel):
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
type: NodeType = NodeType.LOOP
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
@@ -58,7 +60,7 @@ class LoopStartNodeData(BaseNodeData):
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_START
|
||||
|
||||
|
||||
class LoopEndNodeData(BaseNodeData):
|
||||
@@ -66,7 +68,7 @@ class LoopEndNodeData(BaseNodeData):
|
||||
Loop End Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
type: NodeType = NodeType.LOOP_END
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
@@ -327,7 +325,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
graph_config=graph_config, config=NodeConfigDictAdapter.validate_python(sub_node_config)
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
@@ -342,7 +340,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from core.workflow.variables.types import SegmentType
|
||||
@@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.PARAMETER_EXTRACTOR
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
@@ -101,7 +102,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -834,15 +835,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ParameterExtractorNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
|
||||
@@ -11,6 +12,7 @@ class ClassConfig(BaseModel):
|
||||
|
||||
|
||||
class QuestionClassifierNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.QUESTION_CLASSIFIER
|
||||
query_variable_selector: list[str]
|
||||
model: ModelConfig
|
||||
classes: list[ClassConfig]
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
NodeType,
|
||||
@@ -60,7 +61,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -246,16 +247,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: QuestionClassifierNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# graph_config is not used in this node type
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -11,4 +12,5 @@ class StartNodeData(BaseNodeData):
|
||||
Start Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.START
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@@ -7,5 +8,6 @@ class TemplateTransformNodeData(BaseNodeData):
|
||||
Template Transform Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TEMPLATE_TRANSFORM
|
||||
variables: list[VariableSelector]
|
||||
template: str
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -26,7 +27,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
@@ -87,12 +88,10 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in typed_node_data.variables
|
||||
for variable_selector in node_data.variables
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
@@ -32,6 +33,8 @@ class ToolEntity(BaseModel):
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = NodeType.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
|
||||
@@ -467,7 +467,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: ToolNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -476,9 +476,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
typed_node_data = node_data
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Literal, Union
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.trigger.entities.entities import EventParameter
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
|
||||
@@ -11,6 +12,8 @@ from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError
|
||||
class TriggerEventNodeData(BaseNodeData):
|
||||
"""Plugin trigger node data"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_PLUGIN
|
||||
|
||||
class TriggerEventInput(BaseModel):
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
@@ -38,7 +41,6 @@ class TriggerEventNodeData(BaseNodeData):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return type
|
||||
|
||||
title: str
|
||||
desc: str | None = None
|
||||
plugin_id: str = Field(..., description="Plugin ID")
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -10,6 +11,7 @@ class TriggerScheduleNodeData(BaseNodeData):
|
||||
Trigger Schedule Node Data
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_SCHEDULE
|
||||
mode: str = Field(default="visual", description="Schedule mode: visual or cron")
|
||||
frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly")
|
||||
cron_expression: str | None = Field(default=None, description="Cron expression for cron mode")
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
_WEBHOOK_ALLOWED_TYPES = frozenset(
|
||||
{
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.FILE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Method(StrEnum):
|
||||
@@ -28,26 +43,31 @@ class WebhookParameter(BaseModel):
|
||||
"""Parameter definition for headers, query params, or body."""
|
||||
|
||||
name: str
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookBodyParameter(BaseModel):
|
||||
"""Body parameter with type information."""
|
||||
|
||||
name: str
|
||||
type: Literal[
|
||||
"string",
|
||||
"number",
|
||||
"boolean",
|
||||
"object",
|
||||
"array[string]",
|
||||
"array[number]",
|
||||
"array[boolean]",
|
||||
"array[object]",
|
||||
"file",
|
||||
] = "string"
|
||||
type: SegmentType = SegmentType.STRING
|
||||
required: bool = False
|
||||
|
||||
@field_validator("type", mode="after")
|
||||
@classmethod
|
||||
def validate_type(cls, v: SegmentType) -> SegmentType:
|
||||
if v not in _WEBHOOK_ALLOWED_TYPES:
|
||||
raise ValueError(f"Unsupported webhook body parameter type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class WebhookData(BaseNodeData):
|
||||
"""
|
||||
@@ -57,6 +77,7 @@ class WebhookData(BaseNodeData):
|
||||
class SyncMode(StrEnum):
|
||||
SYNC = "async" # only support
|
||||
|
||||
type: NodeType = NodeType.TRIGGER_WEBHOOK
|
||||
method: Method = Method.GET
|
||||
content_type: ContentType = Field(default=ContentType.JSON)
|
||||
headers: Sequence[WebhookParameter] = Field(default_factory=list)
|
||||
|
||||
@@ -151,7 +151,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
||||
outputs[param_name] = raw_data
|
||||
continue
|
||||
|
||||
if param_type == "file":
|
||||
if param_type == SegmentType.FILE:
|
||||
# Get File object (already processed by webhook controller)
|
||||
files = webhook_data.get("files", {})
|
||||
if files and isinstance(files, dict):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.variables.types import SegmentType
|
||||
|
||||
@@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData):
|
||||
Variable Aggregator Node Data.
|
||||
"""
|
||||
|
||||
type: NodeType = NodeType.VARIABLE_AGGREGATOR
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_settings: AdvancedSettings | None = None
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -22,7 +23,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
@@ -52,21 +53,19 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: VariableAssignerData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
mapping = {}
|
||||
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(typed_node_data.assigned_variable_selector)
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.assigned_variable_selector
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(typed_node_data.input_variable_selector)
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.input_variable_selector
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
@@ -11,6 +12,7 @@ class WriteMode(StrEnum):
|
||||
|
||||
|
||||
class VariableAssignerData(BaseNodeData):
|
||||
type: NodeType = NodeType.VARIABLE_ASSIGNER
|
||||
assigned_variable_selector: Sequence[str]
|
||||
write_mode: WriteMode
|
||||
input_variable_selector: Sequence[str]
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
from .enums import InputType, Operation
|
||||
@@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel):
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
type: NodeType = NodeType.VARIABLE_ASSIGNER
|
||||
version: str = "2"
|
||||
items: Sequence[VariableOperationItem] = Field(default_factory=list)
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -56,7 +57,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
):
|
||||
@@ -94,13 +95,11 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: VariableAssignerNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
|
||||
|
||||
_ = graph_config # Explicitly mark as unused
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in typed_node_data.items:
|
||||
for item in node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.graph import Graph
|
||||
@@ -148,7 +148,7 @@ class WorkflowEntry:
|
||||
node_config_data = node_config["data"]
|
||||
|
||||
# Get node type
|
||||
node_type = NodeType(node_config_data["type"])
|
||||
node_type = node_config_data.type
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -303,10 +303,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config: NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"data": cast(NodeConfigData, node_data),
|
||||
}
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -234,7 +234,7 @@ class Workflow(Base): # bug
|
||||
def get_node_config_by_id(self, node_id: str) -> NodeConfigDict:
|
||||
"""Extract a node configuration from the workflow graph by node ID.
|
||||
A node configuration is a dictionary containing the node's properties, including
|
||||
the node's id, title, and its data as a dict.
|
||||
the node's id and its validated `data` payload as a `BaseNodeData` model.
|
||||
"""
|
||||
workflow_graph = self.graph_dict
|
||||
|
||||
@@ -252,12 +252,9 @@ class Workflow(Base): # bug
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
@staticmethod
|
||||
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
|
||||
def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType:
|
||||
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
|
||||
node_config_data = node_config.get("data", {})
|
||||
# Get node class
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
return node_type
|
||||
return node_config["data"].type
|
||||
|
||||
@staticmethod
|
||||
def get_enclosing_node_type_and_id(
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.entities import (
|
||||
ScheduleConfig,
|
||||
SchedulePlanUpdate,
|
||||
TriggerScheduleNodeData,
|
||||
VisualConfig,
|
||||
)
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
@@ -176,26 +180,26 @@ class ScheduleService:
|
||||
return next_run_at
|
||||
|
||||
@staticmethod
|
||||
def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig:
|
||||
def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig:
|
||||
"""
|
||||
Converts user-friendly visual schedule settings to cron expression.
|
||||
Maintains consistency with frontend UI expectations while supporting croniter's extended syntax.
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node_config.get("id", "start")
|
||||
node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True)
|
||||
mode = node_data.mode
|
||||
timezone = node_data.timezone
|
||||
node_id = node_config["id"]
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
cron_expression = node_data.cron_expression
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = str(node_data.get("frequency"))
|
||||
frequency = str(node_data.frequency or "")
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig(**node_data.get("visual_config", {}))
|
||||
visual_config = VisualConfig.model_validate(node_data.visual_config or {})
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config)
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for visual mode")
|
||||
@@ -239,19 +243,21 @@ class ScheduleService:
|
||||
if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value:
|
||||
continue
|
||||
|
||||
mode = node_data.get("mode", "visual")
|
||||
timezone = node_data.get("timezone", "UTC")
|
||||
node_id = node.get("id", "start")
|
||||
trigger_data = TriggerScheduleNodeData.model_validate(node_data)
|
||||
mode = trigger_data.mode
|
||||
timezone = trigger_data.timezone
|
||||
|
||||
cron_expression = None
|
||||
if mode == "cron":
|
||||
cron_expression = node_data.get("cron_expression")
|
||||
cron_expression = trigger_data.cron_expression
|
||||
if not cron_expression:
|
||||
raise ScheduleConfigError("Cron expression is required for cron mode")
|
||||
elif mode == "visual":
|
||||
frequency = node_data.get("frequency")
|
||||
visual_config_dict = node_data.get("visual_config", {})
|
||||
visual_config = VisualConfig(**visual_config_dict)
|
||||
frequency = trigger_data.frequency
|
||||
if not frequency:
|
||||
raise ScheduleConfigError("Frequency is required for visual mode")
|
||||
visual_config = VisualConfig.model_validate(trigger_data.visual_config or {})
|
||||
cron_expression = ScheduleService.visual_to_cron(frequency, visual_config)
|
||||
else:
|
||||
raise ScheduleConfigError(f"Invalid schedule mode: {mode}")
|
||||
|
||||
@@ -16,6 +16,7 @@ from core.trigger.debug.events import PluginTriggerDebugEvent
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
|
||||
from extensions.ext_database import db
|
||||
@@ -41,7 +42,7 @@ class TriggerService:
|
||||
|
||||
@classmethod
|
||||
def invoke_trigger_event(
|
||||
cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent
|
||||
cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent
|
||||
) -> TriggerInvokeEventResponse:
|
||||
"""Invoke a trigger event."""
|
||||
subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id(
|
||||
@@ -50,7 +51,7 @@ class TriggerService:
|
||||
)
|
||||
if not subscription:
|
||||
raise ValueError("Subscription not found")
|
||||
node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {}))
|
||||
node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True)
|
||||
request = TriggerHttpRequestCachingService.get_request(event.request_id)
|
||||
payload = TriggerHttpRequestCachingService.get_payload(event.request_id)
|
||||
# invoke triger
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import secrets
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
@@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.file.models import FileTransferMethod
|
||||
from core.workflow.variables.types import SegmentType
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from core.workflow.variables.types import ArrayValidation, SegmentType
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@@ -57,7 +64,7 @@ class WebhookService:
|
||||
@classmethod
|
||||
def get_webhook_trigger_and_workflow(
|
||||
cls, webhook_id: str, is_debug: bool = False
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]:
|
||||
) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]:
|
||||
"""Get webhook trigger, workflow, and node configuration.
|
||||
|
||||
Args:
|
||||
@@ -129,13 +136,13 @@ class WebhookService:
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")
|
||||
|
||||
node_config = workflow.get_node_config_by_id(webhook_trigger.node_id)
|
||||
node_config = NodeConfigDictAdapter.validate_python(workflow.get_node_config_by_id(webhook_trigger.node_id))
|
||||
|
||||
return webhook_trigger, workflow, node_config
|
||||
|
||||
@classmethod
|
||||
def extract_and_validate_webhook_data(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any]
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict
|
||||
) -> dict[str, Any]:
|
||||
"""Extract and validate webhook data in a single unified process.
|
||||
|
||||
@@ -153,7 +160,7 @@ class WebhookService:
|
||||
raw_data = cls.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# Validate HTTP metadata (method, content-type)
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
|
||||
validation_result = cls._validate_http_metadata(raw_data, node_data)
|
||||
if not validation_result["valid"]:
|
||||
raise ValueError(validation_result["error"])
|
||||
@@ -192,7 +199,7 @@ class WebhookService:
|
||||
content_type = cls._extract_content_type(dict(request.headers))
|
||||
|
||||
# Route to appropriate extractor based on content type
|
||||
extractors = {
|
||||
extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = {
|
||||
"application/json": cls._extract_json_body,
|
||||
"application/x-www-form-urlencoded": cls._extract_form_body,
|
||||
"multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger),
|
||||
@@ -214,7 +221,7 @@ class WebhookService:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
"""Process and validate webhook data according to node configuration.
|
||||
|
||||
Args:
|
||||
@@ -230,18 +237,13 @@ class WebhookService:
|
||||
result = raw_data.copy()
|
||||
|
||||
# Validate and process headers
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.get("headers", []))
|
||||
cls._validate_required_headers(raw_data["headers"], node_data.headers)
|
||||
|
||||
# Process query parameters with type conversion and validation
|
||||
result["query_params"] = cls._process_parameters(
|
||||
raw_data["query_params"], node_data.get("params", []), is_form_data=True
|
||||
)
|
||||
result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True)
|
||||
|
||||
# Process body parameters based on content type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
result["body"] = cls._process_body_parameters(
|
||||
raw_data["body"], node_data.get("body", []), configured_content_type
|
||||
)
|
||||
result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type)
|
||||
|
||||
return result
|
||||
|
||||
@@ -424,7 +426,11 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def _process_parameters(
|
||||
cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False
|
||||
cls,
|
||||
raw_params: dict[str, str],
|
||||
param_configs: Sequence[WebhookParameter],
|
||||
*,
|
||||
is_form_data: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Process parameters with unified validation and type conversion.
|
||||
|
||||
@@ -440,13 +446,13 @@ class WebhookService:
|
||||
ValueError: If required parameters are missing or validation fails
|
||||
"""
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in param_configs}
|
||||
configured_params = {config.name: config for config in param_configs}
|
||||
|
||||
# Process configured parameters
|
||||
for param_config in param_configs:
|
||||
name = param_config.get("name", "")
|
||||
param_type = param_config.get("type", SegmentType.STRING)
|
||||
required = param_config.get("required", False)
|
||||
name = param_config.name
|
||||
param_type = param_config.type
|
||||
required = param_config.required
|
||||
|
||||
# Check required parameters
|
||||
if required and name not in raw_params:
|
||||
@@ -465,7 +471,10 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def _process_body_parameters(
|
||||
cls, raw_body: dict[str, Any], body_configs: list, content_type: str
|
||||
cls,
|
||||
raw_body: dict[str, Any],
|
||||
body_configs: Sequence[WebhookBodyParameter],
|
||||
content_type: ContentType,
|
||||
) -> dict[str, Any]:
|
||||
"""Process body parameters based on content type and configuration.
|
||||
|
||||
@@ -480,25 +489,28 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If required body parameters are missing or validation fails
|
||||
"""
|
||||
if content_type in ["text/plain", "application/octet-stream"]:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.get("required", False) for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
match content_type:
|
||||
case ContentType.TEXT | ContentType.BINARY:
|
||||
# For text/plain and octet-stream, validate required content exists
|
||||
if body_configs and any(config.required for config in body_configs):
|
||||
raw_content = raw_body.get("raw")
|
||||
if not raw_content:
|
||||
raise ValueError(f"Required body content missing for {content_type} request")
|
||||
return raw_body
|
||||
case _:
|
||||
pass
|
||||
|
||||
# For structured data (JSON, form-data, etc.)
|
||||
processed = {}
|
||||
configured_params = {config.get("name", ""): config for config in body_configs}
|
||||
configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs}
|
||||
|
||||
for body_config in body_configs:
|
||||
name = body_config.get("name", "")
|
||||
param_type = body_config.get("type", SegmentType.STRING)
|
||||
required = body_config.get("required", False)
|
||||
name = body_config.name
|
||||
param_type = body_config.type
|
||||
required = body_config.required
|
||||
|
||||
# Handle file parameters for multipart data
|
||||
if param_type == SegmentType.FILE and content_type == "multipart/form-data":
|
||||
if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA:
|
||||
# File validation is handled separately in extract phase
|
||||
continue
|
||||
|
||||
@@ -508,7 +520,7 @@ class WebhookService:
|
||||
|
||||
if name in raw_body:
|
||||
raw_value = raw_body[name]
|
||||
is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"]
|
||||
is_form_data = content_type in (ContentType.FORM_URLENCODED, ContentType.FORM_DATA)
|
||||
processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data)
|
||||
|
||||
# Include unconfigured parameters
|
||||
@@ -519,7 +531,9 @@ class WebhookService:
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any:
|
||||
def _validate_and_convert_value(
|
||||
cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool
|
||||
) -> Any:
|
||||
"""Unified validation and type conversion for parameter values.
|
||||
|
||||
Args:
|
||||
@@ -545,7 +559,7 @@ class WebhookService:
|
||||
raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any:
|
||||
def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any:
|
||||
"""Convert form data string values to specified types.
|
||||
|
||||
Args:
|
||||
@@ -559,24 +573,28 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If the value cannot be converted to the specified type
|
||||
"""
|
||||
if param_type == SegmentType.STRING:
|
||||
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
|
||||
if param_type_enum is None:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
|
||||
if param_type_enum == SegmentType.STRING:
|
||||
return value
|
||||
elif param_type == SegmentType.NUMBER:
|
||||
elif param_type_enum == SegmentType.NUMBER:
|
||||
if not cls._can_convert_to_number(value):
|
||||
raise ValueError(f"Cannot convert '{value}' to number")
|
||||
numeric_value = float(value)
|
||||
return int(numeric_value) if numeric_value.is_integer() else numeric_value
|
||||
elif param_type == SegmentType.BOOLEAN:
|
||||
elif param_type_enum == SegmentType.BOOLEAN:
|
||||
lower_value = value.lower()
|
||||
bool_map = {"true": True, "false": False, "1": True, "0": False, "yes": True, "no": False}
|
||||
if lower_value not in bool_map:
|
||||
raise ValueError(f"Cannot convert '{value}' to boolean")
|
||||
return bool_map[lower_value]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'")
|
||||
raise ValueError(f"Unsupported type '{param_type_enum.value}' for form data parameter '{param_name}'")
|
||||
|
||||
@classmethod
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any:
|
||||
def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any:
|
||||
"""Validate JSON values against expected types.
|
||||
|
||||
Args:
|
||||
@@ -590,43 +608,42 @@ class WebhookService:
|
||||
Raises:
|
||||
ValueError: If the value type doesn't match the expected type
|
||||
"""
|
||||
type_validators = {
|
||||
SegmentType.STRING: (lambda v: isinstance(v, str), "string"),
|
||||
SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"),
|
||||
SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"),
|
||||
SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"),
|
||||
SegmentType.ARRAY_STRING: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v),
|
||||
"array of strings",
|
||||
),
|
||||
SegmentType.ARRAY_NUMBER: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v),
|
||||
"array of numbers",
|
||||
),
|
||||
SegmentType.ARRAY_BOOLEAN: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v),
|
||||
"array of booleans",
|
||||
),
|
||||
SegmentType.ARRAY_OBJECT: (
|
||||
lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v),
|
||||
"array of objects",
|
||||
),
|
||||
}
|
||||
|
||||
validator_info = type_validators.get(SegmentType(param_type))
|
||||
if not validator_info:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name)
|
||||
if param_type_enum is None:
|
||||
return value
|
||||
|
||||
validator, expected_type = validator_info
|
||||
if not validator(value):
|
||||
if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL):
|
||||
actual_type = type(value).__name__
|
||||
expected_type = cls._expected_type_label(param_type_enum)
|
||||
raise ValueError(f"Expected {expected_type}, got {actual_type}")
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None:
|
||||
def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None:
|
||||
if isinstance(param_type, SegmentType):
|
||||
return param_type
|
||||
try:
|
||||
return SegmentType(param_type)
|
||||
except Exception:
|
||||
logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _expected_type_label(param_type: SegmentType) -> str:
|
||||
match param_type:
|
||||
case SegmentType.ARRAY_STRING:
|
||||
return "array of strings"
|
||||
case SegmentType.ARRAY_NUMBER:
|
||||
return "array of numbers"
|
||||
case SegmentType.ARRAY_BOOLEAN:
|
||||
return "array of booleans"
|
||||
case SegmentType.ARRAY_OBJECT:
|
||||
return "array of objects"
|
||||
case _:
|
||||
return param_type.value
|
||||
|
||||
@classmethod
|
||||
def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None:
|
||||
"""Validate required headers are present.
|
||||
|
||||
Args:
|
||||
@@ -639,14 +656,14 @@ class WebhookService:
|
||||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||||
headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()}
|
||||
for header_config in header_configs:
|
||||
if header_config.get("required", False):
|
||||
header_name = header_config.get("name", "")
|
||||
if header_config.required:
|
||||
header_name = header_config.name
|
||||
sanitized_name = cls._sanitize_key(header_name).lower()
|
||||
if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized:
|
||||
raise ValueError(f"Required header missing: {header_name}")
|
||||
|
||||
@classmethod
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]:
|
||||
"""Validate HTTP method and content-type.
|
||||
|
||||
Args:
|
||||
@@ -657,13 +674,13 @@ class WebhookService:
|
||||
dict[str, Any]: Validation result with 'valid' key and optional 'error' key
|
||||
"""
|
||||
# Validate HTTP method
|
||||
configured_method = node_data.get("method", "get").upper()
|
||||
configured_method = node_data.method.value.upper()
|
||||
request_method = webhook_data["method"].upper()
|
||||
if configured_method != request_method:
|
||||
return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}")
|
||||
|
||||
# Validate Content-type
|
||||
configured_content_type = node_data.get("content_type", "application/json").lower()
|
||||
configured_content_type = node_data.content_type.value.lower()
|
||||
request_content_type = cls._extract_content_type(webhook_data["headers"])
|
||||
|
||||
if configured_content_type != request_content_type:
|
||||
@@ -788,7 +805,7 @@ class WebhookService:
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]:
|
||||
"""Generate HTTP response based on node configuration.
|
||||
|
||||
Args:
|
||||
@@ -797,11 +814,11 @@ class WebhookService:
|
||||
Returns:
|
||||
tuple[dict[str, Any], int]: Response data and HTTP status code
|
||||
"""
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = WebhookData.model_validate(node_config["data"], from_attributes=True)
|
||||
|
||||
# Get configured status code and response body
|
||||
status_code = node_data.get("status_code", 200)
|
||||
response_body = node_data.get("response_body", "")
|
||||
status_code = node_data.status_code
|
||||
response_body = node_data.response_body
|
||||
|
||||
# Parse response body as JSON if it's valid JSON, otherwise return as text
|
||||
try:
|
||||
|
||||
@@ -16,6 +16,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities import GraphInitParams, WorkflowNodeExecution
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -694,7 +695,7 @@ class WorkflowService:
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id)
|
||||
node_type = Workflow.get_node_type_from_node_config(node_config)
|
||||
node_data = node_config.get("data", {})
|
||||
node_data = node_config["data"]
|
||||
if node_type.is_start_node:
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
@@ -704,7 +705,7 @@ class WorkflowService:
|
||||
workflow=draft_workflow,
|
||||
)
|
||||
if node_type is NodeType.START:
|
||||
start_data = StartNodeData.model_validate(node_data)
|
||||
start_data = StartNodeData.model_validate(node_data, from_attributes=True)
|
||||
user_inputs = _rebuild_file_for_user_inputs_in_start_node(
|
||||
tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
|
||||
)
|
||||
@@ -1063,6 +1064,7 @@ class WorkflowService:
|
||||
node_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> HumanInputNode:
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
@@ -1078,8 +1080,8 @@ class WorkflowService:
|
||||
start_at=time.perf_counter(),
|
||||
)
|
||||
node = HumanInputNode(
|
||||
id=node_config.get("id", str(uuid.uuid4())),
|
||||
config=node_config,
|
||||
id=typed_node_config["id"],
|
||||
config=typed_node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
@@ -1093,6 +1095,7 @@ class WorkflowService:
|
||||
node_config: Mapping[str, Any],
|
||||
manual_inputs: Mapping[str, Any],
|
||||
) -> VariablePool:
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(node_config)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
|
||||
draft_var_srv = WorkflowDraftVariableService(session)
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
@@ -1111,7 +1114,7 @@ class WorkflowService:
|
||||
)
|
||||
variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config,
|
||||
config=typed_node_config,
|
||||
)
|
||||
normalized_user_inputs: dict[str, Any] = dict(manual_inputs)
|
||||
|
||||
|
||||
@@ -190,6 +190,7 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.http_request.entities import (
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeData,
|
||||
@@ -200,7 +201,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Create variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="test", files=[]),
|
||||
user_inputs={},
|
||||
@@ -208,8 +208,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create node data with custom auth and empty api_key
|
||||
node_data = HttpRequestNodeData(
|
||||
type=NodeType.HTTP_REQUEST,
|
||||
title="http",
|
||||
desc="",
|
||||
url="http://example.com",
|
||||
@@ -228,7 +228,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
|
||||
ssl_verify=True,
|
||||
)
|
||||
|
||||
# Create executor should raise AuthorizationConfigError
|
||||
with pytest.raises(AuthorizationConfigError, match="API key is required"):
|
||||
Executor(
|
||||
node_data=node_data,
|
||||
|
||||
@@ -172,7 +172,7 @@ class TestWebhookService:
|
||||
assert workflow.app_id == test_data["app"].id
|
||||
assert node_config is not None
|
||||
assert node_config["id"] == "webhook_node"
|
||||
assert node_config["data"]["title"] == "Test Webhook"
|
||||
assert node_config["data"].title == "Test Webhook"
|
||||
|
||||
def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers):
|
||||
"""Test webhook trigger not found scenario."""
|
||||
|
||||
@@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing
|
||||
|
||||
# load dotenv file with pydantic-settings
|
||||
config = DifyConfig()
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
# constant values
|
||||
assert config.COMMIT_SHA == ""
|
||||
@@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
|
||||
config = DifyConfig()
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
config = DifyConfig(_env_file=None)
|
||||
|
||||
# Verify default timeout values
|
||||
assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10
|
||||
@@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*")
|
||||
monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/")
|
||||
|
||||
flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore
|
||||
# Disable `.env` loading to ensure test stability across environments
|
||||
flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore
|
||||
config = flask_app.config
|
||||
|
||||
# configs read from pydantic-settings
|
||||
|
||||
@@ -126,6 +126,33 @@ def test_graph_initialization_runs_default_validators(
|
||||
assert "answer" in graph.nodes
|
||||
|
||||
|
||||
def test_graph_initialization_filters_custom_note_nodes_before_validation(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
):
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{
|
||||
"id": "note",
|
||||
"type": "custom-note",
|
||||
"data": {
|
||||
"type": "",
|
||||
"title": "",
|
||||
"text": "UI-only note node",
|
||||
},
|
||||
},
|
||||
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||
{"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}},
|
||||
]
|
||||
graph_config["edges"] = [
|
||||
{"source": "start", "target": "answer", "sourceHandle": "success"},
|
||||
]
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert "note" not in graph.nodes
|
||||
assert graph.root_node.id == "start"
|
||||
|
||||
|
||||
def test_graph_validation_fails_for_unknown_edge_targets(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
) -> None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@@ -39,7 +40,7 @@ def test_abort_command():
|
||||
# Create mock nodes with required attributes - using shared runtime state
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -149,7 +150,7 @@ def test_pause_command():
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -205,7 +206,7 @@ def test_update_variables_command_updates_pool():
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
|
||||
@@ -9,6 +9,8 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities.base_node import BaseNodeData
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
@@ -75,34 +77,36 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
NodeType.CODE: MockCodeNode,
|
||||
}
|
||||
|
||||
def create_node(self, node_config: Mapping[str, Any]) -> Node:
|
||||
def create_node(self, node_config: Mapping[str, Any] | NodeConfigDict) -> Node:
|
||||
"""
|
||||
Create a node instance, using mock implementations for third-party service nodes.
|
||||
|
||||
:param node_config: Node configuration dictionary
|
||||
:return: Node instance (real or mocked)
|
||||
"""
|
||||
# Get node type from config
|
||||
node_data = node_config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
node_data = node_config.get("data")
|
||||
|
||||
if not node_type_str:
|
||||
# Fall back to parent implementation for nodes without type
|
||||
# Support both dict-based and BaseNodeData-based configurations.
|
||||
node_type: NodeType | None = None
|
||||
if isinstance(node_data, BaseNodeData):
|
||||
node_type = node_data.type
|
||||
elif isinstance(node_data, Mapping):
|
||||
node_type_str = node_data.get("type")
|
||||
if node_type_str:
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
return super().create_node(node_config)
|
||||
|
||||
if node_type is None:
|
||||
return super().create_node(node_config)
|
||||
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
# Unknown node type, use parent implementation
|
||||
return super().create_node(node_config)
|
||||
|
||||
# Check if this node type should be mocked
|
||||
# Check if this node type should be mocked.
|
||||
if node_type in self._mock_node_types:
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node config missing id")
|
||||
|
||||
# Create mock node instance
|
||||
mock_class = self._mock_node_types[node_type]
|
||||
if node_type == NodeType.CODE:
|
||||
mock_instance = mock_class(
|
||||
@@ -147,7 +151,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
# For non-mocked node types, use parent implementation.
|
||||
return super().create_node(node_config)
|
||||
|
||||
def should_mock_node(self, node_type: NodeType) -> bool:
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
@@ -63,7 +64,7 @@ class TestStopEventPropagation:
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -111,7 +112,7 @@ class TestStopEventPropagation:
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -195,7 +196,10 @@ class TestNodeStopCheck:
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {"type": NodeType.ANSWER, "title": "answer", "answer": "{{#start.result#}}"},
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -225,7 +229,7 @@ class TestNodeStopCheck:
|
||||
# Create a simple node
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
config={"id": "answer", "data": {"type": NodeType.ANSWER, "title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -276,7 +280,7 @@ class TestStopEventIntegration:
|
||||
# Create start and answer nodes
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -292,7 +296,7 @@ class TestStopEventIntegration:
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
config={"id": "answer", "data": {"type": NodeType.ANSWER, "title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -343,7 +347,10 @@ class TestStopEventIntegration:
|
||||
for i in range(3):
|
||||
answer_node = AnswerNode(
|
||||
id=f"answer_{i}",
|
||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
||||
config={
|
||||
"id": f"answer_{i}",
|
||||
"data": {"type": NodeType.ANSWER, "title": f"answer_{i}", "answer": f"test{i}"},
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@@ -447,7 +454,7 @@ class TestStopEventResumeBehavior:
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
config={"id": "start", "data": {"type": NodeType.START, "title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
|
||||
@@ -49,7 +49,7 @@ def test_node_hydrates_data_during_initialization():
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
config={"id": "node-1", "data": {"type": NodeType.ANSWER, "title": "Sample", "foo": "bar"}},
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
@@ -58,6 +58,20 @@ def test_node_hydrates_data_during_initialization():
|
||||
assert node.title == "Sample"
|
||||
|
||||
|
||||
def test_node_initialization_falls_back_to_node_type_when_data_type_is_missing():
|
||||
graph_config: dict[str, object] = {}
|
||||
init_params, runtime_state = _build_context(graph_config)
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert node.node_data.type == NodeType.ANSWER
|
||||
|
||||
|
||||
def test_missing_generic_argument_raises_type_error():
|
||||
graph_config: dict[str, object] = {}
|
||||
|
||||
|
||||
@@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias():
|
||||
|
||||
def test_webhook_data_validation_errors():
|
||||
"""Test WebhookData validation errors."""
|
||||
# Title is required (inherited from BaseNodeData)
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData()
|
||||
|
||||
# Invalid method
|
||||
with pytest.raises(ValidationError):
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
)
|
||||
from core.workflow.entities.graph_config import NodeConfigDictAdapter
|
||||
from core.workflow.file.enums import FileType
|
||||
from core.workflow.file.models import File, FileTransferMethod
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
@@ -124,7 +125,7 @@ class TestWorkflowEntry:
|
||||
|
||||
def get_node_config_by_id(self, target_id: str):
|
||||
assert target_id == node_id
|
||||
return node_config
|
||||
return NodeConfigDictAdapter.validate_python(node_config)
|
||||
|
||||
workflow = StubWorkflow()
|
||||
variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={})
|
||||
|
||||
Reference in New Issue
Block a user