feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -1,4 +1,3 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["GraphEngine"]

View File

@@ -0,0 +1,33 @@
# Command Channels
Channel implementations for external workflow control.
## Components
### InMemoryChannel
Thread-safe in-memory queue for single-process deployments.
- `fetch_commands()` - Get pending commands
- `send_command()` - Add command to queue
### RedisChannel
Redis-based queue for distributed deployments.
- `fetch_commands()` - Get commands with JSON deserialization
- `send_command()` - Store commands with TTL
## Usage
```python
# Local execution
channel = InMemoryChannel()
channel.send_command(AbortCommand(graph_id="workflow-123"))
# Distributed execution
redis_channel = RedisChannel(
redis_client=redis_client,
channel_key="workflow:123:commands"
)
```

View File

@@ -0,0 +1,6 @@
"""Command channel implementations for GraphEngine."""
from .in_memory_channel import InMemoryChannel
from .redis_channel import RedisChannel
__all__ = ["InMemoryChannel", "RedisChannel"]

View File

@@ -0,0 +1,53 @@
"""
In-memory implementation of CommandChannel for local/testing scenarios.
This implementation uses a thread-safe queue for command communication
within a single process. Each instance handles commands for one workflow execution.
"""
from queue import Queue
from typing import final
from ..entities.commands import GraphEngineCommand
@final
class InMemoryChannel:
"""
In-memory command channel implementation using a thread-safe queue.
Each instance is dedicated to a single GraphEngine/workflow execution.
Suitable for local development, testing, and single-instance deployments.
"""
def __init__(self) -> None:
"""Initialize the in-memory channel with a single queue."""
self._queue: Queue[GraphEngineCommand] = Queue()
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from the queue.
Returns:
List of pending commands (drains the queue)
"""
commands: list[GraphEngineCommand] = []
# Drain all available commands from the queue
while not self._queue.empty():
try:
command = self._queue.get_nowait()
commands.append(command)
except Exception:
break
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to this channel's queue.
Args:
command: The command to send
"""
self._queue.put(command)

View File

@@ -0,0 +1,114 @@
"""
Redis-based implementation of CommandChannel for distributed scenarios.
This implementation uses Redis lists for command queuing, supporting
multi-instance deployments and cross-server communication.
Each instance uses a unique key for its command queue.
"""
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@final
class RedisChannel:
"""
Redis-based command channel implementation for distributed systems.
Each instance uses a unique Redis key for its command queue.
Commands are JSON-serialized for transport.
"""
def __init__(
self,
redis_client: "RedisClientWrapper",
channel_key: str,
command_ttl: int = 3600,
) -> None:
"""
Initialize the Redis channel.
Args:
redis_client: Redis client instance
channel_key: Unique key for this channel's command queue
command_ttl: TTL for command keys in seconds (default: 3600)
"""
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch all pending commands from Redis.
Returns:
List of pending commands (drains the Redis list)
"""
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
with self._redis.pipeline() as pipe:
# Get all commands and clear the list atomically
pipe.lrange(self._key, 0, -1)
pipe.delete(self._key)
results = pipe.execute()
# Parse commands from JSON
if results[0]:
for command_json in results[0]:
try:
command_data = json.loads(command_json)
command = self._deserialize_command(command_data)
if command:
commands.append(command)
except (json.JSONDecodeError, ValueError):
# Skip invalid commands
continue
return commands
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to Redis.
Args:
command: The command to send
"""
command_json = json.dumps(command.model_dump())
# Push to list and set expiry
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
"""
Deserialize a command from dictionary data.
Args:
data: Command data dictionary
Returns:
Deserialized command or None if invalid
"""
command_type_value = data.get("command_type")
if not isinstance(command_type_value, str):
return None
try:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
except (ValueError, TypeError):
return None

View File

@@ -0,0 +1,14 @@
"""
Command processing subsystem for graph engine.
This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
]

View File

@@ -0,0 +1,32 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")

View File

@@ -0,0 +1,79 @@
"""
Main command processor for handling external commands.
"""
import logging
from typing import Protocol, final
from ..domain.graph_execution import GraphExecution
from ..entities.commands import GraphEngineCommand
from ..protocols.command_channel import CommandChannel
logger = logging.getLogger(__name__)
class CommandHandler(Protocol):
"""Protocol for command handlers."""
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ...
@final
class CommandProcessor:
"""
Processes external commands sent to the engine.
This polls the command channel and dispatches commands to
appropriate handlers.
"""
def __init__(
self,
command_channel: CommandChannel,
graph_execution: GraphExecution,
) -> None:
"""
Initialize the command processor.
Args:
command_channel: Channel for receiving commands
graph_execution: Graph execution aggregate
"""
self._command_channel = command_channel
self._graph_execution = graph_execution
self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {}
def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None:
"""
Register a handler for a command type.
Args:
command_type: Type of command to handle
handler: Handler for the command
"""
self._handlers[command_type] = handler
def process_commands(self) -> None:
"""Check for and process any pending commands."""
try:
commands = self._command_channel.fetch_commands()
for command in commands:
self._handle_command(command)
except Exception as e:
logger.warning("Error processing commands: %s", e)
def _handle_command(self, command: GraphEngineCommand) -> None:
"""
Handle a single command.
Args:
command: The command to handle
"""
handler = self._handlers.get(type(command))
if handler:
try:
handler.handle(command, self._graph_execution)
except Exception:
logger.exception("Error handling command %s", command.__class__.__name__)
else:
logger.warning("No handler registered for command: %s", command.__class__.__name__)

View File

@@ -1,25 +0,0 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@@ -1,27 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState):
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
_, _, final_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions,
operator="and",
)
return final_result

View File

@@ -1,25 +0,0 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams, graph: Graph, run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition)
else:
return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition)

View File

@@ -0,0 +1,14 @@
"""
Domain models for graph engine.
This package contains the core domain entities, value objects, and aggregates
that represent the business concepts of workflow graph execution.
"""
from .graph_execution import GraphExecution
from .node_execution import NodeExecution
__all__ = [
"GraphExecution",
"NodeExecution",
]

View File

@@ -0,0 +1,207 @@
"""GraphExecution aggregate root managing the overall graph execution state."""
from __future__ import annotations
from dataclasses import dataclass, field
from importlib import import_module
from typing import Literal
from pydantic import BaseModel, Field
from core.workflow.enums import NodeState
from .node_execution import NodeExecution
class GraphExecutionErrorState(BaseModel):
"""Serializable representation of an execution error."""
module: str = Field(description="Module containing the exception class")
qualname: str = Field(description="Qualified name of the exception class")
message: str | None = Field(default=None, description="Exception message string")
class NodeExecutionState(BaseModel):
"""Serializable representation of a node execution entity."""
node_id: str
state: NodeState = Field(default=NodeState.UNKNOWN)
retry_count: int = Field(default=0)
execution_id: str | None = Field(default=None)
error: str | None = Field(default=None)
class GraphExecutionState(BaseModel):
"""Pydantic model describing serialized GraphExecution state."""
type: Literal["GraphExecution"] = Field(default="GraphExecution")
version: str = Field(default="1.0")
workflow_id: str
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
error: GraphExecutionErrorState | None = Field(default=None)
node_executions: list[NodeExecutionState] = Field(default_factory=list)
def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None:
"""Convert an exception into its serializable representation."""
if error is None:
return None
return GraphExecutionErrorState(
module=error.__class__.__module__,
qualname=error.__class__.__qualname__,
message=str(error),
)
def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]:
"""Locate an exception class from its module and qualified name."""
module = import_module(module_name)
attr: object = module
for part in qualname.split("."):
attr = getattr(attr, part)
if isinstance(attr, type) and issubclass(attr, Exception):
return attr
raise TypeError(f"{qualname} in {module_name} is not an Exception subclass")
def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None:
"""Reconstruct an exception instance from serialized data."""
if state is None:
return None
try:
exception_class = _resolve_exception_class(state.module, state.qualname)
if state.message is None:
return exception_class()
return exception_class(state.message)
except Exception:
# Fallback to RuntimeError when reconstruction fails
if state.message is None:
return RuntimeError(state.qualname)
return RuntimeError(state.message)
@dataclass
class GraphExecution:
"""
Aggregate root for graph execution.
This manages the overall execution state of a workflow graph,
coordinating between multiple node executions.
"""
workflow_id: str
started: bool = False
completed: bool = False
aborted: bool = False
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict)
def start(self) -> None:
"""Mark the graph execution as started."""
if self.started:
raise RuntimeError("Graph execution already started")
self.started = True
def complete(self) -> None:
"""Mark the graph execution as completed."""
if not self.started:
raise RuntimeError("Cannot complete execution that hasn't started")
if self.completed:
raise RuntimeError("Graph execution already completed")
self.completed = True
def abort(self, reason: str) -> None:
"""Abort the graph execution."""
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
self.completed = True
def get_or_create_node_execution(self, node_id: str) -> NodeExecution:
"""Get or create a node execution entity."""
if node_id not in self.node_executions:
self.node_executions[node_id] = NodeExecution(node_id=node_id)
return self.node_executions[node_id]
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
@property
def has_error(self) -> bool:
"""Check if the execution has encountered an error."""
return self.error is not None
@property
def error_message(self) -> str | None:
"""Get the error message if an error exists."""
if not self.error:
return None
return str(self.error)
def dumps(self) -> str:
"""Serialize the aggregate state into a JSON string."""
node_states = [
NodeExecutionState(
node_id=node_id,
state=node_execution.state,
retry_count=node_execution.retry_count,
execution_id=node_execution.execution_id,
error=node_execution.error,
)
for node_id, node_execution in sorted(self.node_executions.items())
]
state = GraphExecutionState(
workflow_id=self.workflow_id,
started=self.started,
completed=self.completed,
aborted=self.aborted,
error=_serialize_error(self.error),
node_executions=node_states,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore aggregate state from a serialized JSON string."""
state = GraphExecutionState.model_validate_json(data)
if state.type != "GraphExecution":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
if self.workflow_id != state.workflow_id:
raise ValueError("Serialized workflow_id does not match aggregate identity")
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.error = _deserialize_error(state.error)
self.node_executions = {
item.node_id: NodeExecution(
node_id=item.node_id,
state=item.state,
retry_count=item.retry_count,
execution_id=item.execution_id,
error=item.error,
)
for item in state.node_executions
}

