refactor(graph_engine): Merge state managers into unified_state_manager

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-09-01 02:08:08 +08:00
parent 546d75d84d
commit e2f4c9ba8d
13 changed files with 737 additions and 350 deletions

View File

@@ -34,7 +34,7 @@ ignore_imports =
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
layers =
graph_engine
response_coordinator
output_registry
@@ -44,7 +44,7 @@ containers =
[importlinter:contract:worker]
name = Worker
type = layers
layers =
layers =
graph_engine
worker
containers =
@@ -77,18 +77,8 @@ forbidden_modules =
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:state-management-layers]
name = State Management Layers
type = layers
layers =
execution_tracker
node_state_manager
edge_state_manager
containers =
core.workflow.graph_engine.state_management
[importlinter:contract:worker-management-layers]
name = Worker Management Layers
name = Worker Management Layers
type = layers
layers =
worker_pool
@@ -119,4 +109,4 @@ name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@@ -32,7 +32,7 @@ from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handling import ErrorHandler
from ..graph_traversal import BranchHandler, EdgeProcessor
from ..state_management import ExecutionTracker, NodeStateManager
from ..state_management import UnifiedStateManager
from .event_collector import EventCollector
logger = logging.getLogger(__name__)
@@ -56,8 +56,8 @@ class EventHandlerRegistry:
event_collector: "EventCollector",
branch_handler: "BranchHandler",
edge_processor: "EdgeProcessor",
node_state_manager: "NodeStateManager",
execution_tracker: "ExecutionTracker",
node_state_manager: "UnifiedStateManager",
execution_tracker: "UnifiedStateManager",
error_handler: "ErrorHandler",
) -> None:
"""

View File

@@ -39,7 +39,7 @@ from .orchestration import Dispatcher, ExecutionCoordinator
from .output_registry import OutputRegistry
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import EdgeStateManager, ExecutionTracker, NodeStateManager
from .state_management import UnifiedStateManager
from .worker_management import ActivityTracker, DynamicScaler, WorkerFactory, WorkerPool
logger = logging.getLogger(__name__)
@@ -119,10 +119,8 @@ class GraphEngine:
def _initialize_subsystems(self) -> None:
"""Initialize all subsystems with proper dependency injection."""
# State management
self.node_state_manager = NodeStateManager(self.graph, self.ready_queue)
self.edge_state_manager = EdgeStateManager(self.graph)
self.execution_tracker = ExecutionTracker()
# Unified state management - single instance handles all state operations
self.state_manager = UnifiedStateManager(self.graph, self.ready_queue)
# Response coordination
self.output_registry = OutputRegistry(self.graph_runtime_state.variable_pool)
@@ -139,20 +137,20 @@ class GraphEngine:
self.node_readiness_checker = NodeReadinessChecker(self.graph)
self.edge_processor = EdgeProcessor(
graph=self.graph,
edge_state_manager=self.edge_state_manager,
node_state_manager=self.node_state_manager,
edge_state_manager=self.state_manager,
node_state_manager=self.state_manager,
response_coordinator=self.response_coordinator,
)
self.skip_propagator = SkipPropagator(
graph=self.graph,
edge_state_manager=self.edge_state_manager,
node_state_manager=self.node_state_manager,
edge_state_manager=self.state_manager,
node_state_manager=self.state_manager,
)
self.branch_handler = BranchHandler(
graph=self.graph,
edge_processor=self.edge_processor,
skip_propagator=self.skip_propagator,
edge_state_manager=self.edge_state_manager,
edge_state_manager=self.state_manager,
)
# Event handler registry with all dependencies
@@ -164,8 +162,8 @@ class GraphEngine:
event_collector=self.event_collector,
branch_handler=self.branch_handler,
edge_processor=self.edge_processor,
node_state_manager=self.node_state_manager,
execution_tracker=self.execution_tracker,
node_state_manager=self.state_manager,
execution_tracker=self.state_manager,
error_handler=self.error_handler,
)
@@ -182,8 +180,8 @@ class GraphEngine:
# Orchestration
self.execution_coordinator = ExecutionCoordinator(
graph_execution=self.graph_execution,
node_state_manager=self.node_state_manager,
execution_tracker=self.execution_tracker,
node_state_manager=self.state_manager,
execution_tracker=self.state_manager,
event_handler=self.event_handler_registry,
event_collector=self.event_collector,
command_processor=self.command_processor,
@@ -335,8 +333,8 @@ class GraphEngine:
# Enqueue root node
root_node = self.graph.root_node
self.node_state_manager.enqueue_node(root_node.id)
self.execution_tracker.add(root_node.id)
self.state_manager.enqueue_node(root_node.id)
self.state_manager.add(root_node.id)
# Start dispatcher
self.dispatcher.start()

View File

@@ -8,7 +8,7 @@ from typing import final
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunStreamChunkEvent
from ..state_management import EdgeStateManager
from ..state_management import UnifiedStateManager
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
@@ -27,7 +27,7 @@ class BranchHandler:
graph: Graph,
edge_processor: EdgeProcessor,
skip_propagator: SkipPropagator,
edge_state_manager: EdgeStateManager,
edge_state_manager: UnifiedStateManager,
) -> None:
"""
Initialize the branch handler.

