mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 23:34:11 +00:00
refactor(graph_engine): Merge state managers into unified_state_manager
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -34,7 +34,7 @@ ignore_imports =
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
output_registry
|
||||
@@ -44,7 +44,7 @@ containers =
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
@@ -77,18 +77,8 @@ forbidden_modules =
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:state-management-layers]
|
||||
name = State Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
execution_tracker
|
||||
node_state_manager
|
||||
edge_state_manager
|
||||
containers =
|
||||
core.workflow.graph_engine.state_management
|
||||
|
||||
[importlinter:contract:worker-management-layers]
|
||||
name = Worker Management Layers
|
||||
name = Worker Management Layers
|
||||
type = layers
|
||||
layers =
|
||||
worker_pool
|
||||
@@ -119,4 +109,4 @@ name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
|
||||
@@ -32,7 +32,7 @@ from ..response_coordinator import ResponseStreamCoordinator
|
||||
if TYPE_CHECKING:
|
||||
from ..error_handling import ErrorHandler
|
||||
from ..graph_traversal import BranchHandler, EdgeProcessor
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
from .event_collector import EventCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,8 +56,8 @@ class EventHandlerRegistry:
|
||||
event_collector: "EventCollector",
|
||||
branch_handler: "BranchHandler",
|
||||
edge_processor: "EdgeProcessor",
|
||||
node_state_manager: "NodeStateManager",
|
||||
execution_tracker: "ExecutionTracker",
|
||||
node_state_manager: "UnifiedStateManager",
|
||||
execution_tracker: "UnifiedStateManager",
|
||||
error_handler: "ErrorHandler",
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -39,7 +39,7 @@ from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .output_registry import OutputRegistry
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .state_management import EdgeStateManager, ExecutionTracker, NodeStateManager
|
||||
from .state_management import UnifiedStateManager
|
||||
from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -119,10 +119,8 @@ class GraphEngine:
|
||||
def _initialize_subsystems(self) -> None:
|
||||
"""Initialize all subsystems with proper dependency injection."""
|
||||
|
||||
# State management
|
||||
self.node_state_manager = NodeStateManager(self.graph, self.ready_queue)
|
||||
self.edge_state_manager = EdgeStateManager(self.graph)
|
||||
self.execution_tracker = ExecutionTracker()
|
||||
# Unified state management - single instance handles all state operations
|
||||
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
|
||||
|
||||
# Response coordination
|
||||
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
|
||||
@@ -139,20 +137,20 @@ class GraphEngine:
|
||||
self.node_readiness_checker = NodeReadinessChecker(self.graph)
|
||||
self.edge_processor = EdgeProcessor(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
node_state_manager=self.node_state_manager,
|
||||
edge_state_manager=self.state_manager,
|
||||
node_state_manager=self.state_manager,
|
||||
response_coordinator=self.response_coordinator,
|
||||
)
|
||||
self.skip_propagator = SkipPropagator(
|
||||
graph=self.graph,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
node_state_manager=self.node_state_manager,
|
||||
edge_state_manager=self.state_manager,
|
||||
node_state_manager=self.state_manager,
|
||||
)
|
||||
self.branch_handler = BranchHandler(
|
||||
graph=self.graph,
|
||||
edge_processor=self.edge_processor,
|
||||
skip_propagator=self.skip_propagator,
|
||||
edge_state_manager=self.edge_state_manager,
|
||||
edge_state_manager=self.state_manager,
|
||||
)
|
||||
|
||||
# Event handler registry with all dependencies
|
||||
@@ -164,8 +162,8 @@ class GraphEngine:
|
||||
event_collector=self.event_collector,
|
||||
branch_handler=self.branch_handler,
|
||||
edge_processor=self.edge_processor,
|
||||
node_state_manager=self.node_state_manager,
|
||||
execution_tracker=self.execution_tracker,
|
||||
node_state_manager=self.state_manager,
|
||||
execution_tracker=self.state_manager,
|
||||
error_handler=self.error_handler,
|
||||
)
|
||||
|
||||
@@ -182,8 +180,8 @@ class GraphEngine:
|
||||
# Orchestration
|
||||
self.execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self.graph_execution,
|
||||
node_state_manager=self.node_state_manager,
|
||||
execution_tracker=self.execution_tracker,
|
||||
node_state_manager=self.state_manager,
|
||||
execution_tracker=self.state_manager,
|
||||
event_handler=self.event_handler_registry,
|
||||
event_collector=self.event_collector,
|
||||
command_processor=self.command_processor,
|
||||
@@ -335,8 +333,8 @@ class GraphEngine:
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self.graph.root_node
|
||||
self.node_state_manager.enqueue_node(root_node.id)
|
||||
self.execution_tracker.add(root_node.id)
|
||||
self.state_manager.enqueue_node(root_node.id)
|
||||
self.state_manager.add(root_node.id)
|
||||
|
||||
# Start dispatcher
|
||||
self.dispatcher.start()
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import final
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
|
||||
|
||||
from ..state_management import EdgeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
@@ -27,7 +27,7 @@ class BranchHandler:
|
||||
graph: Graph,
|
||||
edge_processor: EdgeProcessor,
|
||||
skip_propagator: SkipPropagator,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the branch handler.
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.workflow.graph import Edge, Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
|
||||
|
||||
@final
|
||||
@@ -25,8 +25,8 @@ class EdgeProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import final
|
||||
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
from ..state_management import EdgeStateManager, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
|
||||
|
||||
@final
|
||||
@@ -22,8 +22,8 @@ class SkipPropagator:
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
edge_state_manager: EdgeStateManager,
|
||||
node_state_manager: NodeStateManager,
|
||||
edge_state_manager: UnifiedStateManager,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, final
|
||||
from ..command_processing import CommandProcessor
|
||||
from ..domain import GraphExecution
|
||||
from ..event_management import EventCollector
|
||||
from ..state_management import ExecutionTracker, NodeStateManager
|
||||
from ..state_management import UnifiedStateManager
|
||||
from ..worker_management import WorkerPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,8 +26,8 @@ class ExecutionCoordinator:
|
||||
def __init__(
|
||||
self,
|
||||
graph_execution: GraphExecution,
|
||||
node_state_manager: NodeStateManager,
|
||||
execution_tracker: ExecutionTracker,
|
||||
node_state_manager: UnifiedStateManager,
|
||||
execution_tracker: UnifiedStateManager,
|
||||
event_handler: "EventHandlerRegistry",
|
||||
event_collector: EventCollector,
|
||||
command_processor: CommandProcessor,
|
||||
|
||||
@@ -5,12 +5,8 @@ This package manages node states, edge states, and execution tracking
|
||||
during workflow graph execution.
|
||||
"""
|
||||
|
||||
from .edge_state_manager import EdgeStateManager
|
||||
from .execution_tracker import ExecutionTracker
|
||||
from .node_state_manager import NodeStateManager
|
||||
from .unified_state_manager import UnifiedStateManager
|
||||
|
||||
__all__ = [
|
||||
"EdgeStateManager",
|
||||
"ExecutionTracker",
|
||||
"NodeStateManager",
|
||||
"UnifiedStateManager",
|
||||
]
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
Manager for edge states during graph execution.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class EdgeStateManager:
|
||||
"""
|
||||
Manages edge states and transitions during graph execution.
|
||||
|
||||
This handles edge state changes and provides analysis of edge
|
||||
states for decision making during execution.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
"""
|
||||
Initialize the edge state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
"""
|
||||
self.graph = graph
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Tracker for currently executing nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionTracker:
|
||||
"""
|
||||
Tracks nodes that are currently being executed.
|
||||
|
||||
This replaces the ExecutingNodesManager with a cleaner interface
|
||||
focused on tracking which nodes are in progress.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the execution tracker."""
|
||||
self._executing_nodes: set[str] = set()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def add(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def remove(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if no nodes are currently executing.
|
||||
|
||||
Returns:
|
||||
True if no nodes are executing
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes) == 0
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
Manager for node states during graph execution.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
|
||||
|
||||
@final
|
||||
class NodeStateManager:
|
||||
"""
|
||||
Manages node states and the ready queue for execution.
|
||||
|
||||
This centralizes node state transitions and enqueueing logic,
|
||||
ensuring thread-safe operations on node states.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
||||
"""
|
||||
Initialize the node state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self.graph = graph
|
||||
self.ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self.ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.nodes[node_id].state
|
||||
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Unified state manager that combines node, edge, and execution tracking.
|
||||
|
||||
This is a proposed simplification that merges NodeStateManager, EdgeStateManager,
|
||||
and ExecutionTracker into a single cohesive class.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from typing import TypedDict, final
|
||||
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Edge, Graph
|
||||
|
||||
|
||||
class EdgeStateAnalysis(TypedDict):
|
||||
"""Analysis result for edge states."""
|
||||
|
||||
has_unknown: bool
|
||||
has_taken: bool
|
||||
all_skipped: bool
|
||||
|
||||
|
||||
@final
|
||||
class UnifiedStateManager:
|
||||
"""
|
||||
Unified manager for all graph state operations.
|
||||
|
||||
This class combines the responsibilities of:
|
||||
- NodeStateManager: Node state transitions and ready queue
|
||||
- EdgeStateManager: Edge state transitions and analysis
|
||||
- ExecutionTracker: Tracking executing nodes
|
||||
|
||||
Benefits:
|
||||
- Single lock for all state operations (reduced contention)
|
||||
- Cohesive state management interface
|
||||
- Simplified dependency injection
|
||||
"""
|
||||
|
||||
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
|
||||
"""
|
||||
Initialize the unified state manager.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
ready_queue: Queue for nodes ready to execute
|
||||
"""
|
||||
self.graph = graph
|
||||
self.ready_queue = ready_queue
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Execution tracking state
|
||||
self._executing_nodes: set[str] = set()
|
||||
|
||||
# ============= Node State Operations =============
|
||||
|
||||
def enqueue_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as TAKEN and add it to the ready queue.
|
||||
|
||||
This combines the state transition and enqueueing operations
|
||||
that always occur together when preparing a node for execution.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to enqueue
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.TAKEN
|
||||
self.ready_queue.put(node_id)
|
||||
|
||||
def mark_node_skipped(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as SKIPPED.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.nodes[node_id].state = NodeState.SKIPPED
|
||||
|
||||
def is_node_ready(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is ready to be executed.
|
||||
|
||||
A node is ready when all its incoming edges from taken branches
|
||||
have been satisfied.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is ready for execution
|
||||
"""
|
||||
with self._lock:
|
||||
# Get all incoming edges to this node
|
||||
incoming_edges = self.graph.get_incoming_edges(node_id)
|
||||
|
||||
# If no incoming edges, node is always ready
|
||||
if not incoming_edges:
|
||||
return True
|
||||
|
||||
# If any edge is UNKNOWN, node is not ready
|
||||
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
||||
return False
|
||||
|
||||
# Node is ready if at least one edge is TAKEN
|
||||
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
||||
|
||||
def get_node_state(self, node_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node
|
||||
|
||||
Returns:
|
||||
The current node state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.nodes[node_id].state
|
||||
|
||||
# ============= Edge State Operations =============
|
||||
|
||||
def mark_edge_taken(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as TAKEN.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.TAKEN
|
||||
|
||||
def mark_edge_skipped(self, edge_id: str) -> None:
|
||||
"""
|
||||
Mark an edge as SKIPPED.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge to mark
|
||||
"""
|
||||
with self._lock:
|
||||
self.graph.edges[edge_id].state = NodeState.SKIPPED
|
||||
|
||||
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
||||
"""
|
||||
Analyze the states of edges and return summary flags.
|
||||
|
||||
Args:
|
||||
edges: List of edges to analyze
|
||||
|
||||
Returns:
|
||||
Analysis result with state flags
|
||||
"""
|
||||
with self._lock:
|
||||
states = {edge.state for edge in edges}
|
||||
|
||||
return EdgeStateAnalysis(
|
||||
has_unknown=NodeState.UNKNOWN in states,
|
||||
has_taken=NodeState.TAKEN in states,
|
||||
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
||||
)
|
||||
|
||||
def get_edge_state(self, edge_id: str) -> NodeState:
|
||||
"""
|
||||
Get the current state of an edge.
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the edge
|
||||
|
||||
Returns:
|
||||
The current edge state
|
||||
"""
|
||||
with self._lock:
|
||||
return self.graph.edges[edge_id].state
|
||||
|
||||
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
||||
"""
|
||||
Categorize branch edges into selected and unselected.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
A tuple of (selected_edges, unselected_edges)
|
||||
"""
|
||||
with self._lock:
|
||||
outgoing_edges = self.graph.get_outgoing_edges(node_id)
|
||||
selected_edges: list[Edge] = []
|
||||
unselected_edges: list[Edge] = []
|
||||
|
||||
for edge in outgoing_edges:
|
||||
if edge.source_handle == selected_handle:
|
||||
selected_edges.append(edge)
|
||||
else:
|
||||
unselected_edges.append(edge)
|
||||
|
||||
return selected_edges, unselected_edges
|
||||
|
||||
# ============= Execution Tracking Operations =============
|
||||
|
||||
def start_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node starting execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.add(node_id)
|
||||
|
||||
def finish_execution(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node as no longer executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node finishing execution
|
||||
"""
|
||||
with self._lock:
|
||||
self._executing_nodes.discard(node_id)
|
||||
|
||||
def is_executing(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a node is currently executing.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to check
|
||||
|
||||
Returns:
|
||||
True if the node is executing
|
||||
"""
|
||||
with self._lock:
|
||||
return node_id in self._executing_nodes
|
||||
|
||||
def get_executing_count(self) -> int:
|
||||
"""
|
||||
Get the count of currently executing nodes.
|
||||
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
def get_executing_nodes(self) -> set[str]:
|
||||
"""
|
||||
Get a copy of the set of executing node IDs.
|
||||
|
||||
Returns:
|
||||
Set of node IDs currently executing
|
||||
"""
|
||||
with self._lock:
|
||||
return self._executing_nodes.copy()
|
||||
|
||||
def clear_executing(self) -> None:
|
||||
"""Clear all executing nodes."""
|
||||
with self._lock:
|
||||
self._executing_nodes.clear()
|
||||
|
||||
# ============= Composite Operations =============
|
||||
|
||||
def is_execution_complete(self) -> bool:
|
||||
"""
|
||||
Check if graph execution is complete.
|
||||
|
||||
Execution is complete when:
|
||||
- Ready queue is empty
|
||||
- No nodes are executing
|
||||
|
||||
Returns:
|
||||
True if execution is complete
|
||||
"""
|
||||
with self._lock:
|
||||
return self.ready_queue.empty() and len(self._executing_nodes) == 0
|
||||
|
||||
def get_queue_depth(self) -> int:
|
||||
"""
|
||||
Get the current depth of the ready queue.
|
||||
|
||||
Returns:
|
||||
Number of nodes in the ready queue
|
||||
"""
|
||||
return self.ready_queue.qsize()
|
||||
|
||||
def get_execution_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get execution statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with execution statistics
|
||||
"""
|
||||
with self._lock:
|
||||
taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN)
|
||||
skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED)
|
||||
unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
||||
|
||||
return {
|
||||
"queue_depth": self.ready_queue.qsize(),
|
||||
"executing": len(self._executing_nodes),
|
||||
"taken_nodes": taken_nodes,
|
||||
"skipped_nodes": skipped_nodes,
|
||||
"unknown_nodes": unknown_nodes,
|
||||
}
|
||||
|
||||
# ============= Backward Compatibility Methods =============
|
||||
# These methods provide compatibility with existing code
|
||||
|
||||
@property
|
||||
def execution_tracker(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for ExecutionTracker access."""
|
||||
return self
|
||||
|
||||
@property
|
||||
def node_state_manager(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for NodeStateManager access."""
|
||||
return self
|
||||
|
||||
@property
|
||||
def edge_state_manager(self) -> "UnifiedStateManager":
|
||||
"""Compatibility property for EdgeStateManager access."""
|
||||
return self
|
||||
|
||||
# ExecutionTracker compatibility methods
|
||||
def add(self, node_id: str) -> None:
|
||||
"""Compatibility method for ExecutionTracker.add()."""
|
||||
self.start_execution(node_id)
|
||||
|
||||
def remove(self, node_id: str) -> None:
|
||||
"""Compatibility method for ExecutionTracker.remove()."""
|
||||
self.finish_execution(node_id)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Compatibility method for ExecutionTracker.is_empty()."""
|
||||
return len(self._executing_nodes) == 0
|
||||
|
||||
def count(self) -> int:
|
||||
"""Compatibility method for ExecutionTracker.count()."""
|
||||
return self.get_executing_count()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Compatibility method for ExecutionTracker.clear()."""
|
||||
self.clear_executing()
|
||||
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
Enhanced worker pool with integrated activity tracking and dynamic scaling.
|
||||
|
||||
This is a proposed simplification that merges WorkerPool, ActivityTracker,
|
||||
and DynamicScaler into a single cohesive class.
|
||||
"""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
from ..worker import Worker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class EnhancedWorkerPool:
|
||||
"""
|
||||
Enhanced worker pool with integrated features.
|
||||
|
||||
This class combines the responsibilities of:
|
||||
- WorkerPool: Managing worker threads
|
||||
- ActivityTracker: Tracking worker activity
|
||||
- DynamicScaler: Making scaling decisions
|
||||
|
||||
Benefits:
|
||||
- Simplified interface with fewer classes
|
||||
- Direct integration of related features
|
||||
- Reduced inter-class communication overhead
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ready_queue: queue.Queue[str],
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the enhanced worker pool.
|
||||
|
||||
Args:
|
||||
ready_queue: Queue of nodes ready for execution
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
scale_down_idle_time: Seconds before scaling down idle workers
|
||||
"""
|
||||
self.ready_queue = ready_queue
|
||||
self.event_queue = event_queue
|
||||
self.graph = graph
|
||||
self.flask_app = flask_app
|
||||
self.context_vars = context_vars
|
||||
|
||||
# Scaling parameters
|
||||
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
|
||||
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
|
||||
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
|
||||
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
|
||||
|
||||
# Worker management
|
||||
self.workers: list[Worker] = []
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
|
||||
# Activity tracking (integrated)
|
||||
self._worker_activity: dict[int, tuple[bool, float]] = {}
|
||||
|
||||
# Scaling control
|
||||
self._last_scale_check = time.time()
|
||||
self._scale_check_interval = 1.0 # Check scaling every second
|
||||
|
||||
def start(self, initial_count: int | None = None) -> None:
|
||||
"""
|
||||
Start the worker pool with initial workers.
|
||||
|
||||
Args:
|
||||
initial_count: Number of workers to start with (auto-calculated if None)
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Calculate initial worker count if not specified
|
||||
if initial_count is None:
|
||||
initial_count = self._calculate_initial_workers()
|
||||
|
||||
# Create initial workers
|
||||
for _ in range(initial_count):
|
||||
self._add_worker()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop all workers in the pool."""
|
||||
with self._lock:
|
||||
self._running = False
|
||||
|
||||
# Stop all workers
|
||||
for worker in self.workers:
|
||||
worker.stop()
|
||||
|
||||
# Wait for workers to finish
|
||||
for worker in self.workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
|
||||
self.workers.clear()
|
||||
self._worker_activity.clear()
|
||||
|
||||
def check_and_scale(self) -> None:
|
||||
"""
|
||||
Check and perform scaling if needed.
|
||||
|
||||
This method should be called periodically to adjust pool size.
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Rate limit scaling checks
|
||||
if current_time - self._last_scale_check < self._scale_check_interval:
|
||||
return
|
||||
|
||||
self._last_scale_check = current_time
|
||||
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
current_count = len(self.workers)
|
||||
queue_depth = self.ready_queue.qsize()
|
||||
|
||||
# Check for scale up
|
||||
if self._should_scale_up(current_count, queue_depth):
|
||||
self._add_worker()
|
||||
|
||||
# Check for scale down
|
||||
idle_workers = self._get_idle_workers(current_time)
|
||||
if idle_workers and self._should_scale_down(current_count):
|
||||
# Remove the most idle worker
|
||||
self._remove_worker(idle_workers[0])
|
||||
|
||||
# ============= Private Methods =============
|
||||
|
||||
def _calculate_initial_workers(self) -> int:
|
||||
"""
|
||||
Calculate initial number of workers based on graph complexity.
|
||||
|
||||
Returns:
|
||||
Initial worker count
|
||||
"""
|
||||
# Simple heuristic: start with min_workers, scale based on graph size
|
||||
node_count = len(self.graph.nodes)
|
||||
|
||||
if node_count < 10:
|
||||
return self.min_workers
|
||||
elif node_count < 50:
|
||||
return min(self.min_workers + 1, self.max_workers)
|
||||
else:
|
||||
return min(self.min_workers + 2, self.max_workers)
|
||||
|
||||
def _should_scale_up(self, current_count: int, queue_depth: int) -> bool:
|
||||
"""
|
||||
Determine if pool should scale up.
|
||||
|
||||
Args:
|
||||
current_count: Current number of workers
|
||||
queue_depth: Current queue depth
|
||||
|
||||
Returns:
|
||||
True if should scale up
|
||||
"""
|
||||
if current_count >= self.max_workers:
|
||||
return False
|
||||
|
||||
# Scale up if queue is deep
|
||||
if queue_depth > self.scale_up_threshold:
|
||||
return True
|
||||
|
||||
# Scale up if all workers are busy and queue is not empty
|
||||
active_count = self._get_active_count()
|
||||
if active_count == current_count and queue_depth > 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_scale_down(self, current_count: int) -> bool:
|
||||
"""
|
||||
Determine if pool should scale down.
|
||||
|
||||
Args:
|
||||
current_count: Current number of workers
|
||||
|
||||
Returns:
|
||||
True if should scale down
|
||||
"""
|
||||
return current_count > self.min_workers
|
||||
|
||||
def _add_worker(self) -> None:
|
||||
"""Add a new worker to the pool."""
|
||||
worker_id = self._worker_counter
|
||||
self._worker_counter += 1
|
||||
|
||||
# Create worker with activity callbacks
|
||||
worker = Worker(
|
||||
ready_queue=self.ready_queue,
|
||||
event_queue=self.event_queue,
|
||||
graph=self.graph,
|
||||
worker_id=worker_id,
|
||||
flask_app=self.flask_app,
|
||||
context_vars=self.context_vars,
|
||||
on_idle_callback=self._on_worker_idle,
|
||||
on_active_callback=self._on_worker_active,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
self._worker_activity[worker_id] = (False, time.time())
|
||||
|
||||
def _remove_worker(self, worker_id: int) -> None:
|
||||
"""
|
||||
Remove a specific worker from the pool.
|
||||
|
||||
Args:
|
||||
worker_id: ID of worker to remove
|
||||
"""
|
||||
worker_to_remove = None
|
||||
for worker in self.workers:
|
||||
if worker.worker_id == worker_id:
|
||||
worker_to_remove = worker
|
||||
break
|
||||
|
||||
if worker_to_remove:
|
||||
worker_to_remove.stop()
|
||||
self.workers.remove(worker_to_remove)
|
||||
self._worker_activity.pop(worker_id, None)
|
||||
|
||||
if worker_to_remove.is_alive():
|
||||
worker_to_remove.join(timeout=1.0)
|
||||
|
||||
def _on_worker_idle(self, worker_id: int) -> None:
|
||||
"""
|
||||
Callback when worker becomes idle.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the idle worker
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity[worker_id] = (False, time.time())
|
||||
|
||||
def _on_worker_active(self, worker_id: int) -> None:
|
||||
"""
|
||||
Callback when worker becomes active.
|
||||
|
||||
Args:
|
||||
worker_id: ID of the active worker
|
||||
"""
|
||||
with self._lock:
|
||||
self._worker_activity[worker_id] = (True, time.time())
|
||||
|
||||
def _get_idle_workers(self, current_time: float) -> list[int]:
|
||||
"""
|
||||
Get list of workers that have been idle too long.
|
||||
|
||||
Args:
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
List of idle worker IDs sorted by idle time (longest first)
|
||||
"""
|
||||
idle_workers: list[tuple[int, float]] = []
|
||||
|
||||
for worker_id, (is_active, last_change) in self._worker_activity.items():
|
||||
if not is_active:
|
||||
idle_time = current_time - last_change
|
||||
if idle_time > self.scale_down_idle_time:
|
||||
idle_workers.append((worker_id, idle_time))
|
||||
|
||||
# Sort by idle time (longest first)
|
||||
idle_workers.sort(key=lambda x: x[1], reverse=True)
|
||||
return [worker_id for worker_id, _ in idle_workers]
|
||||
|
||||
def _get_active_count(self) -> int:
|
||||
"""
|
||||
Get count of currently active workers.
|
||||
|
||||
Returns:
|
||||
Number of active workers
|
||||
"""
|
||||
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)
|
||||
|
||||
# ============= Public Status Methods =============
|
||||
|
||||
def get_worker_count(self) -> int:
|
||||
"""Get current number of workers."""
|
||||
with self._lock:
|
||||
return len(self.workers)
|
||||
|
||||
def get_status(self) -> dict[str, int]:
|
||||
"""
|
||||
Get pool status information.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_workers": len(self.workers),
|
||||
"active_workers": self._get_active_count(),
|
||||
"idle_workers": len(self.workers) - self._get_active_count(),
|
||||
"queue_depth": self.ready_queue.qsize(),
|
||||
"min_workers": self.min_workers,
|
||||
"max_workers": self.max_workers,
|
||||
}
|
||||
|
||||
# ============= Backward Compatibility =============
|
||||
|
||||
def scale_up(self) -> None:
|
||||
"""Compatibility method for manual scale up."""
|
||||
with self._lock:
|
||||
if self._running and len(self.workers) < self.max_workers:
|
||||
self._add_worker()
|
||||
|
||||
def scale_down(self, worker_ids: list[int]) -> None:
|
||||
"""Compatibility method for manual scale down."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
for worker_id in worker_ids:
|
||||
if len(self.workers) > self.min_workers:
|
||||
self._remove_worker(worker_id)
|
||||
|
||||
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
|
||||
"""
|
||||
Compatibility method for checking scaling.
|
||||
|
||||
Args:
|
||||
queue_depth: Current queue depth (ignored, we check directly)
|
||||
executing_count: Number of executing nodes (ignored)
|
||||
"""
|
||||
self.check_and_scale()
|
||||
Reference in New Issue
Block a user