View File

@@ -0,0 +1,45 @@
"""
NodeExecution entity representing a node's execution state.
"""
from dataclasses import dataclass
from core.workflow.enums import NodeState
@dataclass
class NodeExecution:
"""
Entity representing the execution state of a single node.
This is a mutable entity that tracks the runtime state of a node
during graph execution.
"""
node_id: str
state: NodeState = NodeState.UNKNOWN
retry_count: int = 0
execution_id: str | None = None
error: str | None = None
def mark_started(self, execution_id: str) -> None:
"""Mark the node as started with an execution ID."""
self.state = NodeState.TAKEN
self.execution_id = execution_id
def mark_taken(self) -> None:
"""Mark the node as successfully completed."""
self.state = NodeState.TAKEN
self.error = None
def mark_failed(self, error: str) -> None:
"""Mark the node as failed with an error."""
self.error = error
def mark_skipped(self) -> None:
"""Mark the node as skipped."""
self.state = NodeState.SKIPPED
def increment_retry(self) -> None:
"""Increment the retry count for this node."""
self.retry_count += 1

View File

@@ -1,6 +0,0 @@
from .graph import Graph
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .runtime_route_state import RuntimeRouteState
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@@ -0,0 +1,33 @@
"""
GraphEngine command entities for external control.
This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
"""Base class for all GraphEngine commands."""
command_type: CommandType = Field(..., description="Type of command")
payload: dict[str, Any] | None = Field(default=None, description="Optional command payload")
class AbortCommand(GraphEngineCommand):
"""Command to abort a running workflow execution."""
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")

View File

@@ -1,277 +0,0 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: dict[str, Any] | None = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: int = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, Any] | None = None
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: str | None = None
"""predecessor node id"""
parallel_mode_run_id: str | None = None
"""iteration node parallel mode run id"""
agent_strategy: AgentNodeStrategyInit | None = None
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] | None = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInLoopFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Any | None = None
duration: float | None = None
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
iteration_duration_map: dict[str, float] | None = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Loop Events
###########################################
class BaseLoopEvent(GraphEngineEvent):
loop_id: str = Field(..., description="loop node execution id")
loop_node_id: str = Field(..., description="loop node id")
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
loop_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""loop run in parallel mode run id"""
class LoopRunStartedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
predecessor_node_id: str | None = None
class LoopRunNextEvent(BaseLoopEvent):
index: int = Field(..., description="index")
pre_loop_output: Any | None = None
duration: float | None = None
class LoopRunSucceededEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
loop_duration_map: dict[str, float] | None = None
class LoopRunFailedEvent(BaseLoopEvent):
start_at: datetime = Field(..., description="start at")
inputs: Mapping[str, Any] | None = None
outputs: Mapping[str, Any] | None = None
metadata: Mapping[str, Any] | None = None
steps: int = 0
error: str = Field(..., description="failed reason")
###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
label: str = Field(..., description="label")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
metadata: Mapping[str, Any] | None = Field(default=None, description="metadata")
node_id: str = Field(..., description="agent node id")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent

View File

@@ -1,674 +0,0 @@
import uuid
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: RunCondition | None = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: str | None = None
"""parent parallel id"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id"""
end_to_node_id: str | None = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=dict, description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict, description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict, description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes")
end_stream_param: EndStreamParam = Field(..., description="end stream param")
@classmethod
def init(cls, graph_config: Mapping[str, Any], root_node_id: str | None = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get("target")
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get("id") for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next(
(
node_config.get("id")
for node_config in root_node_configs
if node_config.get("data", {}).get("type", "") == NodeType.START.value
),
None,
)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping,
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param,
)
return graph
@classmethod
def _recursively_add_node_ids(cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str):
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]):
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
)
new_route = route.copy()
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: GraphParallel | None = None,
):
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = defaultdict(list)
condition_edge_mappings = defaultdict(list)
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
condition_edge_mappings[condition_hash].append(graph_edge)
for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping
parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=condition_parallel_branch_node_ids,
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(
node_parallel_mapping.get(target_node_id)
and node_parallel_mapping.get(target_node_id) == parent_parallel_id
)
or (
parent_parallel
and parent_parallel.end_to_node_id
and target_node_id == parent_parallel.end_to_node_id
)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if (
parent_parallel
and parent_parallel.end_to_node_id
and parallel.end_to_node_id == parent_parallel.end_to_node_id
):
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
for graph_edge in graph_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: GraphParallel | None = None,
parent_parallel: GraphParallel | None = None,
) -> GraphParallel | None:
"""
Get current parallel
"""
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (
parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id
):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if parent_parallel_parent_parallel and (
not parent_parallel_parent_parallel.end_to_node_id
or (
parent_parallel_parent_parallel.end_to_node_id
and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id
)
):
current_parallel = parent_parallel_parent_parallel
return current_parallel
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1,
):
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level,
)
@classmethod
def _recursively_add_parallel_node_ids(
cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str,
):
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids:
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id,
)
@classmethod
def _fetch_all_node_ids_in_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str],
) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id],
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (
node_id2,
node_id,
) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id,
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(
cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str]
):
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(
cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]]
) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id in routes_node_ids:
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
expected_branch_set = set(routes_node_ids.keys())
for _, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == expected_branch_set:
return True
return False
@classmethod
def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping
):
return True
return False

View File

@@ -1,21 +0,0 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from models.enums import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@@ -1,31 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = Field(default_factory=dict)
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@@ -1,21 +0,0 @@
import hashlib
from typing import Literal
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@@ -1,117 +0,0 @@
import uuid
from datetime import datetime
from enum import StrEnum, auto
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from libs.datetime_utils import naive_utc_now
class RouteNodeState(BaseModel):
class Status(StrEnum):
RUNNING = auto()
SUCCESS = auto()
FAILED = auto()
PAUSED = auto()
EXCEPTION = auto()
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: NodeRunResult | None = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: datetime | None = None
"""paused time"""
finished_at: datetime | None = None
"""finished time"""
failed_reason: str | None = None
"""failed reason"""
paused_by: str | None = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult):
"""
Node finished
:param run_result: run result
"""
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = naive_utc_now()
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str):
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [
self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])
]

View File

@@ -0,0 +1,211 @@
"""
Main error handler that coordinates error strategies.
"""
import logging
import time
from typing import TYPE_CHECKING, final
from core.workflow.enums import (
ErrorStrategy as ErrorStrategyEnum,
)
from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetryEvent,
)
from core.workflow.node_events import NodeRunResult
if TYPE_CHECKING:
from .domain import GraphExecution
logger = logging.getLogger(__name__)
@final
class ErrorHandler:
"""
Coordinates error handling strategies for node failures.
This acts as a facade for the various error strategies,
selecting and applying the appropriate strategy based on
node configuration.
"""
def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None:
"""
Initialize the error handler.
Args:
graph: The workflow graph
graph_execution: The graph execution state
"""
self._graph = graph
self._graph_execution = graph_execution
def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None:
"""
Handle a node failure event.
Selects and applies the appropriate error strategy based on
the node's configuration.
Args:
event: The node failure event
Returns:
Optional new event to process, or None to abort
"""
node = self._graph.nodes[event.node_id]
# Get retry count from NodeExecution
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
retry_count = node_execution.retry_count
# First check if retry is configured and not exhausted
if node.retry and retry_count < node.retry_config.max_retries:
result = self._handle_retry(event, retry_count)
if result:
# Retry count will be incremented when NodeRunRetryEvent is handled
return result
# Apply configured error strategy
strategy = node.error_strategy
match strategy:
case None:
return self._handle_abort(event)
case ErrorStrategyEnum.FAIL_BRANCH:
return self._handle_fail_branch(event)
case ErrorStrategyEnum.DEFAULT_VALUE:
return self._handle_default_value(event)
def _handle_abort(self, event: NodeRunFailedEvent):
"""
Handle error by aborting execution.
This is the default strategy when no other strategy is specified.
It stops the entire graph execution when a node fails.
Args:
event: The failure event
Returns:
None - signals abortion
"""
logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error)
# Return None to signal that execution should stop
def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int):
"""
Handle error by retrying the node.
This strategy re-attempts node execution up to a configured
maximum number of retries with configurable intervals.
Args:
event: The failure event
retry_count: Current retry attempt count
Returns:
NodeRunRetryEvent if retry should occur, None otherwise
"""
node = self._graph.nodes[event.node_id]
# Check if we've exceeded max retries
if not node.retry or retry_count >= node.retry_config.max_retries:
return None
# Wait for retry interval
time.sleep(node.retry_config.retry_interval_seconds)
# Create retry event
return NodeRunRetryEvent(
id=event.id,
node_title=node.title,
node_id=event.node_id,
node_type=event.node_type,
node_run_result=event.node_run_result,
start_at=event.start_at,
error=event.error,
retry_index=retry_count + 1,
)
def _handle_fail_branch(self, event: NodeRunFailedEvent):
"""
Handle error by taking the fail branch.
This strategy converts failures to exceptions and routes execution
through a designated fail-branch edge.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent to continue via fail branch
"""
outputs = {
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
edge_source_handle="fail-branch",
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH,
},
),
error=event.error,
)
def _handle_default_value(self, event: NodeRunFailedEvent):
"""
Handle error by using default values.
This strategy allows nodes to fail gracefully by providing
predefined default output values.
Args:
event: The failure event
Returns:
NodeRunExceptionEvent with default values
"""
node = self._graph.nodes[event.node_id]
outputs = {
**node.default_value_dict,
"error_message": event.node_run_result.error,
"error_type": event.node_run_result.error_type,
}
return NodeRunExceptionEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
start_at=event.start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.EXCEPTION,
inputs=event.node_run_result.inputs,
process_data=event.node_run_result.process_data,
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE,
},
),
error=event.error,
)

