diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 768ad6a130..91f9ef95fe 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,7 +3,7 @@ from typing import Optional from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode -from models.workflow import Workflow +from models.workflow import Workflow, WorkflowType class WorkflowNodeAndResult: @@ -16,7 +16,11 @@ class WorkflowNodeAndResult: class WorkflowRunState: - workflow: Workflow + tenant_id: str + app_id: str + workflow_id: str + workflow_type: WorkflowType + start_at: float variable_pool: VariablePool @@ -25,6 +29,10 @@ class WorkflowRunState: workflow_nodes_and_results: list[WorkflowNodeAndResult] = [] def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool): - self.workflow = workflow + self.workflow_id = workflow.id + self.tenant_id = workflow.tenant_id + self.app_id = workflow.app_id + self.workflow_type = WorkflowType.value_of(workflow.type) + self.start_at = start_at self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3f2e806433..6db25bea7e 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -12,14 +12,25 @@ class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType + tenant_id: str + app_id: str + workflow_id: str + node_id: str node_data: BaseNodeData node_run_result: Optional[NodeRunResult] = None callbacks: list[BaseWorkflowCallback] - def __init__(self, config: dict, + def __init__(self, tenant_id: str, + app_id: str, + workflow_id: str, + config: dict, callbacks: list[BaseWorkflowCallback] = None) -> None: + self.tenant_id = tenant_id + self.app_id = app_id + self.workflow_id = workflow_id + self.node_id = config.get("id") if not self.node_id: raise ValueError("Node ID is required.") diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 50f79df1f0..d01746ceb8 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -122,6 +122,7 @@ class WorkflowEngineManager: while True: # get next node, multiple target nodes in the future next_node = self._get_next_node( + workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, callbacks=callbacks @@ -198,7 +199,8 @@ class WorkflowEngineManager: error=error ) - def _get_next_node(self, graph: dict, + def _get_next_node(self, workflow_run_state: WorkflowRunState, + graph: dict, predecessor_node: Optional[BaseNode] = None, callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]: """ @@ -216,7 +218,13 @@ class WorkflowEngineManager: if not predecessor_node: for node_config in nodes: if node_config.get('data', {}).get('type', '') == NodeType.START.value: - return StartNode(config=node_config) + return StartNode( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + config=node_config, + callbacks=callbacks + ) else: edges = graph.get('edges') source_node_id = predecessor_node.node_id @@ -256,6 +264,9 @@ class WorkflowEngineManager: target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) return target_node( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, config=target_node_config, callbacks=callbacks ) @@ -354,7 +365,7 @@ class WorkflowEngineManager: :param node_run_result: node run result :return: """ - if workflow_run_state.workflow.type == WorkflowType.CHAT.value and node.node_type == NodeType.END: + if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END: workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2] if workflow_nodes_and_result_before_end: if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM: