Compare commits

...

1 Commits

Author SHA1 Message Date
Novice
e93040f63c fix: iteration and loop node single step run 2025-09-02 09:40:24 +08:00
4 changed files with 55 additions and 83 deletions

View File

@@ -89,6 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
graph_runtime_state.variable_pool = variable_pool
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
@@ -101,6 +102,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
graph_runtime_state.variable_pool = variable_pool
else: else:
inputs = self.application_generate_entity.inputs inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query query = self.application_generate_entity.query

View File

@@ -149,7 +149,9 @@ class WorkflowBasedAppRunner:
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id if node.get("id") == node_id
or node.get("data", {}).get("iteration_id", "") == node_id
or node.get("id") == f"{node_id}start"
] ]
graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs
@@ -264,7 +266,9 @@ class WorkflowBasedAppRunner:
node_configs = [ node_configs = [
node node
for node in graph_config.get("nodes", []) for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id if node.get("id") == node_id
or node.get("data", {}).get("loop_id", "") == node_id
or node.get("id") == f"{node_id}start"
] ]
graph_config["nodes"] = node_configs graph_config["nodes"] = node_configs

View File

@@ -208,43 +208,16 @@ class IterationNode(Node):
variable_mapping: dict[str, Sequence[str]] = { variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": typed_node_data.iterator_selector, f"{node_id}.input_selector": typed_node_data.iterator_selector,
} }
iteration_node_ids = set()
# init graph # Find all nodes that belong to this loop
from core.workflow.entities import GraphInitParams, GraphRuntimeState nodes = graph_config.get("nodes", [])
from core.workflow.graph import Graph for node in nodes:
from core.workflow.nodes.node_factory import DifyNodeFactory node_data = node.get("data", {})
if node_data.get("iteration_id") == node_id:
# Create minimal GraphInitParams for static analysis in_iteration_node_id = node.get("id")
graph_init_params = GraphInitParams( if in_iteration_node_id:
tenant_id="", iteration_node_ids.add(in_iteration_node_id)
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis
from core.workflow.entities import VariablePool
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
iteration_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping # Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
@@ -280,9 +253,7 @@ class IterationNode(Node):
variable_mapping.update(sub_node_variable_mapping) variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration # remove variable out from iteration
variable_mapping = { variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids}
key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids
}
return variable_mapping return variable_mapping

View File

@@ -1,3 +1,4 @@
import contextlib
import json import json
import logging import logging
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
@@ -126,11 +127,13 @@ class LoopNode(Node):
try: try:
reach_break_condition = False reach_break_condition = False
if break_conditions: if break_conditions:
_, _, reach_break_condition = condition_processor.process_conditions( with contextlib.suppress(ValueError):
variable_pool=self.graph_runtime_state.variable_pool, _, _, reach_break_condition = condition_processor.process_conditions(
conditions=break_conditions, variable_pool=self.graph_runtime_state.variable_pool,
operator=logical_operator, conditions=break_conditions,
) operator=logical_operator,
)
if reach_break_condition: if reach_break_condition:
loop_count = 0 loop_count = 0
cost_tokens = 0 cost_tokens = 0
@@ -294,42 +297,11 @@ class LoopNode(Node):
variable_mapping = {} variable_mapping = {}
# init graph # Extract loop node IDs statically from graph_config
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
# Create minimal GraphInitParams for static analysis loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
graph_init_params = GraphInitParams(
tenant_id="",
app_id="",
workflow_id="",
graph_config=graph_config,
user_id="",
user_from="",
invoke_from="",
call_depth=0,
)
# Create minimal GraphRuntimeState for static analysis # Get node configs from graph_config
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(),
start_at=0,
)
# Create node factory for static analysis
node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state)
loop_graph = Graph.init(
graph_config=graph_config,
node_factory=node_factory,
root_node_id=typed_node_data.start_node_id,
)
if not loop_graph:
raise ValueError("loop graph not found")
# Get node configs from graph_config instead of non-existent node_id_config_mapping
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
for sub_node_id, sub_node_config in node_configs.items(): for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id: if sub_node_config.get("data", {}).get("loop_id") != node_id:
@@ -370,12 +342,35 @@ class LoopNode(Node):
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop # remove variable out from loop
variable_mapping = { variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids}
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
}
return variable_mapping return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
This method statically analyzes the graph configuration to find all nodes
that are part of the specified loop, without creating actual node instances.
:param graph_config: the complete graph configuration
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
loop_node_ids.add(node_id)
return loop_node_ids
@staticmethod @staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
"""Get the appropriate segment type for a constant value.""" """Get the appropriate segment type for a constant value."""