View File

@@ -0,0 +1,14 @@
"""
Event management subsystem for graph engine.
This package handles event routing, collection, and emission for
workflow graph execution events.
"""
from .event_handlers import EventHandler
from .event_manager import EventManager
__all__ = [
"EventHandler",
"EventManager",
]

View File

@@ -0,0 +1,267 @@
"""
Event handler implementations for different event types.
"""
import logging
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from ..error_handler import ErrorHandler
from ..graph_state_manager import GraphStateManager
from ..graph_traversal import EdgeProcessor
from .event_manager import EventManager
logger = logging.getLogger(__name__)
@final
class EventHandler:
"""
Registry of event handlers for different event types.
This centralizes the business logic for handling specific events,
keeping it separate from the routing and collection infrastructure.
"""
def __init__(
self,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
graph_execution: GraphExecution,
response_coordinator: ResponseStreamCoordinator,
event_collector: "EventManager",
edge_processor: "EdgeProcessor",
state_manager: "GraphStateManager",
error_handler: "ErrorHandler",
) -> None:
"""
Initialize the event handler registry.
Args:
graph: The workflow graph
graph_runtime_state: Runtime state with variable pool
graph_execution: Graph execution aggregate
response_coordinator: Response stream coordinator
event_collector: Event manager for collecting events
edge_processor: Edge processor for edge traversal
state_manager: Unified state manager
error_handler: Error handler
"""
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._event_collector = event_collector
self._edge_processor = edge_processor
self._state_manager = state_manager
self._error_handler = error_handler
def dispatch(self, event: GraphNodeEventBase) -> None:
"""
Handle any node event by dispatching to the appropriate handler.
Args:
event: The event to handle
"""
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
return self._dispatch(event)
@singledispatchmethod
def _dispatch(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
logger.warning("Unhandled event type: %s", type(event).__name__)
@_dispatch.register(NodeRunIterationStartedEvent)
@_dispatch.register(NodeRunIterationNextEvent)
@_dispatch.register(NodeRunIterationSucceededEvent)
@_dispatch.register(NodeRunIterationFailedEvent)
@_dispatch.register(NodeRunLoopStartedEvent)
@_dispatch.register(NodeRunLoopNextEvent)
@_dispatch.register(NodeRunLoopSucceededEvent)
@_dispatch.register(NodeRunLoopFailedEvent)
@_dispatch.register(NodeRunAgentLogEvent)
def _(self, event: GraphNodeEventBase) -> None:
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStartedEvent) -> None:
"""
Handle node started event.
Args:
event: The node started event
"""
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_started(event.id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunStreamChunkEvent) -> None:
"""
Handle stream chunk event with full processing.
Args:
event: The stream chunk event
"""
# Process with response coordinator
streaming_events = list(self._response_coordinator.intercept_event(event))
# Collect all events
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
@_dispatch.register
def _(self, event: NodeRunSucceededEvent) -> None:
"""
Handle node success by coordinating subsystems.
This method coordinates between different subsystems to process
node completion, handle edges, and trigger downstream execution.
Args:
event: The node succeeded event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
# Store outputs in variable pool
self._store_node_outputs(event)
# Forward to response coordinator and emit streaming events
streaming_events = self._response_coordinator.intercept_event(event)
for stream_event in streaming_events:
self._event_collector.collect(stream_event)
# Process edges and get ready nodes
node = self._graph.nodes[event.node_id]
if node.execution_type == NodeExecutionType.BRANCH:
ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion(
event.node_id, event.node_run_result.edge_source_handle
)
else:
ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id)
# Collect streaming events from edge processing
for edge_event in edge_streaming_events:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)
# Handle response node outputs
if node.execution_type == NodeExecutionType.RESPONSE:
self._update_response_outputs(event)
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""
Handle node failure using error handler.
Args:
event: The node failed event
"""
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
result = self._error_handler.handle_node_failure(event)
if result:
# Process the resulting event (retry, exception, etc.)
self.dispatch(result)
else:
# Abort execution
self._graph_execution.fail(RuntimeError(event.error))
self._event_collector.collect(event)
self._state_manager.finish_execution(event.node_id)
@_dispatch.register
def _(self, event: NodeRunExceptionEvent) -> None:
"""
Handle node exception event (fail-branch strategy).
Args:
event: The node exception event
"""
# Node continues via fail-branch, so it's technically "succeeded"
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
@_dispatch.register
def _(self, event: NodeRunRetryEvent) -> None:
"""
Handle node retry event.
Args:
event: The node retry event
"""
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.increment_retry()
def _store_node_outputs(self, event: NodeRunSucceededEvent) -> None:
"""
Store node outputs in the variable pool.
Args:
event: The node succeeded event containing outputs
"""
for variable_name, variable_value in event.node_run_result.outputs.items():
self._graph_runtime_state.variable_pool.add((event.node_id, variable_name), variable_value)
def _update_response_outputs(self, event: NodeRunSucceededEvent) -> None:
"""Update response outputs for response nodes."""
# TODO: Design a mechanism for nodes to notify the engine about how to update outputs
# in runtime state, rather than allowing nodes to directly access runtime state.
for key, value in event.node_run_result.outputs.items():
if key == "answer":
existing = self._graph_runtime_state.get_output("answer", "")
if existing:
self._graph_runtime_state.set_output("answer", f"{existing}{value}")
else:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)

View File

@@ -0,0 +1,193 @@
"""
Unified event manager for collecting and emitting events.
"""
import threading
import time
from collections.abc import Generator
from typing import final
from core.workflow.graph_events import GraphEngineEvent
from ..layers.base import GraphEngineLayer
@final
class ReadWriteLock:
"""
A read-write lock implementation that allows multiple concurrent readers
but only one writer at a time.
"""
def __init__(self) -> None:
self._read_ready = threading.Condition(threading.RLock())
self._readers = 0
def acquire_read(self) -> None:
"""Acquire a read lock."""
_ = self._read_ready.acquire()
try:
self._readers += 1
finally:
self._read_ready.release()
def release_read(self) -> None:
"""Release a read lock."""
_ = self._read_ready.acquire()
try:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
finally:
self._read_ready.release()
def acquire_write(self) -> None:
"""Acquire a write lock."""
_ = self._read_ready.acquire()
while self._readers > 0:
_ = self._read_ready.wait()
def release_write(self) -> None:
"""Release a write lock."""
self._read_ready.release()
def read_lock(self) -> "ReadLockContext":
"""Return a context manager for read locking."""
return ReadLockContext(self)
def write_lock(self) -> "WriteLockContext":
"""Return a context manager for write locking."""
return WriteLockContext(self)
@final
class ReadLockContext:
"""Context manager for read locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "ReadLockContext":
self._lock.acquire_read()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_read()
@final
class WriteLockContext:
"""Context manager for write locks."""
def __init__(self, lock: ReadWriteLock) -> None:
self._lock = lock
def __enter__(self) -> "WriteLockContext":
self._lock.acquire_write()
return self
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: object) -> None:
self._lock.release_write()
@final
class EventManager:
"""
Unified event manager that collects, buffers, and emits events.
This class combines event collection with event emission, providing
thread-safe event management with support for notifying layers and
streaming events to external consumers.
"""
def __init__(self) -> None:
"""Initialize the event manager."""
self._events: list[GraphEngineEvent] = []
self._lock = ReadWriteLock()
self._layers: list[GraphEngineLayer] = []
self._execution_complete = threading.Event()
def set_layers(self, layers: list[GraphEngineLayer]) -> None:
"""
Set the layers to notify on event collection.
Args:
layers: List of layers to notify
"""
self._layers = layers
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.
Args:
event: The event to collect
"""
with self._lock.write_lock():
self._events.append(event)
self._notify_layers(event)
def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]:
"""
Get new events starting from a specific index.
Args:
start_index: The index to start from
Returns:
List of new events
"""
with self._lock.read_lock():
return list(self._events[start_index:])
def _event_count(self) -> int:
"""
Get the current count of collected events.
Returns:
Number of collected events
"""
with self._lock.read_lock():
return len(self._events)
def mark_complete(self) -> None:
"""Mark execution as complete to stop the event emission generator."""
self._execution_complete.set()
def emit_events(self) -> Generator[GraphEngineEvent, None, None]:
"""
Generator that yields events as they're collected.
Yields:
GraphEngineEvent instances as they're processed
"""
yielded_count = 0
while not self._execution_complete.is_set() or yielded_count < self._event_count():
# Get new events since last yield
new_events = self._get_new_events(yielded_count)
# Yield any new events
for event in new_events:
yield event
yielded_count += 1
# Small sleep to avoid busy waiting
if not self._execution_complete.is_set() and not new_events:
time.sleep(0.001)
def _notify_layers(self, event: GraphEngineEvent) -> None:
"""
Notify all layers of an event.
Layer exceptions are caught and logged to prevent disrupting collection.
Args:
event: The event to send to layers
"""
for layer in self._layers:
try:
layer.on_event(event)
except Exception:
# Silently ignore layer errors during collection
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,288 @@
"""
Graph state manager that combines node, edge, and execution tracking.
"""
import threading
from collections.abc import Sequence
from typing import TypedDict, final
from core.workflow.enums import NodeState
from core.workflow.graph import Edge, Graph
from .ready_queue import ReadyQueue
class EdgeStateAnalysis(TypedDict):
"""Analysis result for edge states."""
has_unknown: bool
has_taken: bool
all_skipped: bool
@final
class GraphStateManager:
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
"""
Initialize the state manager.
Args:
graph: The workflow graph
ready_queue: Queue for nodes ready to execute
"""
self._graph = graph
self._ready_queue = ready_queue
self._lock = threading.RLock()
# Execution tracking state
self._executing_nodes: set[str] = set()
# ============= Node State Operations =============
def enqueue_node(self, node_id: str) -> None:
"""
Mark a node as TAKEN and add it to the ready queue.
This combines the state transition and enqueueing operations
that always occur together when preparing a node for execution.
Args:
node_id: The ID of the node to enqueue
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.TAKEN
self._ready_queue.put(node_id)
def mark_node_skipped(self, node_id: str) -> None:
"""
Mark a node as SKIPPED.
Args:
node_id: The ID of the node to skip
"""
with self._lock:
self._graph.nodes[node_id].state = NodeState.SKIPPED
def is_node_ready(self, node_id: str) -> bool:
"""
Check if a node is ready to be executed.
A node is ready when all its incoming edges from taken branches
have been satisfied.
Args:
node_id: The ID of the node to check
Returns:
True if the node is ready for execution
"""
with self._lock:
# Get all incoming edges to this node
incoming_edges = self._graph.get_incoming_edges(node_id)
# If no incoming edges, node is always ready
if not incoming_edges:
return True
# If any edge is UNKNOWN, node is not ready
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
return False
# Node is ready if at least one edge is TAKEN
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
def get_node_state(self, node_id: str) -> NodeState:
"""
Get the current state of a node.
Args:
node_id: The ID of the node
Returns:
The current node state
"""
with self._lock:
return self._graph.nodes[node_id].state
# ============= Edge State Operations =============
def mark_edge_taken(self, edge_id: str) -> None:
"""
Mark an edge as TAKEN.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.TAKEN
def mark_edge_skipped(self, edge_id: str) -> None:
"""
Mark an edge as SKIPPED.
Args:
edge_id: The ID of the edge to mark
"""
with self._lock:
self._graph.edges[edge_id].state = NodeState.SKIPPED
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
"""
Analyze the states of edges and return summary flags.
Args:
edges: List of edges to analyze
Returns:
Analysis result with state flags
"""
with self._lock:
states = {edge.state for edge in edges}
return EdgeStateAnalysis(
has_unknown=NodeState.UNKNOWN in states,
has_taken=NodeState.TAKEN in states,
all_skipped=states == {NodeState.SKIPPED} if states else True,
)
def get_edge_state(self, edge_id: str) -> NodeState:
"""
Get the current state of an edge.
Args:
edge_id: The ID of the edge
Returns:
The current edge state
"""
with self._lock:
return self._graph.edges[edge_id].state
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
"""
Categorize branch edges into selected and unselected.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
A tuple of (selected_edges, unselected_edges)
"""
with self._lock:
outgoing_edges = self._graph.get_outgoing_edges(node_id)
selected_edges: list[Edge] = []
unselected_edges: list[Edge] = []
for edge in outgoing_edges:
if edge.source_handle == selected_handle:
selected_edges.append(edge)
else:
unselected_edges.append(edge)
return selected_edges, unselected_edges
# ============= Execution Tracking Operations =============
def start_execution(self, node_id: str) -> None:
"""
Mark a node as executing.
Args:
node_id: The ID of the node starting execution
"""
with self._lock:
self._executing_nodes.add(node_id)
def finish_execution(self, node_id: str) -> None:
"""
Mark a node as no longer executing.
Args:
node_id: The ID of the node finishing execution
"""
with self._lock:
self._executing_nodes.discard(node_id)
def is_executing(self, node_id: str) -> bool:
"""
Check if a node is currently executing.
Args:
node_id: The ID of the node to check
Returns:
True if the node is executing
"""
with self._lock:
return node_id in self._executing_nodes
def get_executing_count(self) -> int:
"""
Get the count of currently executing nodes.
Returns:
Number of executing nodes
"""
with self._lock:
return len(self._executing_nodes)
def get_executing_nodes(self) -> set[str]:
"""
Get a copy of the set of executing node IDs.
Returns:
Set of node IDs currently executing
"""
with self._lock:
return self._executing_nodes.copy()
def clear_executing(self) -> None:
"""Clear all executing nodes."""
with self._lock:
self._executing_nodes.clear()
# ============= Composite Operations =============
def is_execution_complete(self) -> bool:
"""
Check if graph execution is complete.
Execution is complete when:
- Ready queue is empty
- No nodes are executing
Returns:
True if execution is complete
"""
with self._lock:
return self._ready_queue.empty() and len(self._executing_nodes) == 0
def get_queue_depth(self) -> int:
"""
Get the current depth of the ready queue.
Returns:
Number of nodes in the ready queue
"""
return self._ready_queue.qsize()
def get_execution_stats(self) -> dict[str, int]:
"""
Get execution statistics.
Returns:
Dictionary with execution statistics
"""
with self._lock:
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
return {
"queue_depth": self._ready_queue.qsize(),
"executing": len(self._executing_nodes),
"taken_nodes": taken_nodes,
"skipped_nodes": skipped_nodes,
"unknown_nodes": unknown_nodes,
}