View File

@@ -10,7 +10,7 @@ from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..response_coordinator import ResponseStreamCoordinator
from ..state_management import EdgeStateManager, NodeStateManager
from ..state_management import UnifiedStateManager
@final
@@ -25,8 +25,8 @@ class EdgeProcessor:
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
edge_state_manager: UnifiedStateManager,
node_state_manager: UnifiedStateManager,
response_coordinator: ResponseStreamCoordinator,
) -> None:
"""

View File

@@ -7,7 +7,7 @@ from typing import final
from core.workflow.graph import Edge, Graph
from ..state_management import EdgeStateManager, NodeStateManager
from ..state_management import UnifiedStateManager
@final
@@ -22,8 +22,8 @@ class SkipPropagator:
def __init__(
self,
graph: Graph,
edge_state_manager: EdgeStateManager,
node_state_manager: NodeStateManager,
edge_state_manager: UnifiedStateManager,
node_state_manager: UnifiedStateManager,
) -> None:
"""
Initialize the skip propagator.

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventCollector
from ..state_management import ExecutionTracker, NodeStateManager
from ..state_management import UnifiedStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
@@ -26,8 +26,8 @@ class ExecutionCoordinator:
def __init__(
self,
graph_execution: GraphExecution,
node_state_manager: NodeStateManager,
execution_tracker: ExecutionTracker,
node_state_manager: UnifiedStateManager,
execution_tracker: UnifiedStateManager,
event_handler: "EventHandlerRegistry",
event_collector: EventCollector,
command_processor: CommandProcessor,

View File

@@ -5,12 +5,8 @@ This package manages node states, edge states, and execution tracking
during workflow graph execution.
"""
from .edge_state_manager import EdgeStateManager
from .execution_tracker import ExecutionTracker
from .node_state_manager import NodeStateManager
from .unified_state_manager import UnifiedStateManager
__all__ = [
"EdgeStateManager",
"ExecutionTracker",
"NodeStateManager",
"UnifiedStateManager",
]

View File

@@ -1,114 +0,0 @@
"""
Manager for edge states during graph execution.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class EdgeStateManager:
"""
Manages edge states and transitions during graph execution.
This handles edge state changes and provides analysis of edge
states for decision making during execution.
"""
def __init__(self, graph: Graph) -> None:
"""
Initialize the edge state manager.
Args:
graph: The workflow graph
"""
self.graph = graph
self._lock = threading.RLock()
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges

View File

@@ -1,89 +0,0 @@
"""
Tracker for currently executing nodes.
"""
import threading
from typing import final
@final
class ExecutionTracker:
"""
Tracks nodes that are currently being executed.
This replaces the ExecutingNodesManager with a cleaner interface
focused on tracking which nodes are in progress.
"""
def __init__(self) -> None:
"""Initialize the execution tracker."""
self._executing_nodes: set[str] = set()
self._lock = threading.RLock()
def add(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def remove(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def is_empty(self) -> bool:
"""
Check if no nodes are currently executing.
Returns:
True if no nodes are executing
"""
with self._lock:
return len(self._executing_nodes) == 0
def count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()

View File

@@ -1,97 +0,0 @@
"""
Manager for node states during graph execution.
"""
import queue
import threading
from typing import final
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
@final
class NodeStateManager:
"""
Manages node states and the ready queue for execution.
This centralizes node state transitions and enqueueing logic,
ensuring thread-safe operations on node states.
"""
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
"""
Initialize the node state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._lock = threading.RLock()
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state

View File

@@ -0,0 +1,343 @@
"""
Unified state manager that combines node, edge, and execution tracking.
This is a proposed simplification that merges NodeStateManager, EdgeStateManager,
and ExecutionTracker into a single cohesive class.
"""
import queue
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class UnifiedStateManager:
"""
Unified manager for all graph state operations.
This class combines the responsibilities of:
- NodeStateManager: Node state transitions and ready queue
- EdgeStateManager: Edge state transitions and analysis
- ExecutionTracker: Tracking executing nodes
Benefits:
- Single lock for all state operations (reduced contention)
- Cohesive state management interface
- Simplified dependency injection
"""
def __init__(self, graph: Graph, ready_queue: queue.Queue[str]) -> None:
"""
Initialize the unified state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self.graph = graph
self.ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
self._executing_nodes: set[str] = set()
# ============= Node State Operations =============
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.TAKEN
self.ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self.graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self.graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self.graph.nodes[node_id].state
# ============= Edge State Operations =============
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self.graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self.graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self.graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges
# ============= Execution Tracking Operations =============
def start_execution(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def finish_execution(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def get_executing_count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear_executing(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()
# ============= Composite Operations =============
def is_execution_complete(self) -> bool:
"""
Check if graph execution is complete.
Execution is complete when:
- Ready queue is empty
- No nodes are executing
Returns:
True if execution is complete
"""
with self._lock:
return self.ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
Get the current depth of the ready queue.
Returns:
Number of nodes in the ready queue
"""
return self.ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
Get execution statistics.
Returns:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self.graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self.ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}
# ============= Backward Compatibility Methods =============
# These methods provide compatibility with existing code
@property
def execution_tracker(self) -> "UnifiedStateManager":
"""Compatibility property for ExecutionTracker access."""
return self
@property
def node_state_manager(self) -> "UnifiedStateManager":
"""Compatibility property for NodeStateManager access."""
return self
@property
def edge_state_manager(self) -> "UnifiedStateManager":
"""Compatibility property for EdgeStateManager access."""
return self
# ExecutionTracker compatibility methods
def add(self, node_id: str) -> None:
"""Compatibility method for ExecutionTracker.add()."""
self.start_execution(node_id)
def remove(self, node_id: str) -> None:
"""Compatibility method for ExecutionTracker.remove()."""
self.finish_execution(node_id)
def is_empty(self) -> bool:
"""Compatibility method for ExecutionTracker.is_empty()."""
return len(self._executing_nodes) == 0
def count(self) -> int:
"""Compatibility method for ExecutionTracker.count()."""
return self.get_executing_count()
def clear(self) -> None:
"""Compatibility method for ExecutionTracker.clear()."""
self.clear_executing()

View File

@@ -0,0 +1,360 @@
"""
Enhanced worker pool with integrated activity tracking and dynamic scaling.
This is a proposed simplification that merges WorkerPool, ActivityTracker,
and DynamicScaler into a single cohesive class.
"""
import queue
import threading
import time
from typing import TYPE_CHECKING, final
from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..worker import Worker
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class EnhancedWorkerPool:
"""
Enhanced worker pool with integrated features.
This class combines the responsibilities of:
- WorkerPool: Managing worker threads
- ActivityTracker: Tracking worker activity
- DynamicScaler: Making scaling decisions
Benefits:
- Simplified interface with fewer classes
- Direct integration of related features
- Reduced inter-class communication overhead
"""
def __init__(
self,
ready_queue: queue.Queue[str],
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the enhanced worker pool.
Args:
ready_queue: Queue of nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self.ready_queue = ready_queue
self.event_queue = event_queue
self.graph = graph
self.flask_app = flask_app
self.context_vars = context_vars
# Scaling parameters
self.min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self.max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self.scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self.scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management
self.workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
# Activity tracking (integrated)
self._worker_activity: dict[int, tuple[bool, float]] = {}
# Scaling control
self._last_scale_check = time.time()
self._scale_check_interval = 1.0 # Check scaling every second
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool with initial workers.
Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return
self._running = True
# Calculate initial worker count if not specified
if initial_count is None:
initial_count = self._calculate_initial_workers()
# Create initial workers
for _ in range(initial_count):
self._add_worker()
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
# Stop all workers
for worker in self.workers:
worker.stop()
# Wait for workers to finish
for worker in self.workers:
if worker.is_alive():
worker.join(timeout=10.0)
self.workers.clear()
self._worker_activity.clear()
def check_and_scale(self) -> None:
"""
Check and perform scaling if needed.
This method should be called periodically to adjust pool size.
"""
current_time = time.time()
# Rate limit scaling checks
if current_time - self._last_scale_check < self._scale_check_interval:
return
self._last_scale_check = current_time
with self._lock:
if not self._running:
return
current_count = len(self.workers)
queue_depth = self.ready_queue.qsize()
# Check for scale up
if self._should_scale_up(current_count, queue_depth):
self._add_worker()
# Check for scale down
idle_workers = self._get_idle_workers(current_time)
if idle_workers and self._should_scale_down(current_count):
# Remove the most idle worker
self._remove_worker(idle_workers[0])
# ============= Private Methods =============
def _calculate_initial_workers(self) -> int:
"""
Calculate initial number of workers based on graph complexity.
Returns:
Initial worker count
"""
# Simple heuristic: start with min_workers, scale based on graph size
node_count = len(self.graph.nodes)
if node_count < 10:
return self.min_workers
elif node_count < 50:
return min(self.min_workers + 1, self.max_workers)
else:
return min(self.min_workers + 2, self.max_workers)
def _should_scale_up(self, current_count: int, queue_depth: int) -> bool:
"""
Determine if pool should scale up.
Args:
current_count: Current number of workers
queue_depth: Current queue depth
Returns:
True if should scale up
"""
if current_count >= self.max_workers:
return False
# Scale up if queue is deep
if queue_depth > self.scale_up_threshold:
return True
# Scale up if all workers are busy and queue is not empty
active_count = self._get_active_count()
if active_count == current_count and queue_depth > 0:
return True
return False
def _should_scale_down(self, current_count: int) -> bool:
"""
Determine if pool should scale down.
Args:
current_count: Current number of workers
Returns:
True if should scale down
"""
return current_count > self.min_workers
def _add_worker(self) -> None:
"""Add a new worker to the pool."""
worker_id = self._worker_counter
self._worker_counter += 1
# Create worker with activity callbacks
worker = Worker(
ready_queue=self.ready_queue,
event_queue=self.event_queue,
graph=self.graph,
worker_id=worker_id,
flask_app=self.flask_app,
context_vars=self.context_vars,
on_idle_callback=self._on_worker_idle,
on_active_callback=self._on_worker_active,
)
worker.start()
self.workers.append(worker)
self._worker_activity[worker_id] = (False, time.time())
def _remove_worker(self, worker_id: int) -> None:
"""
Remove a specific worker from the pool.
Args:
worker_id: ID of worker to remove
"""
worker_to_remove = None
for worker in self.workers:
if worker.worker_id == worker_id:
worker_to_remove = worker
break
if worker_to_remove:
worker_to_remove.stop()
self.workers.remove(worker_to_remove)
self._worker_activity.pop(worker_id, None)
if worker_to_remove.is_alive():
worker_to_remove.join(timeout=1.0)
def _on_worker_idle(self, worker_id: int) -> None:
"""
Callback when worker becomes idle.
Args:
worker_id: ID of the idle worker
"""
with self._lock:
self._worker_activity[worker_id] = (False, time.time())
def _on_worker_active(self, worker_id: int) -> None:
"""
Callback when worker becomes active.
Args:
worker_id: ID of the active worker
"""
with self._lock:
self._worker_activity[worker_id] = (True, time.time())
def _get_idle_workers(self, current_time: float) -> list[int]:
"""
Get list of workers that have been idle too long.
Args:
current_time: Current timestamp
Returns:
List of idle worker IDs sorted by idle time (longest first)
"""
idle_workers: list[tuple[int, float]] = []
for worker_id, (is_active, last_change) in self._worker_activity.items():
if not is_active:
idle_time = current_time - last_change
if idle_time > self.scale_down_idle_time:
idle_workers.append((worker_id, idle_time))
# Sort by idle time (longest first)
idle_workers.sort(key=lambda x: x[1], reverse=True)
return [worker_id for worker_id, _ in idle_workers]
def _get_active_count(self) -> int:
"""
Get count of currently active workers.
Returns:
Number of active workers
"""
return sum(1 for is_active, _ in self._worker_activity.values() if is_active)
# ============= Public Status Methods =============
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self.workers)
def get_status(self) -> dict[str, int]:
"""
Get pool status information.
Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self.workers),
"active_workers": self._get_active_count(),
"idle_workers": len(self.workers) - self._get_active_count(),
"queue_depth": self.ready_queue.qsize(),
"min_workers": self.min_workers,
"max_workers": self.max_workers,
}
# ============= Backward Compatibility =============
def scale_up(self) -> None:
"""Compatibility method for manual scale up."""
with self._lock:
if self._running and len(self.workers) < self.max_workers:
self._add_worker()
def scale_down(self, worker_ids: list[int]) -> None:
"""Compatibility method for manual scale down."""
with self._lock:
if not self._running:
return
for worker_id in worker_ids:
if len(self.workers) > self.min_workers:
self._remove_worker(worker_id)
def check_scaling(self, queue_depth: int, executing_count: int) -> None:
"""
Compatibility method for checking scaling.
Args:
queue_depth: Current queue depth (ignored, we check directly)
executing_count: Number of executing nodes (ignored)
"""
self.check_and_scale()