View File

@@ -0,0 +1,14 @@
"""
Graph traversal subsystem for graph engine.
This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
__all__ = [
"EdgeProcessor",
"SkipPropagator",
]

View File

@@ -0,0 +1,201 @@
"""
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Edge, Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent
from ..graph_state_manager import GraphStateManager
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from .skip_propagator import SkipPropagator
@final
class EdgeProcessor:
"""
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, triggering downstream node execution,
and managing branch node logic.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
response_coordinator: ResponseStreamCoordinator,
skip_propagator: "SkipPropagator",
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
state_manager: Unified state manager
response_coordinator: Response stream coordinator
skip_propagator: Propagator for skip states
"""
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
self._skip_propagator = skip_propagator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
Args:
node_id: The ID of the succeeded node
selected_handle: For branch nodes, the selected edge handle
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
Args:
node_id: The ID of the succeeded node
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no edge was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
self._process_skipped_edge(edge)
# Process selected edges
for edge in selected_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
Args:
edge: The edge to process
Returns:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""
Mark edge as skipped.
Args:
edge: The edge to skip
"""
self._state_manager.mark_edge_skipped(edge.id)
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@@ -0,0 +1,95 @@
"""
Skip state propagation through the graph.
"""
from collections.abc import Sequence
from typing import final
from core.workflow.graph import Edge, Graph
from ..graph_state_manager import GraphStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
state_manager: Unified state manager
"""
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)

View File

@@ -0,0 +1,52 @@
# Layers
Pluggable middleware for engine extensions.
## Components
### Layer (base)
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
### DebugLoggingLayer
Comprehensive execution logging.
- Configurable detail levels
- Tracks execution statistics
- Truncates long values
## Usage
```python
debug_layer = DebugLoggingLayer(
level="INFO",
include_outputs=True
)
engine = GraphEngine(graph)
engine.add_layer(debug_layer)
engine.run()
```
## Custom Layers
```python
class MetricsLayer(Layer):
def on_event(self, event):
if isinstance(event, NodeRunSucceededEvent):
self.metrics[event.node_id] = event.elapsed_time
```
## Configuration
**DebugLoggingLayer Options:**
- `level` - Log level (INFO, DEBUG, ERROR)
- `include_inputs/outputs` - Log data values
- `max_value_length` - Truncate long values

View File

@@ -0,0 +1,16 @@
"""
Layer system for GraphEngine extensibility.
This module provides the layer infrastructure for extending GraphEngine functionality
with middleware-like components that can observe events and interact with execution.
"""
from .base import GraphEngineLayer
from .debug_logging import DebugLoggingLayer
from .execution_limits import ExecutionLimitsLayer
__all__ = [
"DebugLoggingLayer",
"ExecutionLimitsLayer",
"GraphEngineLayer",
]

View File

@@ -0,0 +1,85 @@
"""
Base layer class for GraphEngine extensions.
This module provides the abstract base class for implementing layers that can
intercept and respond to GraphEngine events.
"""
from abc import ABC, abstractmethod
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
Layers are middleware-like components that can:
- Observe all events emitted by the GraphEngine
- Access the graph runtime state
- Send commands to control execution
Subclasses should override the constructor to accept configuration parameters,
then implement the three lifecycle methods.
"""
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
This is called after the engine has been initialized but before any nodes
are executed. Layers can use this to set up resources or log start information.
"""
pass
@abstractmethod
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
This method receives all events generated during graph execution, including:
- Graph lifecycle events (start, success, failure)
- Node execution events (start, success, failure, retry)
- Stream events for response nodes
- Container events (iteration, loop)
Args:
event: The event emitted by the engine
"""
pass
@abstractmethod
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.
This is called after all nodes have been executed or when execution is
aborted. Layers can use this to clean up resources or log final state.
Args:
error: The exception that caused execution to fail, or None if successful
"""
pass

View File

@@ -0,0 +1,241 @@
"""
Debug logging layer for GraphEngine.
This module provides a layer that logs all events and state changes during
graph execution for debugging purposes.
"""
import logging
from collections.abc import Mapping
from typing import Any, final
from typing_extensions import override
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .base import GraphEngineLayer
@final
class DebugLoggingLayer(GraphEngineLayer):
"""
A layer that provides comprehensive logging of GraphEngine execution.
This layer logs all events with configurable detail levels, helping developers
debug workflow execution and understand the flow of events.
"""
def __init__(
self,
level: str = "INFO",
include_inputs: bool = False,
include_outputs: bool = True,
include_process_data: bool = False,
logger_name: str = "GraphEngine.Debug",
max_value_length: int = 500,
) -> None:
"""
Initialize the debug logging layer.
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
include_inputs: Whether to log node input values
include_outputs: Whether to log node output values
include_process_data: Whether to log node process data
logger_name: Name of the logger to use
max_value_length: Maximum length of logged values (truncated if longer)
"""
super().__init__()
self.level = level
self.include_inputs = include_inputs
self.include_outputs = include_outputs
self.include_process_data = include_process_data
self.max_value_length = max_value_length
# Set up logger
self.logger = logging.getLogger(logger_name)
log_level = getattr(logging, level.upper(), logging.INFO)
self.logger.setLevel(log_level)
# Track execution stats
self.node_count = 0
self.success_count = 0
self.failure_count = 0
self.retry_count = 0
def _truncate_value(self, value: Any) -> str:
"""Truncate long values for logging."""
str_value = str(value)
if len(str_value) > self.max_value_length:
return str_value[: self.max_value_length] + "... (truncated)"
return str_value
def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str:
"""Format a dictionary or mapping for logging with truncation."""
if not data:
return "{}"
formatted_items: list[str] = []
for key, value in data.items():
formatted_value = self._truncate_value(value)
formatted_items.append(f" {key}: {formatted_value}")
return "{\n" + ",\n".join(formatted_items) + "\n}"
@override
def on_graph_start(self) -> None:
"""Log graph execution start."""
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""
event_class = event.__class__.__name__
# Graph-level events
if isinstance(event, GraphRunStartedEvent):
self.logger.debug("Graph run started event")
elif isinstance(event, GraphRunSucceededEvent):
self.logger.info("✅ Graph run succeeded")
if self.include_outputs and event.outputs:
self.logger.info(" Final outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self.logger.error("❌ Graph run failed: %s", event.error)
if event.exceptions_count > 0:
self.logger.error(" Total exceptions: %s", event.exceptions_count)
elif isinstance(event, GraphRunAbortedEvent):
self.logger.warning("⚠️ Graph run aborted: %s", event.reason)
if event.outputs:
self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs))
# Node-level events
elif isinstance(event, NodeRunStartedEvent):
self.node_count += 1
self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type)
if self.include_inputs and event.node_run_result.inputs:
self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs))
elif isinstance(event, NodeRunSucceededEvent):
self.success_count += 1
self.logger.info("✅ Node succeeded: %s", event.node_id)
if self.include_outputs and event.node_run_result.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs))
if self.include_process_data and event.node_run_result.process_data:
self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data))
elif isinstance(event, NodeRunFailedEvent):
self.failure_count += 1
self.logger.error("❌ Node failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
if event.node_run_result.error:
self.logger.error(" Details: %s", event.node_run_result.error)
elif isinstance(event, NodeRunExceptionEvent):
self.logger.warning("⚠️ Node exception handled: %s", event.node_id)
self.logger.warning(" Error: %s", event.error)
elif isinstance(event, NodeRunRetryEvent):
self.retry_count += 1
self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index)
self.logger.warning(" Previous error: %s", event.error)
elif isinstance(event, NodeRunStreamChunkEvent):
# Log stream chunks at debug level to avoid spam
final_indicator = " (FINAL)" if event.is_final else ""
self.logger.debug(
"📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk)
)
# Iteration events
elif isinstance(event, NodeRunIterationStartedEvent):
self.logger.info("🔁 Iteration started: %s", event.node_id)
elif isinstance(event, NodeRunIterationNextEvent):
self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunIterationSucceededEvent):
self.logger.info("✅ Iteration succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunIterationFailedEvent):
self.logger.error("❌ Iteration failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
# Loop events
elif isinstance(event, NodeRunLoopStartedEvent):
self.logger.info("🔄 Loop started: %s", event.node_id)
elif isinstance(event, NodeRunLoopNextEvent):
self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index)
elif isinstance(event, NodeRunLoopSucceededEvent):
self.logger.info("✅ Loop succeeded: %s", event.node_id)
if self.include_outputs and event.outputs:
self.logger.debug(" Outputs: %s", self._format_dict(event.outputs))
elif isinstance(event, NodeRunLoopFailedEvent):
self.logger.error("❌ Loop failed: %s", event.node_id)
self.logger.error(" Error: %s", event.error)
else:
# Log unknown events at debug level
self.logger.debug("Event: %s", event_class)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Log graph execution end with summary statistics."""
self.logger.info("=" * 80)
if error:
self.logger.error("🔴 GRAPH EXECUTION FAILED")
self.logger.error(" Error: %s", error)
else:
self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY")
# Log execution statistics
self.logger.info("Execution Statistics:")
self.logger.info(" Total nodes executed: %s", self.node_count)
self.logger.info(" Successful nodes: %s", self.success_count)
self.logger.info(" Failed nodes: %s", self.failure_count)
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@@ -0,0 +1,150 @@
"""
Execution limits layer for GraphEngine.
This layer monitors workflow execution to enforce limits on:
- Maximum execution steps
- Maximum execution time
When limits are exceeded, the layer automatically aborts execution.
"""
import logging
import time
from enum import Enum
from typing import final
from typing_extensions import override
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
NodeRunStartedEvent,
)
from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
class LimitType(Enum):
"""Types of execution limits that can be exceeded."""
STEP_LIMIT = "step_limit"
TIME_LIMIT = "time_limit"
@final
class ExecutionLimitsLayer(GraphEngineLayer):
"""
Layer that enforces execution limits for workflows.
Monitors:
- Step count: Tracks number of node executions
- Time limit: Monitors total execution time
Automatically aborts execution when limits are exceeded.
"""
def __init__(self, max_steps: int, max_time: int) -> None:
"""
Initialize the execution limits layer.
Args:
max_steps: Maximum number of execution steps allowed
max_time: Maximum execution time in seconds allowed
"""
super().__init__()
self.max_steps = max_steps
self.max_time = max_time
# Runtime tracking
self.start_time: float | None = None
self.step_count = 0
self.logger = logging.getLogger(__name__)
# State tracking
self._execution_started = False
self._execution_ended = False
self._abort_sent = False # Track if abort command has been sent
@override
def on_graph_start(self) -> None:
"""Called when graph execution starts."""
self.start_time = time.time()
self.step_count = 0
self._execution_started = True
self._execution_ended = False
self._abort_sent = False
self.logger.debug("Execution limits monitoring started")
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
Monitors execution progress and enforces limits.
"""
if not self._execution_started or self._execution_ended or self._abort_sent:
return
# Track step count for node execution events
if isinstance(event, NodeRunStartedEvent):
self.step_count += 1
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
# Check step limit when node execution completes
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
if self._reached_step_limitation():
self._send_abort_command(LimitType.STEP_LIMIT)
if self._reached_time_limitation():
self._send_abort_command(LimitType.TIME_LIMIT)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""Called when graph execution ends."""
if self._execution_started and not self._execution_ended:
self._execution_ended = True
if self.start_time:
total_time = time.time() - self.start_time
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
def _reached_step_limitation(self) -> bool:
"""Check if step count limit has been exceeded."""
return self.step_count > self.max_steps
def _reached_time_limitation(self) -> bool:
"""Check if time limit has been exceeded."""
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
def _send_abort_command(self, limit_type: LimitType) -> None:
"""
Send abort command due to limit violation.
Args:
limit_type: Type of limit exceeded
"""
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
return
# Format detailed reason message
if limit_type == LimitType.STEP_LIMIT:
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
self.logger.warning("Execution limit exceeded: %s", reason)
try:
# Send abort command to the engine
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
self.command_channel.send_command(abort_command)
# Mark that abort has been sent to prevent duplicate commands
self._abort_sent = True
self.logger.debug("Abort command sent to engine")
except Exception:
self.logger.exception("Failed to send abort command: %s")

View File

@@ -0,0 +1,50 @@
"""
GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from extensions.ext_redis import redis_client
@final
class GraphEngineManager:
"""
Manager for sending control commands to GraphEngine instances.
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
"""
@staticmethod
def send_stop_command(task_id: str, reason: str | None = None) -> None:
"""
Send a stop command to a running workflow.
Args:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
pass

View File

@@ -0,0 +1,14 @@
"""
Orchestration subsystem for graph engine.
This package coordinates the overall execution flow between
different subsystems.
"""
from .dispatcher import Dispatcher
from .execution_coordinator import ExecutionCoordinator
__all__ = [
"Dispatcher",
"ExecutionCoordinator",
]

View File

@@ -0,0 +1,104 @@
"""
Main dispatcher for processing events from workers.
"""
import logging
import queue
import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
if TYPE_CHECKING:
from ..event_management import EventHandler
logger = logging.getLogger(__name__)
@final
class Dispatcher:
"""
Main dispatcher that processes events from the event queue.
This runs in a separate thread and coordinates event processing
with timeout and completion detection.
"""
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
event_collector: EventManager,
execution_coordinator: ExecutionCoordinator,
event_emitter: EventManager | None = None,
) -> None:
"""
Initialize the dispatcher.
Args:
event_queue: Queue of events from workers
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting unhandled events
execution_coordinator: Coordinator for execution flow
event_emitter: Optional event manager to signal completion
"""
self._event_queue = event_queue
self._event_handler = event_handler
self._event_collector = event_collector
self._execution_coordinator = execution_coordinator
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
"""Start the dispatcher thread."""
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
# Check for scaling
self._execution_coordinator.check_scaling()
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break
except Exception as e:
logger.exception("Dispatcher error")
self._execution_coordinator.mark_failed(e)
finally:
self._execution_coordinator.mark_complete()
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()

View File

@@ -0,0 +1,87 @@
"""
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING, final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@final
class ExecutionCoordinator:
"""
Coordinates overall execution flow between subsystems.
This provides high-level coordination methods used by the
dispatcher to manage execution state.
"""
def __init__(
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
"""
Initialize the execution coordinator.
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
def check_commands(self) -> None:
"""Process any pending commands."""
self._command_processor.process_commands()
def check_scaling(self) -> None:
"""Check and perform worker scaling if needed."""
self._worker_pool.check_and_scale()
def is_execution_complete(self) -> bool:
"""
Check if execution is complete.
Returns:
True if execution is complete
"""
# Check if aborted or failed
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self._state_manager.is_execution_complete()
def mark_complete(self) -> None:
"""Mark execution as complete."""
if not self._graph_execution.completed:
self._graph_execution.complete()
def mark_failed(self, error: Exception) -> None:
"""
Mark execution as failed.
Args:
error: The error that caused failure
"""
self._graph_execution.fail(error)

View File

@@ -0,0 +1,41 @@
"""
CommandChannel protocol for GraphEngine command communication.
This protocol defines the interface for sending and receiving commands
to/from a GraphEngine instance, supporting both local and distributed scenarios.
"""
from typing import Protocol
from ..entities.commands import GraphEngineCommand
class CommandChannel(Protocol):
"""
Protocol for bidirectional command communication with GraphEngine.
Since each GraphEngine instance processes only one workflow execution,
this channel is dedicated to that single execution.
"""
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
Fetch pending commands for this GraphEngine instance.
Called by GraphEngine to poll for commands that need to be processed.
Returns:
List of pending commands (may be empty)
"""
...
def send_command(self, command: GraphEngineCommand) -> None:
"""
Send a command to be processed by this GraphEngine instance.
Called by external systems to send control commands to the running workflow.
Args:
command: The command to send
"""
...

View File

@@ -0,0 +1,12 @@
"""
Ready queue implementations for GraphEngine.
This package contains the protocol and implementations for managing
the queue of nodes ready for execution.
"""
from .factory import create_ready_queue_from_state
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueue, ReadyQueueState
__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"]

View File

@@ -0,0 +1,35 @@
"""
Factory for creating ReadyQueue instances from serialized state.
"""
from typing import TYPE_CHECKING
from .in_memory import InMemoryReadyQueue
from .protocol import ReadyQueueState
if TYPE_CHECKING:
from .protocol import ReadyQueue
def create_ready_queue_from_state(state: ReadyQueueState) -> "ReadyQueue":
"""
Create a ReadyQueue instance from a serialized state.
Args:
state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue
Returns:
A ReadyQueue instance initialized with the given state
Raises:
ValueError: If the queue type is unknown or version is unsupported
"""
if state.type == "InMemoryReadyQueue":
if state.version != "1.0":
raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}")
queue = InMemoryReadyQueue()
# Always pass as JSON string to loads()
queue.loads(state.model_dump_json())
return queue
else:
raise ValueError(f"Unknown ready queue type: {state.type}")

View File

@@ -0,0 +1,140 @@
"""
In-memory implementation of the ReadyQueue protocol.
This implementation wraps Python's standard queue.Queue and adds
serialization capabilities for state storage.
"""
import queue
from typing import final
from .protocol import ReadyQueue, ReadyQueueState
@final
class InMemoryReadyQueue(ReadyQueue):
"""
In-memory ready queue implementation with serialization support.
This implementation uses Python's queue.Queue internally and provides
methods to serialize and restore the queue state.
"""
def __init__(self, maxsize: int = 0) -> None:
"""
Initialize the in-memory ready queue.
Args:
maxsize: Maximum size of the queue (0 for unlimited)
"""
self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize)
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
self._queue.put(item)
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
if timeout is None:
return self._queue.get(block=True)
return self._queue.get(timeout=timeout)
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
self._queue.task_done()
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
return self._queue.empty()
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
return self._queue.qsize()
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
"""
# Extract all items from the queue without removing them
items: list[str] = []
temp_items: list[str] = []
# Drain the queue temporarily to get all items
while not self._queue.empty():
try:
item = self._queue.get_nowait()
temp_items.append(item)
items.append(item)
except queue.Empty:
break
# Put items back in the same order
for item in temp_items:
self._queue.put(item)
state = ReadyQueueState(
type="InMemoryReadyQueue",
version="1.0",
items=items,
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
state = ReadyQueueState.model_validate_json(data)
if state.type != "InMemoryReadyQueue":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported version: {state.version}")
# Clear the current queue
while not self._queue.empty():
try:
self._queue.get_nowait()
except queue.Empty:
break
# Restore items
for item in state.items:
self._queue.put(item)

View File

@@ -0,0 +1,104 @@
"""
ReadyQueue protocol for GraphEngine node execution queue.
This protocol defines the interface for managing the queue of nodes ready
for execution, supporting both in-memory and persistent storage scenarios.
"""
from collections.abc import Sequence
from typing import Protocol
from pydantic import BaseModel, Field
class ReadyQueueState(BaseModel):
"""
Pydantic model for serialized ready queue state.
This defines the structure of the data returned by dumps()
and expected by loads() for ready queue serialization.
"""
type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')")
version: str = Field(description="Serialization format version")
items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue")
class ReadyQueue(Protocol):
"""
Protocol for managing nodes ready for execution in GraphEngine.
This protocol defines the interface that any ready queue implementation
must provide, enabling both in-memory queues and persistent queues
that can be serialized for state storage.
"""
def put(self, item: str) -> None:
"""
Add a node ID to the ready queue.
Args:
item: The node ID to add to the queue
"""
...
def get(self, timeout: float | None = None) -> str:
"""
Retrieve and remove a node ID from the queue.
Args:
timeout: Maximum time to wait for an item (None for blocking)
Returns:
The node ID retrieved from the queue
Raises:
queue.Empty: If timeout expires and no item is available
"""
...
def task_done(self) -> None:
"""
Indicate that a previously retrieved task is complete.
Used by worker threads to signal task completion for
join() synchronization.
"""
...
def empty(self) -> bool:
"""
Check if the queue is empty.
Returns:
True if the queue has no items, False otherwise
"""
...
def qsize(self) -> int:
"""
Get the approximate size of the queue.
Returns:
The approximate number of items in the queue
"""
...
def dumps(self) -> str:
"""
Serialize the queue state to a JSON string for storage.
Returns:
A JSON string containing the serialized queue state
that can be persisted and later restored
"""
...
def loads(self, data: str) -> None:
"""
Restore the queue state from a JSON string.
Args:
data: The JSON string containing the serialized queue state to restore
"""
...

View File

@@ -0,0 +1,10 @@
"""
ResponseStreamCoordinator - Coordinates streaming output from response nodes
This component manages response streaming sessions and ensures ordered streaming
of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
__all__ = ["ResponseStreamCoordinator"]

View File

@@ -0,0 +1,696 @@
"""
Main ResponseStreamCoordinator implementation.
This module contains the public ResponseStreamCoordinator class that manages
response streaming sessions and ensures ordered streaming of responses.
"""
import logging
from collections import deque
from collections.abc import Sequence
from threading import RLock
from typing import Literal, TypeAlias, final
from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from .path import Path
from .session import ResponseSession
logger = logging.getLogger(__name__)
# Type definitions
NodeID: TypeAlias = str
EdgeID: TypeAlias = str
class ResponseSessionState(BaseModel):
"""Serializable representation of a response session."""
node_id: str
index: int = Field(default=0, ge=0)
class StreamBufferState(BaseModel):
"""Serializable representation of buffered stream chunks."""
selector: tuple[str, ...]
events: list[NodeRunStreamChunkEvent] = Field(default_factory=list)
class StreamPositionState(BaseModel):
"""Serializable representation for stream read positions."""
selector: tuple[str, ...]
position: int = Field(default=0, ge=0)
class ResponseStreamCoordinatorState(BaseModel):
"""Serialized snapshot of ResponseStreamCoordinator."""
type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator")
version: str = Field(default="1.0")
response_nodes: Sequence[str] = Field(default_factory=list)
active_session: ResponseSessionState | None = None
waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list)
node_execution_ids: dict[str, str] = Field(default_factory=dict)
paths_map: dict[str, list[list[str]]] = Field(default_factory=dict)
stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list)
stream_positions: Sequence[StreamPositionState] = Field(default_factory=list)
closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list)
@final
class ResponseStreamCoordinator:
"""
Manages response streaming sessions without relying on global state.
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
"""
Initialize coordinator with variable pool.
Args:
variable_pool: VariablePool instance for accessing node variables
graph: Graph instance for looking up node information
"""
self._variable_pool = variable_pool
self._graph = graph
self._active_session: ResponseSession | None = None
self._waiting_sessions: deque[ResponseSession] = deque()
self._lock = RLock()
# Internal stream management (replacing OutputRegistry)
self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {}
self._stream_positions: dict[tuple[str, ...], int] = {}
self._closed_streams: set[tuple[str, ...]] = set()
# Track response nodes
self._response_nodes: set[NodeID] = set()
# Store paths for each response node
self._paths_maps: dict[NodeID, list[Path]] = {}
# Track node execution IDs and types for proper event forwarding
self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id
# Track response sessions to ensure only one per node
self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session
def register(self, response_node_id: NodeID) -> None:
with self._lock:
if response_node_id in self._response_nodes:
return
self._response_nodes.add(response_node_id)
# Build and save paths map for this response node
paths_map = self._build_paths_map(response_node_id)
self._paths_maps[response_node_id] = paths_map
# Create and store response session for this node
response_node = self._graph.nodes[response_node_id]
session = ResponseSession.from_node(response_node)
self._response_sessions[response_node_id] = session
def track_node_execution(self, node_id: NodeID, execution_id: str) -> None:
"""Track the execution ID for a node when it starts executing.
Args:
node_id: The ID of the node
execution_id: The execution ID from NodeRunStartedEvent
"""
with self._lock:
self._node_execution_ids[node_id] = execution_id
def _get_or_create_execution_id(self, node_id: NodeID) -> str:
"""Get the execution ID for a node, creating one if it doesn't exist.
Args:
node_id: The ID of the node
Returns:
The execution ID for the node
"""
with self._lock:
if node_id not in self._node_execution_ids:
self._node_execution_ids[node_id] = str(uuid4())
return self._node_execution_ids[node_id]
def _build_paths_map(self, response_node_id: NodeID) -> list[Path]:
"""
Build a paths map for a response node by finding all paths from root node
to the response node, recording branch edges along each path.
Args:
response_node_id: ID of the response node to analyze
Returns:
List of Path objects, where each path contains branch edge IDs
"""
# Get root node ID
root_node_id = self._graph.root_node.id
# If root is the response node, return empty path
if root_node_id == response_node_id:
return [Path()]
# Extract variable selectors from the response node's template
response_node = self._graph.nodes[response_node_id]
response_session = ResponseSession.from_node(response_node)
template = response_session.template
# Collect all variable selectors from the template
variable_selectors: set[tuple[str, ...]] = set()
for segment in template.segments:
if isinstance(segment, VariableSegment):
variable_selectors.add(tuple(segment.selector[:2]))
# Step 1: Find all complete paths from root to response node
all_complete_paths: list[list[EdgeID]] = []
def find_paths(
current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID]
) -> None:
"""Recursively find all paths from current node to target node."""
if current_node_id == target_node_id:
# Found a complete path, store it
all_complete_paths.append(current_path.copy())
return
# Mark as visited to avoid cycles
visited.add(current_node_id)
# Explore outgoing edges
outgoing_edges = self._graph.get_outgoing_edges(current_node_id)
for edge in outgoing_edges:
edge_id = edge.id
next_node_id = edge.head
# Skip if already visited in this path
if next_node_id not in visited:
# Add edge to path and recurse
new_path = current_path + [edge_id]
find_paths(next_node_id, target_node_id, new_path, visited.copy())
# Start searching from root node
find_paths(root_node_id, response_node_id, [], set())
# Step 2: For each complete path, filter edges based on node blocking behavior
filtered_paths: list[Path] = []
for path in all_complete_paths:
blocking_edges: list[str] = []
for edge_id in path:
edge = self._graph.edges[edge_id]
source_node = self._graph.nodes[edge.tail]
# Check if node is a branch/container (original behavior)
if source_node.execution_type in {
NodeExecutionType.BRANCH,
NodeExecutionType.CONTAINER,
} or source_node.blocks_variable_output(variable_selectors):
blocking_edges.append(edge_id)
# Keep the path even if it's empty
filtered_paths.append(Path(edges=blocking_edges))
return filtered_paths
def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Handle when an edge is taken (selected by a branch node).
This method updates the paths for all response nodes by removing
the taken edge. If any response node has an empty path after removal,
it means the node is now deterministically reachable and should start.
Args:
edge_id: The ID of the edge that was taken
Returns:
List of events to emit from starting new sessions
"""
events: list[NodeRunStreamChunkEvent] = []
with self._lock:
# Check each response node in order
for response_node_id in self._response_nodes:
if response_node_id not in self._paths_maps:
continue
paths = self._paths_maps[response_node_id]
has_reachable_path = False
# Update each path by removing the taken edge
for path in paths:
# Remove the taken edge from this path
path.remove_edge(edge_id)
# Check if this path is now empty (node is reachable)
if path.is_empty():
has_reachable_path = True
# If node is now reachable (has empty path), start/queue session
if has_reachable_path:
# Pass the node_id to the activation method
# The method will handle checking and removing from map
events.extend(self._active_or_queue_session(response_node_id))
return events
def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]:
"""
Start a session immediately if no active session, otherwise queue it.
Only activates sessions that exist in the _response_sessions map.
Args:
node_id: The ID of the response node to activate
Returns:
List of events from flush attempt if session started immediately
"""
events: list[NodeRunStreamChunkEvent] = []
# Get the session from our map (only activate if it exists)
session = self._response_sessions.get(node_id)
if not session:
return events
# Remove from map to ensure it won't be activated again
del self._response_sessions[node_id]
if self._active_session is None:
self._active_session = session
# Try to flush immediately
events.extend(self.try_flush())
else:
# Queue the session if another is active
self._waiting_sessions.append(session)
return events
def intercept_event(
self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent
) -> Sequence[NodeRunStreamChunkEvent]:
with self._lock:
if isinstance(event, NodeRunStreamChunkEvent):
self._append_stream_chunk(event.selector, event)
if event.is_final:
self._close_stream(event.selector)
return self.try_flush()
else:
# Skip cause we share the same variable pool.
#
# for variable_name, variable_value in event.node_run_result.outputs.items():
# self._variable_pool.add((event.node_id, variable_name), variable_value)
return self.try_flush()
def _create_stream_chunk_event(
self,
node_id: str,
execution_id: str,
selector: Sequence[str],
chunk: str,
is_final: bool = False,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
# Use the active response node for special selectors
response_node = self._graph.nodes[self._active_session.node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
# Standard case: selector refers to an actual node
node = self._graph.nodes[node_id]
return NodeRunStreamChunkEvent(
id=execution_id,
node_id=node.id,
node_type=node.node_type,
selector=selector,
chunk=chunk,
is_final=is_final,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
"""Process a variable segment. Returns (events, is_complete).
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
is_complete = False
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
current_response_node = self._graph.nodes[self._active_session.node_id]
# Use get_or_create_execution_id to ensure we have a consistent ID
execution_id = self._get_or_create_execution_id(current_response_node.id)
is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1
event = self._create_stream_chunk_event(
node_id=current_response_node.id,
execution_id=execution_id,
selector=[current_response_node.id, "answer"], # FIXME(-LAN-)
chunk=segment.text,
is_final=is_last_segment,
)
return [event]
def try_flush(self) -> list[NodeRunStreamChunkEvent]:
with self._lock:
if not self._active_session:
return []
template = self._active_session.template
response_node_id = self._active_session.node_id
events: list[NodeRunStreamChunkEvent] = []
# Process segments sequentially from current index
while self._active_session.index < len(template.segments):
segment = template.segments[self._active_session.index]
if isinstance(segment, VariableSegment):
# Check if the source node for this variable is skipped
# Only check for actual nodes, not special selectors (sys, env, conversation)
source_selector_prefix = segment.selector[0] if segment.selector else ""
if source_selector_prefix in self._graph.nodes:
source_node = self._graph.nodes[source_selector_prefix]
if source_node.state == NodeState.SKIPPED:
# Skip this variable segment if the source node is skipped
self._active_session.index += 1
continue
segment_events, is_complete = self._process_variable_segment(segment)
events.extend(segment_events)
# Only advance index if this variable segment is complete
if is_complete:
self._active_session.index += 1
else:
# Wait for more data
break
else:
segment_events = self._process_text_segment(segment)
events.extend(segment_events)
self._active_session.index += 1
if self._active_session.is_complete():
# End current session and get events from starting next session
next_session_events = self.end_session(response_node_id)
events.extend(next_session_events)
return events
def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]:
"""
End the active session for a response node.
Automatically starts the next waiting session if available.
Args:
node_id: ID of the response node ending its session
Returns:
List of events from starting the next session
"""
with self._lock:
events: list[NodeRunStreamChunkEvent] = []
if self._active_session and self._active_session.node_id == node_id:
self._active_session = None
# Try to start next waiting session
if self._waiting_sessions:
next_session = self._waiting_sessions.popleft()
self._active_session = next_session
# Immediately try to flush any available segments
events = self.try_flush()
return events
# ============= Internal Stream Management Methods =============
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.
Args:
selector: List of strings identifying the stream location
event: The NodeRunStreamChunkEvent to append
Raises:
ValueError: If the stream is already closed
"""
key = tuple(selector)
if key in self._closed_streams:
raise ValueError(f"Stream {'.'.join(selector)} is already closed")
if key not in self._stream_buffers:
self._stream_buffers[key] = []
self._stream_positions[key] = 0
self._stream_buffers[key].append(event)
def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None:
"""
Pop the next unread stream chunk from the buffer.
Args:
selector: List of strings identifying the stream location
Returns:
The next event, or None if no unread events available
"""
key = tuple(selector)
if key not in self._stream_buffers:
return None
position = self._stream_positions.get(key, 0)
buffer = self._stream_buffers[key]
if position >= len(buffer):
return None
event = buffer[position]
self._stream_positions[key] = position + 1
return event
def _has_unread_stream(self, selector: Sequence[str]) -> bool:
"""
Check if the stream has unread events.
Args:
selector: List of strings identifying the stream location
Returns:
True if there are unread events, False otherwise
"""
key = tuple(selector)
if key not in self._stream_buffers:
return False
position = self._stream_positions.get(key, 0)
return position < len(self._stream_buffers[key])
def _close_stream(self, selector: Sequence[str]) -> None:
"""
Mark a stream as closed (no more chunks can be appended).
Args:
selector: List of strings identifying the stream location
"""
key = tuple(selector)
self._closed_streams.add(key)
def _is_stream_closed(self, selector: Sequence[str]) -> bool:
"""
Check if a stream is closed.
Args:
selector: List of strings identifying the stream location
Returns:
True if the stream is closed, False otherwise
"""
key = tuple(selector)
return key in self._closed_streams
def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None:
"""Convert an in-memory session into its serializable form."""
if session is None:
return None
return ResponseSessionState(node_id=session.node_id, index=session.index)
def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession:
"""Rebuild a response session from serialized data."""
node = self._graph.nodes.get(session_state.node_id)
if node is None:
raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state")
session = ResponseSession.from_node(node)
session.index = session_state.index
return session
def dumps(self) -> str:
"""Serialize coordinator state to JSON."""
with self._lock:
state = ResponseStreamCoordinatorState(
response_nodes=sorted(self._response_nodes),
active_session=self._serialize_session(self._active_session),
waiting_sessions=[
session_state
for session in list(self._waiting_sessions)
if (session_state := self._serialize_session(session)) is not None
],
pending_sessions=[
session_state
for _, session in sorted(self._response_sessions.items())
if (session_state := self._serialize_session(session)) is not None
],
node_execution_ids=dict(sorted(self._node_execution_ids.items())),
paths_map={
node_id: [path.edges.copy() for path in paths]
for node_id, paths in sorted(self._paths_maps.items())
},
stream_buffers=[
StreamBufferState(
selector=selector,
events=[event.model_copy(deep=True) for event in events],
)
for selector, events in sorted(self._stream_buffers.items())
],
stream_positions=[
StreamPositionState(selector=selector, position=position)
for selector, position in sorted(self._stream_positions.items())
],
closed_streams=sorted(self._closed_streams),
)
return state.model_dump_json()
def loads(self, data: str) -> None:
"""Restore coordinator state from JSON."""
state = ResponseStreamCoordinatorState.model_validate_json(data)
if state.type != "ResponseStreamCoordinator":
raise ValueError(f"Invalid serialized data type: {state.type}")
if state.version != "1.0":
raise ValueError(f"Unsupported serialized version: {state.version}")
with self._lock:
self._response_nodes = set(state.response_nodes)
self._paths_maps = {
node_id: [Path(edges=list(path_edges)) for path_edges in paths]
for node_id, paths in state.paths_map.items()
}
self._node_execution_ids = dict(state.node_execution_ids)
self._stream_buffers = {
tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events]
for buffer in state.stream_buffers
}
self._stream_positions = {
tuple(position.selector): position.position for position in state.stream_positions
}
for selector in self._stream_buffers:
self._stream_positions.setdefault(selector, 0)
self._closed_streams = {tuple(selector) for selector in state.closed_streams}
self._waiting_sessions = deque(
self._session_from_state(session_state) for session_state in state.waiting_sessions
)
self._response_sessions = {
session_state.node_id: self._session_from_state(session_state)
for session_state in state.pending_sessions
}
self._active_session = self._session_from_state(state.active_session) if state.active_session else None

View File

@@ -0,0 +1,35 @@
"""
Internal path representation for response coordinator.
This module contains the private Path class used internally by ResponseStreamCoordinator
to track execution paths to response nodes.
"""
from dataclasses import dataclass, field
from typing import TypeAlias
EdgeID: TypeAlias = str
@dataclass
class Path:
"""
Represents a path of branch edges that must be taken to reach a response node.
Note: This is an internal class not exposed in the public API.
"""
edges: list[EdgeID] = field(default_factory=list[EdgeID])
def contains_edge(self, edge_id: EdgeID) -> bool:
"""Check if this path contains the given edge."""
return edge_id in self.edges
def remove_edge(self, edge_id: EdgeID) -> None:
"""Remove the given edge from this path in place."""
if self.contains_edge(edge_id):
self.edges.remove(edge_id)
def is_empty(self) -> bool:
"""Check if the path has no edges (node is reachable)."""
return len(self.edges) == 0

View File

@@ -0,0 +1,52 @@
"""
Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
"""
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@dataclass
class ResponseSession:
"""
Represents an active response streaming session.
Note: This is an internal class not exposed in the public API.
"""
node_id: str
template: Template # Template object from the response node
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: Node) -> "ResponseSession":
"""
Create a ResponseSession from an AnswerNode or EndNode.
Args:
node: Must be either an AnswerNode or EndNode instance
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not an AnswerNode or EndNode
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError
return cls(
node_id=node.id,
template=node.get_streaming_template(),
)
def is_complete(self) -> bool:
"""Check if all segments in the template have been processed."""
return self.index >= len(self.template.segments)

View File

@@ -0,0 +1,142 @@
"""
Worker - Thread implementation for queue-based node execution
Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from datetime import datetime
from typing import final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
@final
class Worker(threading.Thread):
"""
Worker thread that executes nodes from the ready queue.
Workers continuously pull node IDs from the ready_queue, execute the
corresponding nodes, and push the resulting events to the event_queue
for the dispatcher to process.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
) -> None:
"""
Initialize worker thread.
Args:
ready_queue: Ready queue containing node IDs ready for execution
event_queue: Queue for pushing execution events
graph: Graph containing nodes to execute
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._stop_event = threading.Event()
self._last_task_time = time.time()
def stop(self) -> None:
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:
"""Check if the worker is currently idle."""
# Worker is idle if it hasn't processed a task recently (within 0.2 seconds)
return (time.time() - self._last_task_time) > 0.2
@property
def idle_duration(self) -> float:
"""Get the duration in seconds since the worker last processed a task."""
return time.time() - self._last_task_time
@property
def worker_id(self) -> int:
"""Get the worker's ID."""
return self._worker_id
@override
def run(self) -> None:
"""
Main worker loop.
Continuously pulls node IDs from ready_queue, executes them,
and pushes events to event_queue until stopped.
"""
while not self._stop_event.is_set():
# Try to get a node ID from the ready queue (with timeout)
try:
node_id = self._ready_queue.get(timeout=0.1)
except queue.Empty:
continue
self._last_task_time = time.time()
node = self._graph.nodes[node_id]
try:
self._execute_node(node)
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
node_id="unknown",
node_type=NodeType.CODE,
in_iteration_id=None,
error=str(e),
start_at=datetime.now(),
)
self._event_queue.put(error_event)
def _execute_node(self, node: Node) -> None:
"""
Execute a single node and handle its events.
Args:
node: The node instance to execute
"""
# Execute the node with preserved context if Flask app is provided
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)
else:
# Execute without context preservation
node_events = node.run()
for event in node_events:
# Forward event to dispatcher immediately for streaming
self._event_queue.put(event)

View File

@@ -0,0 +1,12 @@
"""
Worker management subsystem for graph engine.
This package manages the worker pool, including creation,
scaling, and activity tracking.
"""
from .worker_pool import WorkerPool
__all__ = [
"WorkerPool",
]

View File

@@ -0,0 +1,291 @@
"""
Simple worker pool that consolidates functionality.
This is a simpler implementation that merges WorkerPool, ActivityTracker,
DynamicScaler, and WorkerFactory into a single class.
"""
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
from configs import dify_config
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
from ..ready_queue import ReadyQueue
from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
"""
Simple worker pool with integrated management.
This class consolidates all worker management functionality into
a single, simpler implementation without excessive abstraction.
"""
def __init__(
self,
ready_queue: ReadyQueue,
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""
Initialize the simple worker pool.
Args:
ready_queue: Ready queue for nodes ready for execution
event_queue: Queue for worker events
graph: The workflow graph
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
scale_down_idle_time: Seconds before scaling down idle workers
"""
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
# Scaling parameters with defaults
self._min_workers = min_workers or dify_config.GRAPH_ENGINE_MIN_WORKERS
self._max_workers = max_workers or dify_config.GRAPH_ENGINE_MAX_WORKERS
self._scale_up_threshold = scale_up_threshold or dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD
self._scale_down_idle_time = scale_down_idle_time or dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME
# Worker management
self._workers: list[Worker] = []
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
# No longer tracking worker states with callbacks to avoid lock contention
def start(self, initial_count: int | None = None) -> None:
"""
Start the worker pool.
Args:
initial_count: Number of workers to start with (auto-calculated if None)
"""
with self._lock:
if self._running:
return
self._running = True
# Calculate initial worker count
if initial_count is None:
node_count = len(self._graph.nodes)
if node_count < 10:
initial_count = self._min_workers
elif node_count < 50:
initial_count = min(self._min_workers + 1, self._max_workers)
else:
initial_count = min(self._min_workers + 2, self._max_workers)
logger.debug(
"Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)",
initial_count,
node_count,
self._min_workers,
self._max_workers,
)
# Create initial workers
for _ in range(initial_count):
self._create_worker()
def stop(self) -> None:
"""Stop all workers in the pool."""
with self._lock:
self._running = False
worker_count = len(self._workers)
if worker_count > 0:
logger.debug("Stopping worker pool: %d workers", worker_count)
# Stop all workers
for worker in self._workers:
worker.stop()
# Wait for workers to finish
for worker in self._workers:
if worker.is_alive():
worker.join(timeout=10.0)
self._workers.clear()
def _create_worker(self) -> None:
"""Create and start a new worker."""
worker_id = self._worker_counter
self._worker_counter += 1
worker = Worker(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
)
worker.start()
self._workers.append(worker)
def _remove_worker(self, worker: Worker, worker_id: int) -> None:
"""Remove a specific worker from the pool."""
# Stop the worker
worker.stop()
# Wait for it to finish
if worker.is_alive():
worker.join(timeout=2.0)
# Remove from list
if worker in self._workers:
self._workers.remove(worker)
def _try_scale_up(self, queue_depth: int, current_count: int) -> bool:
"""
Try to scale up workers if needed.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
Returns:
True if scaled up, False otherwise
"""
if queue_depth > self._scale_up_threshold and current_count < self._max_workers:
old_count = current_count
self._create_worker()
logger.debug(
"Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)",
old_count,
len(self._workers),
queue_depth,
self._scale_up_threshold,
)
return True
return False
def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool:
"""
Try to scale down workers if we have excess capacity.
Args:
queue_depth: Current queue depth
current_count: Current number of workers
active_count: Number of active workers
idle_count: Number of idle workers
Returns:
True if scaled down, False otherwise
"""
# Skip if we're at minimum or have no idle workers
if current_count <= self._min_workers or idle_count == 0:
return False
# Check if we have excess capacity
has_excess_capacity = (
queue_depth <= active_count # Active workers can handle current queue
or idle_count > active_count # More idle than active workers
or (queue_depth == 0 and idle_count > 0) # No work and have idle workers
)
if not has_excess_capacity:
return False
# Find and remove idle workers that have been idle long enough
workers_to_remove: list[tuple[Worker, int]] = []
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold
if worker.is_idle and worker.idle_duration >= self._scale_down_idle_time:
# Don't remove if it would leave us unable to handle the queue
remaining_workers = current_count - len(workers_to_remove) - 1
if remaining_workers >= self._min_workers and remaining_workers >= max(1, queue_depth // 2):
workers_to_remove.append((worker, worker.worker_id))
# Only remove one worker per check to avoid aggressive scaling
break
# Remove idle workers if any found
if workers_to_remove:
old_count = current_count
for worker, worker_id in workers_to_remove:
self._remove_worker(worker, worker_id)
logger.debug(
"Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, "
"queue_depth=%d, active=%d, idle=%d)",
old_count,
len(self._workers),
len(workers_to_remove),
self._scale_down_idle_time,
queue_depth,
active_count,
idle_count - len(workers_to_remove),
)
return True
return False
def check_and_scale(self) -> None:
"""Check and perform scaling if needed."""
with self._lock:
if not self._running:
return
current_count = len(self._workers)
queue_depth = self._ready_queue.qsize()
# Count active vs idle workers by querying their state directly
idle_count = sum(1 for worker in self._workers if worker.is_idle)
active_count = current_count - idle_count
# Try to scale up if queue is backing up
self._try_scale_up(queue_depth, current_count)
# Try to scale down if we have excess capacity
self._try_scale_down(queue_depth, current_count, active_count, idle_count)
def get_worker_count(self) -> int:
"""Get current number of workers."""
with self._lock:
return len(self._workers)
def get_status(self) -> dict[str, int]:
"""
Get pool status information.
Returns:
Dictionary with status information
"""
with self._lock:
return {
"total_workers": len(self._workers),
"queue_depth": self._ready_queue.qsize(),
"min_workers": self._min_workers,
"max_workers": self._max_workers,
}